Skip to content

Commit 5710f95

Browse files
committed
Reintroduce inline printing for short single output Composites
1 parent 1f2542e commit 5710f95

File tree

4 files changed

+35
-15
lines changed

4 files changed

+35
-15
lines changed

pytensor/scalar/basic.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4143,6 +4143,7 @@ class Composite(ScalarInnerGraphOp):
41434143

41444144
def __init__(self, inputs, outputs, name="Composite"):
41454145
self.name = name
4146+
self._name = None
41464147
# We need to clone the graph as sometimes its nodes already
41474148
# contain a reference to an fgraph. As we want the Composite
41484149
# to be pickable, we can't have reference to fgraph.
@@ -4189,7 +4190,26 @@ def __init__(self, inputs, outputs, name="Composite"):
41894190
super().__init__()
41904191

41914192
def __str__(self):
4192-
return self.name
4193+
if self._name is not None:
4194+
return self._name
4195+
4196+
# Rename internal variables
4197+
for i, r in enumerate(self.fgraph.inputs):
4198+
r.name = f"i{int(i)}"
4199+
for i, r in enumerate(self.fgraph.outputs):
4200+
r.name = f"o{int(i)}"
4201+
io = set(self.fgraph.inputs + self.fgraph.outputs)
4202+
for i, r in enumerate(self.fgraph.variables):
4203+
if r not in io and len(self.fgraph.clients[r]) > 1:
4204+
r.name = f"t{int(i)}"
4205+
4206+
if len(self.fgraph.outputs) > 1 or len(self.fgraph.apply_nodes) > 10:
4207+
self._name = "Composite{...}"
4208+
else:
4209+
outputs_str = ", ".join([pprint(output) for output in self.fgraph.outputs])
4210+
self._name = f"Composite{{{outputs_str}}}"
4211+
4212+
return self._name
41934213

41944214
def make_new_inplace(self, output_types_preference=None, name=None):
41954215
"""

tests/scalar/test_basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def test_composite_printing(self):
183183
make_function(DualLinker().accept(g))
184184

185185
assert str(g) == (
186-
"FunctionGraph(*1 -> Composite(x, y, z), *1::1, *1::2, *1::3, *1::4, *1::5, *1::6, *1::7)"
186+
"FunctionGraph(*1 -> Composite{...}(x, y, z), *1::1, *1::2, *1::3, *1::4, *1::5, *1::6, *1::7)"
187187
)
188188

189189
def test_non_scalar_error(self):

tests/scan/test_printing.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -654,7 +654,7 @@ def no_shared_fn(n, x_tm1, M):
654654
Inner graphs:
655655
656656
forall_inplace,cpu,scan_fn} [id A]
657-
← Elemwise{Composite} [id I] (inner_out_sit_sot-0)
657+
← Elemwise{Composite{Switch(LT(i0, i1), i2, i0)}} [id I] (inner_out_sit_sot-0)
658658
├─ TensorConstant{0} [id J]
659659
├─ Subtensor{int64, int64, uint8} [id K]
660660
│ ├─ *2-<TensorType(float64, (20000, 2, 2))> [id L] -> [id H] (inner_in_non_seqs-0)
@@ -665,13 +665,13 @@ def no_shared_fn(n, x_tm1, M):
665665
│ └─ ScalarConstant{0} [id Q]
666666
└─ TensorConstant{1} [id R]
667667
668-
Elemwise{Composite} [id I]
669-
← Switch [id S]
668+
Elemwise{Composite{Switch(LT(i0, i1), i2, i0)}} [id I]
669+
← Switch [id S] 'o0'
670670
├─ LT [id T]
671-
│ ├─ <int64> [id U]
672-
│ └─ <float64> [id V]
673-
├─ <int64> [id W]
674-
└─ <int64> [id U]
671+
│ ├─ i0 [id U]
672+
│ └─ i1 [id V]
673+
├─ i2 [id W]
674+
└─ i0 [id U]
675675
"""
676676

677677
output_str = debugprint(out, file="str", print_op_info=True)

tests/test_printing.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ def test_debugprint():
274274
s = s.getvalue()
275275
exp_res = dedent(
276276
r"""
277-
Elemwise{Composite} 4
277+
Elemwise{Composite{(i2 + (i0 - i1))}} 4
278278
├─ InplaceDimShuffle{x,0} v={0: [0]} 3
279279
│ └─ CGemv{inplace} d={0: [0]} 2
280280
│ ├─ AllocEmpty{dtype='float64'} 1
@@ -289,12 +289,12 @@ def test_debugprint():
289289
290290
Inner graphs:
291291
292-
Elemwise{Composite}
293-
← add
294-
├─ <float64>
292+
Elemwise{Composite{(i2 + (i0 - i1))}}
293+
← add 'o0'
294+
├─ i2
295295
└─ sub
296-
├─ <float64>
297-
└─ <float64>
296+
├─ i0
297+
└─ i1
298298
"""
299299
).lstrip()
300300

0 commit comments

Comments
 (0)