Skip to content

Commit fd2b4b4

Browse files
committed
Make tensor support uint8
1 parent 9707aa6 commit fd2b4b4

File tree

5 files changed

+25
-6
lines changed

5 files changed

+25
-6
lines changed

paddle/fluid/framework/data_type.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ static DataTypeMap* InitDataTypeMap() {
5858
RegType(bool, proto::VarType::BOOL);
5959
RegType(size_t, proto::VarType::SIZE_T);
6060
RegType(int16_t, proto::VarType::INT16);
61+
RegType(uint8_t, proto::VarType::UINT8);
6162

6263
#undef RegType
6364
return retv;

paddle/fluid/framework/data_type.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,14 @@ inline void VisitDataType(proto::VarType::Type type, Visitor visitor) {
4747
case proto::VarType::BOOL:
4848
visitor.template operator()<bool>();
4949
break;
50+
case proto::VarType::UINT8:
51+
visitor.template operator()<uint8_t>();
52+
break;
53+
case proto::VarType::INT16:
54+
visitor.template operator()<int16_t>();
55+
break;
5056
default:
51-
PADDLE_THROW("Not supported");
57+
PADDLE_THROW("Not supported %d", type);
5258
}
5359
}
5460

paddle/fluid/framework/framework.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ message VarType {
103103
FP64 = 6;
104104
// Tensor<size_t> is used in C++.
105105
SIZE_T = 19;
106+
UINT8 = 20;
106107

107108
// Other types that may need additional descriptions
108109
LOD_TENSOR = 7;

paddle/fluid/framework/lod_tensor_test.cc

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -228,11 +228,12 @@ TEST(LoD, CheckAbsLoD) {
228228
ASSERT_FALSE(CheckAbsLoD(abs_lod0));
229229
}
230230

231-
TEST(LoDTensor, RecordIO) {
231+
template <typename T>
232+
static void TestRecordIO() {
232233
LoDTensor tensor;
233-
int* tmp = tensor.mutable_data<int>(make_ddim({4, 5}), platform::CPUPlace());
234+
T* tmp = tensor.mutable_data<T>(make_ddim({4, 5}), platform::CPUPlace());
234235
for (int i = 0; i < 20; ++i) {
235-
tmp[i] = i;
236+
tmp[i] = static_cast<T>(i);
236237
}
237238

238239
std::stringstream* stream = new std::stringstream();
@@ -247,7 +248,7 @@ TEST(LoDTensor, RecordIO) {
247248

248249
auto assert_tensor_ok = [](const LoDTensor& tensor) {
249250
for (int i = 0; i < 20; ++i) {
250-
ASSERT_EQ(tensor.data<int>()[i], i);
251+
ASSERT_EQ(tensor.data<T>()[i], static_cast<T>(i));
251252
}
252253
};
253254

@@ -265,5 +266,13 @@ TEST(LoDTensor, RecordIO) {
265266
}
266267
}
267268

269+
TEST(LoDTensor, RecordIO) {
270+
TestRecordIO<int>();
271+
TestRecordIO<int16_t>();
272+
TestRecordIO<uint8_t>();
273+
TestRecordIO<float>();
274+
TestRecordIO<double>();
275+
}
276+
268277
} // namespace framework
269278
} // namespace paddle

paddle/fluid/operators/math/math_function.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,9 @@ template struct SetConstant<platform::CPUDeviceContext, bool>;
3838
template struct Transpose<platform::CPUDeviceContext, double, RANK>; \
3939
template struct Transpose<platform::CPUDeviceContext, int, RANK>; \
4040
template struct Transpose<platform::CPUDeviceContext, int64_t, RANK>; \
41-
template struct Transpose<platform::CPUDeviceContext, bool, RANK>;
41+
template struct Transpose<platform::CPUDeviceContext, bool, RANK>; \
42+
template struct Transpose<platform::CPUDeviceContext, int16_t, RANK>; \
43+
template struct Transpose<platform::CPUDeviceContext, uint8_t, RANK>;
4244

4345
DEFINE_CPU_TRANS(1);
4446
DEFINE_CPU_TRANS(2);

0 commit comments

Comments
 (0)