@@ -292,6 +292,89 @@ def expand_dims(
292292 return a
293293
294294
295+ def isin (
296+ element : Array ,
297+ test_elements : Array ,
298+ / ,
299+ * ,
300+ assume_unique : bool = False ,
301+ invert : bool = False ,
302+ xp : ModuleType ,
303+ ) -> Array :
304+ """Calculates ``element in test_elements``, broadcasting over `element`
305+ only.
306+
307+ Returns a boolean array of the same shape as `element` that is True
308+ where an element of `element` is in `test_elements` and False otherwise.
309+ """
310+
311+ original_element_shape = element .shape
312+ element = xp .reshape (element , (- 1 ,))
313+ test_elements = xp .reshape (test_elements , (- 1 ,))
314+ return xp .reshape (
315+ _in1d (
316+ element ,
317+ test_elements ,
318+ assume_unique = assume_unique ,
319+ invert = invert ,
320+ xp = xp ,
321+ ),
322+ original_element_shape ,
323+ )
324+
325+
326+ def _in1d (
327+ x1 : Array ,
328+ x2 : Array ,
329+ / ,
330+ * ,
331+ assume_unique : bool = False ,
332+ invert : bool = False ,
333+ xp : ModuleType ,
334+ ) -> Array :
335+ """Checks whether each element of an array is also present in a
336+ second array.
337+
338+ Returns a boolean array the same length as `x1` that is True
339+ where an element of `x1` is in `x2` and False otherwise.
340+
341+ This function has been adapted using the original implementation
342+ present in numpy:
343+ https://github.com/numpy/numpy/blob/v1.26.0/numpy/lib/arraysetops.py#L524-L758
344+ """
345+
346+ # This code is run to make the code significantly faster
347+ if x2 .shape [0 ] < 10 * x1 .shape [0 ] ** 0.145 :
348+ if invert :
349+ mask = xp .ones (x1 .shape [0 ], dtype = xp .bool , device = x1 .device )
350+ for a in x2 :
351+ mask &= x1 != a
352+ else :
353+ mask = xp .zeros (x1 .shape [0 ], dtype = xp .bool , device = x1 .device )
354+ for a in x2 :
355+ mask |= x1 == a
356+ return mask
357+
358+ if not assume_unique :
359+ x1 , rev_idx = xp .unique_inverse (x1 )
360+ x2 = xp .unique_values (x2 )
361+
362+ # ar = xp.concat((x1, x2))
363+ # device_ = device(ar)
364+ # # We need this to be a stable sort.
365+ # order = xp.argsort(ar, stable=True)
366+ # reverse_order = xp.argsort(order, stable=True)
367+ # sar = xp.take(ar, order, axis=0)
368+ # bool_ar = sar[1:] != sar[:-1] if invert else sar[1:] == sar[:-1]
369+ # flag = xp.concat((bool_ar, xp.asarray([invert], device=device_)))
370+ # ret = xp.take(flag, reverse_order, axis=0)
371+
372+ # if assume_unique:
373+ # return ret[: x1.shape[0]]
374+ # return xp.take(ret, rev_idx, axis=0)
375+ return None
376+
377+
295378def kron (a : Array , b : Array , / , * , xp : ModuleType ) -> Array :
296379 """
297380 Kronecker product of two arrays.
@@ -399,6 +482,22 @@ def kron(a: Array, b: Array, /, *, xp: ModuleType) -> Array:
399482 return xp .reshape (result , tuple (xp .multiply (a_shape , b_shape )))
400483
401484
485+ def setdiff1d (
486+ x1 : Array , x2 : Array , / , * , assume_unique : bool = False , xp : ModuleType
487+ ) -> Array :
488+ """Find the set difference of two arrays.
489+
490+ Return the unique values in `x1` that are not in `x2`.
491+ """
492+
493+ if assume_unique :
494+ x1 = xp .reshape (x1 , (- 1 ,))
495+ else :
496+ x1 = xp .unique_values (x1 )
497+ x2 = xp .unique_values (x2 )
498+ return x1 [_in1d (x1 , x2 , assume_unique = True , invert = True , xp = xp )]
499+
500+
402501def sinc (x : Array , / , * , xp : ModuleType ) -> Array :
403502 r"""
404503 Return the normalized sinc function.
0 commit comments