Skip to content

Commit a43f1cf

Browse files
cetagostinijessegrabowski
authored andcommitted
Changes with Ricardo
1 parent 70734c9 commit a43f1cf

File tree

4 files changed

+13
-29
lines changed

4 files changed

+13
-29
lines changed

doc/_drafts/benchmark_mlx_v_jax_corrected.ipynb

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
"cells": [
33
{
44
"cell_type": "code",
5-
"execution_count": 15,
5+
"execution_count": 1,
66
"metadata": {},
77
"outputs": [],
88
"source": [
@@ -21,7 +21,7 @@
2121
},
2222
{
2323
"cell_type": "code",
24-
"execution_count": 23,
24+
"execution_count": 2,
2525
"metadata": {},
2626
"outputs": [],
2727
"source": [
@@ -296,7 +296,7 @@
296296
},
297297
{
298298
"cell_type": "code",
299-
"execution_count": 24,
299+
"execution_count": null,
300300
"metadata": {},
301301
"outputs": [
302302
{
@@ -306,16 +306,14 @@
306306
"Verifying computational correctness...\n",
307307
"Max difference between JAX and MLX results: 0.00e+00\n",
308308
"✅ Results match within tolerance\n",
309-
"Running benchmarks with N=1000 repetitions per test...\n",
310-
"Testing 128x128 matrices...\n",
311-
"Testing 256x256 matrices...\n",
312-
"Testing 512x512 matrices...\n",
313-
"Testing 1024x1024 matrices...\n"
309+
"Running benchmarks with N=20 repetitions per test...\n",
310+
"Testing 128x128 matrices...\n"
314311
]
315312
}
316313
],
317314
"source": [
318-
"_, results = main()"
315+
"iteration=20\n",
316+
"_, results = main(N=iteration)"
319317
]
320318
},
321319
{
@@ -346,7 +344,7 @@
346344
}
347345
],
348346
"source": [
349-
"print(\"\\nBenchmark Results over 1000 repetitions:\")\n",
347+
"print(f\"\\nBenchmark Results over {iteration} repetitions:\")\n",
350348
"print(results.to_string(index=False))"
351349
]
352350
},

pytensor/link/mlx/dispatch/core.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -212,17 +212,8 @@ def tensor_from_scalar(x):
212212
@mlx_funcify.register(ScalarFromTensor)
213213
def mlx_funcify_ScalarFromTensor(op, **kwargs):
214214
def scalar_from_tensor(x):
215-
arr = mx.array(x)
216-
try:
217-
# Try .item() first (cleaner and faster when possible)
218-
return arr.item()
219-
except ValueError as e:
220-
if "eval" in str(e):
221-
# Fall back to reshape approach for compiled contexts
222-
return arr.reshape(-1)[0]
223-
else:
224-
# Re-raise if it's a different error
225-
raise
215+
"We can't not return a scalar in MLX without trigger evaluation"
216+
return x
226217

227218
return scalar_from_tensor
228219

pytensor/link/mlx/dispatch/shape.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
@mlx_funcify.register(Shape)
88
def mlx_funcify_Shape(op, **kwargs):
99
def shape(x):
10-
return x.shape
10+
return mx.array(x.shape, dtype=mx.int64)
1111

1212
return shape
1313

tests/link/mlx/test_basic.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
Basic tests for the MLX backend.
33
"""
4+
import pytest
45

56
from collections.abc import Callable, Iterable
67
from functools import partial
@@ -224,12 +225,7 @@ def test_alloc_pytensor_integration():
224225
x = pt.scalar("x", dtype="float32")
225226
result = pt.alloc(x, 3, 4)
226227

227-
# Use MLX mode
228-
from pytensor.compile import mode
229-
230-
mlx_mode = mode.get_mode("MLX")
231-
232-
f = pytensor.function([x], result, mode=mlx_mode)
228+
f = pytensor.function([x], result, mode="MLX")
233229
output = f(5.0)
234230

235231
assert output.shape == (3, 4)
@@ -238,7 +234,6 @@ def test_alloc_pytensor_integration():
238234

239235
def test_alloc_compilation_limitation():
240236
"""Test that Alloc operations with dynamic shapes provide helpful error in compiled contexts."""
241-
import pytest
242237

243238
# Create variables
244239
x = pt.scalar("x", dtype="float32")

0 commit comments

Comments
 (0)