Skip to content

Commit e26f112

Browse files
authored
Add fp16 mul op support and bind paddle fp16 to numpy fp16 (#9017)
* add fp16 mul op support * small fix * fix bug * small fix * fix PADDLE_WITH_CUDA compiling issue * reorg code * test for pybind * treate as float16 as uint16_t in pybind * bind np.float16 to paddle float16 * small fix * clean code * remove redundancy * fix mul_op test * address comments * small fix * add is_float16_supported func
1 parent 7140071 commit e26f112

File tree

6 files changed

+117
-19
lines changed

6 files changed

+117
-19
lines changed

paddle/fluid/operators/mul_op.cc

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,14 @@ limitations under the License. */
1717
namespace paddle {
1818
namespace operators {
1919

20+
using framework::OpKernelType;
2021
using framework::Tensor;
2122

22-
class MulOpShapeInference : public framework::InferShapeBase {
23+
class MulOp : public framework::OperatorWithKernel {
2324
public:
24-
void operator()(framework::InferShapeContext* ctx) const override {
25+
using framework::OperatorWithKernel::OperatorWithKernel;
26+
27+
void InferShape(framework::InferShapeContext* ctx) const override {
2528
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of MulOp should not be null.");
2629
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) of MulOp should not be null.");
2730
PADDLE_ENFORCE(ctx->HasOutput("Out"),
@@ -122,7 +125,7 @@ or not. But the output only shares the LoD information with input $X$.
122125
}
123126
};
124127

125-
class MulOpGrad : public framework::OperatorWithKernel {
128+
class MulGradOp : public framework::OperatorWithKernel {
126129
public:
127130
using framework::OperatorWithKernel::OperatorWithKernel;
128131

@@ -156,10 +159,7 @@ class MulOpGrad : public framework::OperatorWithKernel {
156159
} // namespace paddle
157160

158161
namespace ops = paddle::operators;
159-
REGISTER_OPERATOR(mul, paddle::framework::OperatorWithKernel, ops::MulOpMaker,
160-
ops::MulOpShapeInference,
161-
paddle::framework::DefaultGradOpDescMaker<true>);
162-
REGISTER_OPERATOR(mul_grad, ops::MulOpGrad);
162+
REGISTER_OP(mul, ops::MulOp, ops::MulOpMaker, mul_grad, ops::MulGradOp);
163163
REGISTER_OP_CPU_KERNEL(
164164
mul, ops::MulKernel<paddle::platform::CPUDeviceContext, float>);
165165
REGISTER_OP_CPU_KERNEL(

paddle/fluid/operators/mul_op.cu.cc

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,11 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/fluid/operators/mul_op.h"
16+
#include "paddle/fluid/platform/float16.h"
1617

1718
namespace ops = paddle::operators;
18-
REGISTER_OP_CUDA_KERNEL(
19-
mul, ops::MulKernel<paddle::platform::CUDADeviceContext, float>);
20-
REGISTER_OP_CUDA_KERNEL(
21-
mul_grad, ops::MulGradKernel<paddle::platform::CUDADeviceContext, float>);
19+
namespace plat = paddle::platform;
20+
REGISTER_OP_CUDA_KERNEL(mul, ops::MulKernel<plat::CUDADeviceContext, float>,
21+
ops::MulKernel<plat::CUDADeviceContext, plat::float16>);
22+
REGISTER_OP_CUDA_KERNEL(mul_grad,
23+
ops::MulGradKernel<plat::CUDADeviceContext, float>);

paddle/fluid/operators/mul_op.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ class MulKernel : public framework::OpKernel<T> {
4848
}
4949
math::matmul<DeviceContext, T>(
5050
context.template device_context<DeviceContext>(), x_matrix, false,
51-
y_matrix, false, 1, z, 0);
51+
y_matrix, false, static_cast<T>(1), z, static_cast<T>(0));
5252
if (z_dim.size() != 2) {
5353
z->Resize(z_dim);
5454
}

paddle/fluid/pybind/pybind.cc

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ limitations under the License. */
3131
#include "paddle/fluid/operators/cond_op.h"
3232
#include "paddle/fluid/operators/net_op.h"
3333
#include "paddle/fluid/platform/enforce.h"
34+
#include "paddle/fluid/platform/gpu_info.h"
3435
#include "paddle/fluid/platform/place.h"
3536
#include "paddle/fluid/platform/profiler.h"
3637
#include "paddle/fluid/pybind/const_value.h"
@@ -103,12 +104,14 @@ PYBIND11_PLUGIN(core) {
103104
.def("set", PyCPUTensorSetFromArray<double>)
104105
.def("set", PyCPUTensorSetFromArray<int64_t>)
105106
.def("set", PyCPUTensorSetFromArray<bool>)
107+
.def("set", PyCPUTensorSetFromArray<uint16_t>)
106108
#ifdef PADDLE_WITH_CUDA
107109
.def("set", PyCUDATensorSetFromArray<float>)
108110
.def("set", PyCUDATensorSetFromArray<int>)
109111
.def("set", PyCUDATensorSetFromArray<double>)
110112
.def("set", PyCUDATensorSetFromArray<int64_t>)
111113
.def("set", PyCUDATensorSetFromArray<bool>)
114+
.def("set", PyCUDATensorSetFromArray<uint16_t>)
112115
#endif
113116
.def("shape", [](Tensor &self) { return vectorize(self.dims()); })
114117
.def("set_float_element", TensorSetElement<float>)
@@ -315,7 +318,6 @@ All parameter, weight, gradient are variables in Paddle.
315318
#endif
316319
});
317320
// clang-format on
318-
319321
#ifdef PADDLE_WITH_CUDA
320322
py::class_<platform::Communicator>(m, "Communicator").def(py::init<>());
321323
#endif
@@ -423,6 +425,12 @@ All parameter, weight, gradient are variables in Paddle.
423425
m.def("init_devices", &framework::InitDevices);
424426

425427
m.def("is_compiled_with_cuda", IsCompiledWithCUDA);
428+
#ifdef PADDLE_WITH_CUDA
429+
m.def("is_float16_supported", [](const platform::CUDAPlace &place) -> bool {
430+
// Only GPUs with Compute Capability >= 53 support float16
431+
return platform::GetCUDAComputeCapability(place.device) >= 53;
432+
});
433+
#endif
426434

427435
m.def("set_feed_variable", framework::SetFeedVariable);
428436
m.def("get_fetch_variable", framework::GetFetchVariable);

paddle/fluid/pybind/tensor_py.h

Lines changed: 56 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License. */
1717
#include "paddle/fluid/framework/lod_tensor.h"
1818
#include "paddle/fluid/memory/memcpy.h"
1919
#include "paddle/fluid/platform/device_context.h"
20+
#include "paddle/fluid/platform/float16.h"
2021
#include "pybind11/numpy.h"
2122
#include "pybind11/pybind11.h"
2223

@@ -77,21 +78,32 @@ struct CastToPyBufferImpl<true, I, ARGS...> {
7778
} else if (paddle::platform::is_cpu_place(tensor.place())) {
7879
dst_tensor = tensor;
7980
}
80-
return py::buffer_info(dst_tensor.data<CUR_TYPE>(), sizeof(CUR_TYPE),
81-
py::format_descriptor<CUR_TYPE>::format(),
82-
(size_t)framework::arity(dst_tensor.dims()),
83-
dims_outside, strides);
81+
82+
if (std::type_index(typeid(CUR_TYPE)) ==
83+
std::type_index(typeid(platform::float16))) {
84+
return py::buffer_info(dst_tensor.data<CUR_TYPE>(), sizeof(CUR_TYPE),
85+
"e", /* np.dtype('e') == np.float16 */
86+
(size_t)framework::arity(dst_tensor.dims()),
87+
dims_outside, strides);
88+
} else {
89+
return py::buffer_info(dst_tensor.data<CUR_TYPE>(), sizeof(CUR_TYPE),
90+
py::format_descriptor<CUR_TYPE>::format(),
91+
(size_t)framework::arity(dst_tensor.dims()),
92+
dims_outside, strides);
93+
}
8494
} else {
8595
constexpr bool less = I + 1 < std::tuple_size<std::tuple<ARGS...>>::value;
8696
return CastToPyBufferImpl<less, I + 1, ARGS...>()(tensor);
8797
}
8898
}
8999
};
100+
90101
} // namespace details
102+
91103
inline py::buffer_info CastToPyBuffer(framework::Tensor &tensor) {
92104
auto buffer_info =
93-
details::CastToPyBufferImpl<true, 0, float, int, double, int64_t, bool>()(
94-
tensor);
105+
details::CastToPyBufferImpl<true, 0, float, int, double, int64_t, bool,
106+
platform::float16>()(tensor);
95107
return buffer_info;
96108
}
97109

@@ -136,6 +148,22 @@ void PyCPUTensorSetFromArray(
136148
std::memcpy(dst, array.data(), sizeof(T) * array.size());
137149
}
138150

151+
template <>
152+
void PyCPUTensorSetFromArray(
153+
framework::Tensor &self,
154+
py::array_t<uint16_t, py::array::c_style | py::array::forcecast> array,
155+
paddle::platform::CPUPlace &place) {
156+
std::vector<int64_t> dims;
157+
dims.reserve(array.ndim());
158+
for (size_t i = 0; i < array.ndim(); ++i) {
159+
dims.push_back((int)array.shape()[i]);
160+
}
161+
162+
self.Resize(framework::make_ddim(dims));
163+
auto *dst = self.mutable_data<platform::float16>(place);
164+
std::memcpy(dst, array.data(), sizeof(uint16_t) * array.size());
165+
}
166+
139167
#ifdef PADDLE_WITH_CUDA
140168
template <typename T>
141169
void PyCUDATensorSetFromArray(
@@ -157,6 +185,28 @@ void PyCUDATensorSetFromArray(
157185
paddle::platform::GpuMemcpyAsync(dst, array.data(), sizeof(T) * array.size(),
158186
cudaMemcpyHostToDevice, dev_ctx->stream());
159187
}
188+
189+
template <>
190+
void PyCUDATensorSetFromArray(
191+
framework::Tensor &self,
192+
py::array_t<uint16_t, py::array::c_style | py::array::forcecast> array,
193+
paddle::platform::CUDAPlace &place) {
194+
std::vector<int64_t> dims;
195+
dims.reserve(array.ndim());
196+
for (size_t i = 0; i < array.ndim(); ++i) {
197+
dims.push_back((int)array.shape()[i]);
198+
}
199+
200+
self.Resize(framework::make_ddim(dims));
201+
auto *dst = self.mutable_data<platform::float16>(place);
202+
203+
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
204+
auto dev_ctx =
205+
static_cast<const platform::CUDADeviceContext *>(pool.Get(place));
206+
paddle::platform::GpuMemcpyAsync(dst, array.data(),
207+
sizeof(uint16_t) * array.size(),
208+
cudaMemcpyHostToDevice, dev_ctx->stream());
209+
}
160210
#endif
161211

162212
} // namespace pybind

python/paddle/fluid/tests/unittests/test_mul_op.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import unittest
1616
import numpy as np
17+
import paddle.fluid.core as core
1718
from op_test import OpTest
1819

1920

@@ -69,5 +70,42 @@ def test_check_grad_ignore_y(self):
6970
['X'], 'Out', max_relative_error=0.5, no_grad_set=set('Y'))
7071

7172

73+
class TestFP16MulOp1(OpTest):
74+
def setUp(self):
75+
self.op_type = "mul"
76+
x = np.random.random((32, 84)).astype("float16")
77+
y = np.random.random((84, 100)).astype("float16")
78+
self.inputs = {'X': x.view(np.uint16), 'Y': y.view(np.uint16)}
79+
self.outputs = {'Out': np.dot(x, y)}
80+
81+
def test_check_output(self):
82+
if core.is_compiled_with_cuda():
83+
place = core.CUDAPlace(0)
84+
if core.is_float16_supported(place):
85+
self.check_output_with_place(place, atol=1e-1)
86+
87+
88+
class TestFP16MulOp2(OpTest):
89+
def setUp(self):
90+
self.op_type = "mul"
91+
x = np.random.random((15, 4, 12, 10)).astype("float16")
92+
y = np.random.random((4, 30, 8, 2, 9)).astype("float16")
93+
self.inputs = {'X': x.view(np.uint16), 'Y': y.view(np.uint16)}
94+
self.attrs = {
95+
'x_num_col_dims': 2,
96+
'y_num_col_dims': 2,
97+
}
98+
result = np.dot(
99+
x.reshape(15 * 4, 12 * 10), y.reshape(4 * 30, 8 * 2 * 9))
100+
result = result.reshape(15, 4, 8, 2, 9)
101+
self.outputs = {'Out': result}
102+
103+
def test_check_output(self):
104+
if core.is_compiled_with_cuda():
105+
place = core.CUDAPlace(0)
106+
if core.is_float16_supported(place):
107+
self.check_output_with_place(place, atol=2e-1)
108+
109+
72110
if __name__ == "__main__":
73111
unittest.main()

0 commit comments

Comments
 (0)