Skip to content

Commit 8c88d94

Browse files
authored
Better than PL auto strategy selection (#3)
* Better than PL auto strategy * Update README
1 parent fbd7ee6 commit 8c88d94

File tree

3 files changed

+86
-23
lines changed

3 files changed

+86
-23
lines changed

README.md

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,6 @@ A collection of neural vocoders suitable for singing voice synthesis tasks.
55
## If you have any questions, please open an issue.
66

77

8-
9-
# 使用ddp
10-
```
11-
pl_trainer_strategy:
12-
name: ddp
13-
process_group_backend: nccl
14-
find_unused_parameters: true
15-
```
16-
178
# 预处理
189
python [process.py](process.py) --config 配置文件 --num_cpu 并行数量 --strx 1 代表 强制绝对路径 0 代表相对路径
1910

train.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,13 @@ def train(config, exp_name, work_dir):
6161
accelerator=config['pl_trainer_accelerator'],
6262
devices=config['pl_trainer_devices'],
6363
num_nodes=config['pl_trainer_num_nodes'],
64-
strategy=get_strategy(config['pl_trainer_strategy']),
64+
strategy=get_strategy(
65+
config['pl_trainer_devices'],
66+
config['pl_trainer_num_nodes'],
67+
config['pl_trainer_accelerator'],
68+
config['pl_trainer_strategy'],
69+
config['pl_trainer_precision'],
70+
),
6571
precision=config['pl_trainer_precision'],
6672
callbacks=[
6773
DsModelCheckpoint(

utils/training_utils.py

Lines changed: 79 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -304,16 +304,82 @@ def get_metrics(self, trainer, model):
304304
return items
305305

306306

307-
def get_strategy(strategy):
308-
if strategy['name'] == 'auto':
309-
return 'auto'
310-
311-
from lightning.pytorch.strategies import StrategyRegistry
312-
if strategy['name'] not in StrategyRegistry:
313-
available_names = ", ".join(sorted(StrategyRegistry.keys())) or "none"
314-
raise ValueError(f"Invalid strategy name {strategy['name']}. Available names: {available_names}")
315-
316-
data = StrategyRegistry[strategy['name']]
317-
params = data['init_params']
318-
params.update({k: v for k, v in strategy.items() if k != 'name'})
319-
return data['strategy'](**utils.filter_kwargs(params, data['strategy']))
307+
def get_strategy(
308+
devices="auto",
309+
num_nodes=1,
310+
accelerator="auto",
311+
strategy={"name": "auto"},
312+
precision=None,
313+
):
314+
from lightning.fabric.utilities.device_parser import _determine_root_gpu_device
315+
from lightning.pytorch.accelerators import AcceleratorRegistry
316+
from lightning.pytorch.accelerators.cuda import CUDAAccelerator
317+
from lightning.pytorch.accelerators.mps import MPSAccelerator
318+
from lightning.pytorch.strategies import Strategy, SingleDeviceStrategy, StrategyRegistry
319+
from lightning.pytorch.trainer.connectors import accelerator_connector
320+
from lightning.pytorch.utilities.rank_zero import rank_zero_warn
321+
class _DsAcceleratorConnector(accelerator_connector._AcceleratorConnector):
322+
def __init__(self) -> None:
323+
accelerator_connector._register_external_accelerators_and_strategies()
324+
self._registered_strategies = StrategyRegistry.available_strategies()
325+
self._accelerator_types = AcceleratorRegistry.available_accelerators()
326+
self._parallel_devices = []
327+
self._check_config_and_set_final_flags(
328+
strategy=strategy["name"],
329+
accelerator=accelerator,
330+
precision=precision,
331+
plugins=[],
332+
sync_batchnorm=False,
333+
)
334+
if self._accelerator_flag == "auto":
335+
self._accelerator_flag = self._choose_auto_accelerator()
336+
elif self._accelerator_flag == "gpu":
337+
self._accelerator_flag = self._choose_gpu_accelerator_backend()
338+
self._check_device_config_and_set_final_flags(devices=devices, num_nodes=num_nodes)
339+
self._set_parallel_devices_and_init_accelerator()
340+
if self._strategy_flag == "auto":
341+
self._strategy_flag = self._choose_strategy()
342+
self._check_strategy_and_fallback()
343+
self._init_strategy()
344+
for k in ["colossalai", "bagua", "hpu", "hpu_parallel", "hpu_single", "ipu", "ipu_strategy"]:
345+
if k in StrategyRegistry:
346+
StrategyRegistry.remove(k)
347+
348+
def _init_strategy(self) -> None:
349+
assert isinstance(self._strategy_flag, (str, Strategy))
350+
if isinstance(self._strategy_flag, str):
351+
if self._strategy_flag not in StrategyRegistry:
352+
available_names = ", ".join(sorted(StrategyRegistry.available_strategies())) or "none"
353+
raise KeyError(f"Invalid strategy name {strategy['name']}. Available names: {available_names}")
354+
data = StrategyRegistry[self._strategy_flag]
355+
params = {}
356+
# Replicate additional logic for _choose_strategy when dealing with single device strategies
357+
if issubclass(data["strategy"], SingleDeviceStrategy):
358+
if self._accelerator_flag == "hpu":
359+
params = {"device": torch.device("hpu")}
360+
elif self._accelerator_flag == "tpu":
361+
params = {"device": self._parallel_devices[0]}
362+
elif data["strategy"] is SingleDeviceStrategy:
363+
if isinstance(self._accelerator_flag, (CUDAAccelerator, MPSAccelerator)) or (
364+
isinstance(self._accelerator_flag, str) and self._accelerator_flag in ("cuda", "gpu", "mps")
365+
):
366+
params = {"device": _determine_root_gpu_device(self._parallel_devices)}
367+
else:
368+
params = {"device": "cpu"}
369+
else:
370+
raise NotImplementedError
371+
params.update(data["init_params"])
372+
params.update({k: v for k, v in strategy.items() if k != "name"})
373+
self.strategy = data["strategy"](**utils.filter_kwargs(params, data["strategy"]))
374+
elif isinstance(self._strategy_flag, SingleDeviceStrategy):
375+
params = {"device": self._strategy_flag.root_device}
376+
params.update({k: v for k, v in strategy.items() if k != "name"})
377+
self.strategy = self._strategy_flag.__class__(**utils.filter_kwargs(params, self._strategy_flag.__class__))
378+
else:
379+
rank_zero_warn(
380+
f"Inferred strategy {self._strategy_flag.__class__.__name__} cannot take custom configurations."
381+
f"To use custom configurations, please specify the strategy name explicitly."
382+
)
383+
self.strategy = self._strategy_flag
384+
385+
return _DsAcceleratorConnector().strategy

0 commit comments

Comments
 (0)