Skip to content

Commit 1e5c007

Browse files
committed
Fix gradient norm check to match new parameter behavivor
1 parent a528305 commit 1e5c007

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

stochman/geodesic.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,12 +89,12 @@ def closure():
8989

9090
for k in range(max_iter):
9191
opt.step(closure=closure)
92-
if torch.max(torch.abs(curve.parameters.grad)) < thresh:
92+
max_grad = max([p.grad.abs().max() for p in curve.parameters()])
93+
if max_grad < thresh:
9394
break
9495
# if k % (max_iter // 10) == 0:
9596
# curve.constant_speed(manifold)
96-
max_grad = torch.max(torch.abs(curve.parameters.grad))
97-
curve.constant_speed(manifold)
97+
#curve.constant_speed(manifold)
9898
return max_grad < thresh
9999

100100

0 commit comments

Comments
 (0)