Skip to content

Commit 11fff57

Browse files
authored
【Flexcheckpoint】add_get_var_mapping_chain_macro (PaddlePaddle#76013)
* add_get_var_mapping_chain_macro * add note * fix the bug input_vars and resolve_mapping_chain * fix the code style * fit the dtype assert bug * fix the bug * fix the merge_sharded_state_dict bug
1 parent 8036231 commit 11fff57

File tree

6 files changed

+381
-56
lines changed

6 files changed

+381
-56
lines changed

python/paddle/distributed/flex_checkpoint/aoa/aoa_engine.py

Lines changed: 83 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@ def __init__(
9090
) -> None:
9191
self.source_state_shard_info = source_state_shard_info
9292
self.destination_state_shard_info = destination_state_shard_info
93+
self.left_var_to_right_var_mapping = {}
94+
self.right_var_from_left_var_mapping = {}
9395

9496
def get_all_dst_state_keys(self):
9597
dst_state_keys = set()
@@ -135,11 +137,16 @@ def get_src_state_shard_num(self, src_state_key: str) -> int:
135137
"AOA notions apply only to the model state, but are automatically propagated to the optimizer state."
136138
)
137139

140+
# Only need to parse the model state key for optimizer state shard num, because the optimizer state slice info is completely consistent with the model state slice info.
141+
resolved_model_state_key = self.resolve_mapping_chain(
142+
model_state_key, reverse=True
143+
)
144+
138145
state_keys = [
139-
model_state_key,
140-
f"{model_state_key}.w_0",
141-
f"{model_state_key}.moment1_0",
142-
f"{model_state_key}.moment2_0",
146+
resolved_model_state_key,
147+
f"{resolved_model_state_key}.w_0",
148+
f"{resolved_model_state_key}.moment1_0",
149+
f"{resolved_model_state_key}.moment2_0",
143150
]
144151

145152
shard_nums = {
@@ -168,7 +175,6 @@ def get_dst_state_shard_num(self, dst_state_key: str) -> int:
168175
if self.destination_state_shard_info is None:
169176
# Default `dst_state_shard_num=1` if `destination_state_shard_info` is missing.
170177
return 1
171-
172178
model_state_key, opt_state_name = split_optimizer_state_key(
173179
dst_state_key
174180
)
@@ -177,11 +183,16 @@ def get_dst_state_shard_num(self, dst_state_key: str) -> int:
177183
"AOA notions apply only to the model state, but are automatically propagated to the optimizer state."
178184
)
179185

186+
# Only need to parse the model state key for optimizer state shard num, because the optimizer state slice info is completely consistent with the model state slice info.
187+
resolved_model_state_key = self.resolve_mapping_chain(
188+
model_state_key, reverse=False
189+
)
190+
180191
state_keys = [
181-
model_state_key,
182-
f"{model_state_key}.w_0",
183-
f"{model_state_key}.moment1_0",
184-
f"{model_state_key}.moment2_0",
192+
resolved_model_state_key,
193+
f"{resolved_model_state_key}.w_0",
194+
f"{resolved_model_state_key}.moment1_0",
195+
f"{resolved_model_state_key}.moment2_0",
185196
]
186197

187198
shard_nums = {
@@ -206,6 +217,44 @@ def get_dst_state_shard_num(self, dst_state_key: str) -> int:
206217
)
207218
return shard_nums.pop()
208219

220+
def resolve_mapping_chain(self, key: str, reverse: bool = False) -> str:
221+
"""
222+
Recursively resolve the mapping chain, find the final leaf node
223+
224+
Args:
225+
key: The key to be resolved
226+
reverse: False use left_var_to_right_var_mapping,True use right_var_from_left_var_mapping
227+
228+
For example:
229+
- reverse=False: temp_var -> dst_key
230+
- reverse=True: temp_var -> src_key
231+
"""
232+
visited = set() # avoid infinite loop
233+
current_key = key
234+
235+
if reverse:
236+
mapping_dict = self.right_var_from_left_var_mapping
237+
else:
238+
mapping_dict = self.left_var_to_right_var_mapping
239+
240+
while current_key in mapping_dict:
241+
assert current_key not in visited, (
242+
"Infinite loop detected in resolve_mapping_chain,which means the start key is not src_key or the end key is not dst_key, the aoa_config is error"
243+
)
244+
visited.add(current_key)
245+
if reverse and current_key in self.get_all_src_state_keys():
246+
break
247+
elif not reverse and current_key in self.get_all_dst_state_keys():
248+
break
249+
250+
mapped_vars = mapping_dict[current_key]
251+
if mapped_vars and len(mapped_vars) > 0:
252+
current_key = mapped_vars[0]
253+
else:
254+
break
255+
256+
return current_key
257+
209258

