Skip to content

Commit 7312d2f

Browse files
tchatonlexierule
authored andcommitted
[bugfix] Prevent a DDP failure using copy (#9239)
1 parent 7cefa86 commit 7312d2f

File tree

3 files changed

+11
-6
lines changed

3 files changed

+11
-6
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1010
- Fixed reduction using `self.log(sync_dict=True, reduce_fx={mean,max})` ([#9142](https://github.com/PyTorchLightning/pytorch-lightning/pull/9142))
1111
- Fixed not setting a default value for `max_epochs` if `max_time` was specified on the `Trainer` constructor ([#9072](https://github.com/PyTorchLightning/pytorch-lightning/pull/9072))
1212
- Fixed the CometLogger, no longer modifies the metrics in place. Instead creates a copy of metrics before performing any operations ([#9150](https://github.com/PyTorchLightning/pytorch-lightning/pull/9150))
13+
- Fixed `DDP` "CUDA error: initialization error" due to a `copy` instead of `deepcopy` on `ResultCollection` ([#9239](https://github.com/PyTorchLightning/pytorch-lightning/pull/9239))
1314

1415

1516
## [1.4.4] - 2021-08-24

pytorch_lightning/loops/batch/training_batch_loop.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,9 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
1514
from collections import OrderedDict
1615
from contextlib import contextmanager
17-
from copy import copy
16+
from copy import deepcopy
1817
from functools import partial, update_wrapper
1918
from typing import Any, Callable, Dict, Generator, List, Mapping, Optional, Tuple
2019

@@ -147,12 +146,12 @@ def advance(self, batch, batch_idx, dataloader_idx):
147146

148147
result = self._run_optimization(batch_idx, split_batch, opt_idx, optimizer)
149148
if result:
150-
self.batch_outputs[opt_idx].append(copy(result.training_step_output))
149+
self.batch_outputs[opt_idx].append(deepcopy(result.training_step_output))
151150
else:
152151
# in manual optimization, there is no looping over optimizers
153152
result = self._run_optimization(batch_idx, split_batch)
154153
if result:
155-
self.batch_outputs[0].append(copy(result.training_step_output))
154+
self.batch_outputs[0].append(deepcopy(result.training_step_output))
156155

157156
def teardown(self) -> None:
158157
# release memory

pytorch_lightning/trainer/connectors/logger_connector/result.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union
1818

1919
import torch
20+
from torch.functional import Tensor
2021
from torchmetrics import Metric
2122

2223
from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin
@@ -434,8 +435,12 @@ def log(
434435
) -> None:
435436
"""See :meth:`~pytorch_lightning.core.lightning.LightningModule.log`"""
436437
# no metrics should be logged with graphs
437-
if not enable_graph and isinstance(value, torch.Tensor):
438-
value = value.detach()
438+
if not enable_graph:
439+
440+
def detach_fn(tensor: Tensor) -> Tensor:
441+
return tensor.detach()
442+
443+
value = apply_to_collection(value, Tensor, detach_fn)
439444

440445
# move metrics to cpu on TPU.
441446
if isinstance(value, torch.Tensor) and value.device.type == "xla":

0 commit comments

Comments
 (0)