Skip to content

Commit 13cb73d

Browse files
committed
fix: cast to elemwise outputs to their respective dtypes
1 parent 067ed32 commit 13cb73d

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

pytensor/scalar/basic.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1142,12 +1142,12 @@ def output_types(self, types):
11421142

11431143
def perform(self, node, inputs, output_storage):
11441144
if self.nout == 1:
1145-
output_storage[0][0] = self.impl(*inputs)
1145+
output_storage[0][0] = self.impl(*inputs).astype(node.outputs[0].dtype, copy=False)
11461146
else:
11471147
variables = from_return_values(self.impl(*inputs))
11481148
assert len(variables) == len(output_storage)
1149-
for storage, variable in zip(output_storage, variables):
1150-
storage[0] = variable
1149+
for out, storage, variable in zip(node.outputs, output_storage, variables):
1150+
storage[0] = variable.astype(out.dtype, copy=False)
11511151

11521152
def impl(self, *inputs):
11531153
raise MethodNotDefined("impl", type(self), self.__class__.__name__)

0 commit comments

Comments
 (0)