Skip to content

Commit 227a468

Browse files
author
Luca Citi
committed
Changed PatternNodeRewriter::transform to allow types that do not contain dtype
like MyType in the tests
1 parent 6277546 commit 227a468

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

pytensor/graph/rewriting/basic.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1668,9 +1668,11 @@ def transform(self, fgraph, node, get_nodes=True):
16681668
):
16691669
return False
16701670
else:
1671-
out_dtype = node.outputs[0].type.dtype
1672-
if self.allow_cast and ret.owner.outputs[0].type.dtype != out_dtype:
1673-
ret = pytensor.tensor.basic.cast(ret, out_dtype)
1671+
if self.allow_cast:
1672+
out_dtype = getattr(node.outputs[0].type, "dtype", None)
1673+
ret_dtype = getattr(ret.owner.outputs[0].type, "dtype", None)
1674+
if ret_dtype != out_dtype:
1675+
ret = pytensor.tensor.basic.cast(ret, out_dtype)
16741676
if not node.outputs[0].type.is_super(ret.owner.outputs[0].type):
16751677
return False
16761678
else:

0 commit comments

Comments
 (0)