Skip to content

Commit ac93949

Browse files
committed
almost there baby william
1 parent 12daeac commit ac93949

File tree

5 files changed

+66
-83
lines changed

5 files changed

+66
-83
lines changed

pytensor/link/mlx/dispatch/basic.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import warnings
12
from functools import singledispatch
23
from types import NoneType
34

@@ -7,6 +8,7 @@
78
from pytensor.compile.ops import DeepCopyOp
89
from pytensor.graph.fg import FunctionGraph
910
from pytensor.link.utils import fgraph_to_python
11+
from pytensor.raise_op import Assert, CheckAndRaise
1012

1113

1214
@singledispatch
@@ -59,3 +61,17 @@ def deepcopyop(x):
5961
return x.copy()
6062

6163
return deepcopyop
64+
65+
66+
@mlx_funcify.register(Assert)
67+
@mlx_funcify.register(CheckAndRaise)
68+
def mlx_funcify_CheckAndRaise(op, **kwargs):
69+
warnings.warn(
70+
f"""Skipping `CheckAndRaise` Op (assertion: {op.msg}) as MLX tracing would remove it.""",
71+
stacklevel=2,
72+
)
73+
74+
def assert_fn(x, *inputs):
75+
return x
76+
77+
return assert_fn
Lines changed: 7 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,66 +1,18 @@
11
import mlx.core as mx
22

3-
from pytensor.graph import FunctionGraph
43
from pytensor.link.mlx.dispatch import mlx_funcify
54
from pytensor.tensor.blockwise import Blockwise
65

76

87
@mlx_funcify.register(Blockwise)
98
def funcify_Blockwise(op: Blockwise, node, *args, **kwargs):
10-
# Create a function graph for the core operation
119
core_node = op._create_dummy_core_node(node.inputs)
12-
core_fgraph = FunctionGraph(inputs=core_node.inputs, outputs=core_node.outputs)
10+
core_f = mlx_funcify(op.core_op, core_node)
11+
blockwise_f = core_f
12+
for i in range(op.batch_ndim(node)):
13+
blockwise_f = mx.vmap(blockwise_f)
1314

14-
# Convert the core function graph to an MLX function
15-
tuple_core_fn = mlx_funcify(core_fgraph, **kwargs)
15+
def blockwise_fun(*inputs):
16+
return blockwise_f(*inputs)
1617

17-
# If there's only one output, unwrap it from the tuple
18-
if len(node.outputs) == 1:
19-
20-
def core_fn(*inputs):
21-
return tuple_core_fn(*inputs)[0]
22-
else:
23-
core_fn = tuple_core_fn
24-
25-
# Apply vmap for each batch dimension
26-
batch_ndims = op.batch_ndim(node)
27-
vmap_fn = core_fn
28-
for _ in range(batch_ndims):
29-
vmap_fn = mx.vmap(vmap_fn)
30-
31-
def blockwise_fn(*inputs):
32-
# Check for runtime broadcasting compatibility
33-
op._check_runtime_broadcast(node, inputs)
34-
35-
# Handle broadcasting for batched dimensions
36-
if batch_ndims > 0:
37-
# Get batch shapes for broadcasting
38-
batch_shapes = [inp.shape[:batch_ndims] for inp in inputs]
39-
40-
# Calculate the broadcasted batch shape
41-
from functools import reduce
42-
43-
def broadcast_shapes(shape1, shape2):
44-
return tuple(max(s1, s2) for s1, s2 in zip(shape1, shape2, strict=True))
45-
46-
if batch_shapes:
47-
broadcasted_shape = reduce(broadcast_shapes, batch_shapes)
48-
49-
# Broadcast inputs to the common batch shape
50-
broadcasted_inputs = []
51-
for inp in inputs:
52-
if inp.shape[:batch_ndims] != broadcasted_shape:
53-
# Create the full target shape
54-
target_shape = broadcasted_shape + inp.shape[batch_ndims:]
55-
# Broadcast the input
56-
broadcasted_inputs.append(mx.broadcast_to(inp, target_shape))
57-
else:
58-
broadcasted_inputs.append(inp)
59-
60-
# Apply the vectorized function to the broadcasted inputs
61-
return vmap_fn(*broadcasted_inputs)
62-
63-
# No broadcasting needed
64-
return vmap_fn(*inputs)
65-
66-
return blockwise_fn
18+
return blockwise_fun

pytensor/link/mlx/dispatch/math.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,40 @@
11
import mlx.core as mx
22

