Skip to content

Commit 9f33f9e

Browse files
float32 elemwise
1 parent 14d2dcd commit 9f33f9e

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

tests/tensor/test_elemwise.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,9 @@
1111
import pytensor.scalar as ps
1212
import pytensor.tensor as pt
1313
import tests.unittest_tools as utt
14-
from pytensor import In, Out, grad
14+
from pytensor import In, Out, config, grad
1515
from pytensor.compile.function import function
1616
from pytensor.compile.mode import Mode
17-
from pytensor.configdefaults import config
1817
from pytensor.graph.basic import Apply, Variable
1918
from pytensor.graph.fg import FunctionGraph
2019
from pytensor.graph.replace import vectorize_node
@@ -1088,7 +1087,7 @@ def L_op(self, inputs, outputs, output_gradients):
10881087
x = vector("x")
10891088
y, _ = op(x)
10901089
np.testing.assert_array_equal(
1091-
grad(y.sum(), x).eval({x: np.full((12,), np.nan)}),
1092-
np.ones((12,)),
1090+
grad(y.sum(), x).eval({x: np.full((12,), np.nan, dtype=config.floatX)}),
1091+
np.ones((12,), dtype=config.floatX),
10931092
strict=True,
10941093
)

0 commit comments

Comments
 (0)