22import torch
33from mmcv .parallel import MMDataParallel , MMDistributedDataParallel
44
5- dp_factory = {
6- 'cuda' : MMDataParallel ,
7- 'cpu' : MMDataParallel
8- }
5+ dp_factory = {'cuda' : MMDataParallel , 'cpu' : MMDataParallel }
6+
7+ ddp_factory = {'cuda' : MMDistributedDataParallel }
98
10- ddp_factory = {
11- 'cuda' : MMDistributedDataParallel
12- }
139
1410def build_dp (model , device = 'cuda' , dim = 0 , * args , ** kwargs ):
1511 """build DataParallel module by device type.
16-
12+
1713 if device is cuda, return a MMDataParallel model; if device is mlu,
1814 return a MLUDataParallel model.
1915
2016 Args:
21- model(:class:`nn.Module`): model to be parallelized.
22- device(str): device type, cuda, cpu or mlu. Defaults to cuda.
23- dim(int): Dimension used to scatter the data. Defaults to 0.
17+ model (:class:`nn.Module`): model to be parallelized.
18+ device (str): device type, cuda, cpu or mlu. Defaults to cuda.
19+ dim (int): Dimension used to scatter the data. Defaults to 0.
2420
2521 Returns:
26- model( nn.Module) : the model to be parallelized.
22+ nn.Module: the model to be parallelized.
2723 """
2824 if device == 'cuda' :
2925 model = model .cuda ()
@@ -38,15 +34,15 @@ def build_dp(model, device='cuda', dim=0, *args, **kwargs):
3834def build_ddp (model , device = 'cuda' , * args , ** kwargs ):
3935 """Build DistributedDataParallel module by device type.
4036
41- If device is cuda, return a MMDistributedDataParallel model; if device is mlu,
42- return a MLUDistributedDataParallel model.
37+ If device is cuda, return a MMDistributedDataParallel model;
38+ if device is mlu, return a MLUDistributedDataParallel model.
4339
4440 Args:
45- model(:class:`nn.Module`): module to be parallelized.
46- device(str): device type, mlu or cuda.
41+ model (:class:`nn.Module`): module to be parallelized.
42+ device (str): device type, mlu or cuda.
4743
4844 Returns:
49- model( :class:`nn.Module`) : the module to be parallelized
45+ :class:`nn.Module`: the module to be parallelized
5046
5147 References:
5248 .. [1] https://pytorch.org/docs/stable/generated/torch.nn.parallel.
@@ -56,19 +52,23 @@ def build_ddp(model, device='cuda', *args, **kwargs):
5652 if device == 'cuda' :
5753 model = model .cuda ()
5854 elif device == 'mlu' :
59- from mmcv .device .mlu import MLUDistributedDataParallel
55+ from mmcv .device .mlu import MLUDistributedDataParallel
6056 ddp_factory ['mlu' ] = MLUDistributedDataParallel
6157 model = model .mlu ()
6258
6359 return ddp_factory [device ](model , * args , ** kwargs )
6460
61+
6562def is_mlu_available ():
66- """ Returns a bool indicating if MLU is currently available. """
63+ """Returns a bool indicating if MLU is currently available."""
6764 return hasattr (torch , 'is_mlu_available' ) and torch .is_mlu_available ()
6865
66+
6967def get_device ():
70- """ Returns an available device, cpu, cuda or mlu. """
71- is_device_available = {'cuda' : torch .cuda .is_available (),
72- 'mlu' : is_mlu_available ()}
73- device_list = [k for k , v in is_device_available .items () if v ]
68+ """Returns an available device, cpu, cuda or mlu."""
69+ is_device_available = {
70+ 'cuda' : torch .cuda .is_available (),
71+ 'mlu' : is_mlu_available ()
72+ }
73+ device_list = [k for k , v in is_device_available .items () if v ]
7474 return device_list [0 ] if len (device_list ) == 1 else 'cpu'
0 commit comments