@@ -48,52 +48,17 @@ struct CastOpFunctor {
48
48
}
49
49
};
50
50
51
- template <typename DeviceContext, typename InT, typename OutT>
52
- static void CastFunction (const framework::ExecutionContext& context) {
53
- auto * in = context.Input <framework::Tensor>(" X" );
54
- auto * out = context.Output <framework::Tensor>(" Out" );
55
-
56
- auto in_t = framework::EigenVector<InT>::Flatten (*in);
57
- out->mutable_data <OutT>(context.GetPlace ());
58
- auto out_t = framework::EigenVector<OutT>::Flatten (*out);
59
- auto & place =
60
- *context.template device_context <DeviceContext>().eigen_device ();
61
- out_t .device (place) = in_t .template cast <OutT>();
62
- }
63
-
64
51
template <typename DeviceContext, typename InT>
65
52
class CastOpKernel : public framework ::OpKernel<InT> {
66
53
public:
67
54
void Compute (const framework::ExecutionContext& context) const override {
68
- auto out_type = static_cast <framework::proto::VarType::Type>(
69
- context.Attr <int >(" out_dtype" ));
70
-
71
- if (out_type == paddle::framework::proto::VarType::FP64) {
72
- CastFunction<DeviceContext, InT, double >(context);
73
- } else if (out_type == paddle::framework::proto::VarType::FP32) {
74
- CastFunction<DeviceContext, InT, float >(context);
75
- } else if (out_type == paddle::framework::proto::VarType::FP16) {
76
- CastFunction<DeviceContext, InT, paddle::platform::float16>(context);
77
- } else if (out_type == paddle::framework::proto::VarType::INT64) {
78
- CastFunction<DeviceContext, InT, int64_t >(context);
79
- } else if (out_type == paddle::framework::proto::VarType::INT32) {
80
- CastFunction<DeviceContext, InT, int >(context);
81
- } else if (out_type == paddle::framework::proto::VarType::UINT8) {
82
- CastFunction<DeviceContext, InT, uint8_t >(context);
83
- } else if (out_type == paddle::framework::proto::VarType::BOOL) {
84
- CastFunction<DeviceContext, InT, bool >(context);
85
- } else if (out_type == paddle::framework::proto::VarType::COMPLEX64) {
86
- CastFunction<DeviceContext, InT, paddle::platform::complex64>(context);
87
- } else if (out_type == paddle::framework::proto::VarType::COMPLEX128) {
88
- CastFunction<DeviceContext, InT, paddle::platform::complex128>(context);
89
- } else {
90
- // NOTE(chenweihang): if else branch do nothing, the output var will
91
- // be non-initialized in dygraph, which will throw error if the
92
- // non-initialized var is used as the next op's input
93
- PADDLE_THROW (platform::errors::Unimplemented (
94
- " Now does not support casting Tensor to `%s` data type." ,
95
- framework::DataTypeToString (out_type)));
96
- }
55
+ auto * in = context.Input <framework::Tensor>(" X" );
56
+ auto * out = context.Output <framework::Tensor>(" Out" );
57
+ framework::VisitDataType (
58
+ static_cast <framework::proto::VarType::Type>(
59
+ context.Attr <int >(" out_dtype" )),
60
+ CastOpFunctor<DeviceContext, InT>(
61
+ in, out, context.template device_context <DeviceContext>()));
97
62
}
98
63
};
99
64
0 commit comments