Skip to content

Commit 13a700a

Browse files
cetagostinijessegrabowski
authored andcommitted
Optimizing reshape
1 parent 9527f6c commit 13a700a

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

pytensor/link/mlx/dispatch/core.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,17 @@ def tensor_from_scalar(x):
212212
@mlx_funcify.register(ScalarFromTensor)
213213
def mlx_funcify_ScalarFromTensor(op, **kwargs):
214214
def scalar_from_tensor(x):
215-
return mx.array(x).reshape(-1)[0]
215+
arr = mx.array(x)
216+
try:
217+
# Try .item() first (cleaner and faster when possible)
218+
return arr.item()
219+
except ValueError as e:
220+
if "eval" in str(e):
221+
# Fall back to reshape approach for compiled contexts
222+
return arr.reshape(-1)[0]
223+
else:
224+
# Re-raise if it's a different error
225+
raise
216226

217227
return scalar_from_tensor
218228

tests/link/mlx/test_basic.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
Basic tests for the MLX backend.
33
"""
4+
45
from collections.abc import Callable, Iterable
56
from functools import partial
67

0 commit comments

Comments
 (0)