Skip to content

Commit ed32581

Browse files
committed
Added print_shape option to debugprint and simplify __str__ logic in TensorType
1 parent 5d4e9e0 commit ed32581

File tree

2 files changed

+21
-13
lines changed

2 files changed

+21
-13
lines changed

pytensor/printing.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def debugprint(
8989
| Sequence[Variable | Apply | Function | FunctionGraph],
9090
depth: int = -1,
9191
print_type: bool = False,
92+
print_shape: bool = False,
9293
file: Literal["str"] | TextIO | None = None,
9394
id_type: IDTypesType = "CHAR",
9495
stop_on_name: bool = False,
@@ -123,6 +124,8 @@ def debugprint(
123124
Print graph to this depth (``-1`` for unlimited).
124125
print_type
125126
If ``True``, print the `Type`\s of each `Variable` in the graph.
127+
print_shape
128+
If ``True``, print the shape of each `Variable` in the graph.
126129
file
127130
When `file` extends `TextIO`, print to it; when `file` is
128131
equal to ``"str"``, return a string; when `file` is ``None``, print to
@@ -265,6 +268,7 @@ def debugprint(
265268
depth=depth,
266269
done=done,
267270
print_type=print_type,
271+
print_shape=print_shape,
268272
file=_file,
269273
id_type=id_type,
270274
inner_graph_ops=inner_graph_vars,
@@ -295,6 +299,7 @@ def debugprint(
295299
depth=depth,
296300
done=done,
297301
print_type=print_type,
302+
print_shape=print_shape,
298303
file=_file,
299304
topo_order=topo_order,
300305
id_type=id_type,
@@ -365,6 +370,7 @@ def debugprint(
365370
depth=depth,
366371
done=done,
367372
print_type=print_type,
373+
print_shape=print_shape,
368374
file=_file,
369375
id_type=id_type,
370376
inner_graph_ops=inner_graph_vars,
@@ -387,6 +393,7 @@ def debugprint(
387393
depth=depth,
388394
done=done,
389395
print_type=print_type,
396+
print_shape=print_shape,
390397
file=_file,
391398
id_type=id_type,
392399
stop_on_name=stop_on_name,
@@ -421,6 +428,7 @@ def debugprint(
421428
depth=depth,
422429
done=done,
423430
print_type=print_type,
431+
print_shape=print_shape,
424432
file=_file,
425433
id_type=id_type,
426434
stop_on_name=stop_on_name,
@@ -452,6 +460,7 @@ def _debugprint(
452460
depth: int = -1,
453461
done: dict[Literal["output"] | Variable | Apply, str] | None = None,
454462
print_type: bool = False,
463+
print_shape: bool = False,
455464
file: TextIO = sys.stdout,
456465
print_destroy_map: bool = False,
457466
print_view_map: bool = False,
@@ -484,6 +493,8 @@ def _debugprint(
484493
See `debugprint`.
485494
print_type
486495
See `debugprint`.
496+
print_shape
497+
See `debugprint`.
487498
file
488499
File-like object to which to print.
489500
print_destroy_map
@@ -532,6 +543,11 @@ def _debugprint(
532543
else:
533544
type_str = ""
534545

546+
if print_shape:
547+
shape_str = f" shape={str(var.type.shape).replace("None", "?")}"
548+
else:
549+
shape_str = ""
550+
535551
if prefix_child is None:
536552
prefix_child = prefix
537553

@@ -612,7 +628,7 @@ def get_id_str(
612628
if is_inner_graph_header:
613629
var_output = f"{prefix}{node.op}{id_str}{destroy_map_str}{view_map_str}{o}"
614630
else:
615-
var_output = f"{prefix}{node.op}{output_idx}{id_str}{type_str}{var_name}{destroy_map_str}{view_map_str}{o}{data}"
631+
var_output = f"{prefix}{node.op}{output_idx}{id_str}{type_str}{shape_str}{var_name}{destroy_map_str}{view_map_str}{o}{data}"
616632

617633
if print_op_info and node not in op_information:
618634
op_information.update(op_debug_information(node.op, node))
@@ -662,6 +678,7 @@ def get_id_str(
662678
depth=depth - 1,
663679
done=_done,
664680
print_type=print_type,
681+
print_shape=print_shape,
665682
file=file,
666683
topo_order=topo_order,
667684
id_type=id_type,
@@ -692,7 +709,7 @@ def get_id_str(
692709
else:
693710
data = ""
694711

695-
var_output = f"{prefix}{var}{id_str}{type_str}{data}"
712+
var_output = f"{prefix}{var}{id_str}{type_str}{shape_str}{data}"
696713

697714
if print_op_info and var.owner and var.owner not in op_information:
698715
op_information.update(op_debug_information(var.owner.op, var.owner))

pytensor/tensor/type.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -399,22 +399,13 @@ def __str__(self):
399399
else:
400400
shape = self.shape
401401
len_shape = len(shape)
402-
403-
def shape_str(s):
404-
if s is None:
405-
return "?"
406-
else:
407-
return str(s)
408-
409-
formatted_shape = ", ".join(shape_str(s) for s in shape)
410-
if len_shape == 1:
411-
formatted_shape += ","
402+
formatted_shape = str(shape).replace("None", "?")
412403

413404
if len_shape > 2:
414405
name = f"Tensor{len_shape}"
415406
else:
416407
name = ("Scalar", "Vector", "Matrix")[len_shape]
417-
return f"{name}({self.dtype}, shape=({formatted_shape}))"
408+
return f"{name}({self.dtype}, shape={formatted_shape})"
418409

419410
def __repr__(self):
420411
return f"TensorType({self.dtype}, shape={self.shape})"

0 commit comments

Comments
 (0)