Skip to content

Commit 5d83e87

Browse files
[NPU] fix ut question: histogram. (#1358)
1 parent 0981d33 commit 5d83e87

File tree

3 files changed

+120
-71
lines changed

3 files changed

+120
-71
lines changed

backends/npu/kernels/histogram_kernel.cc

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,34 +14,38 @@
1414

1515
#include "kernels/funcs/npu_funcs.h"
1616
#include "kernels/funcs/npu_op_runner.h"
17+
1718
namespace custom_kernel {
1819
template <typename T, typename Context>
1920
void HistogramKernel(const Context& dev_ctx,
2021
const phi::DenseTensor& input,
22+
const paddle::optional<phi::DenseTensor>& weight,
2123
int64_t bins,
2224
int min,
2325
int max,
26+
bool density,
2427
phi::DenseTensor* output) {
25-
int mbins = static_cast<int>(bins);
26-
dev_ctx.template Alloc<int>(output);
27-
phi::DenseTensor ranges;
28-
phi::DenseTensor nbins;
29-
T mmin = static_cast<T>(min);
30-
T mmax = static_cast<T>(max);
31-
std::vector<T> mrange{mmax, mmin};
32-
TensorFromVector<T>(dev_ctx, mrange, dev_ctx, &ranges);
33-
std::vector<T> ss = {bins};
34-
TensorFromVector<T>(dev_ctx, ss, dev_ctx, &nbins);
35-
auto output_dim = output->dims();
36-
output->Resize({-1});
37-
NpuOpRunner histogram_runner;
38-
histogram_runner.SetType("HistogramFixedWidth")
28+
PADDLE_ENFORCE_EQ(
29+
weight || density,
30+
false,
31+
phi::errors::InvalidArgument("PaddlePaddle does not support parameters "
32+
"weight and density on the NPU."));
33+
34+
dev_ctx.template Alloc<T>(output);
35+
EXEC_NPU_CMD(aclnnInplaceZero, dev_ctx, *output);
36+
37+
int bins_trans = bins;
38+
float min_trans = min;
39+
float max_trans = max;
40+
41+
NpuOpRunner runner;
42+
runner.SetType("Histogram")
3943
.AddInput(input)
40-
.AddInput(ranges)
41-
.AddInput(nbins)
4244
.AddOutput(*output)
43-
.Run(dev_ctx.stream());
44-
output->Resize(output_dim);
45+
.AddAttr("bins", bins_trans)
46+
.AddAttr("min", min_trans)
47+
.AddAttr("max", max_trans);
48+
runner.Run(dev_ctx.stream());
4549
}
4650
}; // namespace custom_kernel
4751

@@ -51,6 +55,4 @@ PD_REGISTER_PLUGIN_KERNEL(histogram,
5155
custom_kernel::HistogramKernel,
5256
float,
5357
int,
54-
int64_t) {
55-
kernel->OutputAt(0).SetDataType(paddle::DataType::INT64);
56-
}
58+
int64_t) {}

backends/npu/tests/unittests/test_histogram_op_npu.py

Lines changed: 97 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -15,78 +15,126 @@
1515
import unittest
1616

1717
import numpy as np
18-
from tests.op_test import OpTest
1918

2019
import paddle
2120
from paddle import base
21+
from paddle.pir_utils import test_with_pir_api
2222

2323

24-
class TestHistogramOpError(unittest.TestCase):
25-
"""Test histogram op error."""
24+
class TestHistogram(unittest.TestCase):
25+
"""Test histogram api."""
2626

27-
def run_network(self, net_func):
28-
main_program = base.Program()
29-
startup_program = base.Program()
30-
with base.program_guard(main_program, startup_program):
31-
net_func()
32-
exe = base.Executor()
33-
exe.run(main_program)
34-
35-
def test_bins_error(self):
36-
"""Test bins should be greater than or equal to 1."""
37-
38-
def net_func():
39-
input_value = paddle.tensor.fill_constant(
40-
shape=[3, 4], dtype="float32", value=3.0
41-
)
42-
paddle.histogram(input=input_value, bins=-1, min=1, max=5)
43-
44-
with self.assertRaises(ValueError):
45-
self.run_network(net_func)
46-
47-
def test_min_max_error(self):
48-
"""Test max must be larger or equal to min."""
49-
50-
def net_func():
51-
input_value = paddle.tensor.fill_constant(
52-
shape=[3, 4], dtype="float32", value=3.0
53-
)
54-
paddle.histogram(input=input_value, bins=1, min=5, max=1)
55-
56-
with self.assertRaises(ValueError):
57-
self.run_network(net_func)
58-
59-
60-
class TestHistogramOp(OpTest):
6127
def setUp(self):
62-
self.op_type = "histogram"
6328
self.init_test_case()
64-
np_input = np.random.uniform(low=0.0, high=20.0, size=self.in_shape)
65-
self.python_api = paddle.histogram
66-
self.inputs = {"X": np_input}
67-
self.init_attrs()
68-
Out, _ = np.histogram(np_input, bins=self.bins, range=(self.min, self.max))
69-
self.outputs = {"Out": Out.astype(np.int64)}
29+
self.input_np = np.random.uniform(
30+
low=0.0, high=20.0, size=self.in_shape
31+
).astype(np.float32)
32+
self.weight_np = np.random.uniform(
33+
low=0.0, high=1.0, size=self.in_shape
34+
).astype(np.float32)
7035

