Skip to content

Commit 7d61e5c

Browse files
author
neutrino
committed
Fix memory leak in zeroGradients().
1 parent 963332d commit 7d61e5c

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxGradientCollector.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,10 @@ public void zeroGradients() {
123123
NDManager systemManager = MxNDManager.getSystemManager();
124124
for (NDArray array : systemManager.getManagedArrays()) {
125125
if (array.hasGradient()) {
126-
array.getGradient().subi(array.getGradient());
126+
// To prevent memory leak we must close gradient after use.
127+
try (NDArray gradient = array.getGradient()) {
128+
gradient.subi(gradient);
129+
}
127130
}
128131
}
129132
}

engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtGradientCollector.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,9 @@ private void backward(NDArray target, NDArray grad, boolean keepGraph, boolean c
7575
public void zeroGradients() {
7676
NDManager systemManager = PtNDManager.getSystemManager();
7777
for (NDArray array : systemManager.getManagedArrays()) {
78-
if (array.hasGradient()) {
79-
array.getGradient().subi(array.getGradient());
78+
// To prevent memory leak we must close gradient after use.
79+
try (NDArray gradient = array.getGradient()) {
80+
gradient.subi(gradient);
8081
}
8182
}
8283
}

0 commit comments

Comments
 (0)