22
33from pytensor .link .mlx .dispatch .basic import mlx_funcify
44from pytensor .scalar import Softplus
5- from pytensor .scalar .basic import AND , OR , Add , Mul , ScalarMaximum , ScalarMinimum
65from pytensor .tensor .elemwise import CAReduce , DimShuffle
76from pytensor .tensor .special import Softmax , SoftmaxGrad
87
@@ -24,44 +23,53 @@ def dimshuffle(x):
2423
2524@mlx_funcify .register (CAReduce )
2625def mlx_funcify_CAReduce (op , ** kwargs ):
27- if isinstance (op .scalar_op , Add ):
28-
29- def sum (x ):
30- return mx .sum (x , axis = op .axis )
31-
32- return sum
33- elif isinstance (op .scalar_op , Mul ):
34-
35- def prod (x ):
36- return mx .prod (x , axis = op .axis )
37-
38- return prod
39- elif isinstance (op .scalar_op , AND ):
40-
41- def all (x ):
42- return x .all (axis = op .axis )
43-
44- return all
45- elif isinstance (op .scalar_op , OR ):
46-
47- def any (x ):
48- return mx .any (x , axis = op .axis )
49-
50- return any
51- elif isinstance (op .scalar_op , ScalarMaximum ):
52-
53- def max (x ):
54- return mx .max (x , axis = op .axis )
55-
56- return max
57- elif isinstance (op .scalar_op , ScalarMinimum ):
58-
59- def min (x ):
60- return mx .min (x , axis = op .axis )
61-
62- return min
63- else :
64- raise NotImplementedError (f"MLX does not support Elemwise { op .scalar_op } " )
26+ axis = op .axis
27+ op_nfunc_spec = getattr (op , "nfunc_spec" , None )
28+ scalar_nfunc_spec = getattr (op .scalar_op , "nfunc_spec" , None )
29+ scalar_op_name = getattr (op .scalar_op , "name" , None )
30+ scalar_op_identity = getattr (op .scalar_op , "identity" , None )
31+ acc_dtype = getattr (op , "acc_dtype" , None )
32+
33+ def careduce (x ):
34+ nonlocal \
35+ axis , \
36+ op_nfunc_spec , \
37+ scalar_nfunc_spec , \
38+ scalar_op_name , \
39+ scalar_op_identity , \
40+ acc_dtype
41+
42+ if axis is None :
43+ axis = list (range (x .ndim ))
44+
45+ if acc_dtype is None :
46+ acc_dtype = x .dtype .type
47+
48+ if op_nfunc_spec :
49+ mlx_op = getattr (mx , op_nfunc_spec [0 ])
50+ return mlx_op (x , axis = axis )
51+ return mlx_op (x , axis = axis ).astype (acc_dtype )
52+
53+ # The PyTensor `Op` didn't tell us which NumPy equivalent to use (or
54+ # there isn't one), so we use this fallback approach
55+ if scalar_nfunc_spec :
56+ scalar_fn_name = scalar_nfunc_spec [0 ]
57+ elif scalar_op_name :
58+ scalar_fn_name = scalar_op_name
59+
60+ to_reduce = sorted (axis , reverse = True )
61+
62+ if to_reduce :
63+ raise NotImplementedError ("Not implemented yet" )
64+ # In this case, we need to use the `jax.lax` function (if there
65+ # is one), and not the `jnp` version.
66+ mlx_op = getattr (mx , scalar_fn_name )
67+ init_value = mx .array (scalar_op_identity , dtype = acc_dtype )
68+ return mx .reduce (x , init_value , mlx_op , to_reduce ).astype (acc_dtype )
69+ else :
70+ return x
71+
72+ return careduce
6573
6674
6775@mlx_funcify .register (Softmax )
0 commit comments