@@ -490,42 +490,86 @@ def diagonal(a, offset=0, axis1=0, axis2=1):
490
490
)
491
491
492
492
493
- def extract (condition , x ):
493
+ def extract (condition , a ):
494
494
"""
495
495
Return the elements of an array that satisfy some condition.
496
496
497
+ This is equivalent to
498
+ ``dpnp.compress(dpnp.ravel(condition), dpnp.ravel(a))``. If `condition`
499
+ is boolean :obj:`dpnp.extract` is equivalent to ``a[condition]``.
500
+
501
+ Note that :obj:`dpnp.place` does the exact opposite of :obj:`dpnp.extract`.
502
+
497
503
For full documentation refer to :obj:`numpy.extract`.
498
504
505
+ Parameters
506
+ ----------
507
+ condition : {array_like, scalar}
508
+ An array whose non-zero or ``True`` entries indicate the element of `a`
509
+ to extract.
510
+ a : {dpnp_array, usm_ndarray}
511
+ Input array of the same size as `condition`.
512
+
499
513
Returns
500
514
-------
501
515
out : dpnp.ndarray
502
- Rank 1 array of values from `x` where `condition` is True.
516
+ Rank 1 array of values from `a` where `condition` is ``True``.
517
+
518
+ See Also
519
+ --------
520
+ :obj:`dpnp.take` : Take elements from an array along an axis.
521
+ :obj:`dpnp.put` : Replaces specified elements of an array with given values.
522
+ :obj:`dpnp.copyto` : Copies values from one array to another, broadcasting
523
+ as necessary.
524
+ :obj:`dpnp.compress` : eturn selected slices of an array along given axis.
525
+ :obj:`dpnp.place` : Change elements of an array based on conditional and
526
+ input values.
527
+
528
+ Examples
529
+ --------
530
+ >>> import dpnp as np
531
+ >>> a = np.arange(12).reshape((3, 4))
532
+ >>> a
533
+ array([[ 0, 1, 2, 3],
534
+ [ 4, 5, 6, 7],
535
+ [ 8, 9, 10, 11]])
536
+ >>> condition = np.mod(a, 3) == 0
537
+ >>> condition
538
+ array([[ True, False, False, True],
539
+ [False, False, True, False],
540
+ [False, True, False, False]])
541
+ >>> np.extract(condition, a)
542
+ array([0, 3, 6, 9])
543
+
544
+ If `condition` is boolean:
545
+
546
+ >>> a[condition]
547
+ array([0, 3, 6, 9])
503
548
504
- Limitations
505
- -----------
506
- Parameters `condition` and `x` are supported either as
507
- :class:`dpnp.ndarray` or :class:`dpctl.tensor.usm_ndarray`.
508
- Parameter `x` must be the same shape as `condition`.
509
- Otherwise the function will be executed sequentially on CPU.
510
549
"""
511
550
512
- if dpnp .is_supported_array_type (condition ) and dpnp .is_supported_array_type (
513
- x
514
- ):
515
- if condition .shape != x .shape :
516
- pass
517
- else :
518
- dpt_condition = (
519
- condition .get_array ()
520
- if isinstance (condition , dpnp_array )
521
- else condition
522
- )
523
- dpt_array = x .get_array () if isinstance (x , dpnp_array ) else x
524
- return dpnp_array ._create_from_usm_ndarray (
525
- dpt .extract (dpt_condition , dpt_array )
526
- )
551
+ usm_a = dpnp .get_usm_ndarray (a )
552
+ if not dpnp .is_supported_array_type (condition ):
553
+ usm_cond = dpt .asarray (
554
+ condition , usm_type = a .usm_type , sycl_queue = a .sycl_queue
555
+ )
556
+ else :
557
+ usm_cond = dpnp .get_usm_ndarray (condition )
558
+
559
+ if usm_cond .size != usm_a .size :
560
+ usm_a = dpt .reshape (usm_a , - 1 )
561
+ usm_cond = dpt .reshape (usm_cond , - 1 )
562
+
563
+ usm_res = dpt .take (usm_a , dpt .nonzero (usm_cond )[0 ])
564
+ else :
565
+ if usm_cond .shape != usm_a .shape :
566
+ usm_a = dpt .reshape (usm_a , - 1 )
567
+ usm_cond = dpt .reshape (usm_cond , - 1 )
568
+
569
+ usm_res = dpt .extract (usm_cond , usm_a )
527
570
528
- return call_origin (numpy .extract , condition , x )
571
+ dpnp .synchronize_array_data (usm_res )
572
+ return dpnp_array ._create_from_usm_ndarray (usm_res )
529
573
530
574
531
575
def fill_diagonal (a , val , wrap = False ):
0 commit comments