Skip to content

Commit d33cda0

Browse files
Simplify mlx_funcify_CAReduce
1 parent 2421a6f commit d33cda0

File tree

1 file changed

+14
-21
lines changed

1 file changed

+14
-21
lines changed

pytensor/link/mlx/dispatch/elemwise.py

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -38,67 +38,60 @@ def dimshuffle(x):
3838

3939
# Second-level dispatch for scalar operations in CAReduce
4040
@singledispatch
41-
def mlx_funcify_CAReduce_scalar_op(scalar_op):
41+
def mlx_funcify_CAReduce_scalar_op(scalar_op, axis):
4242
raise NotImplementedError(
4343
f"MLX does not support CAReduce with scalar op {scalar_op}"
4444
)
4545

4646

4747
@mlx_funcify.register(CAReduce)
4848
def mlx_funcify_CAReduce(op, **kwargs):
49-
# Dispatch to the appropriate scalar op handler
50-
scalar_reduce_fn = mlx_funcify_CAReduce_scalar_op(op.scalar_op)
51-
axis = op.axis
52-
53-
def reduce(x):
54-
return scalar_reduce_fn(x, axis)
55-
56-
return reduce
49+
return mlx_funcify_CAReduce_scalar_op(op.scalar_op, op.axis)
5750

5851

5952
@mlx_funcify_CAReduce_scalar_op.register(Add)
60-
def mlx_funcify_Elemwise_scalar_Add(scalar_op):
61-
def sum_reduce(x, axis):
53+
def mlx_funcify_Elemwise_scalar_Add(scalar_op, axis):
54+
def sum_reduce(x):
6255
return mx.sum(x, axis=axis)
6356

6457
return sum_reduce
6558

6659

6760
@mlx_funcify_CAReduce_scalar_op.register(Mul)
68-
def mlx_funcify_Elemwise_scalar_Mul(scalar_op):
69-
def prod_reduce(x, axis):
61+
def mlx_funcify_Elemwise_scalar_Mul(scalar_op, axis):
62+
def prod_reduce(x):
7063
return mx.prod(x, axis=axis)
7164

7265
return prod_reduce
7366

7467

7568
@mlx_funcify_CAReduce_scalar_op.register(AND)
76-
def mlx_funcify_Elemwise_scalar_AND(scalar_op):
77-
def all_reduce(x, axis):
69+
def mlx_funcify_Elemwise_scalar_AND(scalar_op, axis):
70+
def all_reduce(x):
7871
return x.all(axis=axis)
7972

8073
return all_reduce
8174

8275

8376
@mlx_funcify_CAReduce_scalar_op.register(OR)
84-
def mlx_funcify_CARreduce_OR(scalar_op):
85-
def any_reduce(x, axis):
77+
def mlx_funcify_CARreduce_OR(scalar_op, axis):
78+
def any_reduce(x):
8679
return mx.any(x, axis=axis)
8780

8881
return any_reduce
8982

9083

9184
@mlx_funcify_CAReduce_scalar_op.register(ScalarMaximum)
92-
def mlx_funcify_CARreduce_Maximum(scalar_op):
93-
def max_reduce(x, axis):
85+
def mlx_funcify_CARreduce_Maximum(scalar_op, axis):
86+
def max_reduce(x):
9487
return mx.max(x, axis=axis)
9588

9689
return max_reduce
9790

9891

9992
@mlx_funcify_CAReduce_scalar_op.register(ScalarMinimum)
100-
def mlx_funcify_CARreduce_Minimum(scalar_op):
101-
def min_reduce(x, axis):
93+
def mlx_funcify_CARreduce_Minimum(scalar_op, axis):
94+
def min_reduce(x):
10295
return mx.min(x, axis=axis)
10396

10497
return min_reduce

0 commit comments

Comments
 (0)