diff --git a/README.md b/README.md index 14b317c..b362898 100644 --- a/README.md +++ b/README.md @@ -96,6 +96,8 @@ for i in range(1, EPOCHS+1): running_loss += loss.item() * features_.size(0) train_loss = running_loss / len(dl_train.dataset) print("EPOCH: {}, TRAIN LOSS: {}".format(i, train_loss)) +# set grad to None to avoid memory leak +mw.end() ``` Note that on the first step of the train loop PyTorch will return the following warning: ``` diff --git a/src/gradient_descent_the_ultimate_optimizer/gdtuo.py b/src/gradient_descent_the_ultimate_optimizer/gdtuo.py index 8c8ea98..e99cb35 100644 --- a/src/gradient_descent_the_ultimate_optimizer/gdtuo.py +++ b/src/gradient_descent_the_ultimate_optimizer/gdtuo.py @@ -51,6 +51,14 @@ def step(self): ''' Update parameters ''' pass + def end(self): + if hasattr(self, "all_params_with_gradients"): + for param in self.all_params_with_gradients: + param.grad = None + self.all_params_with_gradients.clear() + if hasattr(self, "optimizer"): + self.optimizer.end() + class NoOpOptimizer(Optimizable): ''' NoOpOptimizer sits on top of a stack, and does not affect what lies below.