File tree Expand file tree Collapse file tree 2 files changed +12
-1
lines changed
pytensor/link/mlx/dispatch Expand file tree Collapse file tree 2 files changed +12
-1
lines changed Original file line number Diff line number Diff line change @@ -212,7 +212,17 @@ def tensor_from_scalar(x):
212
212
@mlx_funcify .register (ScalarFromTensor )
213
213
def mlx_funcify_ScalarFromTensor (op , ** kwargs ):
214
214
def scalar_from_tensor (x ):
215
- return mx .array (x ).reshape (- 1 )[0 ]
215
+ arr = mx .array (x )
216
+ try :
217
+ # Try .item() first (cleaner and faster when possible)
218
+ return arr .item ()
219
+ except ValueError as e :
220
+ if "eval" in str (e ):
221
+ # Fall back to reshape approach for compiled contexts
222
+ return arr .reshape (- 1 )[0 ]
223
+ else :
224
+ # Re-raise if it's a different error
225
+ raise
216
226
217
227
return scalar_from_tensor
218
228
Original file line number Diff line number Diff line change 1
1
"""
2
2
Basic tests for the MLX backend.
3
3
"""
4
+
4
5
from collections .abc import Callable , Iterable
5
6
from functools import partial
6
7
You can’t perform that action at this time.
0 commit comments