12
12
from keras .src .backend .mlx .core import cast
13
13
from keras .src .backend .mlx .core import convert_to_tensor
14
14
from keras .src .backend .mlx .core import convert_to_tensors
15
+ from keras .src .backend .mlx .core import is_tensor
15
16
from keras .src .backend .mlx .core import slice
16
17
from keras .src .backend .mlx .core import to_mlx_dtype
17
18
@@ -272,8 +273,20 @@ def bitwise_xor(x, y):
272
273
273
274
def bitwise_left_shift (x , y ):
274
275
x = convert_to_tensor (x )
275
- y = convert_to_tensor (y )
276
- return mx .left_shift (x , y )
276
+ if not isinstance (y , int ):
277
+ y = convert_to_tensor (y )
278
+
279
+ # handle result dtype to match other backends
280
+ types = [x .dtype ]
281
+ if is_tensor (y ):
282
+ types .append (y .dtype )
283
+ result_dtype = result_type (* types )
284
+ mlx_result_dtype = to_mlx_dtype (result_dtype )
285
+
286
+ result = mx .left_shift (x , y )
287
+ if result .dtype != mlx_result_dtype :
288
+ return result .astype (mlx_result_dtype )
289
+ return result
277
290
278
291
279
292
def left_shift (x , y ):
@@ -282,8 +295,20 @@ def left_shift(x, y):
282
295
283
296
def bitwise_right_shift (x , y ):
284
297
x = convert_to_tensor (x )
285
- y = convert_to_tensor (y )
286
- return mx .right_shift (x , y )
298
+ if not isinstance (y , int ):
299
+ y = convert_to_tensor (y )
300
+
301
+ # handle result dtype to match other backends
302
+ types = [x .dtype ]
303
+ if is_tensor (y ):
304
+ types .append (y .dtype )
305
+ result_dtype = result_type (* types )
306
+ mlx_result_dtype = to_mlx_dtype (result_dtype )
307
+
308
+ result = mx .right_shift (x , y )
309
+ if result .dtype != mlx_result_dtype :
310
+ return result .astype (mlx_result_dtype )
311
+ return result
287
312
288
313
289
314
def right_shift (x , y ):
@@ -1567,3 +1592,34 @@ def rot90(array, k=1, axes=(0, 1)):
1567
1592
array = array [tuple (slices )]
1568
1593
1569
1594
return array
1595
+
1596
+
1597
+ def signbit (x ):
1598
+ x = convert_to_tensor (x )
1599
+
1600
+ if x .dtype in (
1601
+ mx .float16 ,
1602
+ mx .float32 ,
1603
+ mx .float64 ,
1604
+ mx .bfloat16 ,
1605
+ mx .complex64 ,
1606
+ ):
1607
+ if x .dtype == mx .complex64 :
1608
+ # check sign of real part for complex numbers
1609
+ real_part = mx .real (x )
1610
+ return signbit (real_part )
1611
+ zeros = x == 0
1612
+ # this works because in mlx 1/0=inf and 1/-0=-inf
1613
+ neg_zeros = (1 / x == mx .array (float ("-inf" ))) & zeros
1614
+ return mx .where (zeros , neg_zeros , x < 0 )
1615
+ elif x .dtype in (mx .uint8 , mx .uint16 , mx .uint32 , mx .uint64 ):
1616
+ # unsigned integers never negative
1617
+ return mx .zeros_like (x ).astype (mx .bool_ )
1618
+ elif x .dtype in (mx .int8 , mx .int16 , mx .int32 , mx .int64 ):
1619
+ # for integers, simple negative check
1620
+ return x < 0
1621
+ elif x .dtype == mx .bool_ :
1622
+ # for boolean array, return false
1623
+ return mx .zeros_like (x ).astype (mx .bool_ )
1624
+ else :
1625
+ raise ValueError (f"Unsupported dtype in `signbit`: { x .dtype } " )
0 commit comments