Skip to content

Commit 8c442e4

Browse files
zeshengzongalbanD
authored andcommitted
Fix LBFGS warning convert a tensor with requires_grad=True to a scalar (pytorch#160389)
Fixes pytorch#160197 ## Test Result ```python In [1]: import warnings ...: warnings.simplefilter('error') ...: import torch ...: print(torch.__version__) ...: a, b = torch.rand((2, 32, 32)) ...: a.requires_grad_() ...: optimizer = torch.optim.LBFGS([a]) ...: loss_fn = lambda x, y: (x-y).pow(2).mean() ...: ...: def closure(): ...: optimizer.zero_grad() ...: loss = loss_fn(a, b) ...: loss.backward() ...: return loss ...: ...: for i in range(100): ...: optimizer.step(closure) ...: print(i, loss_fn(a, b)) ...: 2.9.0a0+gitf33f3f8 0 tensor(5.8066e-11, grad_fn=<MeanBackward0>) 1 tensor(5.8066e-11, grad_fn=<MeanBackward0>) 2 tensor(5.8066e-11, grad_fn=<MeanBackward0>) 3 tensor(5.8066e-11, grad_fn=<MeanBackward0>) 4 tensor(5.8066e-11, grad_fn=<MeanBackward0>) 5 tensor(5.8066e-11, grad_fn=<MeanBackward0>) 6 tensor(5.8066e-11, grad_fn=<MeanBackward0>) 7 tensor(5.8066e-11, grad_fn=<MeanBackward0>) 8 tensor(5.8066e-11, grad_fn=<MeanBackward0>) 9 tensor(5.8066e-11, grad_fn=<MeanBackward0>) 10 tensor(5.8066e-11, grad_fn=<MeanBackward0>) ... ``` ```bash pytest test/test_optim.py -vv ... test/test_optim.py::TestOptimRenewedCUDA::test_tensor_lr_num_dim_2_NAdam_cuda_float32 PASSED [2.7192s] [ 99%] test/test_optim.py::TestOptimRenewedCUDA::test_tensor_lr_num_dim_2_RAdam_cuda_float32 PASSED [2.5370s] [ 99%] test/test_optim.py::TestOptimRenewedCUDA::test_tensor_lr_num_dim_2_RMSprop_cuda_float32 PASSED [2.0190s] [ 99%] test/test_optim.py::TestOptimRenewedCUDA::test_tensor_lr_num_dim_2_Rprop_cuda_float32 PASSED [1.8554s] [ 99%] test/test_optim.py::TestOptimRenewedCUDA::test_tensor_lr_num_dim_2_SGD_cuda_float32 PASSED [2.0433s] [ 99%] test/test_optim.py::TestOptimRenewedCUDA::test_tensor_lr_num_dim_2_SparseAdam_cuda_float32 PASSED [1.1788s] [100%] ================== 1471 passed, 242 skipped in 2440.52s (0:40:40) ============ ``` Pull Request resolved: pytorch#160389 Approved by: https://github.com/janeyx99 Co-authored-by: albanD <[email protected]>
1 parent e34b6a0 commit 8c442e4

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

torch/optim/lbfgs.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,8 @@ def obj_func(x, t, d):
454454
# the reason we do this: in a stochastic setting,
455455
# no use to re-evaluate that function here
456456
with torch.enable_grad():
457-
loss = float(closure())
457+
loss = closure()
458+
loss = float(loss)
458459
flat_grad = self._gather_flat_grad()
459460
opt_cond = flat_grad.abs().max() <= tolerance_grad
460461
ls_func_evals = 1

0 commit comments

Comments
 (0)