Skip to content

Commit 0121ec3

Browse files
committed
1 parent c2f2fce commit 0121ec3

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

pytensor/optimise/fixed_point.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,9 @@ def fixed_point_solver(
4545
):
4646
args = [pt.as_tensor(arg) for arg in args]
4747

48-
def _scan_step(x, *args, func, solver, tol):
48+
def _scan_step(x, n_steps, *args, func, solver, tol):
4949
x, is_converged = solver(x, *args, func=func, tol=tol)
50-
return x, until(is_converged)
50+
return (x, n_steps + 1), until(is_converged)
5151

5252
partial_step = partial(
5353
_scan_step,
@@ -58,13 +58,13 @@ def _scan_step(x, *args, func, solver, tol):
5858

5959
outputs, updates = pytensor.scan(
6060
partial_step,
61-
outputs_info=[x0],
61+
outputs_info=[x0, pt.constant(0, dtype="int64")],
6262
non_sequences=list(args),
6363
n_steps=max_iter,
6464
strict=True,
6565
)
6666

67-
x_trace = outputs
67+
x_trace, n_steps_trace = outputs
6868
assert not updates
6969

70-
return x_trace[-1], x_trace.shape[0]
70+
return x_trace[-1], n_steps_trace[-1]

0 commit comments

Comments
 (0)