Skip to content

Commit 85e0fc8

Browse files
committed
refactor!: s/n_iter/max_iter
1 parent 667a89f commit 85e0fc8

File tree

5 files changed

+12
-12
lines changed

5 files changed

+12
-12
lines changed

examples/GPU/example_pinv.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def process_cb_results(cb_results):
9898
for optim in OPTIM:
9999
image, iter_cb = nufft.pinv_solver(
100100
kspace_data=kspace_data_gpu,
101-
n_iter=1000,
101+
max_iter=1000,
102102
callback=mixed_cb,
103103
optim=optim,
104104
)
@@ -180,7 +180,7 @@ def process_cb_results(cb_results):
180180
for optim in OPTIM:
181181
image, iter_cb = nufft.pinv_solver(
182182
kspace_data=kspace_data_gpu,
183-
n_iter=1000,
183+
max_iter=1000,
184184
callback=mixed_cb,
185185
damp=0.1,
186186
optim=optim,

src/mrinufft/extras/optim.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
If provided, a callback function will be called at the end of each
2828
iteration with the current estimate. It should have the following signature
2929
``callback(operator, kspace_data, damp, x0)``
30-
n_iter: int, optional
30+
max_iter: int, optional
3131
Maximum number of iterations. Default is 100.
3232
progressbar: bool, optional
3333
If True (default) display a progress bar to track iterations.
@@ -230,7 +230,7 @@ def lsqr(
230230
atol: float = 1e-6,
231231
btol: float = 1e-6,
232232
conlim: float = 1e8,
233-
n_iter: int = 100,
233+
max_iter: int = 100,
234234
x0: NDArray | None = None,
235235
x_init: NDArray | None = None,
236236
callback: Callable | None = None,
@@ -332,7 +332,7 @@ def lsqr(
332332
sn2 = 0.0
333333
istop = 0
334334
callback_returns = []
335-
for _ in tqdm(range(n_iter), disable=not progressbar):
335+
for _ in tqdm(range(max_iter), disable=not progressbar):
336336
u *= -_bc_left(alpha, u)
337337
u += operator.op(v).reshape(operator.ksp_full_shape)
338338
beta = norm_batched(u)
@@ -463,7 +463,7 @@ def lsmr(
463463
atol: float = 1e-6,
464464
btol: float = 1e-6,
465465
conlim: float = 1e8,
466-
n_iter: int = 100,
466+
max_iter: int = 100,
467467
x0: NDArray | None = None,
468468
x_init: NDArray | None = None,
469469
callback: Callable | None = None,
@@ -612,7 +612,7 @@ def A(x):
612612
normr = beta
613613

614614
callback_returns = []
615-
for _ in tqdm(range(n_iter), disable=not progressbar):
615+
for _ in tqdm(range(max_iter), disable=not progressbar):
616616

617617
u *= -_bc_left(alpha, u)
618618
u += A(v)
@@ -751,7 +751,7 @@ def cg(
751751
damp: float = 0.0,
752752
x0: NDArray | None = None,
753753
x_init: NDArray | None = None,
754-
n_iter: int = 10,
754+
max_iter: int = 10,
755755
tol: float = 1e-4,
756756
progressbar: bool = True,
757757
callback: Callable | None = None,
@@ -796,7 +796,7 @@ def cg(
796796
image = image - velocity
797797

798798
callbacks_results = []
799-
for _ in tqdm(range(n_iter), disable=not progressbar):
799+
for _ in tqdm(range(max_iter), disable=not progressbar):
800800
grad_new = operator.data_consistency(image, kspace_data).reshape(
801801
operator.img_full_shape
802802
)

tests/operators/test_batch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ def test_pinv_solver(operator, array_interface, image_data, kspace_data, optim):
235235
from mrinufft.extras.optim import loss_l2_reg
236236

237237
img, res = operator.pinv_solver(
238-
kspace_data, optim=optim, callback=loss_l2_reg, n_iter=5
238+
kspace_data, optim=optim, callback=loss_l2_reg, max_iter=5
239239
)
240240

241241
assert img.shape == operator.img_full_shape

tests/operators/test_density_for_op.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def test_pipe(backend, traj, shape, osf):
3333
pytest.skip("OSF < 2 not supported for tensorflow.")
3434
if osf == 1 and "finufft" in backend:
3535
pytest.skip("cufinufft and finufft dont support OSF=1")
36-
result = pipe(traj, shape, backend=backend, osf=osf, num_iterations=10)
36+
result = pipe(traj, shape, backend=backend, osf=osf, max_iter=10)
3737
if backend == "cufinufft":
3838
result = result.get()
3939
result = result / np.mean(result)

tests/operators/test_optim.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,6 @@ def test_pinv(operator, image_data, optim):
6262
kspace_nufft = operator.op(image_data).squeeze()
6363

6464
_, residuals = operator.pinv_solver(
65-
kspace_nufft, optim=optim, n_iter=10, callback=loss_l2_reg
65+
kspace_nufft, optim=optim, max_iter=10, callback=loss_l2_reg
6666
)
6767
assert residuals[-1] <= residuals[0]

0 commit comments

Comments
 (0)