Skip to content

Commit 9f41a4e

Browse files
Add function names and remove wrappers
1 parent d16d245 commit 9f41a4e

File tree

2 files changed

+75
-162
lines changed

2 files changed

+75
-162
lines changed

pytensor/link/mlx/dispatch/elemwise.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,47 +57,47 @@ def reduce(x):
5757

5858

5959
@mlx_funcify_CAReduce_scalar_op.register(Add)
60-
def _(scalar_op):
60+
def mlx_funcify_Elemwise_scalar_Add(scalar_op):
6161
def sum_reduce(x, axis):
6262
return mx.sum(x, axis=axis)
6363

6464
return sum_reduce
6565

6666

6767
@mlx_funcify_CAReduce_scalar_op.register(Mul)
68-
def _(scalar_op):
68+
def mlx_funcify_Elemwise_scalar_Mul(scalar_op):
6969
def prod_reduce(x, axis):
7070
return mx.prod(x, axis=axis)
7171

7272
return prod_reduce
7373

7474

7575
@mlx_funcify_CAReduce_scalar_op.register(AND)
76-
def _(scalar_op):
76+
def mlx_funcify_Elemwise_scalar_AND(scalar_op):
7777
def all_reduce(x, axis):
7878
return x.all(axis=axis)
7979

8080
return all_reduce
8181

8282

8383
@mlx_funcify_CAReduce_scalar_op.register(OR)
84-
def _(scalar_op):
84+
def mlx_funcify_CARreduce_OR(scalar_op):
8585
def any_reduce(x, axis):
8686
return mx.any(x, axis=axis)
8787

8888
return any_reduce
8989

9090

9191
@mlx_funcify_CAReduce_scalar_op.register(ScalarMaximum)
92-
def _(scalar_op):
92+
def mlx_funcify_CARreduce_Maximum(scalar_op):
9393
def max_reduce(x, axis):
9494
return mx.max(x, axis=axis)
9595

9696
return max_reduce
9797

9898

9999
@mlx_funcify_CAReduce_scalar_op.register(ScalarMinimum)
100-
def _(scalar_op):
100+
def mlx_funcify_CARreduce_Minimum(scalar_op):
101101
def min_reduce(x, axis):
102102
return mx.min(x, axis=axis)
103103

0 commit comments

Comments
 (0)