Skip to content

Commit bcbba3b

Browse files
authored
Simplify GPU and TPU accelerator (#5024)
1 parent 90d1d9f commit bcbba3b

File tree

2 files changed

+18
-54
lines changed

2 files changed

+18
-54
lines changed

pytorch_lightning/accelerators/gpu_accelerator.py

Lines changed: 10 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Any, Optional, Union
14+
from typing import Any, Callable, Optional, Union
1515

1616
import torch
1717

@@ -66,53 +66,25 @@ def train(self):
6666
results = self.train_or_test()
6767
return results
6868

69-
def training_step(self, args):
69+
def _step(self, model_step: Callable, args):
70+
args[0] = self.to_device(args[0])
71+
7072
if self.trainer.amp_backend == AMPType.NATIVE:
7173
with torch.cuda.amp.autocast():
72-
output = self.__training_step(args)
74+
output = model_step(*args)
7375
else:
74-
output = self.__training_step(args)
76+
output = model_step(*args)
7577

7678
return output
7779

78-
def __training_step(self, args):
79-
batch = args[0]
80-
batch = self.to_device(batch)
81-
args[0] = batch
82-
output = self.trainer.model.training_step(*args)
83-
return output
80+
def training_step(self, args):
81+
return self._step(self.trainer.model.training_step, args)
8482

8583
def validation_step(self, args):
86-
if self.trainer.amp_backend == AMPType.NATIVE:
87-
with torch.cuda.amp.autocast():
88-
output = self.__validation_step(args)
89-
else:
90-
output = self.__validation_step(args)
91-
92-
return output
93-
94-
def __validation_step(self, args):
95-
batch = args[0]
96-
batch = self.to_device(batch)
97-
args[0] = batch
98-
output = self.trainer.model.validation_step(*args)
99-
return output
84+
return self._step(self.trainer.model.validation_step, args)
10085

10186
def test_step(self, args):
102-
if self.trainer.amp_backend == AMPType.NATIVE:
103-
with torch.cuda.amp.autocast():
104-
output = self.__test_step(args)
105-
else:
106-
output = self.__test_step(args)
107-
108-
return output
109-
110-
def __test_step(self, args):
111-
batch = args[0]
112-
batch = self.to_device(batch)
113-
args[0] = batch
114-
output = self.trainer.model.test_step(*args)
115-
return output
87+
return self._step(self.trainer.model.test_step, args)
11688

11789
def to_device(self, batch):
11890
gpu_id = 0

pytorch_lightning/accelerators/tpu_accelerator.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import io
1515
import os
1616
import re
17-
from typing import Any, Optional, Union
17+
from typing import Any, Callable, Optional, Union
1818

1919
import torch
2020
import torch.multiprocessing as mp
@@ -145,26 +145,18 @@ def tpu_train_in_process(self, tpu_core_idx: int, model: LightningModule, traine
145145
# persist info in spawn
146146
self.transfer_distrib_spawn_state_on_fit_end(model, mp_queue, results)
147147

148+
def _step(self, model_step: Callable, args):
149+
args[0] = self.to_device(args[0])
150+
return model_step(*args)
151+
148152
def training_step(self, args):
149-
batch = args[0]
150-
batch = self.to_device(batch)
151-
args[0] = batch
152-
output = self.trainer.model.training_step(*args)
153-
return output
153+
return self._step(self.trainer.model.training_step, args)
154154

155155
def validation_step(self, args):
156-
batch = args[0]
157-
batch = self.to_device(batch)
158-
args[0] = batch
159-
output = self.trainer.model.validation_step(*args)
160-
return output
156+
return self._step(self.trainer.model.validation_step, args)
161157

162158
def test_step(self, args):
163-
batch = args[0]
164-
batch = self.to_device(batch)
165-
args[0] = batch
166-
output = self.trainer.model.test_step(*args)
167-
return output
159+
return self._step(self.trainer.model.test_step, args)
168160

169161
def process_dataloader(self, dataloader):
170162
device = xm.xla_device(self.trainer.tpu_id)

0 commit comments

Comments
 (0)