13
13
# limitations under the License.
14
14
import logging
15
15
import os
16
- from typing import Dict , List , Optional
16
+ from typing import Any , Callable , Dict , List , Optional
17
17
18
- import torch
19
18
import torch .distributed
20
19
21
20
import pytorch_lightning as pl
22
21
from pytorch_lightning .overrides import LightningDistributedModule
23
22
from pytorch_lightning .overrides .torch_distributed import broadcast_object_list
23
+ from pytorch_lightning .plugins .environments .cluster_environment import ClusterEnvironment
24
24
from pytorch_lightning .plugins .io .checkpoint_plugin import CheckpointIO
25
25
from pytorch_lightning .plugins .io .hpu_plugin import HPUCheckpointIO
26
26
from pytorch_lightning .plugins .precision import PrecisionPlugin
@@ -45,9 +45,15 @@ def __init__(
45
45
self ,
46
46
accelerator : Optional ["pl.accelerators.accelerator.Accelerator" ] = None ,
47
47
parallel_devices : Optional [List [torch .device ]] = None ,
48
+ cluster_environment : Optional [ClusterEnvironment ] = None ,
48
49
checkpoint_io : Optional [CheckpointIO ] = None ,
49
50
precision_plugin : Optional [PrecisionPlugin ] = None ,
51
+ ddp_comm_state : Optional [object ] = None ,
52
+ ddp_comm_hook : Optional [Callable ] = None ,
53
+ ddp_comm_wrapper : Optional [Callable ] = None ,
54
+ model_averaging_period : Optional [int ] = None ,
50
55
process_group_backend : Optional [str ] = "hccl" ,
56
+ ** kwargs : Any ,
51
57
) -> None :
52
58
53
59
if not _HPU_AVAILABLE :
@@ -56,9 +62,15 @@ def __init__(
56
62
super ().__init__ (
57
63
accelerator = accelerator ,
58
64
parallel_devices = parallel_devices ,
65
+ cluster_environment = cluster_environment ,
59
66
checkpoint_io = checkpoint_io or HPUCheckpointIO (),
60
67
precision_plugin = precision_plugin ,
68
+ ddp_comm_state = ddp_comm_state ,
69
+ ddp_comm_hook = ddp_comm_hook ,
70
+ ddp_comm_wrapper = ddp_comm_wrapper ,
71
+ model_averaging_period = model_averaging_period ,
61
72
process_group_backend = process_group_backend ,
73
+ ** kwargs ,
62
74
)
63
75
64
76
def setup_environment (self ) -> None :
@@ -75,7 +87,7 @@ def setup_environment(self) -> None:
75
87
def determine_ddp_device_ids (self ) -> None :
76
88
return None
77
89
78
- def pre_configure_ddp (self ): # type: ignore
90
+ def _pre_configure_ddp (self ) -> None :
79
91
# if unset, default `find_unused_parameters` `True`
80
92
# Many models require setting this parameter to True, as there are corner cases
81
93
# when not all parameter backward hooks are fired by the autograd engine even if require_grad is set to True.
@@ -97,7 +109,7 @@ def configure_ddp(self) -> None:
97
109
# DDP does not accept static graph as param with torch < 1.11
98
110
if _TORCH_LESSER_EQUAL_1_10_2 :
99
111
log .detail (f"{ self .__class__ .__name__ } : configuring DistributedDataParallel" )
100
- self .pre_configure_ddp ()
112
+ self ._pre_configure_ddp ()
101
113
self .model = self ._setup_model (LightningDistributedModule (self .model )) # type: ignore
102
114
if self .root_device .type == "hpu" and self ._static_graph :
103
115
self ._model ._set_static_graph () # type: ignore
0 commit comments