|
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | | -from typing import Any, Optional, Union |
| 14 | +from typing import Any, Callable, Optional, Union |
15 | 15 |
|
16 | 16 | import torch |
17 | 17 |
|
@@ -66,53 +66,25 @@ def train(self): |
66 | 66 | results = self.train_or_test() |
67 | 67 | return results |
68 | 68 |
|
69 | | - def training_step(self, args): |
| 69 | + def _step(self, model_step: Callable, args): |
| 70 | + args[0] = self.to_device(args[0]) |
| 71 | + |
70 | 72 | if self.trainer.amp_backend == AMPType.NATIVE: |
71 | 73 | with torch.cuda.amp.autocast(): |
72 | | - output = self.__training_step(args) |
| 74 | + output = model_step(*args) |
73 | 75 | else: |
74 | | - output = self.__training_step(args) |
| 76 | + output = model_step(*args) |
75 | 77 |
|
76 | 78 | return output |
77 | 79 |
|
78 | | - def __training_step(self, args): |
79 | | - batch = args[0] |
80 | | - batch = self.to_device(batch) |
81 | | - args[0] = batch |
82 | | - output = self.trainer.model.training_step(*args) |
83 | | - return output |
| 80 | + def training_step(self, args): |
| 81 | + return self._step(self.trainer.model.training_step, args) |
84 | 82 |
|
85 | 83 | def validation_step(self, args): |
86 | | - if self.trainer.amp_backend == AMPType.NATIVE: |
87 | | - with torch.cuda.amp.autocast(): |
88 | | - output = self.__validation_step(args) |
89 | | - else: |
90 | | - output = self.__validation_step(args) |
91 | | - |
92 | | - return output |
93 | | - |
94 | | - def __validation_step(self, args): |
95 | | - batch = args[0] |
96 | | - batch = self.to_device(batch) |
97 | | - args[0] = batch |
98 | | - output = self.trainer.model.validation_step(*args) |
99 | | - return output |
| 84 | + return self._step(self.trainer.model.validation_step, args) |
100 | 85 |
|
101 | 86 | def test_step(self, args): |
102 | | - if self.trainer.amp_backend == AMPType.NATIVE: |
103 | | - with torch.cuda.amp.autocast(): |
104 | | - output = self.__test_step(args) |
105 | | - else: |
106 | | - output = self.__test_step(args) |
107 | | - |
108 | | - return output |
109 | | - |
110 | | - def __test_step(self, args): |
111 | | - batch = args[0] |
112 | | - batch = self.to_device(batch) |
113 | | - args[0] = batch |
114 | | - output = self.trainer.model.test_step(*args) |
115 | | - return output |
| 87 | + return self._step(self.trainer.model.test_step, args) |
116 | 88 |
|
117 | 89 | def to_device(self, batch): |
118 | 90 | gpu_id = 0 |
|
0 commit comments