Skip to content

Commit e068af9

Browse files
Ampt (#2572)
* remove grad scaling tpu * remove grad scaling tpu * remove grad scaling tpu * remove grad scaling tpu * remove grad scaling tpu * remove grad scaling tpu * remove grad scaling tpu * remove grad scaling tpu * remove grad scaling tpu
1 parent c197b74 commit e068af9

File tree

6 files changed

+15
-10
lines changed

6 files changed

+15
-10
lines changed

pytorch_lightning/core/memory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ def _forward_example_input(self) -> None:
209209
input_ = model.example_input_array
210210
input_ = model.transfer_batch_to_device(input_, model.device)
211211

212-
if trainer is not None and trainer.use_amp:
212+
if trainer is not None and trainer.use_amp and not trainer.use_tpu:
213213
if NATIVE_AMP_AVALAIBLE:
214214
model.forward = torch.cuda.amp.autocast()(model.forward)
215215

pytorch_lightning/trainer/distrib_parts.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -240,14 +240,14 @@ def dp_train(self, model):
240240

241241
# hack forward to do autocast for the user
242242
model_autocast_original_forward = model.forward
243-
if self.use_amp and NATIVE_AMP_AVALAIBLE:
243+
if self.use_amp and NATIVE_AMP_AVALAIBLE and not self.use_tpu:
244244
# wrap the user's forward in autocast and give it back at the end
245245
model.forward = torch.cuda.amp.autocast()(model.forward)
246246

247247
# TODO: remove with dropping NVIDIA AMP support
248248
# check for this bug (amp + dp + !01 doesn't work)
249249
# https://github.com/NVIDIA/apex/issues/227
250-
if self.use_dp and self.use_amp and not NATIVE_AMP_AVALAIBLE:
250+
if self.use_dp and self.use_amp and not NATIVE_AMP_AVALAIBLE and not self.use_tpu:
251251
if self.amp_level == 'O2':
252252
raise MisconfigurationException(
253253
f'Amp level {self.amp_level} with DataParallel is not supported.'

pytorch_lightning/trainer/evaluation_loop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ def _evaluate(
286286
# -----------------
287287
# RUN EVALUATION STEP
288288
# -----------------
289-
if self.use_amp and NATIVE_AMP_AVALAIBLE:
289+
if self.use_amp and NATIVE_AMP_AVALAIBLE and not self.use_tpu:
290290
with torch.cuda.amp.autocast():
291291
output = self.evaluation_forward(model, batch, batch_idx, dataloader_idx, test_mode)
292292
else:

pytorch_lightning/trainer/trainer.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1118,7 +1118,7 @@ def run_pretrain_routine(self, model: LightningModule):
11181118
self.copy_trainer_model_properties(ref_model)
11191119

11201120
# init amp. Must be done here instead of __init__ to allow ddp to work
1121-
if NATIVE_AMP_AVALAIBLE and self.precision == 16:
1121+
if NATIVE_AMP_AVALAIBLE and self.precision == 16 and not self.use_tpu:
11221122
self.scaler = torch.cuda.amp.GradScaler()
11231123

11241124
# log hyper-parameters
@@ -1300,6 +1300,11 @@ def __test_using_best_weights(self, ckpt_path, test_dataloaders):
13001300
if ckpt_path == 'best':
13011301
ckpt_path = self.checkpoint_callback.best_model_path
13021302

1303+
if len(ckpt_path) == 0:
1304+
rank_zero_warn(f'.test() found no path for the best weights, {ckpt_path}. Please '
1305+
f'specify a path for a checkpoint .test(ckpt_path=PATH)')
1306+
return {}
1307+
13031308
ckpt = torch.load(ckpt_path, map_location=lambda storage, loc: storage)
13041309
model.load_state_dict(ckpt['state_dict'])
13051310

pytorch_lightning/trainer/training_io.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
358358
checkpoint['lr_schedulers'] = lr_schedulers
359359

360360
# save native amp scaling
361-
if self.use_amp and NATIVE_AMP_AVALAIBLE:
361+
if self.use_amp and NATIVE_AMP_AVALAIBLE and not self.use_tpu:
362362
checkpoint['native_amp_scaling_state'] = self.scaler.state_dict()
363363

364364
# add the module_arguments and state_dict from the model

pytorch_lightning/trainer/training_loop.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -702,7 +702,7 @@ def run_batch_backward_pass(self, split_batch, batch_idx, opt_idx, optimizer):
702702
# ------------------
703703
# CLIP GRADS
704704
# ------------------
705-
if self.use_amp and NATIVE_AMP_AVALAIBLE:
705+
if self.use_amp and NATIVE_AMP_AVALAIBLE and not self.use_tpu:
706706
self.scaler.unscale_(optimizer)
707707
self.clip_gradients()
708708

@@ -750,7 +750,7 @@ def call_optimizer_step(self, optimizer, opt_idx, batch_idx, split_batch):
750750
using_native_amp=native_amp)
751751

752752
# in native 16-bit we need to update scaler after optimizer step
753-
if self.use_amp and NATIVE_AMP_AVALAIBLE:
753+
if self.use_amp and NATIVE_AMP_AVALAIBLE and not self.use_tpu:
754754
self.scaler.update()
755755

756756
# model hook
@@ -767,7 +767,7 @@ def optimizer_closure(self, split_batch, batch_idx, opt_idx, optimizer, hiddens)
767767
# FORWARD
768768
# ---------------------------
769769
with self.profiler.profile('model_forward'):
770-
if self.use_amp and NATIVE_AMP_AVALAIBLE:
770+
if self.use_amp and NATIVE_AMP_AVALAIBLE and not self.use_tpu:
771771
with torch.cuda.amp.autocast():
772772
training_step_output = self.training_forward(split_batch, batch_idx,
773773
opt_idx, hiddens)
@@ -817,7 +817,7 @@ def optimizer_closure(self, split_batch, batch_idx, opt_idx, optimizer, hiddens)
817817
model_ref.backward(self, closure_loss, optimizer, opt_idx)
818818

819819
# exit amp context
820-
if self.precision == 16 and not NATIVE_AMP_AVALAIBLE:
820+
if self.precision == 16 and not NATIVE_AMP_AVALAIBLE and not self.on_tpu:
821821
a, b, c = None, None, None
822822
error = context.__exit__(a, b, c)
823823
if error:

0 commit comments

Comments
 (0)