Skip to content

Commit b341bac

Browse files
authored
Refine cast op (#8923)
* fix mac build error * override GetExpectedKernelType for cast op * fix typo * add cuda unittest
1 parent 8468037 commit b341bac

File tree

2 files changed

+27
-2
lines changed

2 files changed

+27
-2
lines changed

paddle/fluid/operators/cast_op.cc

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,27 @@ class CastOpGradMaker : public framework::SingleGradOpDescMaker {
6363
}
6464
};
6565

66+
class CastOp : public framework::OperatorWithKernel {
67+
public:
68+
using framework::OperatorWithKernel::OperatorWithKernel;
69+
70+
protected:
71+
framework::OpKernelType GetExpectedKernelType(
72+
const framework::ExecutionContext &ctx) const override {
73+
framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx);
74+
// CastOp kernel's device type is decided by input tensor place
75+
kt.place_ = ctx.Input<framework::LoDTensor>("X")->place();
76+
return kt;
77+
}
78+
};
79+
6680
} // namespace operators
6781
} // namespace paddle
6882

6983
namespace ops = paddle::operators;
7084
using CPU = paddle::platform::CPUDeviceContext;
71-
REGISTER_OP_WITH_KERNEL(cast, ops::CastOpGradMaker, ops::CastOpInferShape,
72-
ops::CastOpProtoMaker);
85+
REGISTER_OPERATOR(cast, ops::CastOp, ops::CastOpGradMaker,
86+
ops::CastOpInferShape, ops::CastOpProtoMaker);
7387
REGISTER_OP_CPU_KERNEL(cast, ops::CastOpKernel<CPU, float>,
7488
ops::CastOpKernel<CPU, double>,
7589
ops::CastOpKernel<CPU, int>,

python/paddle/fluid/tests/unittests/test_learning_rate_scheduler.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import paddle.fluid as fluid
2020
import paddle.fluid.layers as layers
2121
import paddle.fluid.framework as framework
22+
import paddle.fluid.core as core
2223

2324

2425
def exponential_decay(learning_rate,
@@ -81,6 +82,16 @@ def piecewise_decay(global_step, boundaries, values):
8182

8283
class TestLearningRateDecay(unittest.TestCase):
8384
def check_decay(self, python_decay_fn, fluid_decay_fn, kwargs):
85+
places = [fluid.CPUPlace()]
86+
if core.is_compiled_with_cuda():
87+
places.append(fluid.CUDAPlace(0))
88+
for place in places:
89+
self.check_decay_with_place(place, python_decay_fn, fluid_decay_fn,
90+
kwargs)
91+
92+
def check_decay_with_place(self, place, python_decay_fn, fluid_decay_fn,
93+
kwargs):
94+
8495
decayed_lr = fluid_decay_fn(**kwargs)
8596

8697
place = fluid.CPUPlace()

0 commit comments

Comments
 (0)