@@ -38,67 +38,60 @@ def dimshuffle(x):
38
38
39
39
# Second-level dispatch for scalar operations in CAReduce
40
40
@singledispatch
41
- def mlx_funcify_CAReduce_scalar_op (scalar_op ):
41
+ def mlx_funcify_CAReduce_scalar_op (scalar_op , axis ):
42
42
raise NotImplementedError (
43
43
f"MLX does not support CAReduce with scalar op { scalar_op } "
44
44
)
45
45
46
46
47
47
@mlx_funcify .register (CAReduce )
48
48
def mlx_funcify_CAReduce (op , ** kwargs ):
49
- # Dispatch to the appropriate scalar op handler
50
- scalar_reduce_fn = mlx_funcify_CAReduce_scalar_op (op .scalar_op )
51
- axis = op .axis
52
-
53
- def reduce (x ):
54
- return scalar_reduce_fn (x , axis )
55
-
56
- return reduce
49
+ return mlx_funcify_CAReduce_scalar_op (op .scalar_op , op .axis )
57
50
58
51
59
52
@mlx_funcify_CAReduce_scalar_op .register (Add )
60
- def mlx_funcify_Elemwise_scalar_Add (scalar_op ):
61
- def sum_reduce (x , axis ):
53
+ def mlx_funcify_Elemwise_scalar_Add (scalar_op , axis ):
54
+ def sum_reduce (x ):
62
55
return mx .sum (x , axis = axis )
63
56
64
57
return sum_reduce
65
58
66
59
67
60
@mlx_funcify_CAReduce_scalar_op .register (Mul )
68
- def mlx_funcify_Elemwise_scalar_Mul (scalar_op ):
69
- def prod_reduce (x , axis ):
61
+ def mlx_funcify_Elemwise_scalar_Mul (scalar_op , axis ):
62
+ def prod_reduce (x ):
70
63
return mx .prod (x , axis = axis )
71
64
72
65
return prod_reduce
73
66
74
67
75
68
@mlx_funcify_CAReduce_scalar_op .register (AND )
76
- def mlx_funcify_Elemwise_scalar_AND (scalar_op ):
77
- def all_reduce (x , axis ):
69
+ def mlx_funcify_Elemwise_scalar_AND (scalar_op , axis ):
70
+ def all_reduce (x ):
78
71
return x .all (axis = axis )
79
72
80
73
return all_reduce
81
74
82
75
83
76
@mlx_funcify_CAReduce_scalar_op .register (OR )
84
- def mlx_funcify_CARreduce_OR (scalar_op ):
85
- def any_reduce (x , axis ):
77
+ def mlx_funcify_CARreduce_OR (scalar_op , axis ):
78
+ def any_reduce (x ):
86
79
return mx .any (x , axis = axis )
87
80
88
81
return any_reduce
89
82
90
83
91
84
@mlx_funcify_CAReduce_scalar_op .register (ScalarMaximum )
92
- def mlx_funcify_CARreduce_Maximum (scalar_op ):
93
- def max_reduce (x , axis ):
85
+ def mlx_funcify_CARreduce_Maximum (scalar_op , axis ):
86
+ def max_reduce (x ):
94
87
return mx .max (x , axis = axis )
95
88
96
89
return max_reduce
97
90
98
91
99
92
@mlx_funcify_CAReduce_scalar_op .register (ScalarMinimum )
100
- def mlx_funcify_CARreduce_Minimum (scalar_op ):
101
- def min_reduce (x , axis ):
93
+ def mlx_funcify_CARreduce_Minimum (scalar_op , axis ):
94
+ def min_reduce (x ):
102
95
return mx .min (x , axis = axis )
103
96
104
97
return min_reduce
0 commit comments