diff --git a/detectron2/engine/train_loop.py b/detectron2/engine/train_loop.py index 738a69de94..3d583d0233 100644 --- a/detectron2/engine/train_loop.py +++ b/detectron2/engine/train_loop.py @@ -469,7 +469,12 @@ def __init__( ) if grad_scaler is None: - from torch.cuda.amp import GradScaler + if torch.__version__ < "2.4.0": + from torch.cuda.amp import GradScaler + else: + from torch.amp import GradScaler + from functools import partial + GradScaler = partial(GradScaler, device="cuda") grad_scaler = GradScaler() self.grad_scaler = grad_scaler @@ -482,7 +487,12 @@ def run_step(self): """ assert self.model.training, "[AMPTrainer] model was changed to eval mode!" assert torch.cuda.is_available(), "[AMPTrainer] CUDA is required for AMP training!" - from torch.cuda.amp import autocast + if torch.__version__ < "2.4.0": + from torch.cuda.amp import autocast + else: + from torch.amp import autocast + from functools import partial + autocast = partial(autocast, device_type="cuda") start = time.perf_counter() data = next(self._data_loader_iter) diff --git a/tests/layers/test_blocks.py b/tests/layers/test_blocks.py index 5a0488adbf..91edce5c64 100644 --- a/tests/layers/test_blocks.py +++ b/tests/layers/test_blocks.py @@ -24,7 +24,12 @@ def test_aspp(self): @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_frozen_batchnorm_fp16(self): - from torch.cuda.amp import autocast + if torch.__version__ < "2.4.0": + from torch.cuda.amp import autocast + else: + from torch.amp import autocast + from functools import partial + autocast = partial(autocast, device_type="cuda") C = 10 input = torch.rand(1, C, 10, 10).cuda() diff --git a/tests/modeling/test_model_e2e.py b/tests/modeling/test_model_e2e.py index 8c07e6856d..9ba76c5920 100644 --- a/tests/modeling/test_model_e2e.py +++ b/tests/modeling/test_model_e2e.py @@ -155,7 +155,12 @@ def test_roiheads_inf_nan_data(self): @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_autocast(self): - from torch.cuda.amp import autocast + if torch.__version__ < "2.4.0": + from torch.cuda.amp import autocast + else: + from torch.amp import autocast + from functools import partial + autocast = partial(autocast, device_type="cuda") inputs = [{"image": torch.rand(3, 100, 100)}] self.model.eval() @@ -195,7 +200,12 @@ def test_inf_nan_data(self): @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_autocast(self): - from torch.cuda.amp import autocast + if torch.__version__ < "2.4.0": + from torch.cuda.amp import autocast + else: + from torch.amp import autocast + from functools import partial + autocast = partial(autocast, device_type="cuda") inputs = [{"image": torch.rand(3, 100, 100)}] self.model.eval()