Skip to content

Commit acd22f3

Browse files
committed
add test to extract function
1 parent b2a3d1e commit acd22f3

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

tests/test_pytensorf.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,14 @@ def test_minibatch_variable(self):
199199
assert isinstance(res, np.ndarray)
200200
np.testing.assert_array_equal(res, y)
201201

202+
def test_pytensor_operations(self):
203+
x = np.array([1, 2, 3])
204+
target = 1 + 3 * pt.as_tensor_variable(x)
205+
206+
res = extract_obs_data(target)
207+
assert isinstance(res, np.ndarray)
208+
np.testing.assert_array_equal(res, np.array([4, 7, 10]))
209+
202210

203211
@pytest.mark.parametrize("input_dtype", ["int32", "int64", "float32", "float64"])
204212
def test_convert_data(input_dtype):

0 commit comments

Comments
 (0)