7136
def init_test_case(self):
7237
self.in_shape = (10, 12)
7338
self.bins = 5
7439
self.min = 1
7540
self.max = 5
41+
self.density = False
42+
self.is_weight = False
43+
44+
@test_with_pir_api
45+
def test_static_graph(self):
46+
startup_program = paddle.static.Program()
47+
train_program = paddle.static.Program()
48+
with paddle.static.program_guard(train_program, startup_program):
49+
inputs = paddle.static.data(
50+
name="input", dtype="float32", shape=self.in_shape
51+
)
52+
if self.is_weight:
53+
weight = paddle.static.data(
54+
name="weight", dtype="float32", shape=self.in_shape
55+
)
56+
output = paddle.histogram(
57+
inputs,
58+
bins=self.bins,
59+
min=self.min,
60+
max=self.max,
61+
weight=weight,
62+
density=self.density,
63+
)
64+
else:
65+
output = paddle.histogram(
66+
inputs,
67+
bins=self.bins,
68+
min=self.min,
69+
max=self.max,
70+
density=self.density,
71+
)
72+
place = base.CPUPlace()
73+
if base.core.is_compiled_with_cuda():
74+
place = base.CUDAPlace(0)
75+
exe = base.Executor(place)
76+
if self.is_weight:
77+
res = exe.run(
78+
feed={
79+
"input": self.input_np,
80+
"weight": self.weight_np,
81+
},
82+
fetch_list=[output],
83+
)
84+
else:
85+
res = exe.run(feed={"input": self.input_np}, fetch_list=[output])
86+
87+
actual = np.array(res[0])
88+
Out, _ = np.histogram(
89+
self.input_np,
90+
bins=self.bins,
91+
range=(self.min, self.max),
92+
density=self.density,
93+
weights=self.weight_np if self.is_weight else None,
94+
)
95+
np.testing.assert_allclose(actual, Out, rtol=1e-58, atol=1e-5)
96+
97+
def test_dygraph(self):
98+
with base.dygraph.guard():
99+
inputs_np = np.random.uniform(
100+
low=0.0, high=20.0, size=self.in_shape
101+
).astype(np.float32)
102+
103+
self.inputs = paddle.to_tensor(inputs_np)
76104

77-
def init_attrs(self):
78-
self.attrs = {"bins": self.bins, "min": self.min, "max": self.max}
105+
weight_np = np.random.uniform(low=0.0, high=1.0, size=self.in_shape).astype(
106+
np.float32
107+
)
108+
weight = paddle.to_tensor(weight_np)
109+
110+
actual = paddle.histogram(
111+
self.inputs,
112+
bins=5,
113+
min=1,
114+
max=5,
115+
weight=weight if self.is_weight else None,
116+
density=self.density,
117+
)
118+
119+
Out, _ = np.histogram(
120+
inputs_np,
121+
bins=5,
122+
range=(1, 5),
123+
weights=weight_np if self.is_weight else None,
124+
density=self.density,
125+
)
79126

80-
def test_check_output(self):
81-
self.check_output()
127+
np.testing.assert_allclose(actual.numpy(), Out, rtol=1e-58, atol=1e-5)
82128

83129

84-
class TestHistogramOp_ZeroDim(TestHistogramOp):
130+
class TestHistogramOp_ZeroDim(TestHistogram):
85131
def init_test_case(self):
86132
self.in_shape = []
87133
self.bins = 5
88134
self.min = 1
89135
self.max = 5
136+
self.density = False
137+
self.is_weight = False
90138

91139

92140
if __name__ == "__main__":

backends/npu/tools/disable_ut_npu_910b

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ test_conv3d_op_npu
44
test_contiguous_op_npu
55
test_einsum_op_npu
66
test_fused_matmul_bias_op_npu
7-
test_histogram_op_npu
87
test_kldiv_loss_op_npu
98
test_zero_dim_tensor_npu
109
test_matmulv2_op_npu

0 commit comments

Comments
 (0)