Skip to content

Commit 15f524b

Browse files
edpizziBordarohitgr7justusschockawaelchli
authored andcommitted
Avoid non-blocking GPU->CPU copies. (#11288)
Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Rohit Gupta <[email protected]> Co-authored-by: Justus Schock <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]>
1 parent 6a00066 commit 15f524b

File tree

2 files changed

+9
-4
lines changed

2 files changed

+9
-4
lines changed

CHANGELOG.md

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
99

1010
### Fixed
1111

12-
-
13-
1412
- Fixed `LightningCLI` race condition while saving the config ([#11199](https://github.com/PyTorchLightning/pytorch-lightning/pull/11199))
15-
1613
- Fixed data fetcher selection ([#11294](https://github.com/PyTorchLightning/pytorch-lightning/pull/11294))
14+
- Fixed a race condition that could result in incorrect (zero) values being observed in prediction writer callbacks ([#11288](https://github.com/PyTorchLightning/pytorch-lightning/pull/11288))
15+
1716
## [1.5.7] - 2021-12-21
1817

1918
### Fixed

pytorch_lightning/utilities/apply_func.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@
3434
Batch = type(None)
3535

3636

37+
_CPU_DEVICES = ("cpu", torch.device("cpu"))
38+
39+
3740
def to_dtype_tensor(
3841
value: Union[int, float, List[Union[int, float]]], dtype: torch.dtype, device: Union[str, torch.device]
3942
) -> torch.Tensor:
@@ -268,7 +271,10 @@ def batch_to(data: Any) -> Any:
268271
setattr(device_data, field, device_field)
269272
return device_data
270273

271-
kwargs = dict(non_blocking=True) if isinstance(data, torch.Tensor) else {}
274+
kwargs = {}
275+
# Don't issue non-blocking transfers to CPU
276+
if isinstance(data, torch.Tensor) and device not in _CPU_DEVICES:
277+
kwargs["non_blocking"] = True
272278
data_output = data.to(device, **kwargs)
273279
if data_output is not None:
274280
return data_output

0 commit comments

Comments
 (0)