Skip to content

Commit fd1c961

Browse files
authored
[NPU] gather support bf16 (#1482)
1 parent bcd77bc commit fd1c961

File tree

2 files changed

+45
-4
lines changed

2 files changed

+45
-4
lines changed

backends/npu/kernels/gather_kernel.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,8 @@ PD_REGISTER_PLUGIN_KERNEL(gather,
176176
double,
177177
int32_t,
178178
int64_t,
179-
phi::dtype::float16) {}
179+
phi::dtype::float16,
180+
phi::dtype::bfloat16) {}
180181

181182
PD_REGISTER_PLUGIN_KERNEL(gather_grad,
182183
npu,
@@ -187,4 +188,5 @@ PD_REGISTER_PLUGIN_KERNEL(gather_grad,
187188
double,
188189
int32_t,
189190
int64_t,
190-
phi::dtype::float16) {}
191+
phi::dtype::float16,
192+
phi::dtype::bfloat16) {}

backends/npu/tests/unittests/test_gather_op_npu.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
import numpy as np
2020
import paddle
2121
import paddle.base as base
22-
from tests.op_test import OpTest
23-
from npu_utils import check_run_big_shape_test
22+
from tests.op_test import OpTest, convert_float_to_uint16
23+
from npu_utils import check_run_big_shape_test, check_soc_version
2424

2525
paddle.enable_static()
2626
SEED = 2021
@@ -132,6 +132,45 @@ def config(self):
132132
self.index_type = "int32"
133133

134134

135+
class TestCase6(TestGatherOp):
136+
def setUp(self):
137+
self.set_npu()
138+
self.place = paddle.CustomPlace("npu", 0)
139+
self.op_type = "gather"
140+
self.config()
141+
xnp = np.random.random(self.x_shape).astype(np.float32)
142+
self.inputs = {
143+
"X": convert_float_to_uint16(xnp),
144+
"Index": np.array(self.index).astype(self.index_type),
145+
}
146+
self.outputs = {"Out": convert_float_to_uint16(xnp[self.inputs["Index"]])}
147+
148+
def set_npu(self):
149+
self.__class__.use_custom_device = True
150+
151+
@check_soc_version
152+
def test_check_output(self):
153+
self.check_output_with_place(self.place)
154+
155+
@check_soc_version
156+
def test_check_grad(self):
157+
self.check_grad_with_place(
158+
self.place,
159+
["X"],
160+
"Out",
161+
max_relative_error=0.006,
162+
)
163+
164+
def config(self):
165+
"""
166+
For multi-dimension input
167+
"""
168+
self.x_shape = (10, 20)
169+
self.x_type = "bfloat16"
170+
self.index = [1, 3, 5]
171+
self.index_type = "int32"
172+
173+
135174
class API_TestGather(unittest.TestCase):
136175
def test_out1(self):
137176
with base.program_guard(base.Program(), base.Program()):

0 commit comments

Comments
 (0)