Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions pytensor/graph/rewriting/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1668,9 +1668,11 @@ def transform(self, fgraph, node, get_nodes=True):
):
return False
else:
out_dtype = node.outputs[0].type.dtype
if self.allow_cast and ret.owner.outputs[0].type.dtype != out_dtype:
ret = pytensor.tensor.basic.cast(ret, out_dtype)
if self.allow_cast:
out_dtype = getattr(node.outputs[0].type, "dtype", None)
ret_dtype = getattr(ret.owner.outputs[0].type, "dtype", None)
if ret_dtype != out_dtype:
ret = pytensor.tensor.basic.cast(ret, out_dtype)
if not node.outputs[0].type.is_super(ret.owner.outputs[0].type):
return False
else:
Expand Down