Skip to content
This repository was archived by the owner on Nov 17, 2025. It is now read-only.

Commit 4134b22

Browse files
Add missing argument to ShapeFeature.get_shape
1 parent 1706bb3 commit 4134b22

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

aeppl/printing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,7 @@ def process_shape_info(cls, output: Variable, pstate: Optional[PrinterStateType]
380380
try:
381381
old_precedence = getattr(pstate, "precedence", None)
382382
pstate.precedence = new_precedence
383-
_s_i_out = shape_feature.get_shape(output, i)
383+
_s_i_out = shape_feature.get_shape(pstate.fgraph, output, i)
384384

385385
if not isinstance(_s_i_out, (Constant, TensorVariable)):
386386
s_i_out = pstate.pprinter.process(_s_i_out, pstate)

aeppl/scan.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,8 @@ def update_scan_value_vars(
439439
# graph, so we use the shape feature to (hopefully) get the shape
440440
# without the entire `Scan` itself.
441441
full_out_shape = tuple(
442-
fgraph.shape_feature.get_shape(full_out, i) for i in range(full_out.ndim)
442+
fgraph.shape_feature.get_shape(fgraph, full_out, i)
443+
for i in range(full_out.ndim)
443444
)
444445
new_val_var = at.empty(full_out_shape, dtype=full_out.dtype)
445446

0 commit comments

Comments
 (0)