Skip to content

Commit 880dd5c

Browse files
committed
refactor to use getattr
1 parent e7cf10e commit 880dd5c

File tree

1 file changed

+47
-39
lines changed

1 file changed

+47
-39
lines changed

pytensor/link/mlx/dispatch/elemwise.py

Lines changed: 47 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from pytensor.link.mlx.dispatch.basic import mlx_funcify
44
from pytensor.scalar import Softplus
5-
from pytensor.scalar.basic import AND, OR, Add, Mul, ScalarMaximum, ScalarMinimum
65
from pytensor.tensor.elemwise import CAReduce, DimShuffle
76
from pytensor.tensor.special import Softmax, SoftmaxGrad
87

@@ -24,44 +23,53 @@ def dimshuffle(x):
2423

2524
@mlx_funcify.register(CAReduce)
2625
def mlx_funcify_CAReduce(op, **kwargs):
27-
if isinstance(op.scalar_op, Add):
28-
29-
def sum(x):
30-
return mx.sum(x, axis=op.axis)
31-
32-
return sum
33-
elif isinstance(op.scalar_op, Mul):
34-
35-
def prod(x):
36-
return mx.prod(x, axis=op.axis)
37-
38-
return prod
39-
elif isinstance(op.scalar_op, AND):
40-
41-
def all(x):
42-
return x.all(axis=op.axis)
43-
44-
return all
45-
elif isinstance(op.scalar_op, OR):
46-
47-
def any(x):
48-
return mx.any(x, axis=op.axis)
49-
50-
return any
51-
elif isinstance(op.scalar_op, ScalarMaximum):
52-
53-
def max(x):
54-
return mx.max(x, axis=op.axis)
55-
56-
return max
57-
elif isinstance(op.scalar_op, ScalarMinimum):
58-
59-
def min(x):
60-
return mx.min(x, axis=op.axis)
61-
62-
return min
63-
else:
64-
raise NotImplementedError(f"MLX does not support Elemwise {op.scalar_op}")
26+
axis = op.axis
27+
op_nfunc_spec = getattr(op, "nfunc_spec", None)
28+
scalar_nfunc_spec = getattr(op.scalar_op, "nfunc_spec", None)
29+
scalar_op_name = getattr(op.scalar_op, "name", None)
30+
scalar_op_identity = getattr(op.scalar_op, "identity", None)
31+
acc_dtype = getattr(op, "acc_dtype", None)
32+
33+
def careduce(x):
34+
nonlocal \
35+
axis, \
36+
op_nfunc_spec, \
37+
scalar_nfunc_spec, \
38+
scalar_op_name, \
39+
scalar_op_identity, \
40+
acc_dtype
41+
42+
if axis is None:
43+
axis = list(range(x.ndim))
44+
45+
if acc_dtype is None:
46+
acc_dtype = x.dtype.type
47+
48+
if op_nfunc_spec:
49+
mlx_op = getattr(mx, op_nfunc_spec[0])
50+
return mlx_op(x, axis=axis)
51+
return mlx_op(x, axis=axis).astype(acc_dtype)
52+
53+
# The PyTensor `Op` didn't tell us which NumPy equivalent to use (or
54+
# there isn't one), so we use this fallback approach
55+
if scalar_nfunc_spec:
56+
scalar_fn_name = scalar_nfunc_spec[0]
57+
elif scalar_op_name:
58+
scalar_fn_name = scalar_op_name
59+
60+
to_reduce = sorted(axis, reverse=True)
61+
62+
if to_reduce:
63+
raise NotImplementedError("Not implemented yet")
64+
# In this case, we need to use the `jax.lax` function (if there
65+
# is one), and not the `jnp` version.
66+
mlx_op = getattr(mx, scalar_fn_name)
67+
init_value = mx.array(scalar_op_identity, dtype=acc_dtype)
68+
return mx.reduce(x, init_value, mlx_op, to_reduce).astype(acc_dtype)
69+
else:
70+
return x
71+
72+
return careduce
6573

6674

6775
@mlx_funcify.register(Softmax)

0 commit comments

Comments
 (0)