Skip to content

Commit c95cd47

Browse files
authored
Merge pull request #10975 from JiayiFeng/fix_bug_in_uint8_support
Correct uint8 support
2 parents 3a29821 + 4785c00 commit c95cd47

File tree

4 files changed

+8
-2
lines changed

4 files changed

+8
-2
lines changed

paddle/fluid/operators/cast_op.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,4 +89,5 @@ REGISTER_OP_CPU_KERNEL(cast, ops::CastOpKernel<CPU, float>,
8989
ops::CastOpKernel<CPU, int>,
9090
ops::CastOpKernel<CPU, int64_t>,
9191
ops::CastOpKernel<CPU, bool>,
92+
ops::CastOpKernel<CPU, uint8_t>,
9293
ops::CastOpKernel<CPU, paddle::platform::float16>);

paddle/fluid/operators/cast_op.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,5 +21,5 @@ using CastOpKernel =
2121

2222
REGISTER_OP_CUDA_KERNEL(cast, CastOpKernel<float>, CastOpKernel<double>,
2323
CastOpKernel<int>, CastOpKernel<int64_t>,
24-
CastOpKernel<bool>,
24+
CastOpKernel<bool>, CastOpKernel<uint8_t>,
2525
CastOpKernel<paddle::platform::float16>);

paddle/fluid/pybind/pybind.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,19 +117,22 @@ PYBIND11_PLUGIN(core) {
117117
.def("set", PyCPUTensorSetFromArray<int64_t>)
118118
.def("set", PyCPUTensorSetFromArray<bool>)
119119
.def("set", PyCPUTensorSetFromArray<uint16_t>)
120+
.def("set", PyCPUTensorSetFromArray<uint8_t>)
120121
#ifdef PADDLE_WITH_CUDA
121122
.def("set", PyCUDATensorSetFromArray<float>)
122123
.def("set", PyCUDATensorSetFromArray<int>)
123124
.def("set", PyCUDATensorSetFromArray<double>)
124125
.def("set", PyCUDATensorSetFromArray<int64_t>)
125126
.def("set", PyCUDATensorSetFromArray<bool>)
126127
.def("set", PyCUDATensorSetFromArray<uint16_t>)
128+
.def("set", PyCUDATensorSetFromArray<uint8_t>)
127129
.def("set", PyCUDAPinnedTensorSetFromArray<float>)
128130
.def("set", PyCUDAPinnedTensorSetFromArray<int>)
129131
.def("set", PyCUDAPinnedTensorSetFromArray<double>)
130132
.def("set", PyCUDAPinnedTensorSetFromArray<int64_t>)
131133
.def("set", PyCUDAPinnedTensorSetFromArray<bool>)
132134
.def("set", PyCUDAPinnedTensorSetFromArray<uint16_t>)
135+
.def("set", PyCUDAPinnedTensorSetFromArray<uint8_t>)
133136
#endif
134137
.def("shape", [](Tensor &self) { return vectorize(self.dims()); })
135138
.def("set_float_element", TensorSetElement<float>)

python/paddle/fluid/data_feeder.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,11 @@ def __init__(self, place, lod_level, shape, dtype):
3636
self.dtype = 'float64'
3737
elif dtype == core.VarDesc.VarType.INT32:
3838
self.dtype = 'int32'
39+
elif dtype == core.VarDesc.VarType.UINT8:
40+
self.dtype = 'uint8'
3941
else:
4042
raise ValueError("dtype must be any of [int32, float32, int64, "
41-
"float64]")
43+
"float64, uint8]")
4244

4345
self.data = []
4446
self.lod = []

0 commit comments

Comments
 (0)