File tree Expand file tree Collapse file tree 4 files changed +13
-29
lines changed
pytensor/link/mlx/dispatch Expand file tree Collapse file tree 4 files changed +13
-29
lines changed Original file line number Diff line number Diff line change 2
2
"cells" : [
3
3
{
4
4
"cell_type" : " code" ,
5
- "execution_count" : 15 ,
5
+ "execution_count" : 1 ,
6
6
"metadata" : {},
7
7
"outputs" : [],
8
8
"source" : [
21
21
},
22
22
{
23
23
"cell_type" : " code" ,
24
- "execution_count" : 23 ,
24
+ "execution_count" : 2 ,
25
25
"metadata" : {},
26
26
"outputs" : [],
27
27
"source" : [
296
296
},
297
297
{
298
298
"cell_type" : " code" ,
299
- "execution_count" : 24 ,
299
+ "execution_count" : null ,
300
300
"metadata" : {},
301
301
"outputs" : [
302
302
{
306
306
" Verifying computational correctness...\n " ,
307
307
" Max difference between JAX and MLX results: 0.00e+00\n " ,
308
308
" ✅ Results match within tolerance\n " ,
309
- " Running benchmarks with N=1000 repetitions per test...\n " ,
310
- " Testing 128x128 matrices...\n " ,
311
- " Testing 256x256 matrices...\n " ,
312
- " Testing 512x512 matrices...\n " ,
313
- " Testing 1024x1024 matrices...\n "
309
+ " Running benchmarks with N=20 repetitions per test...\n " ,
310
+ " Testing 128x128 matrices...\n "
314
311
]
315
312
}
316
313
],
317
314
"source" : [
318
- " _, results = main()"
315
+ " iteration=20\n " ,
316
+ " _, results = main(N=iteration)"
319
317
]
320
318
},
321
319
{
346
344
}
347
345
],
348
346
"source" : [
349
- " print(\"\\ nBenchmark Results over 1000 repetitions:\" )\n " ,
347
+ " print(f \"\\ nBenchmark Results over {iteration} repetitions:\" )\n " ,
350
348
" print(results.to_string(index=False))"
351
349
]
352
350
},
Original file line number Diff line number Diff line change @@ -212,17 +212,8 @@ def tensor_from_scalar(x):
212
212
@mlx_funcify .register (ScalarFromTensor )
213
213
def mlx_funcify_ScalarFromTensor (op , ** kwargs ):
214
214
def scalar_from_tensor (x ):
215
- arr = mx .array (x )
216
- try :
217
- # Try .item() first (cleaner and faster when possible)
218
- return arr .item ()
219
- except ValueError as e :
220
- if "eval" in str (e ):
221
- # Fall back to reshape approach for compiled contexts
222
- return arr .reshape (- 1 )[0 ]
223
- else :
224
- # Re-raise if it's a different error
225
- raise
215
+ "We can't not return a scalar in MLX without trigger evaluation"
216
+ return x
226
217
227
218
return scalar_from_tensor
228
219
Original file line number Diff line number Diff line change 7
7
@mlx_funcify .register (Shape )
8
8
def mlx_funcify_Shape (op , ** kwargs ):
9
9
def shape (x ):
10
- return x .shape
10
+ return mx . array ( x .shape , dtype = mx . int64 )
11
11
12
12
return shape
13
13
Original file line number Diff line number Diff line change 1
1
"""
2
2
Basic tests for the MLX backend.
3
3
"""
4
+ import pytest
4
5
5
6
from collections .abc import Callable , Iterable
6
7
from functools import partial
@@ -224,12 +225,7 @@ def test_alloc_pytensor_integration():
224
225
x = pt .scalar ("x" , dtype = "float32" )
225
226
result = pt .alloc (x , 3 , 4 )
226
227
227
- # Use MLX mode
228
- from pytensor .compile import mode
229
-
230
- mlx_mode = mode .get_mode ("MLX" )
231
-
232
- f = pytensor .function ([x ], result , mode = mlx_mode )
228
+ f = pytensor .function ([x ], result , mode = "MLX" )
233
229
output = f (5.0 )
234
230
235
231
assert output .shape == (3 , 4 )
@@ -238,7 +234,6 @@ def test_alloc_pytensor_integration():
238
234
239
235
def test_alloc_compilation_limitation ():
240
236
"""Test that Alloc operations with dynamic shapes provide helpful error in compiled contexts."""
241
- import pytest
242
237
243
238
# Create variables
244
239
x = pt .scalar ("x" , dtype = "float32" )
You can’t perform that action at this time.
0 commit comments