210259
class AOAEngine:
211260
def __init__(
@@ -248,14 +297,20 @@ def make_input_tensor(
248297

249298
def build_input_vars(self):
250299
input_vars = {}
251-
for key, shards in self.source_state_shard_info.items():
300+
dtype = None
301+
for key, shards in sorted(self.source_state_shard_info.items()):
252302
global_shape = shards[0].global_shape
253-
dtype = shards[0].dtype
254303
model_state_key, opt_state_name = split_optimizer_state_key(key)
255-
if opt_state_name in [".w_0", ".moment1_0", ".moment2_0", None]:
256-
input_vars[model_state_key] = self.make_input_tensor(
257-
model_state_key, global_shape, dtype
258-
)
304+
if opt_state_name is None:
305+
dtype = shards[0].dtype
306+
if model_state_key in input_vars.keys() or opt_state_name in [
307+
".beta1_pow_acc_0",
308+
".beta2_pow_acc_0",
309+
]:
310+
continue
311+
input_vars[model_state_key] = self.make_input_tensor(
312+
model_state_key, global_shape, dtype
313+
)
259314
return input_vars
260315

261316
def split(
@@ -652,11 +707,19 @@ def find_shard_sources(
652707

653708
for src_key, src_slices, local_slices, pp_list in results:
654709
src_var = self.input_vars[src_key]
655-
if src_var.dtype != target.dtype:
656-
assert pp_list is not None and target.dtype in str(pp_list), (
657-
"Direct assignment of Tensors with different types is prohibited in AOA. "
658-
"If you want to achieve this functionality, please use the cast semantics provided by AOA."
659-
)
710+
target_model_state_key, target_opt_state_name = (
711+
split_optimizer_state_key(target.key)
712+
)
713+
if target_opt_state_name is None:
714+
if src_var.dtype != target.dtype:
715+
assert pp_list is not None and target.dtype in str(
716+
pp_list
717+
), (
718+
"Direct assignment of Tensors with different types is prohibited in AOA. "
719+
"If you want to achieve this functionality, please use the cast semantics provided by AOA."
720+
)
721+
else:
722+
src_var.dtype = target.dtype
660723

661724
src_global_shape = src_var.shape
662725

python/paddle/distributed/flex_checkpoint/aoa/lexer.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,6 @@ def tokenize(self, text):
8888
mo = self.get_token(text, pos)
8989
return tokens
9090

91-
def apply_macros(self, expression):
92-
expressions = [expression]
93-
for macro in self.macros:
94-
expressions = self.apply_macro(expressions, macro)
95-
return expressions
96-
9791
def apply_macro(self, expression, macro):
9892
if isinstance(expression, str):
9993
expression = [expression]
@@ -106,10 +100,24 @@ def apply_macro(self, expression, macro):
106100
new_expression.extend(results)
107101
return new_expression
108102

103+
def apply_single_macro_to_all(self, expressions, macro):
104+
new_expressions = []
105+
for expr in expressions:
106+
results = macro(self.tokenize(expr), expr, self.context)
107+
if isinstance(results, str):
108+
new_expressions.append(results)
109+
else:
110+
new_expressions.extend(results)
111+
return new_expressions
112+
109113
def all_tokens(self, expressions):
114+
current_expressions = expressions
115+
for macro in self.macros:
116+
current_expressions = self.apply_single_macro_to_all(
117+
current_expressions, macro
118+
)
119+
110120
tokens = []
111-
for expr in expressions:
112-
expanded_expressions = self.apply_macros(expr)
113-
for e in expanded_expressions:
114-
tokens.extend(self.tokenize(e))
121+
for expr in current_expressions:
122+
tokens.extend(self.tokenize(expr))
115123
return tokens

python/paddle/distributed/flex_checkpoint/aoa/macros.py

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ def register_macro(self, name, func, priority):
5656
'num_heads',
5757
'num_key_value_groups',
5858
'permute',
59+
'dtype',
60+
'fused_qkv',
5961
]
6062

6163
EXTRA_SUFFIX = [
@@ -211,7 +213,7 @@ def array_macro(tokens, expression, context):
211213
return new_expression
212214

213215

214-
@macro(name='fused_qkv_old_macro', priority=4)
216+
@macro(name='fused_qkv_old_macro', priority=6)
215217
def fused_qkv_old_macro(tokens, expression, context):
216218
FUSED_QKV_OLD_TAG = "fused_qkv_old"
217219
if not any(tkn.value == FUSED_QKV_OLD_TAG for tkn in tokens):
@@ -381,7 +383,7 @@ def gen_expr(tp_degree, num_heads, tp_rank, comp):
381383
return results
382384

383385

384-
@macro(name='fused_ffn_macro', priority=4)
386+
@macro(name='fused_ffn_macro', priority=6)
385387
def fused_ffn_macro(tokens, expression, context):
386388
FUSED_FFN_TAG = "fused_ffn"
387389
if not any(tkn.value == FUSED_FFN_TAG for tkn in tokens):
@@ -505,7 +507,7 @@ def gen_expr(tp_degree, splited_num, tp_rank, comp):
505507
return results
506508

507509

508-
@macro(name='transpose_macro', priority=3)
510+
@macro(name='transpose_macro', priority=5)
509511
def transpose_macro(tokens, expression, context):
510512
TRANSPOSE_TAG = "^T"
511513

@@ -551,7 +553,7 @@ def transpose_macro(tokens, expression, context):
551553
return results
552554

553555

554-
@macro(name='fused_qkv_macro', priority=4)
556+
@macro(name='fused_qkv_macro', priority=6)
555557
def fused_qkv_macro(tokens, expression, context):
556558
FUSED_QKV_TAG = "fused_qkv"
557559
if not any(tkn.value == FUSED_QKV_TAG for tkn in tokens):
@@ -711,6 +713,7 @@ def find_matches(self, pattern: str) -> dict[str, list[int]]:
711713
_REGISTERED_PLACEHOLDERS = ['$EXPERT_ID', '$LAYER_ID']
712714

713715

716+
# TODO: need to adapt the scene of temp_layers.\$LAYER_ID.weight -> dst_layers.\$LAYER_ID.weight
714717
@macro(name='id_macro', priority=1)
715718
def id(tokens, expression, context):
716719
allowed_placeholders = _REGISTERED_PLACEHOLDERS
@@ -783,3 +786,43 @@ def dict_cartesian_tuples(d: dict[str, list[int]]):
783786
results.append(cur_statement)
784787

785788
return results
789+
790+
791+
# This macro processes variable mappings between source and destination states,
792+
# but it requires that all expansion macros (layer_id_macro, expert_id_macro,
793+
# star_macro, array_macro, etc.) have already been executed to expand template
794+
# variables into concrete variable names.
795+
@macro(name='get_var_mapping_chain_macro', priority=4)
796+
def get_var_mapping_chain_macro(tokens, expression, context):
797+
flag_left_var = True
798+
left_var_list = []
799+
right_var_list = []
800+
for tkn in tokens:
801+
if tkn.value in GLOBAL_ATTRIBUTE_KEYWORDS:
802+
break
803+
if tkn.type == TokenType.RARROW:
804+
flag_left_var = False
805+
if tkn.type == TokenType.IDENTIFIER:
806+
extra_suffix_removed_value = tkn.value
807+
for sfx in EXTRA_SUFFIX:
808+
extra_suffix_removed_value = (
809+
extra_suffix_removed_value.removesuffix(sfx)
810+
)
811+
if flag_left_var:
812+
left_var_list.append(extra_suffix_removed_value)
813+
else:
814+
right_var_list.append(extra_suffix_removed_value)
815+
assert len(left_var_list) == 1 or len(right_var_list) == 1, (
816+
"Left or right variable must have the only one element"
817+
)
818+
if len(left_var_list) == 1:
819+
context.left_var_to_right_var_mapping[left_var_list[0]] = right_var_list
820+
for right_var in right_var_list:
821+
context.right_var_from_left_var_mapping[right_var] = left_var_list
822+
else:
823+
context.right_var_from_left_var_mapping[right_var_list[0]] = (
824+
left_var_list
825+
)
826+
for left_var in left_var_list:
827+
context.left_var_to_right_var_mapping[left_var] = right_var_list
828+
return expression

python/paddle/distributed/flex_checkpoint/aoa/parser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def consume(self, expected_type=None):
7575
tok = self.peek()
7676
if expected_type and tok.type != expected_type:
7777
raise SyntaxError(
78-
f'Expected {expected_type}, got {tok.type} at pos {tok.pos}'
78+
f'Expected {expected_type}, got {tok.type} at pos {self.pos}'
7979
)
8080
self.pos += 1
8181
return tok

python/paddle/distributed/flex_checkpoint/dcp/load_state_dict.py

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -888,27 +888,17 @@ def load_state_dict(
888888
f"{key} is not replicated!"
889889
)
890890
load_dict[key] = val
891-
892-
load_state_dict_impl(
893-
state_dict=load_dict,
894-
path=path,
895-
process_group=process_group,
896-
coordinator_rank=coordinator_rank,
897-
unique_id=unique_id,
898-
offload=offload,
899-
mw_name_compatibility=mw_name_compatibility,
900-
safetensors=safetensors,
901-
worker_groups=worker_groups,
891+
destination_state_shard_info = defaultdict(list)
892+
for key, val in load_dict.items():
893+
desc = build_shard_desc(val)
894+
destination_state_shard_info[key].append(desc)
895+
else:
896+
flat_shards, nonflat_shards = _split_flat_shards(state_dict)
897+
load_dict, padding_info = _unflatten_shards(flat_shards)
898+
load_dict.update(nonflat_shards)
899+
destination_state_shard_info = build_global_state_shard_info(
900+
state_dict, process_group
902901
)
903-
return
904-
905-
destination_state_shard_info = build_global_state_shard_info(
906-
state_dict, process_group
907-
)
908-
909-
flat_shards, nonflat_shards = _split_flat_shards(state_dict)
910-
load_dict, padding_info = _unflatten_shards(flat_shards)
911-
load_dict.update(nonflat_shards)
912902

913903
if aoa_config is not None:
914904
_handle_aoa(
@@ -935,7 +925,8 @@ def load_state_dict(
935925
safetensors=safetensors,
936926
worker_groups=worker_groups,
937927
)
938-
_finish_unflatten(flat_shards, padding_info)
928+
if use_dist:
929+
_finish_unflatten(flat_shards, padding_info)
939930

940931
global _metadata_manager
941932
_metadata_manager.clear()

0 commit comments

Comments
 (0)