Skip to content

Commit 12daeac

Browse files
committed
A lot of new code
1 parent 5abd32d commit 12daeac

File tree

5 files changed

+157
-13
lines changed

5 files changed

+157
-13
lines changed

pytensor/link/mlx/dispatch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@
99
import pytensor.link.mlx.dispatch.core
1010
import pytensor.link.mlx.dispatch.signal
1111
import pytensor.link.mlx.dispatch.signal.conv
12+
import pytensor.link.mlx.dispatch.blockwise
1213
# isort: on
Lines changed: 59 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,66 @@
11
import mlx.core as mx
22

3+
from pytensor.graph import FunctionGraph
34
from pytensor.link.mlx.dispatch import mlx_funcify
45
from pytensor.tensor.blockwise import Blockwise
56

7+
68
@mlx_funcify.register(Blockwise)
79
def funcify_Blockwise(op: Blockwise, node, *args, **kwargs):
8-
core_f = mlx_funcify(op.core_op)
9-
batched_f = core_f
10-
for _ in range(op.batch_ndim(node)):
11-
batched_f = mx.vmap(batched_f)
12-
13-
def wrapped_blockwise_f(*inputs):
14-
return batched_f(*inputs)
15-
16-
return wrapped_blockwise_f
10+
# Create a function graph for the core operation
11+
core_node = op._create_dummy_core_node(node.inputs)
12+
core_fgraph = FunctionGraph(inputs=core_node.inputs, outputs=core_node.outputs)
13+
14+
# Convert the core function graph to an MLX function
15+
tuple_core_fn = mlx_funcify(core_fgraph, **kwargs)
16+
17+
# If there's only one output, unwrap it from the tuple
18+
if len(node.outputs) == 1:
19+
20+
def core_fn(*inputs):
21+
return tuple_core_fn(*inputs)[0]
22+
else:
23+
core_fn = tuple_core_fn
24+
25+
# Apply vmap for each batch dimension
26+
batch_ndims = op.batch_ndim(node)
27+
vmap_fn = core_fn
28+
for _ in range(batch_ndims):
29+
vmap_fn = mx.vmap(vmap_fn)
30+
31+
def blockwise_fn(*inputs):
32+
# Check for runtime broadcasting compatibility
33+
op._check_runtime_broadcast(node, inputs)
34+
35+
# Handle broadcasting for batched dimensions
36+
if batch_ndims > 0:
37+
# Get batch shapes for broadcasting
38+
batch_shapes = [inp.shape[:batch_ndims] for inp in inputs]
39+
40+
# Calculate the broadcasted batch shape
41+
from functools import reduce
42+
43+
def broadcast_shapes(shape1, shape2):
44+
return tuple(max(s1, s2) for s1, s2 in zip(shape1, shape2, strict=True))
45+
46+
if batch_shapes:
47+
broadcasted_shape = reduce(broadcast_shapes, batch_shapes)
48+
49+
# Broadcast inputs to the common batch shape
50+
broadcasted_inputs = []
51+
for inp in inputs:
52+
if inp.shape[:batch_ndims] != broadcasted_shape:
53+
# Create the full target shape
54+
target_shape = broadcasted_shape + inp.shape[batch_ndims:]
55+
# Broadcast the input
56+
broadcasted_inputs.append(mx.broadcast_to(inp, target_shape))
57+
else:
58+
broadcasted_inputs.append(inp)
59+
60+
# Apply the vectorized function to the broadcasted inputs
61+
return vmap_fn(*broadcasted_inputs)
62+
63+
# No broadcasting needed
64+
return vmap_fn(*inputs)
65+
66+
return blockwise_fn

pytensor/link/mlx/dispatch/elemwise.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import mlx.core as mx
22

33
from pytensor.link.mlx.dispatch.basic import mlx_funcify
4+
from pytensor.scalar import Softplus
45
from pytensor.scalar.basic import AND, OR, Add, Mul, ScalarMaximum, ScalarMinimum
56
from pytensor.tensor.elemwise import CAReduce, DimShuffle
67
from pytensor.tensor.special import Softmax, SoftmaxGrad
@@ -59,9 +60,8 @@ def min(x):
5960
return mx.min(x, axis=op.axis)
6061

