@@ -192,6 +192,11 @@ def _cast_to_bool(operand: ArrayLike) -> Array:
192192def _cast_to_numeric (operand : ArrayLike ) -> Array :
193193 return promote_dtypes_numeric (operand )[0 ]
194194
195+ def _require_integer (operand : ArrayLike ) -> Array :
196+ arr = lax_internal .asarray (operand )
197+ if not dtypes .isdtype (arr , ("bool" , "integral" )):
198+ raise ValueError (f"integer argument required; got dtype={ arr .dtype } " )
199+ return arr
195200
196201def _ensure_optional_axes (x : Axis ) -> Axis :
197202 def force (x ):
@@ -652,6 +657,63 @@ def any(a: ArrayLike, axis: Axis = None, out: None = None,
652657 return _reduce_any (a , axis = _ensure_optional_axes (axis ), out = out ,
653658 keepdims = keepdims , where = where )
654659
660+
661+ @partial (api .jit , static_argnames = ('axis' , 'keepdims' , 'dtype' ), inline = True )
662+ def _reduce_bitwise_and (a : ArrayLike , axis : Axis = None , dtype : DTypeLike | None = None ,
663+ out : None = None , keepdims : bool = False ,
664+ initial : ArrayLike | None = None , where : ArrayLike | None = None ) -> Array :
665+ arr = lax_internal .asarray (a )
666+ init_val = np .array (- 1 , dtype = dtype or arr .dtype )
667+ return _reduction (arr , name = "reduce_bitwise_and" , np_fun = None , op = lax .bitwise_and , init_val = init_val , preproc = _require_integer ,
668+ axis = _ensure_optional_axes (axis ), dtype = dtype , out = out , keepdims = keepdims ,
669+ initial = initial , where_ = where )
670+
671+
672+ @partial (api .jit , static_argnames = ('axis' , 'keepdims' , 'dtype' ), inline = True )
673+ def _reduce_bitwise_or (a : ArrayLike , axis : Axis = None , dtype : DTypeLike | None = None ,
674+ out : None = None , keepdims : bool = False ,
675+ initial : ArrayLike | None = None , where : ArrayLike | None = None ) -> Array :
676+ return _reduction (a , name = "reduce_bitwise_or" , np_fun = None , op = lax .bitwise_or , init_val = 0 , preproc = _require_integer ,
677+ axis = _ensure_optional_axes (axis ), dtype = dtype , out = out , keepdims = keepdims ,
678+ initial = initial , where_ = where )
679+
680+
681+ @partial (api .jit , static_argnames = ('axis' , 'keepdims' , 'dtype' ), inline = True )
682+ def _reduce_bitwise_xor (a : ArrayLike , axis : Axis = None , dtype : DTypeLike | None = None ,
683+ out : None = None , keepdims : bool = False ,
684+ initial : ArrayLike | None = None , where : ArrayLike | None = None ) -> Array :
685+ return _reduction (a , name = "reduce_bitwise_xor" , np_fun = None , op = lax .bitwise_xor , init_val = 0 , preproc = _require_integer ,
686+ axis = _ensure_optional_axes (axis ), dtype = dtype , out = out , keepdims = keepdims ,
687+ initial = initial , where_ = where )
688+
689+
690+ @partial (api .jit , static_argnames = ('axis' , 'keepdims' , 'dtype' ), inline = True )
691+ def _reduce_logical_and (a : ArrayLike , axis : Axis = None , dtype : DTypeLike | None = None ,
692+ out : None = None , keepdims : bool = False ,
693+ initial : ArrayLike | None = None , where : ArrayLike | None = None ) -> Array :
694+ return _reduction (a , name = "reduce_logical_and" , np_fun = None , op = lax .bitwise_and , init_val = True , preproc = _cast_to_bool ,
695+ axis = _ensure_optional_axes (axis ), dtype = dtype , out = out , keepdims = keepdims ,
696+ initial = initial , where_ = where )
697+
698+
699+ @partial (api .jit , static_argnames = ('axis' , 'keepdims' , 'dtype' ), inline = True )
700+ def _reduce_logical_or (a : ArrayLike , axis : Axis = None , dtype : DTypeLike | None = None ,
701+ out : None = None , keepdims : bool = False ,
702+ initial : ArrayLike | None = None , where : ArrayLike | None = None ) -> Array :
703+ return _reduction (a , name = "reduce_logical_or" , np_fun = None , op = lax .bitwise_or , init_val = False , preproc = _cast_to_bool ,
704+ axis = _ensure_optional_axes (axis ), dtype = dtype , out = out , keepdims = keepdims ,
705+ initial = initial , where_ = where )
706+
707+
708+ @partial (api .jit , static_argnames = ('axis' , 'keepdims' , 'dtype' ), inline = True )
709+ def _reduce_logical_xor (a : ArrayLike , axis : Axis = None , dtype : DTypeLike | None = None ,
710+ out : None = None , keepdims : bool = False ,
711+ initial : ArrayLike | None = None , where : ArrayLike | None = None ) -> Array :
712+ return _reduction (a , name = "reduce_logical_xor" , np_fun = None , op = lax .bitwise_xor , init_val = False , preproc = _cast_to_bool ,
713+ axis = _ensure_optional_axes (axis ), dtype = dtype , out = out , keepdims = keepdims ,
714+ initial = initial , where_ = where )
715+
716+
655717def amin (a : ArrayLike , axis : Axis = None , out : None = None ,
656718 keepdims : bool = False , initial : ArrayLike | None = None ,
657719 where : ArrayLike | None = None ) -> Array :
0 commit comments