Skip to content

Commit 5828101

Browse files
authored
make uint8 support in data_type transform and memory optimize (#10715)
* "a piece of job." * "fix typeo" * "fix ci"
1 parent ebefdbe commit 5828101

File tree

4 files changed

+11
-1
lines changed

4 files changed

+11
-1
lines changed

paddle/fluid/framework/data_type_transform.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,12 @@ void TransDataType(const OpKernelType& kernel_type_for_var,
9191
case proto::VarType::BOOL:
9292
framework::VisitDataType(dst_type, CastDataType<bool>(in, out, ctx));
9393
break;
94+
case proto::VarType::INT16:
95+
framework::VisitDataType(dst_type, CastDataType<bool>(in, out, ctx));
96+
break;
97+
case proto::VarType::UINT8:
98+
framework::VisitDataType(dst_type, CastDataType<bool>(in, out, ctx));
99+
break;
94100
default:
95101
PADDLE_THROW("Not support type %d", src_type);
96102
}

paddle/fluid/pybind/protobuf.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,7 @@ void BindVarDsec(pybind11::module *m) {
238238

239239
pybind11::enum_<pd::proto::VarType::Type>(var_desc, "VarType", "")
240240
.value("BOOL", pd::proto::VarType::BOOL)
241+
.value("UINT8", pd::proto::VarType::UINT8)
241242
.value("INT16", pd::proto::VarType::INT16)
242243
.value("INT32", pd::proto::VarType::INT32)
243244
.value("INT64", pd::proto::VarType::INT64)

python/paddle/fluid/framework.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ def convert_np_dtype_to_dtype_(np_dtype):
7272
return core.VarDesc.VarType.INT64
7373
elif dtype == np.bool:
7474
return core.VarDesc.VarType.BOOL
75+
elif dtype == np.uint8:
76+
return core.VarDesc.VarType.UINT8
7577
else:
7678
raise ValueError("Not supported numpy dtype " + str(dtype))
7779

python/paddle/fluid/transpiler/memory_optimization_transpiler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424
core.VarDesc.VarType.INT16: 2,
2525
core.VarDesc.VarType.INT32: 4,
2626
core.VarDesc.VarType.INT64: 8,
27-
core.VarDesc.VarType.BOOL: 1
27+
core.VarDesc.VarType.BOOL: 1,
28+
core.VarDesc.VarType.UINT8: 1,
2829
}
2930

3031
SUB_BLOCK_OPS = [

0 commit comments

Comments
 (0)