Skip to content

Commit e484ba4

Browse files
move all elemwise dispatches to elemwise.py
1 parent 5940630 commit e484ba4

File tree

4 files changed

+363
-354
lines changed

4 files changed

+363
-354
lines changed

pytensor/link/mlx/dispatch/elemwise.py

Lines changed: 285 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,41 @@
55

66
from pytensor.link.mlx.dispatch.basic import mlx_funcify
77
from pytensor.link.mlx.dispatch.core import convert_dtype_to_mlx
8-
from pytensor.scalar import Softplus
98
from pytensor.scalar.basic import (
109
AND,
10+
EQ,
11+
GE,
12+
GT,
13+
LE,
14+
LT,
15+
NEQ,
1116
OR,
17+
Abs,
1218
Add,
1319
Cast,
20+
Cos,
21+
Exp,
22+
IntDiv,
23+
Invert,
24+
IsInf,
25+
IsNan,
26+
Log,
27+
Log1p,
1428
Mul,
29+
Neg,
30+
Pow,
1531
ScalarMaximum,
1632
ScalarMinimum,
33+
Sign,
34+
Sin,
35+
Sqr,
36+
Sqrt,
37+
Sub,
38+
Switch,
39+
TrueDiv,
1740
)
18-
from pytensor.tensor.elemwise import CAReduce, DimShuffle
41+
from pytensor.scalar.math import Erfc, Erfcx, Sigmoid, Softplus
42+
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
1943
from pytensor.tensor.special import Softmax, SoftmaxGrad
2044

2145

@@ -50,23 +74,23 @@ def mlx_funcify_CAReduce(op, **kwargs):
5074

5175

5276
@mlx_funcify_CAReduce_scalar_op.register(Add)
53-
def mlx_funcify_Elemwise_scalar_Add(scalar_op, axis):
77+
def mlx_funcify_CAReduce_scalar_Add(scalar_op, axis):
5478
def sum_reduce(x):
5579
return mx.sum(x, axis=axis)
5680

5781
return sum_reduce
5882

5983

6084
@mlx_funcify_CAReduce_scalar_op.register(Mul)
61-
def mlx_funcify_Elemwise_scalar_Mul(scalar_op, axis):
85+
def mlx_funcify_CAReduce_scalar_Mul(scalar_op, axis):
6286
def prod_reduce(x):
6387
return mx.prod(x, axis=axis)
6488

6589
return prod_reduce
6690

6791

6892
@mlx_funcify_CAReduce_scalar_op.register(AND)
69-
def mlx_funcify_Elemwise_scalar_AND(scalar_op, axis):
93+
def mlx_funcify_CAReduce_scalar_AND(scalar_op, axis):
7094
def all_reduce(x):
7195
return x.all(axis=axis)
7296

@@ -164,3 +188,259 @@ def cast(x):
164188
raise
165189

