Skip to content

Commit f3436af

Browse files
[cherry-pick] Sum kernel for CPU supporting BF16 and SelectedRows (#32631) (#32755)
1 parent 2144852 commit f3436af

File tree

5 files changed

+115
-26
lines changed

5 files changed

+115
-26
lines changed

paddle/fluid/operators/math/blas_impl.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#ifdef PADDLE_WITH_MKLML
1616
#include <mkl.h>
1717
#endif
18+
1819
#include <algorithm>
1920
#include <cmath>
2021
#include <limits>
@@ -28,6 +29,19 @@
2829
namespace paddle {
2930
namespace operators {
3031
namespace math {
32+
namespace detail {
33+
34+
template <typename T>
35+
static void axpy(int n, const T alpha, const T *x, const int incx, T *y,
36+
const int incy) {
37+
// Y = Y + alpha * X
38+
while (n-- > 0) {
39+
*y += alpha * *x;
40+
y = y + incy;
41+
x = x + incx;
42+
}
43+
}
44+
} // namespace detail
3145

3246
template <typename T>
3347
struct CBlas;
@@ -43,6 +57,11 @@ struct CBlas<int8_t> {
4357

4458
template <>
4559
struct CBlas<platform::bfloat16> {
60+
template <typename... ARGS>
61+
static void AXPY(ARGS... args) {
62+
detail::axpy(args...);
63+
}
64+
4665
template <typename... ARGS>
4766
static void VCOPY(ARGS... args) {
4867
PADDLE_THROW(platform::errors::Unimplemented(

paddle/fluid/operators/math/selected_rows_functor.cc

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,8 @@ template struct SelectedRowsAddToTensor<platform::CPUDeviceContext, float>;
285285
template struct SelectedRowsAddToTensor<platform::CPUDeviceContext, double>;
286286
template struct SelectedRowsAddToTensor<platform::CPUDeviceContext, int>;
287287
template struct SelectedRowsAddToTensor<platform::CPUDeviceContext, int64_t>;
288+
template struct SelectedRowsAddToTensor<platform::CPUDeviceContext,
289+
platform::bfloat16>;
288290

289291
// This is a separated namespace for manipulate SelectedRows typed
290292
// data. Like merge duplicated rows, adding two SelectedRows etc.
@@ -294,21 +296,17 @@ template struct SelectedRowsAddToTensor<platform::CPUDeviceContext, int64_t>;
294296
// add or mul.
295297
namespace scatter {
296298

297-
template <typename DeviceContext, typename T>
298-
typename std::enable_if<
299-
std::is_floating_point<T>::value &&
300-
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
301-
elementwise_add_to(const DeviceContext& ctx, BlasT<DeviceContext, T>* blas,
302-
size_t data_len, const T* in, T* out) {
303-
blas->AXPY(data_len, 1., in, out);
299+
template <typename T>
300+
typename std::enable_if<std::is_floating_point<T>::value>::type
301+
elementwise_add_to(BlasT<platform::CPUDeviceContext, T>* blas, size_t data_len,
302+
const T* in, T* out) {
303+
blas->AXPY(data_len, T(1.f), in, out);
304304
}
305305

306-
template <typename DeviceContext, typename T>
307-
typename std::enable_if<
308-
!std::is_floating_point<T>::value &&
309-
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
310-
elementwise_add_to(const DeviceContext& ctx, BlasT<DeviceContext, T>* blas,
311-
size_t data_len, const T* in, T* out) {
306+
template <typename T>
307+
typename std::enable_if<std::is_integral<T>::value>::type elementwise_add_to(
308+
BlasT<platform::CPUDeviceContext, T>* blas, size_t data_len, const T* in,
309+
T* out) {
312310
for (size_t i = 0; i < data_len; i++) {
313311
out[i] += in[i];
314312
}
@@ -412,7 +410,7 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
412410
out.set_rows(merge_rows);
413411

414412
math::SetConstant<platform::CPUDeviceContext, T> constant_functor;
415-
constant_functor(context, out.mutable_value(), 0.0);
413+
constant_functor(context, out.mutable_value(), static_cast<T>(0.f));
416414

417415
std::unordered_map<int64_t, size_t> rows_to_id;
418416
for (size_t i = 0; i < merge_rows.size(); ++i) {
@@ -429,9 +427,9 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
429427

430428
for (size_t i = 0; i < input_rows.size(); i++) {
431429
size_t out_i = rows_to_id[input_rows[i]];
432-
elementwise_add_to<platform::CPUDeviceContext, T>(
433-
context, &blas, static_cast<size_t>(input_width),
434-
&input_data[i * input_width], &out_data[out_i * input_width]);
430+
elementwise_add_to<T>(&blas, static_cast<size_t>(input_width),
431+
&input_data[i * input_width],
432+
&out_data[out_i * input_width]);
435433
}
436434
}
437435
}
@@ -524,9 +522,9 @@ struct MergeAverage<platform::CPUDeviceContext, T> {
524522

525523
for (size_t i = 0; i < input_rows.size(); i++) {
526524
size_t out_i = rows_to_id[input_rows[i]];
527-
elementwise_add_to<platform::CPUDeviceContext, T>(
528-
context, &blas, static_cast<size_t>(input_width),
529-
&input_data[i * input_width], &out_data[out_i * input_width]);
525+
elementwise_add_to<T>(&blas, static_cast<size_t>(input_width),
526+
&input_data[i * input_width],
527+
&out_data[out_i * input_width]);
530528
}
531529
}
532530
size_t input_width_cast = static_cast<size_t>(input_width);
@@ -547,6 +545,8 @@ template struct MergeAdd<platform::CPUDeviceContext,
547545
paddle::platform::complex64>;
548546
template struct MergeAdd<platform::CPUDeviceContext,
549547
paddle::platform::complex128>;
548+
template struct MergeAdd<platform::CPUDeviceContext,
549+
paddle::platform::bfloat16>;
550550

551551
template struct MergeAverage<platform::CPUDeviceContext, int>;
552552
template struct MergeAverage<platform::CPUDeviceContext, int64_t>;

paddle/fluid/operators/sum_op.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,4 +326,6 @@ REGISTER_OP_CPU_KERNEL(
326326
sum, ops::SumKernel<paddle::platform::CPUDeviceContext, float>,
327327
ops::SumKernel<paddle::platform::CPUDeviceContext, double>,
328328
ops::SumKernel<paddle::platform::CPUDeviceContext, int>,
329+
ops::SumKernel<paddle::platform::CPUDeviceContext,
330+
paddle::platform::bfloat16>,
329331
ops::SumKernel<paddle::platform::CPUDeviceContext, int64_t>);

python/paddle/fluid/tests/unittests/test_sgd_op_bf16.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,7 @@ def create_sparse_grad_var(self, scope, place, height, rows, row_numel):
7676
grad_selected_rows = scope.var('Grad').get_selected_rows()
7777
grad_selected_rows.set_height(height)
7878
grad_selected_rows.set_rows(rows)
79-
# grad_array = np.random.random((len(rows), row_numel)).astype('float32')
80-
grad_array = np.full((len(rows), row_numel), 2, np.float32)
79+
grad_array = np.random.random((len(rows), row_numel)).astype('float32')
8180
np_array_bf16 = convert_float_to_uint16(grad_array)
8281

8382
grad_tensor = grad_selected_rows.get_tensor()
@@ -87,8 +86,7 @@ def create_sparse_grad_var(self, scope, place, height, rows, row_numel):
8786

8887
def create_dense_param_var(self, scope, place, height, width):
8988
param_tensor = scope.var('Param').get_tensor()
90-
# param_array = np.random.random((height, width)).astype('float32')
91-
param_array = np.full((height, width), 5, np.float32)
89+
param_array = np.random.random((height, width)).astype('float32')
9290
param_array_bf16 = convert_float_to_uint16(param_array)
9391
param_tensor.set(param_array_bf16, place)
9492

@@ -109,8 +107,7 @@ def create_sparse_param_var(self, scope, place, height, rows, row_numel):
109107

110108
def create_dense_lr_var(self, scope, place):
111109
lr_tensor = scope.var('LearningRate').get_tensor()
112-
# lr_value = np.random.uniform()
113-
lr_value = 2
110+
lr_value = np.random.uniform()
114111
lr_array = np.full((1), lr_value, np.float32)
115112
lr_array_bf16 = convert_float_to_uint16(lr_array)
116113
lr_tensor.set(lr_array_bf16, place)

python/paddle/fluid/tests/unittests/test_sum_op.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,12 @@
1818
import numpy as np
1919
from op_test import OpTest
2020
import paddle
21+
from paddle import enable_static
2122
import paddle.fluid as fluid
2223
import paddle.fluid.core as core
2324
from paddle.fluid.op import Operator
25+
from paddle.fluid.tests.unittests.op_test import (
26+
OpTest, convert_float_to_uint16, convert_uint16_to_float)
2427

2528

2629
class TestSumOp(OpTest):
@@ -141,6 +144,73 @@ def test_w_is_selected_rows(self):
141144
self.check_with_place(place, inplace)
142145

143146

147+
class TestSelectedRowsSumOpInt(TestSelectedRowsSumOp):
148+
def init_kernel_type(self):
149+
self.dtype = np.int32
150+
151+
152+
@unittest.skipIf(not core.supports_bfloat16(),
153+
'place does not support BF16 evaluation')
154+
class TestSelectedRowsSumBF16Op(TestSelectedRowsSumOp):
155+
def setUp(self):
156+
self.height = 10
157+
self.row_numel = 12
158+
self.rows = [0, 1, 2, 3, 4, 5, 6]
159+
self.dtype = np.uint16
160+
self.init_kernel_type()
161+
np.random.seed(12345)
162+
self.data = np.random.random((len(self.rows),
163+
self.row_numel)).astype(np.float32)
164+
165+
def _get_array(self, rows, row_numel):
166+
if len(rows) > 0:
167+
return convert_float_to_uint16(self.data)
168+
else:
169+
return np.ndarray((0, row_numel), dtype=self.dtype)
170+
171+
def check_input_and_optput(self,
172+
scope,
173+
place,
174+
inplace,
175+
w1_has_data=False,
176+
w2_has_data=False,
177+
w3_has_data=False):
178+
179+
self.create_selected_rows(scope, place, "W1", w1_has_data)
180+
self.create_selected_rows(scope, place, "W2", w2_has_data)
181+
self.create_selected_rows(scope, place, "W3", w3_has_data)
182+
183+
# create Out Variable
184+
if inplace:
185+
out_var_name = "W1"
186+
else:
187+
out_var_name = "Out"
188+
out = scope.var(out_var_name).get_selected_rows()
189+
190+
# create and run sum operator
191+
sum_op = Operator("sum", X=["W1", "W2", "W3"], Out=out_var_name)
192+
sum_op.run(scope, place)
193+
194+
has_data_w_num = 0
195+
for has_data in [w1_has_data, w2_has_data, w3_has_data]:
196+
if has_data:
197+
has_data_w_num += 1
198+
199+
if has_data_w_num > 0:
200+
self.assertEqual(len(out.rows()), 7)
201+
out_bf16 = np.array(out.get_tensor())
202+
out_fp32 = convert_uint16_to_float(out_bf16)
203+
ref_fp32 = convert_uint16_to_float(
204+
self._get_array(self.rows, self.row_numel)) * has_data_w_num
205+
np.testing.assert_allclose(out_fp32, ref_fp32, atol=0, rtol=0.95e-2)
206+
else:
207+
self.assertEqual(len(out.rows()), 0)
208+
209+
def test_w_is_selected_rows(self):
210+
for inplace in [True, False]:
211+
self.check_with_place(core.CPUPlace(), inplace)
212+
213+
144214
class TestLoDTensorAndSelectedRowsOp(TestSelectedRowsSumOp):
145215
def setUp(self):
146216
self.height = 10
@@ -324,4 +394,5 @@ def test_list_of_none_input():
324394
create_test_sum_fp16_class(TestLoDTensorAndSelectedRowsOp)
325395

326396
if __name__ == "__main__":
397+
enable_static()
327398
unittest.main()

0 commit comments

Comments
 (0)