33
from pytensor.link.mlx.dispatch import mlx_funcify
4+
from pytensor.scalar import Softplus
45
from pytensor.scalar.basic import (
6+
AND,
57
EQ,
68
GE,
79
GT,
810
LE,
911
LT,
1012
NEQ,
13+
OR,
1114
Abs,
1215
Add,
16+
Cast,
1317
Cos,
1418
Exp,
1519
Log,
1620
Mul,
21+
Neg,
1722
Pow,
23+
ScalarMaximum,
24+
ScalarMinimum,
25+
Sign,
1826
Sin,
1927
Sqr,
2028
Sqrt,
2129
Sub,
2230
Switch,
2331
TrueDiv,
24-
Neg,
25-
AND,
26-
OR,
27-
ScalarMaximum,
28-
ScalarMinimum,
32+
Log1p
2933
)
3034
from pytensor.scalar.math import Sigmoid
3135
from pytensor.tensor.elemwise import Elemwise
3236
from pytensor.tensor.math import Dot
33-
from pytensor.scalar import Softplus
37+
3438

3539
@mlx_funcify.register(Dot)
3640
def mlx_funcify_Dot(op, **kwargs):
@@ -169,6 +173,7 @@ def abs(x):
169173

170174
return abs
171175
elif isinstance(op.scalar_op, Softplus):
176+
172177
def softplus(x):
173178
return mx.where(
174179
x < -37.0,
@@ -194,7 +199,7 @@ def neg(x):
194199
elif isinstance(op.scalar_op, AND):
195200

196201
def all(x):
197-
return mx.all(x, axis=op.axis)
202+
return mx.all(x)
198203

199204
return all
200205
elif isinstance(op.scalar_op, OR):
@@ -215,5 +220,23 @@ def min(x):
215220
return mx.min(x, axis=op.axis)
216221

217222
return min
223+
elif isinstance(op.scalar_op, Cast):
224+
225+
def cast(x):
226+
return mx.cast(x, op.dtype)
227+
228+
return cast
229+
elif isinstance(op.scalar_op, Sign):
230+
231+
def sign(x):
232+
return mx.sign(x)
233+
234+
return sign
235+
elif isinstance(op.scalar_op, Log1p):
236+
237+
def log1p(x):
238+
return mx.log1p(x)
239+
240+
return log1p
218241
else:
219242
raise NotImplementedError(f"MLX does not support {op.scalar_op}")

pytensor/link/mlx/dispatch/shape.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def specifyshape(x, *shape):
1818

1919
@mlx_funcify.register(Shape_i)
2020
def mlx_funcify_Shape_i(op, node, **kwargs):
21-
def shape_i(x, i):
21+
def shape_i(x):
2222
return x.shape[op.i]
2323

2424
return shape_i

pytensor/link/mlx/dispatch/subtensor.py

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,40 +11,32 @@
1111
from pytensor.tensor.type_other import MakeSlice
1212

1313

14-
BOOLEAN_MASK_ERROR = """MLX does not support resizing arrays with boolean
15-
masks. In some cases, however, it is possible to re-express your model
16-
in a form that MLX can compile:
17-
18-
>>> import pytensor.tensor as pt
19-
>>> x_pt = pt.vector('x')
20-
>>> y_pt = x_pt[x_pt > 0].sum()
21-
22-
can be re-expressed as:
14+
@mlx_funcify.register(Subtensor)
15+
def mlx_funcify_Subtensor(op, node, **kwargs):
16+
idx_list = getattr(op, "idx_list", None)
2317

24-
>>> import pytensor.tensor as pt
25-
>>> x_pt = pt.vector('x')
26-
>>> y_pt = pt.where(x_pt > 0, x_pt, 0).sum()
27-
"""
18+
def subtensor(x, *ilists):
19+
indices = indices_from_subtensor([int(element) for element in ilists], idx_list)
20+
if len(indices) == 1:
21+
indices = indices[0]
2822

29-
DYNAMIC_SLICE_LENGTH_ERROR = """MLX does not support slicing arrays with a dynamic
30-
slice length.
31-
"""
23+
return x.__getitem__(indices)
3224

25+
return subtensor
3326

34-
@mlx_funcify.register(Subtensor)
3527
@mlx_funcify.register(AdvancedSubtensor)
3628
@mlx_funcify.register(AdvancedSubtensor1)
37-
def mlx_funcify_Subtensor(op, node, **kwargs):
29+
def mlx_funcify_AdvancedSubtensor(op, node, **kwargs):
3830
idx_list = getattr(op, "idx_list", None)
3931

40-
def subtensor(x, *ilists):
32+
def advanced_subtensor(x, *ilists):
4133
indices = indices_from_subtensor(ilists, idx_list)
4234
if len(indices) == 1:
4335
indices = indices[0]
4436

4537
return x.__getitem__(indices)
4638

47-
return subtensor
39+
return advanced_subtensor
4840

4941

5042
@mlx_funcify.register(IncSubtensor)

0 commit comments

Comments
 (0)