Skip to content

Commit f85e352

Browse files
skydoorkai乙划
andauthored
sync to internal 1.6.0.dev3 (#4)
Co-authored-by: 乙划 <[email protected]>
1 parent 4af4425 commit f85e352

File tree

197 files changed

+17163
-3211
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

197 files changed

+17163
-3211
lines changed

.isort.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[settings]
22
multi_line_output=3
33
line_length=120
4-
known_third_party = GPy,accelerate,agd,apex,data,datasets,deepspeed,distutils,dlrover,einops,evaluate,example_utils,fairscale,flash_attn,google,grpc,instruction_dataset_utils,matplotlib,model,modeling,networkx,numpy,packaging,pandas,peft,psutil,pymoo,pyomo,pytest,redis,safetensors,scipy,sklearn,tiktoken,torch,torch_npu,torchvision,tqdm,transformers,triton,typing_extensions,utils,yaml
4+
known_third_party = accelerate,agd,apex,datasets,deepspeed,distutils,dlrover,einops,evaluate,example_utils,fairscale,flash_attn,google,grpc,instruction_dataset_utils,matplotlib,megatron,model,model_registry,moe_modules,networkx,numpy,packaging,pandas,peft,psutil,pytest,redis,safetensors,scipy,seaborn,sklearn,tiktoken,torch,torch_npu,torchvision,tqdm,transformers,triton,typing_extensions,utils,wrapt,yaml
55
include_trailing_comma=True

atorch/auto/accelerate.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,25 +43,46 @@ def model_transform(
4343
assert not strategy.is_tunable()
4444
record_user_defined_half_precision_dtype(strategy)
4545
cpu_offload = False
46+
has_fsdp2 = False
47+
param_init_by_user_fn_already = False
4648
for opt in strategy:
4749
opt_name = opt[0]
4850
opt_config = opt[1]
4951
model_context = opt_lib[opt_name].transform(model_context, opt_config)
5052
if opt_name == "fsdp" and opt_config is not None and opt_config.get("cpu_offload", False) is True:
5153
cpu_offload = True
54+
if opt_name == "fsdp2":
55+
has_fsdp2 = True
56+
model_device = next(model_context.model.parameters()).device
57+
if "param_init_fn" in opt_config and model_device is torch.device("cuda"):
58+
param_init_by_user_fn_already = True
59+
5260
model_context.adjust_wrappers()
5361
if apply_wrapper:
5462
model_context.apply_wrappers(is_pre_wrapper=True)
5563
if create_dataloader:
5664
model_context.update_dataloader()
57-
if create_optim:
58-
model_context.update_optim()
5965
if use_sample_batch:
6066
model_context.update_sample_batch()
61-
if apply_wrapper:
62-
model_context.apply_wrappers(is_pre_wrapper=False)
63-
if torch.cuda.is_available() and not model_context.gpu_used and not cpu_offload:
64-
reload_meta_module(model_context.model, torch.device(type="cuda", index=local_rank()), False)
67+
if not has_fsdp2:
68+
if create_optim:
69+
model_context.update_optim()
70+
if apply_wrapper:
71+
model_context.apply_wrappers(is_pre_wrapper=False)
72+
73+
if (
74+
torch.cuda.is_available()
75+
and not model_context.gpu_used
76+
and not cpu_offload
77+
and not param_init_by_user_fn_already
78+
):
79+
reload_meta_module(model_context.model, torch.device(type="cuda", index=local_rank()), False, True, has_fsdp2)
80+
81+
if has_fsdp2:
82+
if create_optim:
83+
model_context.update_optim()
84+
if apply_wrapper:
85+
model_context.apply_wrappers(is_pre_wrapper=False)
6586
return model_context
6687

6788

atorch/auto/clip_grad_norm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def clip_grad_norm(model, max_norm, norm_type=2, optimizer=None, process_group_n
2929
Returns:
3030
Total norm of the parameters (viewed as a single vector) or None if using ds zero optimizer.
3131
"""
32-
if isinstance(optimizer, DeepSpeedZeroOptimizer):
32+
if DeepSpeedZeroOptimizer is not None and isinstance(optimizer, DeepSpeedZeroOptimizer):
3333
assert norm_type == 2, "deep speed zero optimizer only supports L2 norm"
3434
optimizer.clip_grad = max_norm
3535
return None

atorch/auto/engine/sg_algo/hebo/acq_optimizers/evolution_optimizer.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,20 @@
11
import numpy as np
22
import pandas as pd
3-
from pymoo.config import Config
4-
from pymoo.core.problem import Problem
5-
from pymoo.factory import get_algorithm, get_crossover, get_mutation
6-
from pymoo.operators.mixed_variable_operator import MixedVariableCrossover, MixedVariableMutation
7-
from pymoo.optimize import minimize
3+
4+
try:
5+
from pymoo.config import Config
6+
from pymoo.core.problem import Problem
7+
from pymoo.factory import get_algorithm, get_crossover, get_mutation
8+
from pymoo.operators.mixed_variable_operator import MixedVariableCrossover, MixedVariableMutation
9+
from pymoo.optimize import minimize
10+
11+
Config.show_compile_hint = False
12+
except (ImportError, ModuleNotFoundError):
13+
Problem = object
814

915
from atorch.auto.engine.sg_algo.hebo.acquisitions.acq import Acquisition
1016
from atorch.auto.engine.sg_algo.hebo.design_space.design_space import DesignSpace
1117

12-
Config.show_compile_hint = False
13-
1418

1519
class BOProblem(Problem):
1620
def __init__(
@@ -26,6 +30,9 @@ def __init__(
2630
self.space = space
2731
self.fix = fix # NOTE: use self.fix to enable contextual BO
2832

33+
if Problem == object:
34+
print("Install pymoo==0.5.0 to support evolution optimizer.")
35+
2936
super().__init__(len(lb), xl=lb, xu=ub, n_obj=acq.num_obj, n_constr=acq.num_constr)
3037

3138
def _evaluate(self, x: np.ndarray, out: dict, *args, **kwargs):

atorch/auto/engine/sg_algo/hebo/models/gauss_process/gpy_wgp.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
import logging
22
import warnings
33

4-
import GPy
54
import numpy as np
6-
from GPy.models import InputWarpedGP
7-
from GPy.util.input_warping_functions import KumarWarping
5+
6+
try:
7+
import GPy
8+
from GPy.models import InputWarpedGP
9+
from GPy.util.input_warping_functions import KumarWarping
10+
except (ImportError, ModuleNotFoundError):
11+
print("Install GPy package to support auto training optimization.")
12+
813
from sklearn.preprocessing import MinMaxScaler, StandardScaler
914

1015
from atorch.auto.engine.sg_algo.hebo.models.base_model import BaseModel

atorch/auto/model_context.py

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import types
66

77
try:
8-
from collections import abs as collections_abc # type: ignore[attr-defined]
8+
from collections import abc as collections_abc # type: ignore[attr-defined]
99
except ImportError:
1010
import collections as collections_abc # type: ignore[no-redef]
1111

@@ -28,6 +28,7 @@
2828
from atorch.distributed.distributed import (
2929
get_data_partition_rank_and_size,
3030
local_rank,
31+
parallel_group,
3132
parallel_group_and_ranks,
3233
parallel_group_size,
3334
rank,
@@ -396,6 +397,9 @@ def create_optim(self):
396397
src = ranks[0]
397398
torch.distributed._broadcast_coalesced(process_group, module_states, int(250 * 1024 * 1024), src)
398399

400+
if "fsdp2" in self.pre_wrappers and parallel_group("expert") is not None:
401+
self.optim_args["foreach"] = False
402+
399403
if not self.check_pipe_model():
400404
if not self.optim_param_func:
401405
optim = self.optim_func(self.model.parameters(), **self.optim_args)
@@ -416,6 +420,7 @@ def create_optim(self):
416420
and "ds_zero" not in self.post_wrappers
417421
and "zero2" not in self.post_wrappers
418422
and "fsdp" not in self.pre_wrappers
423+
and "fsdp2" not in self.pre_wrappers
419424
and "ds_3d_parallel" not in self.post_wrappers
420425
):
421426
is_cuda = next(self.model.parameters()).is_cuda
@@ -497,6 +502,8 @@ def adjust_wrappers(self):
497502
self.pre_wrappers.pop("zero2")
498503
if "fsdp" in self.pre_wrappers:
499504
self.pre_wrappers.pop("fsdp")
505+
if "fsdp2" in self.pre_wrappers:
506+
self.pre_wrappers.pop("fsdp2")
500507

501508
# DDP is supported and handled internally by PiPPy.
502509
if "ddp" in self.post_wrappers:
@@ -572,13 +579,18 @@ def adjust_wrappers(self):
572579
ds_3d_parallel_wrapper_exist = "ds_3d_parallel" in self.post_wrappers
573580
fairscale_zero2_wrapper_exist = "zero2" in self.post_wrappers
574581
fsdp_wrapper_exist = "fsdp" in self.pre_wrappers or "zero2" in self.pre_wrappers
582+
fsdp2_wrapper_exist = "fsdp2" in self.pre_wrappers
575583
tensor_parallel_wrapper_exist = "tp" in self.pre_wrappers
576584
ckpt_wrapper_exist = "checkpoint" in self.post_wrappers
577585
native_dynamo_wrapper_exist = "native_dynamo" in self.pre_wrappers
578586

579587
# remove ddp wrapper when using zero2
580588
if ddp_wrapper_exist and (
581-
fairscale_zero2_wrapper_exist or fsdp_wrapper_exist or ds_zero_wrapper_exist or ds_3d_parallel_wrapper_exist
589+
fairscale_zero2_wrapper_exist
590+
or fsdp_wrapper_exist
591+
or ds_zero_wrapper_exist
592+
or ds_3d_parallel_wrapper_exist
593+
or fsdp2_wrapper_exist
582594
):
583595
logger.info("Found Zero, ds_3d_parallel, or pipe wrapper, remove ddp wrapper.")
584596
self.post_wrappers.pop("ddp")
@@ -587,21 +599,28 @@ def adjust_wrappers(self):
587599
logger.info("Found fsdp and amp_native wrapper, turn on mixed_precision in FSDP")
588600
_, amp_native_config = self.post_wrappers["amp_native"]
589601
fp16_dtype = amp_native_config.get("dtype", torch.float16)
590-
mixed_precision_param = (
591-
MixedPrecision(param_dtype=fp16_dtype, reduce_dtype=fp16_dtype, buffer_dtype=fp16_dtype)
592-
if MixedPrecision
593-
else True
594-
)
602+
mixed_precision_param = {"param_dtype": fp16_dtype, "reduce_dtype": fp16_dtype, "buffer_dtype": fp16_dtype}
595603
config = self.pre_wrappers["fsdp"][1] or {}
596604
config["mixed_precision"] = mixed_precision_param
597605
self.pre_wrappers["fsdp"] = (
598606
self.pre_wrappers["fsdp"][0],
599607
config,
600608
)
609+
elif fsdp2_wrapper_exist and "amp_native" in self.post_wrappers:
610+
logger.info("Found fsdp2 and amp_native wrapper, turn on mixed_precision in FSDP")
611+
_, amp_native_config = self.post_wrappers["amp_native"]
612+
fp16_dtype = amp_native_config.get("dtype", torch.float16)
613+
mixed_precision_param = {"param_dtype": fp16_dtype, "reduce_dtype": fp16_dtype, "buffer_dtype": fp16_dtype}
614+
config = self.pre_wrappers["fsdp2"][1] or {}
615+
config["mixed_precision"] = mixed_precision_param
616+
self.pre_wrappers["fsdp2"] = (
617+
self.pre_wrappers["fsdp2"][0],
618+
config,
619+
)
601620

602621
# move dynamo_native wrapper behind ddp or fsdp (fsdp will adjusted later)
603622
# Note that dynamo_native wrapper and fsdp wrapper are pre-wrappers while ddp wrapper is a post-wrapper.
604-
if native_dynamo_wrapper_exist and ddp_wrapper_exist and not fsdp_wrapper_exist:
623+
if native_dynamo_wrapper_exist and ddp_wrapper_exist and not fsdp_wrapper_exist and not fsdp2_wrapper_exist:
605624
# ddp wrapper is a post-wrapper. Popping dynamo_native wrapper from pre-wrappers
606625
# then insert it after ddp wrapper.
607626
post_wrappers_list = []
@@ -616,8 +635,13 @@ def adjust_wrappers(self):
616635

617636
if tensor_parallel_wrapper_exist:
618637
wrap_cls = None
638+
fsdp_wrapper = None
619639
if fsdp_wrapper_exist and torch_version() >= (1, 12, 0):
620640
fsdp_wrapper = self.pre_wrappers["fsdp"]
641+
elif fsdp2_wrapper_exist and torch_version() >= (1, 12, 0):
642+
fsdp_wrapper = self.pre_wrappers["fsdp2"]
643+
644+
if fsdp_wrapper is not None:
621645
fsdp_wrapper = list(fsdp_wrapper)
622646
if fsdp_wrapper[1] is None:
623647
fsdp_wrapper[1] = dict()
@@ -644,15 +668,19 @@ def adjust_wrappers(self):
644668
leaf_modules = _propose_leaf_modules(wrap_cls)
645669
auto_wrap_cls = _propose_wrap_cls(leaf_modules)
646670

647-
if fsdp_wrapper_exist and torch_version() >= (1, 12, 0):
671+
if (fsdp_wrapper_exist or fsdp2_wrapper_exist) and torch_version() >= (1, 12, 0):
648672
if "atorch_wrap_cls" in fsdp_config:
649673
if auto_wrap_cls is not None:
650674
fsdp_config["atorch_wrap_cls"] = auto_wrap_cls
651675
else:
652676
fsdp_config.pop("atorch_wrap_cls")
653677

654678
fsdp_wrapper[1] = fsdp_config
655-
self.pre_wrappers["fsdp"] = tuple(fsdp_wrapper)
679+
680+
if fsdp_wrapper_exist:
681+
self.pre_wrappers["fsdp"] = tuple(fsdp_wrapper)
682+
elif fsdp2_wrapper_exist:
683+
self.pre_wrappers["fsdp2"] = tuple(fsdp_wrapper)
656684

657685
if ckpt_wrapper_exist:
658686
if auto_wrap_cls is not None:
@@ -671,7 +699,7 @@ def adjust_wrappers(self):
671699
tensor_parallel_wrapper_item = list(tensor_parallel_wrapper_item)
672700
tensor_parallel_wrapper_item[1] = list(tensor_parallel_wrapper_item[1])
673701
tensor_parallel_wrapper_item[1][1]["leaf_modules"] = leaf_modules
674-
if fsdp_wrapper_exist or pipe_wrapper_exist:
702+
if fsdp_wrapper_exist or fsdp2_wrapper_exist or pipe_wrapper_exist:
675703
tensor_parallel_wrapper_item[1][1]["defer_init"] = True
676704
tensor_parallel_wrapper_item[1] = tuple(tensor_parallel_wrapper_item[1])
677705
tensor_parallel_wrapper_item = tuple(tensor_parallel_wrapper_item)
@@ -687,7 +715,7 @@ def adjust_wrappers(self):
687715
_insert_amp_config_for_tp_ckpt(amp_config)
688716

689717
# adjust pre_wrapper order
690-
order_wrapper_name = ["half", "module_replace", "sequence_parallel", "fp8", "fsdp", "native_dynamo"]
718+
order_wrapper_name = ["half", "module_replace", "sequence_parallel", "fp8", "fsdp", "fsdp2", "native_dynamo"]
691719
match_names = []
692720
for name in self.pre_wrappers:
693721
if name in order_wrapper_name:

0 commit comments

Comments
 (0)