6162
return min
62-
6363
else:
64-
raise NotImplementedError(f"MLX does not support {op.scalar_op}")
64+
raise NotImplementedError(f"MLX does not support Elemwise {op.scalar_op}")
6565

6666

6767
@mlx_funcify.register(Softmax)
@@ -83,3 +83,23 @@ def softmax_grad(dy, sm):
8383
return dy_times_sm - mx.sum(dy_times_sm, axis=axis, keepdims=True) * sm
8484

8585
return softmax_grad
86+
87+
88+
@mlx_funcify.register(Softplus)
89+
def mlx_funcify_Softplus(op, **kwargs):
90+
def softplus(x):
91+
return mx.where(
92+
x < -37.0,
93+
mx.exp(x),
94+
mx.where(
95+
x < 18.0,
96+
mx.log1p(mx.exp(x)),
97+
mx.where(
98+
x < 33.3,
99+
x + mx.exp(-x),
100+
x,
101+
),
102+
),
103+
)
104+
105+
return softplus

pytensor/link/mlx/dispatch/math.py

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,29 @@
88
LE,
99
LT,
1010
NEQ,
11+
Abs,
1112
Add,
1213
Cos,
1314
Exp,
1415
Log,
1516
Mul,
1617
Pow,
1718
Sin,
19+
Sqr,
20+
Sqrt,
1821
Sub,
1922
Switch,
2023
TrueDiv,
24+
Neg,
25+
AND,
26+
OR,
27+
ScalarMaximum,
28+
ScalarMinimum,
2129
)
2230
from pytensor.scalar.math import Sigmoid
2331
from pytensor.tensor.elemwise import Elemwise
2432
from pytensor.tensor.math import Dot
25-
33+
from pytensor.scalar import Softplus
2634

2735
@mlx_funcify.register(Dot)
2836
def mlx_funcify_Dot(op, **kwargs):
@@ -142,5 +150,70 @@ def true_div(x, y):
142150
return mx.divide(x, y)
143151

144152
return true_div
153+
elif isinstance(op.scalar_op, Sqr):
154+
155+
def sqr(x):
156+
return mx.square(x)
157+
158+
return sqr
159+
elif isinstance(op.scalar_op, Sqrt):
160+
161+
def sqrt(x):
162+
return mx.sqrt(x)
163+
164+
return sqrt
165+
elif isinstance(op.scalar_op, Abs):
166+
167+
def abs(x):
168+
return mx.abs(x)
169+
170+
return abs
171+
elif isinstance(op.scalar_op, Softplus):
172+
def softplus(x):
173+
return mx.where(
174+
x < -37.0,
175+
mx.exp(x),
176+
mx.where(
177+
x < 18.0,
178+
mx.log1p(mx.exp(x)),
179+
mx.where(
180+
x < 33.3,
181+
x + mx.exp(-x),
182+
x,
183+
),
184+
),
185+
)
186+
187+
return softplus
188+
elif isinstance(op.scalar_op, Neg):
189+
190+
def neg(x):
191+
return mx.negative(x)
192+
193+
return neg
194+
elif isinstance(op.scalar_op, AND):
195+
196+
def all(x):
197+
return mx.all(x, axis=op.axis)
198+
199+
return all
200+
elif isinstance(op.scalar_op, OR):
201+
202+
def any(x):
203+
return mx.any(x, axis=op.axis)
204+
205+
return any
206+
elif isinstance(op.scalar_op, ScalarMaximum):
207+
208+
def max(x):
209+
return mx.max(x, axis=op.axis)
210+
211+
return max
212+
elif isinstance(op.scalar_op, ScalarMinimum):
213+
214+
def min(x):
215+
return mx.min(x, axis=op.axis)
216+
217+
return min
145218
else:
146219
raise NotImplementedError(f"MLX does not support {op.scalar_op}")

pytensor/link/mlx/dispatch/signal/conv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66

77
@mlx_funcify.register(Conv1d)
8-
def mlx_funcify_Conv1d(op, node, **kwargs):
8+
def mlx_funcify_Conv1d(op, node=None, **kwargs):
99
mode = op.mode
1010

1111
def conv1d(data, kernel):

0 commit comments

Comments
 (0)