Skip to content

Commit 67a74fb

Browse files
committed
add extension
1 parent 7c8eae7 commit 67a74fb

File tree

1 file changed

+14
-6
lines changed

1 file changed

+14
-6
lines changed
Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1+
import mlx.core as mx
2+
13
from pytensor.link.mlx.dispatch.basic import mlx_funcify
4+
from pytensor.scalar.basic import AND, OR, Add, Mul, ScalarMaximum, ScalarMinimum
25
from pytensor.tensor.elemwise import CAReduce, DimShuffle
3-
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
4-
from pytensor.scalar.basic import Add, Mul, Any, AND, OR, ScalarMaximum, ScalarMinimum
6+
from pytensor.tensor.special import Softmax, SoftmaxGrad
57

6-
import mlx.core as mx
78

89
@mlx_funcify.register(DimShuffle)
910
def mlx_funcify_DimShuffle(op, **kwargs):
@@ -19,42 +20,49 @@ def dimshuffle(x):
1920

2021
return dimshuffle
2122

23+
2224
@mlx_funcify.register(CAReduce)
2325
def mlx_funcify_CAReduce(op, **kwargs):
2426
if isinstance(op.scalar_op, Add):
27+
2528
def sum(x):
2629
return mx.sum(x, axis=op.axis)
2730

2831
return sum
2932
elif isinstance(op.scalar_op, Mul):
33+
3034
def prod(x):
3135
return mx.prod(x, axis=op.axis)
3236

3337
return prod
3438
elif isinstance(op.scalar_op, AND):
39+
3540
def all(x):
3641
return mx.all(x, axis=op.axis)
3742

3843
return all
3944
elif isinstance(op.scalar_op, OR):
45+
4046
def any(x):
4147
return mx.any(x, axis=op.axis)
4248

4349
return any
4450
elif isinstance(op.scalar_op, ScalarMaximum):
51+
4552
def max(x):
4653
return mx.max(x, axis=op.axis)
4754

4855
return max
4956
elif isinstance(op.scalar_op, ScalarMinimum):
57+
5058
def min(x):
5159
return mx.min(x, axis=op.axis)
5260

5361
return min
54-
62+
5563
else:
5664
raise NotImplementedError(f"MLX does not support {op.scalar_op}")
57-
65+
5866

5967
@mlx_funcify.register(Softmax)
6068
def mlx_funcify_Softmax(op, **kwargs):
@@ -74,4 +82,4 @@ def softmax_grad(dy, sm):
7482
dy_times_sm = dy * sm
7583
return dy_times_sm - mx.sum(dy_times_sm, axis=axis, keepdims=True) * sm
7684

77-
return softmax_grad
85+
return softmax_grad

0 commit comments

Comments
 (0)