We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 47a15c6 commit c2ede26Copy full SHA for c2ede26
pytensor/tensor/basic.py
@@ -678,10 +678,9 @@ def make_node(self, t):
678
self, [t], [ps.get_scalar_type(dtype=t.type.dtype).make_variable()]
679
)
680
681
- def perform(self, node, inp, out_):
682
- (s,) = inp
683
- (out,) = out_
684
- out[0] = s.flatten()[0]
+ def perform(self, node, inputs, output_storage):
+ # not using .item() because that returns a Python scalar, not a numpy scalar
+ output_storage[0][0] = inputs[0][()]
685
686
def infer_shape(self, fgraph, node, in_shapes):
687
return [()]
0 commit comments