166190
return cast
191+
192+
193+
@singledispatch
194+
def mlx_funcify_Elemwise_scalar_op(scalar_op):
195+
"""Simplified implementation for MLX scalar operations."""
196+
197+
# Try using the operation name directly (most common case)
198+
op_name = getattr(scalar_op, "name", None)
199+
if op_name is not None:
200+
try:
201+
mlx_func = getattr(mx, op_name)
202+
# Handle variadic functions like Add
203+
if hasattr(scalar_op, "inputs") and len(scalar_op.inputs) > 2:
204+
205+
def variadic_func(*args):
206+
result = args[0]
207+
for arg in args[1:]:
208+
result = mlx_func(result, arg)
209+
return result
210+
211+
return variadic_func
212+
else:
213+
return mlx_func
214+
except AttributeError:
215+
pass
216+
217+
raise NotImplementedError(f"MLX does not support Elemwise scalar op {scalar_op}")
218+
219+
220+
@mlx_funcify_Elemwise_scalar_op.register(Add)
221+
def mlx_funcify_Elemwise_scalar_Add(scalar_op):
222+
def add(*args):
223+
result = args[0]
224+
for arg in args[1:]:
225+
result = mx.add(result, arg)
226+
return result
227+
228+
return add
229+
230+
231+
@mlx_funcify_Elemwise_scalar_op.register(Sub)
232+
def mlx_funcify_Elemwise_scalar_Sub(scalar_op):
233+
return mx.subtract
234+
235+
236+
@mlx_funcify_Elemwise_scalar_op.register(Mul)
237+
def mlx_funcify_Elemwise_scalar_Mul(scalar_op):
238+
def mul(*args):
239+
result = args[0]
240+
for arg in args[1:]:
241+
result = mx.multiply(result, arg)
242+
return result
243+
244+
return mul
245+
246+
247+
@mlx_funcify_Elemwise_scalar_op.register(TrueDiv)
248+
def mlx_funcify_Elemwise_scalar_TrueDiv(scalar_op):
249+
return mx.divide
250+
251+
252+
@mlx_funcify_Elemwise_scalar_op.register(IntDiv)
253+
def mlx_funcify_Elemwise_scalar_IntDiv(scalar_op):
254+
return mx.floor_divide
255+
256+
257+
@mlx_funcify_Elemwise_scalar_op.register(Pow)
258+
def mlx_funcify_Elemwise_scalar_Pow(scalar_op):
259+
return mx.power
260+
261+
262+
@mlx_funcify_Elemwise_scalar_op.register(Exp)
263+
def mlx_funcify_Elemwise_scalar_Exp(scalar_op):
264+
return mx.exp
265+
266+
267+
@mlx_funcify_Elemwise_scalar_op.register(Log)
268+
def mlx_funcify_Elemwise_scalar_Log(scalar_op):
269+
return mx.log
270+
271+
272+
@mlx_funcify_Elemwise_scalar_op.register(Log1p)
273+
def mlx_funcify_Elemwise_scalar_Log1p(scalar_op):
274+
return mx.log1p
275+
276+
277+
@mlx_funcify_Elemwise_scalar_op.register(Sin)
278+
def mlx_funcify_Elemwise_scalar_Sin(scalar_op):
279+
return mx.sin
280+
281+
282+
@mlx_funcify_Elemwise_scalar_op.register(Cos)
283+
def mlx_funcify_Elemwise_scalar_Cos(scalar_op):
284+
return mx.cos
285+
286+
287+
@mlx_funcify_Elemwise_scalar_op.register(Sqrt)
288+
def mlx_funcify_Elemwise_scalar_Sqrt(scalar_op):
289+
return mx.sqrt
290+
291+
292+
@mlx_funcify_Elemwise_scalar_op.register(Sqr)
293+
def mlx_funcify_Elemwise_scalar_Sqr(scalar_op):
294+
return mx.square
295+
296+
297+
@mlx_funcify_Elemwise_scalar_op.register(Abs)
298+
def mlx_funcify_Elemwise_scalar_Abs(scalar_op):
299+
return mx.abs
300+
301+
302+
@mlx_funcify_Elemwise_scalar_op.register(Neg)
303+
def mlx_funcify_Elemwise_scalar_Neg(scalar_op):
304+
return mx.negative
305+
306+
307+
@mlx_funcify_Elemwise_scalar_op.register(Sign)
308+
def mlx_funcify_Elemwise_scalar_Sign(scalar_op):
309+
return mx.sign
310+
311+
312+
@mlx_funcify_Elemwise_scalar_op.register(LE)
313+
def mlx_funcify_Elemwise_scalar_LE(scalar_op):
314+
return mx.less_equal
315+
316+
317+
@mlx_funcify_Elemwise_scalar_op.register(LT)
318+
def mlx_funcify_Elemwise_scalar_LT(scalar_op):
319+
return mx.less
320+
321+
322+
@mlx_funcify_Elemwise_scalar_op.register(GE)
323+
def mlx_funcify_Elemwise_scalar_GE(scalar_op):
324+
return mx.greater_equal
325+
326+
327+
@mlx_funcify_Elemwise_scalar_op.register(GT)
328+
def mlx_funcify_Elemwise_scalar_GT(scalar_op):
329+
return mx.greater
330+
331+
332+
@mlx_funcify_Elemwise_scalar_op.register(EQ)
333+
def mlx_funcify_Elemwise_scalar_EQ(scalar_op):
334+
return mx.equal
335+
336+
337+
@mlx_funcify_Elemwise_scalar_op.register(NEQ)
338+
def mlx_funcify_Elemwise_scalar_NEQ(scalar_op):
339+
return mx.not_equal
340+
341+
342+
@mlx_funcify_Elemwise_scalar_op.register(Switch)
343+
def mlx_funcify_Elemwise_scalar_Switch(scalar_op):
344+
return mx.where
345+
346+
347+
@mlx_funcify_Elemwise_scalar_op.register(AND)
348+
def mlx_funcify_Elemwise_scalar_AND(scalar_op):
349+
return mx.bitwise_and
350+
351+
352+
@mlx_funcify_Elemwise_scalar_op.register(OR)
353+
def mlx_funcify_Elemwise_scalar_OR(scalar_op):
354+
return mx.bitwise_or
355+
356+
357+
@mlx_funcify_Elemwise_scalar_op.register(ScalarMaximum)
358+
def mlx_funcify_Elemwise_scalar_ScalarMaximum(scalar_op):
359+
return mx.maximum
360+
361+
362+
@mlx_funcify_Elemwise_scalar_op.register(ScalarMinimum)
363+
def mlx_funcify_Elemwise_scalar_ScalarMinimum(scalar_op):
364+
return mx.minimum
365+
366+
367+
@mlx_funcify_Elemwise_scalar_op.register(Cast)
368+
def mlx_funcify_Elemwise_scalar_Cast(scalar_op):
369+
def cast(x):
370+
dtype = convert_dtype_to_mlx(scalar_op.o_type.dtype)
371+
try:
372+
return x.astype(dtype)
373+
except ValueError as e:
374+
if "is not supported on the GPU" in str(e):
375+
import warnings
376+
377+
warnings.warn(
378+
f"MLX GPU limitation: {e}. Attempting automatic fallback casting.",
379+
UserWarning,
380+
stacklevel=2,
381+
)
382+
fallback_dtype = convert_dtype_to_mlx(
383+
scalar_op.o_type.dtype, auto_cast_unsupported=True
384+
)
385+
return x.astype(fallback_dtype)
386+
else:
387+
raise e
388+
389+
return cast
390+
391+
392+
@mlx_funcify_Elemwise_scalar_op.register(Sigmoid)
393+
def mlx_funcify_Elemwise_scalar_Sigmoid(scalar_op):
394+
return mx.sigmoid
395+
396+
397+
@mlx_funcify_Elemwise_scalar_op.register(Invert)
398+
def mlx_funcify_Elemwise_scalar_Invert(scalar_op):
399+
return mx.bitwise_invert
400+
401+
402+
@mlx_funcify_Elemwise_scalar_op.register(IsNan)
403+
def mlx_funcify_Elemwise_scalar_IsNan(scalar_op):
404+
return mx.isnan
405+
406+
407+
@mlx_funcify_Elemwise_scalar_op.register(IsInf)
408+
def mlx_funcify_Elemwise_scalar_IsInf(scalar_op):
409+
return mx.isinf
410+
411+
412+
@mlx_funcify_Elemwise_scalar_op.register(Erfc)
413+
def mlx_funcify_Elemwise_scalar_Erfc(scalar_op):
414+
def erfc(x):
415+
return 1.0 - mx.erf(x)
416+
417+
return erfc
418+
419+
420+
@mlx_funcify_Elemwise_scalar_op.register(Erfcx)
421+
def mlx_funcify_Elemwise_scalar_Erfcx(scalar_op):
422+
def erfcx(x):
423+
return mx.exp(x * x) * (1.0 - mx.erf(x))
424+
425+
return erfcx
426+
427+
428+
@mlx_funcify_Elemwise_scalar_op.register(Softplus)
429+
def mlx_funcify_Elemwise_scalar_softplus(scalar_op):
430+
def softplus(x):
431+
# Numerically stable implementation of log(1 + exp(x))
432+
# Following the same logic as the original PyTensor implementation
433+
return mx.where(
434+
x < -37.0,
435+
mx.exp(x),
436+
mx.where(
437+
x < 18.0, mx.log1p(mx.exp(x)), mx.where(x < 33.3, x + mx.exp(-x), x)
438+
),
439+
)
440+
441+
return softplus
442+
443+
444+
@mlx_funcify.register(Elemwise)
445+
def mlx_funcify_Elemwise(op, node, **kwargs):
446+
return mlx_funcify_Elemwise_scalar_op(op.scalar_op)

0 commit comments

Comments
 (0)