How to delete all the gradients in between operations #13728
-
Hi! I am currently working on a project where, for a given trained model, I perform inference batches of inputs and compute (and store) the gradients of the output with respect to the inputs. def test_step(self, batch,batch_idx):
with torch.set_grad_enabled(True):
gradients_list=[]
for batch_of_inputs in batches:
batch_of_inputs.requires_grad_()
output=self(batch_of_inputs)
gradients = torch.autograd.grad(
outputs=output,
inputs=batch_of_inputs,
grad_outputs=torch.ones_like(output),
retain_graph=False,
)
gradients_list.append(gradients.detach_()) The thing is that the used memory increases and increases until OOM error rises. I have tried to use |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 8 replies
-
if the OOM error is on GPU, you can move your gradients from GPU to CPU while storing them to release some GPU memory. |
Beta Was this translation helpful? Give feedback.
if the OOM error is on GPU, you can move your gradients from GPU to CPU while storing them to release some GPU memory.