Skip to content

Commit 2421a6f

Browse files
Handle dynamic shapes to AllocEmpty in non-compiled mode
1 parent 433a2cb commit 2421a6f

File tree

2 files changed

+53
-17
lines changed

2 files changed

+53
-17
lines changed

pytensor/link/mlx/dispatch/core.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -247,18 +247,29 @@ def tri(*args):
247247

248248

249249
@mlx_funcify.register(AllocEmpty)
250-
def mlx_funcify_AllocEmpty(op, **kwargs):
250+
def mlx_funcify_AllocEmpty(op, node, **kwargs):
251251
dtype = convert_dtype_to_mlx(op.dtype)
252+
node_inputs = node.inputs
253+
static_dims = (
254+
_extract_static_dims(node_inputs)
255+
if node_inputs and len(node_inputs) > 1
256+
else None
257+
)
252258

253259
def allocempty(*shape):
254-
return mx.zeros(shape, dtype=dtype)
260+
resolved_shape = (
261+
_resolve_shape(static_dims, shape)
262+
if static_dims is not None
263+
else tuple(_coerce_to_int(dim) for dim in shape)
264+
)
265+
return mx.zeros(resolved_shape, dtype=dtype)
255266

256267
return allocempty
257268

258269

259270
@mlx_funcify.register(Alloc)
260271
def mlx_funcify_Alloc(op, node, **kwargs):
261-
node_inputs = getattr(node, "inputs", None)
272+
node_inputs = node.inputs
262273
static_dims = (
263274
_extract_static_dims(node_inputs[1:])
264275
if node_inputs and len(node_inputs) > 1

tests/link/mlx/test_core.py

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -73,23 +73,18 @@ def test_alloc_compilation_limitation():
7373
# Test that it works with concrete values (non-compiled context)
7474
output = f(5.0, 3, 4)
7575
assert output.shape == (3, 4)
76-
assert np.allclose(output, 5.0)
76+
np.testing.assert_allclose(output, 5.0)
7777

7878
# Test that compilation fails with helpful error
7979
compiled_f = pytensor.function([x, s1, s2], result, mode=compile_mode)
8080

81-
with pytest.raises(ValueError) as exc_info:
81+
with pytest.raises(
82+
ValueError,
83+
match="MLX compilation limitation: Alloc operations with dynamic shapes cannot be "
84+
"used inside compiled functions",
85+
):
8286
compiled_f(5.0, 3, 4)
8387

84-
error_msg = str(exc_info.value)
85-
assert "MLX compilation limitation" in error_msg
86-
assert "Alloc operations with dynamic shapes" in error_msg
87-
assert "cannot be used inside compiled functions" in error_msg
88-
assert "Workarounds:" in error_msg
89-
assert "Avoid using Alloc with dynamic shapes in compiled contexts" in error_msg
90-
assert "Use static shapes when possible" in error_msg
91-
assert "Move Alloc operations outside compiled functions" in error_msg
92-
9388

9489
def test_alloc_static_shapes_compilation():
9590
"""Test that Alloc operations with static shapes work fine in compiled contexts."""
@@ -109,6 +104,36 @@ def test_alloc_static_shapes_compilation():
109104

110105
assert output_normal.shape == (3, 4)
111106
assert output_compiled.shape == (3, 4)
112-
assert np.allclose(output_normal, 5.0)
113-
assert np.allclose(output_compiled, 5.0)
114-
assert np.allclose(output_normal, output_compiled)
107+
np.testing.assert_allclose(output_normal, 5.0)
108+
np.testing.assert_allclose(output_compiled, 5.0)
109+
np.testing.assert_allclose(output_normal, output_compiled)
110+
111+
112+
def test_empty_static_shape():
113+
result = pt.empty((3, 4), dtype="float32")
114+
115+
f = pytensor.function([], result, mode="MLX")
116+
output = f()
117+
118+
assert output.shape == (3, 4)
119+
np.testing.assert_allclose(output, 0.0)
120+
121+
122+
def test_empty_dynamic_shape():
123+
s1 = pt.scalar("s1", dtype="int64")
124+
s2 = pt.scalar("s2", dtype="int64")
125+
result = pt.empty((s1, s2), dtype="float32")
126+
127+
f = pytensor.function([s1, s2], result, mode=mlx_mode_no_compile)
128+
output = f(3, 4)
129+
130+
assert output.shape == (3, 4)
131+
np.testing.assert_allclose(output, 0.0)
132+
133+
f_compiled = pytensor.function([s1, s2], result, mode=compile_mode)
134+
with pytest.raises(
135+
ValueError,
136+
match="MLX compilation limitation: Alloc operations with dynamic shapes cannot be "
137+
"used inside compiled functions",
138+
):
139+
f_compiled(3, 4)

0 commit comments

Comments
 (0)