Skip to content

Commit 777dd53

Browse files
committed
update collate_param function call in dict_learning
1 parent 32572d0 commit 777dd53

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

torch_staintools/functional/optimization/dict_learning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def dict_learning(x: torch.Tensor,
202202
# initialize
203203
z0 = validate_code(algorithm, init, z0=None, x=x, weight=weight, rng=rng)
204204
assert z0 is not None
205-
lr, alpha, tol = collate_params(z0, x, lr, weight, alpha, tol)
205+
lr, alpha, tol = collate_params(x, lr, weight, alpha, tol)
206206
return dict_learning_loop(x, z0, weight, alpha, algorithm, lambd_ridge=lambd_ridge,
207207
steps=steps, rng=rng, init=init, lr=lr, maxiter=maxiter, tol=tol)
208208

0 commit comments

Comments
 (0)