We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent cfcb910 commit 481e3adCopy full SHA for 481e3ad
tests/link/mlx/test_basic.py
@@ -123,15 +123,15 @@ def test_scalar_from_tensor_with_scalars():
123
124
def test_scalar_from_tensor_pytensor_integration():
125
"""Test ScalarFromTensor in a PyTensor graph context."""
126
- # Create a 0-d tensor (scalar tensor)
127
- x = pt.as_tensor_variable(42)
+ # Create a symbolic scalar input to actually test MLX execution
+ x = pt.scalar("x", dtype="int64")
128
129
# Apply ScalarFromTensor
130
scalar_result = pt.scalar_from_tensor(x)
131
132
# Create function and test
133
- f = pytensor.function([], scalar_result, mode="MLX")
134
- result = f()
+ f = pytensor.function([x], scalar_result, mode="MLX")
+ result = f(42)
135
136
assert result == 42
137
0 commit comments