Skip to content

Commit ed5e1b7

Browse files
authored
Merge pull request #38 from OpenBMB/UPD_0708
Upd 0708
2 parents 3ed3e3c + 05a4804 commit ed5e1b7

File tree

7 files changed

+46
-20
lines changed

7 files changed

+46
-20
lines changed

bmtrain/block_layer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, Iterable, Iterator, Tuple, Union
1+
from typing import Dict, Iterable, Iterator, Union
22

33

44
from .global_var import config
@@ -8,7 +8,6 @@
88
from .parameter import DistributedParameter, OpAllGather
99
from .checkpointing import ScopedTensorInspectorContext
1010
from . import debug
11-
from torch.nn.modules.module import _addindent
1211
import copy
1312

1413
def round_up(x, d):
@@ -331,7 +330,8 @@ def __init__(self, inner_module : torch.nn.Module):
331330

332331
# calc total number of parameters
333332
for name, param in ordered_parameters:
334-
assert isinstance(param, DistributedParameter), "All parameters in checkpoint block must be DistributedParameter."
333+
if not isinstance(param, DistributedParameter):
334+
raise ValueError("All parameters in checkpoint block must be DistributedParameter.")
335335

336336
storage_type = storage_type_cuda(param.storage_type())
337337
kw_name = _get_param_kw(param)
@@ -464,7 +464,7 @@ def state_dict(self, destination=None, prefix='', keep_vars=False):
464464
# gather here
465465
with torch.no_grad():
466466
with CheckpointBlockContext(self):
467-
return self._module.state_dict(destination, prefix, keep_vars)
467+
return self._module.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
468468

469469
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
470470
missing_keys, unexpected_keys, error_msgs):

bmtrain/distributed/ops.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,16 @@ def all_gather(x : torch.Tensor):
3232
Returns:
3333
torch.Tensor: The gathered tensor of shape (world_size, ...).
3434
"""
35+
if not config["initialized"]:
36+
raise RuntimeError("BMTrain is not initialized")
37+
3538
assert x.is_cuda
3639
return OpAllGather.apply(x)
3740

3841
class OpAllReduce(torch.autograd.Function):
3942
@staticmethod
4043
def forward(ctx, input : torch.Tensor, op : str):
41-
if not input.contiguous():
44+
if not input.is_contiguous():
4245
input = input.contiguous()
4346
if input.storage_offset() != 0 or input.storage().size() != input.numel():
4447
input = input.clone()
@@ -82,6 +85,9 @@ def all_reduce(x : torch.Tensor, op : str = "sum"):
8285
torch.Tensor: The reduced tensor of shape (...).
8386
8487
"""
88+
if not config["initialized"]:
89+
raise RuntimeError("BMTrain is not initialized")
90+
8591
assert x.is_cuda
8692
return OpAllReduce.apply(x, op)
8793

bmtrain/global_var.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,11 @@ class ConfigMap(TypedDict):
1414
loss_scale_steps : int
1515

1616
gradient_inspect : bool
17+
initialized : bool
1718

1819
comm : 'NCCLCommunicator'
1920

20-
config = ConfigMap()
21+
config = ConfigMap(rank=0, local_rank=0, world_size=1, initialized=False)
2122

