Skip to content

Commit 481e3ad

Browse files
cetagostinijessegrabowski
authored andcommitted
Correcting test by Ricardo
1 parent cfcb910 commit 481e3ad

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

tests/link/mlx/test_basic.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -123,15 +123,15 @@ def test_scalar_from_tensor_with_scalars():
123123

124124
def test_scalar_from_tensor_pytensor_integration():
125125
"""Test ScalarFromTensor in a PyTensor graph context."""
126-
# Create a 0-d tensor (scalar tensor)
127-
x = pt.as_tensor_variable(42)
126+
# Create a symbolic scalar input to actually test MLX execution
127+
x = pt.scalar("x", dtype="int64")
128128

129129
# Apply ScalarFromTensor
130130
scalar_result = pt.scalar_from_tensor(x)
131131

132132
# Create function and test
133-
f = pytensor.function([], scalar_result, mode="MLX")
134-
result = f()
133+
f = pytensor.function([x], scalar_result, mode="MLX")
134+
result = f(42)
135135

136136
assert result == 42
137137

0 commit comments

Comments
 (0)