Skip to content

Commit 87f01ca

Browse files
jerome-habanapre-commit-ci[bot]Bordakaushikb11
authored andcommitted
Break hpu graphs into two for better performance (#14656)
Signed-off-by: Jerome <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Kaushik B <[email protected]>
1 parent 39f899a commit 87f01ca

File tree

3 files changed

+38
-8
lines changed

3 files changed

+38
-8
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1010
### Fixed
1111

1212
- Fixed the availability check for the neptune-client package ([#14714](https://github.com/Lightning-AI/lightning/pull/14714))
13+
- Break HPU Graphs into two parts (forward + backward as one and optimizer as another) for better performance ([#14656](https://github.com/Lightning-AI/lightning/pull/14656))
1314

1415

1516
## [1.7.6] - 2022-09-13

src/pytorch_lightning/strategies/hpu_parallel.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,11 @@
1313
# limitations under the License.
1414
import logging
1515
import os
16-
from typing import Any, Callable, Dict, List, Optional
16+
from typing import Any, Callable, Dict, List, Optional, Union
1717

1818
import torch.distributed
19+
from torch.nn import Module
20+
from torch.optim.optimizer import Optimizer
1921

2022
import pytorch_lightning as pl
2123
from pytorch_lightning.overrides import LightningDistributedModule
@@ -137,10 +139,22 @@ def broadcast(self, obj: object, src: int = 0) -> object: # type: ignore
137139
broadcast_object_list(obj, src, group=_group.WORLD)
138140
return obj[0]
139141

140-
def training_step_end(self, step_output: STEP_OUTPUT) -> STEP_OUTPUT:
141-
# Break lazy accumulation of graph after every step
142+
def on_after_backward(self) -> None:
143+
# Break lazy accumulation of graph after fwd+bwd
142144
htcore.mark_step()
143-
return step_output
145+
146+
def optimizer_step(
147+
self,
148+
optimizer: Optimizer,
149+
opt_idx: int,
150+
closure: Callable[[], Any],
151+
model: Optional[Union["pl.LightningModule", Module]] = None,
152+
**kwargs: Any,
153+
) -> Any:
154+
optimizer_output = super().optimizer_step(optimizer, opt_idx, closure, model, **kwargs)
155+
# Break lazy accumulation of graph after optimizer
156+
htcore.mark_step()
157+
return optimizer_output
144158

145159
def validation_step_end(self, step_output: STEP_OUTPUT) -> STEP_OUTPUT:
146160
# Break lazy accumulation of graph after every step

src/pytorch_lightning/strategies/single_hpu.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Dict, Optional
15+
from typing import Any, Callable, Dict, Optional, Union
16+
17+
from torch.nn import Module
18+
from torch.optim.optimizer import Optimizer
1619

1720
import pytorch_lightning as pl
1821
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
@@ -78,10 +81,22 @@ def setup_optimizers(self, trainer: "pl.Trainer") -> None:
7881
def model_to_device(self) -> None:
7982
self.model.to(self.root_device) # type: ignore
8083

81-
def training_step_end(self, step_output: STEP_OUTPUT) -> STEP_OUTPUT:
82-
# Break lazy accumulation of graph after every step
84+
def on_after_backward(self) -> None:
85+
# Break lazy accumulation of graph after fwd+bwd
8386
htcore.mark_step()
84-
return step_output
87+
88+
def optimizer_step(
89+
self,
90+
optimizer: Optimizer,
91+
opt_idx: int,
92+
closure: Callable[[], Any],
93+
model: Optional[Union["pl.LightningModule", Module]] = None,
94+
**kwargs: Any,
95+
) -> Any:
96+
optimizer_output = super().optimizer_step(optimizer, opt_idx, closure, model, **kwargs)
97+
# Break lazy accumulation of graph after optimizer
98+
htcore.mark_step()
99+
return optimizer_output
85100

86101
def validation_step_end(self, step_output: STEP_OUTPUT) -> STEP_OUTPUT:
87102
# Break lazy accumulation of graph after every step

0 commit comments

Comments
 (0)