Skip to content

Commit c2ede26

Browse files
committed
Simplify python implementation of ScalarFromTensor
1 parent 47a15c6 commit c2ede26

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

pytensor/tensor/basic.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -678,10 +678,9 @@ def make_node(self, t):
678678
self, [t], [ps.get_scalar_type(dtype=t.type.dtype).make_variable()]
679679
)
680680

681-
def perform(self, node, inp, out_):
682-
(s,) = inp
683-
(out,) = out_
684-
out[0] = s.flatten()[0]
681+
def perform(self, node, inputs, output_storage):
682+
# not using .item() because that returns a Python scalar, not a numpy scalar
683+
output_storage[0][0] = inputs[0][()]
685684

686685
def infer_shape(self, fgraph, node, in_shapes):
687686
return [()]

0 commit comments

Comments
 (0)