2223
def rank():
2324
"""

bmtrain/init.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from .utils import print_dict
77
from .global_var import config
88
from . import nccl
9-
import time
109
from .synchronize import synchronize
1110
def init_distributed(
1211
init_method : str = "env://",
@@ -57,6 +56,7 @@ def init_distributed(
5756
store = dist.PrefixStore("bmtrain", store)
5857
torch.cuda.set_device(local_rank)
5958

59+
config["initialized"] = True
6060
config["local_rank"] = local_rank
6161
config["local_size"] = local_size
6262
config["rank"] = rank
@@ -110,3 +110,6 @@ def init_distributed(
110110
"cpus": cpus_this_worker
111111
})
112112
synchronize()
113+
114+
def is_initialized() -> bool:
115+
return config["initialized"]

bmtrain/parameter.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ def __new__(cls,
3333
init_method : Optional[Callable[['DistributedParameter'], None]] = None,
3434
group : Optional[str] = None
3535
):
36+
if not config["initialized"]:
37+
raise RuntimeError("BMTrain is not initialized")
38+
3639
num_of_elements = data.numel()
3740

3841
cuda_tensor = torch.tensor([], dtype=data.dtype, device="cuda")

bmtrain/store.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@
55
from .block_layer import CheckpointBlock
66
from . import nccl
77
import io, pickle
8+
from typing import Mapping
89

910
def _save_to_state_dict(model : torch.nn.Module, destination, prefix):
1011
if isinstance(model, CheckpointBlock):
1112
if config['rank'] != 0:
1213
destination = OrderedDict() # creates an temporary ordered dict
1314
destination._metadata = OrderedDict()
14-
model.state_dict(destination, prefix, False)
15+
model.state_dict(destination=destination, prefix=prefix, keep_vars=False)
1516
else:
1617
if config['rank'] != 0:
1718
destination = OrderedDict() # creates an temporary ordered dict
@@ -109,8 +110,8 @@ def broadcast_object(obj):
109110
obj = _unpickler(io.BytesIO(buf)).load()
110111
return obj
111112

112-
113-
class DistributedStateDictWrapper:
113+
# Must be a Mapping after pytorch 1.12.0
114+
class DistributedStateDictWrapper(Mapping):
114115
def __init__(self, state_dict : Dict) -> None:
115116
self._state_dict = state_dict
116117
self._metadata = broadcast_object(getattr(state_dict, "_metadata", None))
@@ -176,6 +177,10 @@ def __contains__(self, key : str):
176177
def keys(self):
177178
return broadcast_object(list(self._state_dict.keys()))
178179

180+
def __iter__(self):
181+
# pytorch 1.12.0 updated the load_state_dict method, which needs the state_dict to be a `Mapping`.
182+
return iter(self.keys())
183+
179184
def load(model : torch.nn.Module, file_name : str, strict : bool = True):
180185
"""Loads the model from the file.
181186

bmtrain/synchronize.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,24 @@
11
import torch
2-
from . import nccl
2+
from . import distributed, nccl
33
from .global_var import config
4+
import warnings
45

56
def synchronize():
67
"""
78
Synchronize all the workers across all nodes. (both CPU and GPU are synchronized)
89
"""
10+
if not config["initialized"]:
11+
raise RuntimeError("BMTrain is not initialized")
12+
913
with torch.cuda.stream(config['barrier_stream']):
1014
barrier = torch.cuda.FloatTensor([1])
1115
nccl.allReduce(barrier.storage(), barrier.storage(), 'sum', config['comm'])
1216
config['barrier_stream'].synchronize()
1317

1418
def wait_loader():
19+
if not config["initialized"]:
20+
raise RuntimeError("BMTrain is not initialized")
21+
1522
# wait lastest loader event, and set a new one
1623
config['load_event'].synchronize()
1724
config['calc_stream'].record_event(config['load_event'])
@@ -23,22 +30,23 @@ def sum_loss(loss : torch.Tensor):
2330
2431
This is a helper function to reduce the loss across all workers.
2532
"""
26-
ret = torch.empty_like(loss)
27-
nccl.allReduce(
28-
loss.storage(),
29-
ret.storage(),
30-
'avg',
31-
config['comm']
32-
)
33-
return ret
33+
warnings.warn("bmtrain.sum_loss is deprecated and will be removed in later version. Use bmtrain.distributed.all_reduce instead.", DeprecationWarning)
34+
return distributed.all_reduce(loss, "avg")
3435

3536
def gather_result(result: torch.Tensor):
37+
warnings.warn("bmtrain.gather_result is deprecated and will be removed in later version. Use bmtrain.distributed.all_gather instead.", DeprecationWarning)
38+
39+
output_cuda = True
3640
if not result.is_cuda:
3741
result = result.cuda()
42+
output_cuda = False
3843
ret = torch.empty((result.shape[0]*config['world_size'], *list(result.shape[1:])), device=result.device, dtype=result.dtype)
3944
nccl.allGather(
4045
result.storage(),
4146
ret.storage(),
4247
config['comm']
4348
)
44-
return ret
49+
if output_cuda:
50+
return ret
51+
else:
52+
return ret.cpu()

0 commit comments

Comments
 (0)