Skip to content

Commit 3f5705c

Browse files
authored
Merge pull request #9148 from kexinzhao/cast_op_fp16
Add float16 support for cast op
2 parents 02b3cfb + 6ef4f1f commit 3f5705c

File tree

3 files changed

+38
-3
lines changed

3 files changed

+38
-3
lines changed

paddle/fluid/operators/cast_op.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ limitations under the License. */
1414

1515
#include "paddle/fluid/operators/cast_op.h"
1616
#include "paddle/fluid/framework/op_registry.h"
17+
#include "paddle/fluid/platform/float16.h"
1718

1819
namespace paddle {
1920
namespace operators {
@@ -88,4 +89,5 @@ REGISTER_OP_CPU_KERNEL(cast, ops::CastOpKernel<CPU, float>,
8889
ops::CastOpKernel<CPU, double>,
8990
ops::CastOpKernel<CPU, int>,
9091
ops::CastOpKernel<CPU, int64_t>,
91-
ops::CastOpKernel<CPU, bool>);
92+
ops::CastOpKernel<CPU, bool>,
93+
ops::CastOpKernel<CPU, paddle::platform::float16>);

paddle/fluid/operators/cast_op.cu

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,13 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/fluid/operators/cast_op.h"
16+
#include "paddle/fluid/platform/float16.h"
1617

1718
template <typename T>
1819
using CastOpKernel =
1920
paddle::operators::CastOpKernel<paddle::platform::CUDADeviceContext, T>;
2021

2122
REGISTER_OP_CUDA_KERNEL(cast, CastOpKernel<float>, CastOpKernel<double>,
2223
CastOpKernel<int>, CastOpKernel<int64_t>,
23-
CastOpKernel<bool>);
24+
CastOpKernel<bool>,
25+
CastOpKernel<paddle::platform::float16>);

python/paddle/fluid/tests/unittests/test_cast_op.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import paddle.fluid.core as core
1919

2020

21-
class TestCastOp(op_test.OpTest):
21+
class TestCastOp1(op_test.OpTest):
2222
def setUp(self):
2323
ipt = np.random.random(size=[10, 10])
2424
self.inputs = {'X': ipt.astype('float32')}
@@ -36,5 +36,36 @@ def test_grad(self):
3636
self.check_grad(['X'], ['Out'])
3737

3838

39+
class TestCastOp2(op_test.OpTest):
40+
def setUp(self):
41+
ipt = np.random.random(size=[10, 10])
42+
# numpy float16 is binded to fluid float16 via uint16
43+
self.inputs = {'X': ipt.astype('float16').view(np.uint16)}
44+
self.outputs = {'Out': ipt.astype('float32')}
45+
self.attrs = {
46+
'in_dtype': int(core.VarDesc.VarType.FP16),
47+
'out_dtype': int(core.VarDesc.VarType.FP32)
48+
}
49+
self.op_type = 'cast'
50+
51+
def test_check_output(self):
52+
self.check_output(atol=1e-3)
53+
54+
55+
class TestCastOp3(op_test.OpTest):
56+
def setUp(self):
57+
ipt = np.random.random(size=[10, 10])
58+
self.inputs = {'X': ipt.astype('float32')}
59+
self.outputs = {'Out': ipt.astype('float16')}
60+
self.attrs = {
61+
'in_dtype': int(core.VarDesc.VarType.FP32),
62+
'out_dtype': int(core.VarDesc.VarType.FP16)
63+
}
64+
self.op_type = 'cast'
65+
66+
def test_check_output(self):
67+
self.check_output(atol=1e-3)
68+
69+
3970
if __name__ == '__main__':
4071
unittest.main()

0 commit comments

Comments
 (0)