Skip to content

Commit 14cf420

Browse files
revert cast eigen kernel (#29445)
1 parent d77566b commit 14cf420

File tree

2 files changed

+7
-54
lines changed

2 files changed

+7
-54
lines changed

paddle/fluid/operators/cast_op.h

Lines changed: 7 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -48,52 +48,17 @@ struct CastOpFunctor {
4848
}
4949
};
5050

51-
template <typename DeviceContext, typename InT, typename OutT>
52-
static void CastFunction(const framework::ExecutionContext& context) {
53-
auto* in = context.Input<framework::Tensor>("X");
54-
auto* out = context.Output<framework::Tensor>("Out");
55-
56-
auto in_t = framework::EigenVector<InT>::Flatten(*in);
57-
out->mutable_data<OutT>(context.GetPlace());
58-
auto out_t = framework::EigenVector<OutT>::Flatten(*out);
59-
auto& place =
60-
*context.template device_context<DeviceContext>().eigen_device();
61-
out_t.device(place) = in_t.template cast<OutT>();
62-
}
63-
6451
template <typename DeviceContext, typename InT>
6552
class CastOpKernel : public framework::OpKernel<InT> {
6653
public:
6754
void Compute(const framework::ExecutionContext& context) const override {
68-
auto out_type = static_cast<framework::proto::VarType::Type>(
69-
context.Attr<int>("out_dtype"));
70-
71-
if (out_type == paddle::framework::proto::VarType::FP64) {
72-
CastFunction<DeviceContext, InT, double>(context);
73-
} else if (out_type == paddle::framework::proto::VarType::FP32) {
74-
CastFunction<DeviceContext, InT, float>(context);
75-
} else if (out_type == paddle::framework::proto::VarType::FP16) {
76-
CastFunction<DeviceContext, InT, paddle::platform::float16>(context);
77-
} else if (out_type == paddle::framework::proto::VarType::INT64) {
78-
CastFunction<DeviceContext, InT, int64_t>(context);
79-
} else if (out_type == paddle::framework::proto::VarType::INT32) {
80-
CastFunction<DeviceContext, InT, int>(context);
81-
} else if (out_type == paddle::framework::proto::VarType::UINT8) {
82-
CastFunction<DeviceContext, InT, uint8_t>(context);
83-
} else if (out_type == paddle::framework::proto::VarType::BOOL) {
84-
CastFunction<DeviceContext, InT, bool>(context);
85-
} else if (out_type == paddle::framework::proto::VarType::COMPLEX64) {
86-
CastFunction<DeviceContext, InT, paddle::platform::complex64>(context);
87-
} else if (out_type == paddle::framework::proto::VarType::COMPLEX128) {
88-
CastFunction<DeviceContext, InT, paddle::platform::complex128>(context);
89-
} else {
90-
// NOTE(chenweihang): if else branch do nothing, the output var will
91-
// be non-initialized in dygraph, which will throw error if the
92-
// non-initialized var is used as the next op's input
93-
PADDLE_THROW(platform::errors::Unimplemented(
94-
"Now does not support casting Tensor to `%s` data type.",
95-
framework::DataTypeToString(out_type)));
96-
}
55+
auto* in = context.Input<framework::Tensor>("X");
56+
auto* out = context.Output<framework::Tensor>("Out");
57+
framework::VisitDataType(
58+
static_cast<framework::proto::VarType::Type>(
59+
context.Attr<int>("out_dtype")),
60+
CastOpFunctor<DeviceContext, InT>(
61+
in, out, context.template device_context<DeviceContext>()));
9762
}
9863
};
9964

python/paddle/fluid/tests/unittests/test_cast_op.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -90,18 +90,6 @@ def test_dtype_type():
9090
self.assertRaises(TypeError, test_dtype_type)
9191

9292

93-
class TestCastOpErrorInDygraph(unittest.TestCase):
94-
def test_non_support_out_dtype(self):
95-
paddle.disable_static()
96-
97-
with self.assertRaises(NotImplementedError):
98-
tensor = paddle.randn([10, 10], 'float32')
99-
core.ops.cast(tensor, 'in_dtype', core.VarDesc.VarType.FP32,
100-
'out_dtype', core.VarDesc.VarType.INT16)
101-
102-
paddle.enable_static()
103-
104-
10593
if __name__ == '__main__':
10694
paddle.enable_static()
10795
unittest.main()

0 commit comments

Comments
 (0)