1+ import mlx .core as mx
2+
13from pytensor .link .mlx .dispatch .basic import mlx_funcify
4+ from pytensor .scalar .basic import AND , OR , Add , Mul , ScalarMaximum , ScalarMinimum
25from pytensor .tensor .elemwise import CAReduce , DimShuffle
3- from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
4- from pytensor.scalar.basic import Add, Mul, Any, AND, OR, ScalarMaximum, ScalarMinimum
6+ from pytensor .tensor .special import Softmax , SoftmaxGrad
57
6- import mlx.core as mx
78
89@mlx_funcify .register (DimShuffle )
910def mlx_funcify_DimShuffle (op , ** kwargs ):
@@ -19,42 +20,49 @@ def dimshuffle(x):
1920
2021 return dimshuffle
2122
23+
2224@mlx_funcify .register (CAReduce )
2325def mlx_funcify_CAReduce (op , ** kwargs ):
2426 if isinstance (op .scalar_op , Add ):
27+
2528 def sum (x ):
2629 return mx .sum (x , axis = op .axis )
2730
2831 return sum
2932 elif isinstance (op .scalar_op , Mul ):
33+
3034 def prod (x ):
3135 return mx .prod (x , axis = op .axis )
3236
3337 return prod
3438 elif isinstance (op .scalar_op , AND ):
39+
3540 def all (x ):
3641 return mx .all (x , axis = op .axis )
3742
3843 return all
3944 elif isinstance (op .scalar_op , OR ):
45+
4046 def any (x ):
4147 return mx .any (x , axis = op .axis )
4248
4349 return any
4450 elif isinstance (op .scalar_op , ScalarMaximum ):
51+
4552 def max (x ):
4653 return mx .max (x , axis = op .axis )
4754
4855 return max
4956 elif isinstance (op .scalar_op , ScalarMinimum ):
57+
5058 def min (x ):
5159 return mx .min (x , axis = op .axis )
5260
5361 return min
54-
62+
5563 else :
5664 raise NotImplementedError (f"MLX does not support { op .scalar_op } " )
57-
65+
5866
5967@mlx_funcify .register (Softmax )
6068def mlx_funcify_Softmax (op , ** kwargs ):
@@ -74,4 +82,4 @@ def softmax_grad(dy, sm):
7482 dy_times_sm = dy * sm
7583 return dy_times_sm - mx .sum (dy_times_sm , axis = axis , keepdims = True ) * sm
7684
77- return softmax_grad
85+ return softmax_grad
0 commit comments