Skip to content

Commit 285ef85

Browse files
committed
wip
1 parent f8a2a90 commit 285ef85

File tree

1 file changed

+99
-0
lines changed

1 file changed

+99
-0
lines changed

src/array_api_extra/_funcs.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,89 @@ def expand_dims(
292292
return a
293293

294294

295+
def isin(
296+
element: Array,
297+
test_elements: Array,
298+
/,
299+
*,
300+
assume_unique: bool = False,
301+
invert: bool = False,
302+
xp: ModuleType,
303+
) -> Array:
304+
"""Calculates ``element in test_elements``, broadcasting over `element`
305+
only.
306+
307+
Returns a boolean array of the same shape as `element` that is True
308+
where an element of `element` is in `test_elements` and False otherwise.
309+
"""
310+
311+
original_element_shape = element.shape
312+
element = xp.reshape(element, (-1,))
313+
test_elements = xp.reshape(test_elements, (-1,))
314+
return xp.reshape(
315+
_in1d(
316+
element,
317+
test_elements,
318+
assume_unique=assume_unique,
319+
invert=invert,
320+
xp=xp,
321+
),
322+
original_element_shape,
323+
)
324+
325+
326+
def _in1d(
327+
x1: Array,
328+
x2: Array,
329+
/,
330+
*,
331+
assume_unique: bool = False,
332+
invert: bool = False,
333+
xp: ModuleType,
334+
) -> Array:
335+
"""Checks whether each element of an array is also present in a
336+
second array.
337+
338+
Returns a boolean array the same length as `x1` that is True
339+
where an element of `x1` is in `x2` and False otherwise.
340+
341+
This function has been adapted using the original implementation
342+
present in numpy:
343+
https://github.com/numpy/numpy/blob/v1.26.0/numpy/lib/arraysetops.py#L524-L758
344+
"""
345+
346+
# This code is run to make the code significantly faster
347+
if x2.shape[0] < 10 * x1.shape[0] ** 0.145:
348+
if invert:
349+
mask = xp.ones(x1.shape[0], dtype=xp.bool, device=x1.device)
350+
for a in x2:
351+
mask &= x1 != a
352+
else:
353+
mask = xp.zeros(x1.shape[0], dtype=xp.bool, device=x1.device)
354+
for a in x2:
355+
mask |= x1 == a
356+
return mask
357+
358+
if not assume_unique:
359+
x1, rev_idx = xp.unique_inverse(x1)
360+
x2 = xp.unique_values(x2)
361+
362+
# ar = xp.concat((x1, x2))
363+
# device_ = device(ar)
364+
# # We need this to be a stable sort.
365+
# order = xp.argsort(ar, stable=True)
366+
# reverse_order = xp.argsort(order, stable=True)
367+
# sar = xp.take(ar, order, axis=0)
368+
# bool_ar = sar[1:] != sar[:-1] if invert else sar[1:] == sar[:-1]
369+
# flag = xp.concat((bool_ar, xp.asarray([invert], device=device_)))
370+
# ret = xp.take(flag, reverse_order, axis=0)
371+
372+
# if assume_unique:
373+
# return ret[: x1.shape[0]]
374+
# return xp.take(ret, rev_idx, axis=0)
375+
return None
376+
377+
295378
def kron(a: Array, b: Array, /, *, xp: ModuleType) -> Array:
296379
"""
297380
Kronecker product of two arrays.
@@ -399,6 +482,22 @@ def kron(a: Array, b: Array, /, *, xp: ModuleType) -> Array:
399482
return xp.reshape(result, tuple(xp.multiply(a_shape, b_shape)))
400483

401484

485+
def setdiff1d(
486+
x1: Array, x2: Array, /, *, assume_unique: bool = False, xp: ModuleType
487+
) -> Array:
488+
"""Find the set difference of two arrays.
489+
490+
Return the unique values in `x1` that are not in `x2`.
491+
"""
492+
493+
if assume_unique:
494+
x1 = xp.reshape(x1, (-1,))
495+
else:
496+
x1 = xp.unique_values(x1)
497+
x2 = xp.unique_values(x2)
498+
return x1[_in1d(x1, x2, assume_unique=True, invert=True, xp=xp)]
499+
500+
402501
def sinc(x: Array, /, *, xp: ModuleType) -> Array:
403502
r"""
404503
Return the normalized sinc function.

0 commit comments

Comments
 (0)