Skip to content

Commit feaccbb

Browse files
c_alias_bucketize (#76152)
1 parent 5e72177 commit feaccbb

File tree

3 files changed

+291
-2
lines changed

3 files changed

+291
-2
lines changed

python/paddle/tensor/search.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1168,23 +1168,29 @@ def topk(
11681168
return values, indices
11691169

11701170

1171+
@param_two_alias(["x", "input"], ["sorted_sequence", "boundaries"])
11711172
def bucketize(
11721173
x: Tensor,
11731174
sorted_sequence: Tensor,
11741175
out_int32: bool = False,
11751176
right: bool = False,
11761177
name: str | None = None,
1178+
*,
1179+
out: Tensor | None = None,
11771180
) -> Tensor:
11781181
"""
11791182
This API is used to find the index of the corresponding 1D tensor `sorted_sequence` in the innermost dimension based on the given `x`.
11801183
11811184
Args:
11821185
x (Tensor): An input N-D tensor value with type int32, int64, float32, float64.
1186+
alias: ``input``.
11831187
sorted_sequence (Tensor): An input 1-D tensor with type int32, int64, float32, float64. The value of the tensor monotonically increases in the innermost dimension.
1188+
alias: ``boundaries``.
11841189
out_int32 (bool, optional): Data type of the output tensor which can be int32, int64. The default value is False, and it indicates that the output data type is int64.
11851190
right (bool, optional): Find the upper or lower bounds of the sorted_sequence range in the innermost dimension based on the given `x`. If the value of the sorted_sequence is nan or inf, return the size of the innermost dimension.
11861191
The default value is False and it shows the lower bounds.
11871192
name (str|None, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`.
1193+
out (Tensor|None, optional): The output tensor. Default: None.
11881194
11891195
Returns:
11901196
Tensor (the same sizes of the `x`), return the tensor of int32 if set :attr:`out_int32` is True, otherwise return the tensor of int64.
@@ -1229,26 +1235,35 @@ def bucketize(
12291235
raise ValueError(
12301236
f"sorted_sequence tensor must be 1 dimension, but got dim {sorted_sequence.dim()}"
12311237
)
1232-
return searchsorted(sorted_sequence, x, out_int32, right, name)
1238+
return searchsorted(sorted_sequence, x, out_int32, right, name, out=out)
12331239

12341240

1241+
@param_one_alias(["values", "input"])
12351242
def searchsorted(
12361243
sorted_sequence: Tensor,
12371244
values: Tensor,
12381245
out_int32: bool = False,
12391246
right: bool = False,
12401247
name: str | None = None,
1248+
*,
1249+
side: str | None = None,
1250+
out: Tensor | None = None,
1251+
sorter: Tensor | None = None,
12411252
) -> Tensor:
12421253
"""
12431254
Find the index of the corresponding `sorted_sequence` in the innermost dimension based on the given `values`.
12441255
12451256
Args:
12461257
sorted_sequence (Tensor): An input N-D or 1-D tensor with type int32, int64, float16, float32, float64, bfloat16. The value of the tensor monotonically increases in the innermost dimension.
12471258
values (Tensor): An input N-D tensor value with type int32, int64, float16, float32, float64, bfloat16.
1259+
alias: ``input``.
12481260
out_int32 (bool, optional): Data type of the output tensor which can be int32, int64. The default value is False, and it indicates that the output data type is int64.
12491261
right (bool, optional): Find the upper or lower bounds of the sorted_sequence range in the innermost dimension based on the given `values`. If the value of the sorted_sequence is nan or inf, return the size of the innermost dimension.
12501262
The default value is False and it shows the lower bounds.
12511263
name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
1264+
side (str|None, optional): The same as right but preferred. `left` corresponds to False for right and `right` corresponds to True for right. It will error if this is set to `left` while right is True. Default value is None.
1265+
sorter (Tensor|None, optional): if provided, a tensor matching the shape of the unsorted `sorted_sequence` containing a sequence of indices that sort it in the ascending order on the innermost dimension
1266+
out (Tensor|None, optional): The output tensor. Default: None.
12521267
12531268
Returns:
12541269
Tensor (the same sizes of the `values`), return the tensor of int32 if set :attr:`out_int32` is True, otherwise return the tensor of int64.
@@ -1280,8 +1295,18 @@ def searchsorted(
12801295
[1, 3, 4, 5]])
12811296
12821297
"""
1298+
# If side is present, override the value of right if needed.
1299+
if side is not None and side == "right":
1300+
right = True
1301+
12831302
if in_dynamic_or_pir_mode():
1284-
return _C_ops.searchsorted(sorted_sequence, values, out_int32, right)
1303+
if sorter is not None:
1304+
sorted_sequence = sorted_sequence.take_along_axis(
1305+
axis=-1, indices=sorter
1306+
)
1307+
return _C_ops.searchsorted(
1308+
sorted_sequence, values, out_int32, right, out=out
1309+
)
12851310
else:
12861311
check_variable_and_dtype(
12871312
sorted_sequence,

test/legacy_test/test_bucketize_api.py

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,5 +110,149 @@ def test_empty_input_error(self):
110110
self.assertRaises(AttributeError, paddle.bucketize, x, None)
111111

112112

113+
class TestBucketizeAPI_Extended(unittest.TestCase):
114+
def setUp(self):
115+
self.sorted_sequence = np.array([2, 4, 8, 16]).astype("float64")
116+
self.x2d = np.array([[0, 8, 4, 16], [-1, 2, 8, 4]]).astype("float64")
117+
self.x1d = np.array([0, 8, 4, 16]).astype("float64")
118+
self.sorted_dup = np.array([1, 2, 2, 2, 3]).astype("float64")
119+
self.x_dup = np.array([2, 2, 1, 3]).astype("float64")
120+
self.place = get_places()
121+
122+
def test_dygraph_out_and_out_int32_and_name(self):
123+
# Dynamic diagram: Testing the out parameter (inplace write) and out_int32
124+
paddle.disable_static()
125+
for place in self.place:
126+
with paddle.base.dygraph.guard():
127+
seq = paddle.to_tensor(self.sorted_sequence)
128+
x = paddle.to_tensor(self.x2d)
129+
130+
res32 = paddle.bucketize(
131+
x, seq, out_int32=True, name="test_name"
132+
)
133+
self.assertEqual(res32.dtype, paddle.int32)
134+
ref32 = np.searchsorted(self.sorted_sequence, self.x2d)
135+
np.testing.assert_allclose(
136+
ref32, res32.numpy().astype("int64"), rtol=1e-05
137+
)
138+
139+
# out parameter: supply existing tensor, should be written and returned
140+
out_tensor = paddle.empty(shape=self.x2d.shape, dtype="int64")
141+
ret = paddle.bucketize(x, seq, out=out_tensor)
142+
ref = np.searchsorted(self.sorted_sequence, self.x2d)
143+
np.testing.assert_allclose(ref, out_tensor.numpy(), rtol=1e-05)
144+
paddle.enable_static()
145+
146+
def test_static_out_int32_and_right(self):
147+
# Static image: Testing out_int32 and right=True/False
148+
paddle.enable_static()
149+
for place in self.place:
150+
with paddle.static.program_guard(paddle.static.Program()):
151+
seq = paddle.static.data(
152+
name="seq",
153+
shape=self.sorted_sequence.shape,
154+
dtype="float64",
155+
)
156+
x = paddle.static.data(
157+
name="x", shape=self.x2d.shape, dtype="float64"
158+
)
159+
160+
out_left = paddle.bucketize(
161+
x, seq, right=False, out_int32=False
162+
)
163+
out_right = paddle.bucketize(x, seq, right=True, out_int32=True)
164+
165+
exe = paddle.static.Executor(place)
166+
res_left, res_right = exe.run(
167+
feed={"seq": self.sorted_sequence, "x": self.x2d},
168+
fetch_list=[out_left, out_right],
169+
)
170+
ref_left = np.searchsorted(
171+
self.sorted_sequence, self.x2d, side="left"
172+
)
173+
ref_right = np.searchsorted(
174+
self.sorted_sequence, self.x2d, side="right"
175+
)
176+
np.testing.assert_allclose(ref_left, res_left, rtol=1e-05)
177+
# out_int32 True -> numpy result must be cast-compatible to int32
178+
self.assertEqual(res_right.dtype, np.int32)
179+
np.testing.assert_allclose(
180+
ref_right, res_right.astype("int64"), rtol=1e-05
181+
)
182+
paddle.disable_static()
183+
184+
def test_dygraph_1d_input(self):
185+
# Dynamic image: 1D x test
186+
paddle.disable_static()
187+
for place in self.place:
188+
with paddle.base.dygraph.guard():
189+
seq = paddle.to_tensor(self.sorted_sequence)
190+
x = paddle.to_tensor(self.x1d)
191+
192+
out = paddle.bucketize(x, seq)
193+
ref = np.searchsorted(self.sorted_sequence, self.x1d)
194+
np.testing.assert_allclose(ref, out.numpy(), rtol=1e-05)
195+
paddle.enable_static()
196+
197+
def test_dup_elements_side_behavior(self):
198+
# Left/right difference when testing duplicate elements
199+
paddle.disable_static()
200+
for place in self.place:
201+
with paddle.base.dygraph.guard():
202+
seq = paddle.to_tensor(self.sorted_dup)
203+
x = paddle.to_tensor(self.x_dup)
204+
205+
out_left = paddle.bucketize(x, seq, right=False)
206+
out_right = paddle.bucketize(x, seq, right=True)
207+
208+
ref_left = np.searchsorted(
209+
self.sorted_dup, self.x_dup, side="left"
210+
)
211+
ref_right = np.searchsorted(
212+
self.sorted_dup, self.x_dup, side="right"
213+
)
214+
215+
np.testing.assert_allclose(
216+
ref_left, out_left.numpy(), rtol=1e-05
217+
)
218+
np.testing.assert_allclose(
219+
ref_right, out_right.numpy(), rtol=1e-05
220+
)
221+
paddle.enable_static()
222+
223+
def test_static_and_dygraph_sort_of_api_stability(self):
224+
# Simple coverage: Both static and dynamic calls can succeed (without checking for duplicate results)
225+
paddle.enable_static()
226+
for place in self.place:
227+
with paddle.static.program_guard(paddle.static.Program()):
228+
seq = paddle.static.data(
229+
name="seq",
230+
shape=self.sorted_sequence.shape,
231+
dtype="float64",
232+
)
233+
x = paddle.static.data(
234+
name="x", shape=self.x2d.shape, dtype="float64"
235+
)
236+
_ = paddle.bucketize(
237+
x, seq, out_int32=False, right=False, name="static_case"
238+
)
239+
exe = paddle.static.Executor(place)
240+
exe.run(
241+
feed={"seq": self.sorted_sequence, "x": self.x2d},
242+
fetch_list=[],
243+
)
244+
paddle.disable_static()
245+
246+
paddle.disable_static()
247+
for place in self.place:
248+
with paddle.base.dygraph.guard():
249+
seq = paddle.to_tensor(self.sorted_sequence)
250+
x = paddle.to_tensor(self.x2d)
251+
_ = paddle.bucketize(
252+
x, seq, out_int32=False, right=False, name="dy_case"
253+
)
254+
paddle.enable_static()
255+
256+
113257
if __name__ == "__main__":
114258
unittest.main()

test/legacy_test/test_searchsorted_op.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,8 @@ def test_searchsorted_sortedsequence_size_error():
304304
)
305305

306306
def test_check_type_error(self):
307+
paddle.enable_static()
308+
307309
def test_sortedsequence_values_type_error():
308310
with paddle.static.program_guard(paddle.static.Program()):
309311
sorted_sequence = paddle.static.data(
@@ -317,5 +319,123 @@ def test_sortedsequence_values_type_error():
317319
self.assertRaises(TypeError, test_sortedsequence_values_type_error)
318320

319321

322+
class TestSearchSortedAPI_Extended(unittest.TestCase):
323+
def init_test_case(self):
324+
self.sorted_sequence = np.array([2, 4, 6, 8, 10]).astype("float64")
325+
self.values_2d = np.array([[3, 6, 9], [3, 6, 9]]).astype("float64")
326+
self.values_1d = np.array([3, 6, 9]).astype("float64")
327+
self.unsorted_seq = np.array([6, 2, 10, 4, 8]).astype("float64")
328+
# sorter such that unsorted_seq[sorter] is sorted
329+
self.sorter = np.argsort(self.unsorted_seq).astype("int64")
330+
331+
def setUp(self):
332+
self.init_test_case()
333+
self.place = get_places()
334+
335+
def test_dygraph_side_and_right_priority_and_out_int32(self):
336+
# Test: side takes precedence over right, out_int32 controls the returned dtype
337+
paddle.disable_static()
338+
for place in self.place:
339+
with paddle.base.dygraph.guard():
340+
seq = paddle.to_tensor(self.sorted_sequence)
341+
vals = paddle.to_tensor(self.values_2d)
342+
# Mixed parameter passing: right=False, side='right' -> side takes precedence, should be interpreted as right=True
343+
out = paddle.searchsorted(
344+
seq, vals, right=False, side="right", out_int32=True
345+
)
346+
ref = np.searchsorted(
347+
self.sorted_sequence, self.values_2d, side="right"
348+
)
349+
self.assertEqual(out.dtype, paddle.int32)
350+
np.testing.assert_allclose(
351+
ref, out.numpy().astype("int64"), rtol=1e-05
352+
)
353+
354+
def test_dygraph_out_parameter_and_return_is_out(self):
355+
# Test out parameter: Pass in an existing tensor, write the function on it, and return the same Tensor
356+
paddle.disable_static()
357+
for place in self.place:
358+
with paddle.base.dygraph.guard():
359+
seq = paddle.to_tensor(self.sorted_sequence)
360+
vals = paddle.to_tensor(self.values_2d)
361+
out_tensor = paddle.empty(
362+
shape=self.values_2d.shape, dtype="int64"
363+
)
364+
ret = paddle.searchsorted(seq, vals, out=out_tensor)
365+
ref = np.searchsorted(self.sorted_sequence, self.values_2d)
366+
np.testing.assert_allclose(ref, out_tensor.numpy(), rtol=1e-05)
367+
368+
def test_dygraph_sorter_behavior(self):
369+
# Test sorter parameter: When the sequence is unsorted but a sorter is given, the behavior is consistent with numpy
370+
paddle.disable_static()
371+
for place in self.place:
372+
with paddle.base.dygraph.guard():
373+
seq = paddle.to_tensor(self.unsorted_seq)
374+
vals = paddle.to_tensor(self.values_1d)
375+
sorter_t = paddle.to_tensor(self.sorter)
376+
out = paddle.searchsorted(seq, vals, sorter=sorter_t)
377+
ref = np.searchsorted(
378+
self.unsorted_seq, self.values_1d, sorter=self.sorter
379+
)
380+
np.testing.assert_allclose(ref, out.numpy(), rtol=1e-05)
381+
382+
def test_static_side_and_sorter(self):
383+
# Test side parameters and sorter parameters under static images (aligned with numpy)
384+
paddle.enable_static()
385+
for place in self.place:
386+
with paddle.static.program_guard(paddle.static.Program()):
387+
seq = paddle.static.data(
388+
name="seq", shape=self.unsorted_seq.shape, dtype="float64"
389+
)
390+
vals = paddle.static.data(
391+
name="vals", shape=self.values_1d.shape, dtype="float64"
392+
)
393+
sorter = paddle.static.data(
394+
name="sorter", shape=self.sorter.shape, dtype="int64"
395+
)
396+
397+
out_left = paddle.searchsorted(
398+
seq, vals, side="left", sorter=sorter
399+
)
400+
out_right = paddle.searchsorted(
401+
seq, vals, side="right", sorter=sorter
402+
)
403+
404+
exe = paddle.static.Executor(place)
405+
(res_left, res_right) = exe.run(
406+
feed={
407+
"seq": self.unsorted_seq,
408+
"vals": self.values_1d,
409+
"sorter": self.sorter,
410+
},
411+
fetch_list=[out_left, out_right],
412+
)
413+
ref_left = np.searchsorted(
414+
self.unsorted_seq,
415+
self.values_1d,
416+
side="left",
417+
sorter=self.sorter,
418+
)
419+
ref_right = np.searchsorted(
420+
self.unsorted_seq,
421+
self.values_1d,
422+
side="right",
423+
sorter=self.sorter,
424+
)
425+
np.testing.assert_allclose(ref_left, res_left, rtol=1e-05)
426+
np.testing.assert_allclose(ref_right, res_right, rtol=1e-05)
427+
paddle.disable_static()
428+
429+
def test_dygraph_1d_values_and_name_param(self):
430+
paddle.disable_static()
431+
for place in self.place:
432+
with paddle.base.dygraph.guard():
433+
seq = paddle.to_tensor(self.sorted_sequence)
434+
vals = paddle.to_tensor(self.values_1d)
435+
out = paddle.searchsorted(seq, vals, name="my_search")
436+
ref = np.searchsorted(self.sorted_sequence, self.values_1d)
437+
np.testing.assert_allclose(ref, out.numpy(), rtol=1e-05)
438+
439+
320440
if __name__ == '__main__':
321441
unittest.main()

0 commit comments

Comments
 (0)