Skip to content

Commit 90d9f87

Browse files
xingmingyyjNKNaN
andauthored
【FlexCheckpoint】Adapter Transpose and add macros (#74966)
* adapt transpose to load_static_dict * add unittest * add macros and fix * fix * fix * fix * fix * fix test * fix * fix * fix macro * fix * fix --------- Co-authored-by: AyaseNana <[email protected]>
1 parent 0a80351 commit 90d9f87

File tree

13 files changed

+1021
-185
lines changed

13 files changed

+1021
-185
lines changed

python/paddle/distributed/flex_checkpoint/aoa/aoa_engine.py

Lines changed: 62 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
from __future__ import annotations
1515

16+
import ast
1617
import re
1718
from collections.abc import Iterable
1819
from dataclasses import dataclass
@@ -67,7 +68,7 @@ def __init__(
6768
self.destination_state_shard_info = destination_state_shard_info
6869
self.optim_state_name = [
6970
".w_0",
70-
".moment1_0 ",
71+
".moment1_0",
7172
".moment2_0",
7273
".beta1_pow_acc_0",
7374
".beta2_pow_acc_0",
@@ -114,11 +115,13 @@ def get_dst_state_shard_num(self, dst_state_key: str) -> int:
114115
raise KeyError(
115116
f"dst_state_key '{dst_state_key}' not in destination_state_shard_info"
116117
)
118+
119+
new_state_key = dst_state_key
117120
for state_name in self.optim_state_name:
118121
if state_name in dst_state_key:
119122
new_state_key = dst_state_key.replace(state_name, "")
120123
break
121-
new_state_key = dst_state_key
124+
122125
shard_infos = self.destination_state_shard_info[new_state_key]
123126
global_offset_set = set()
124127
for shard_info in shard_infos:
@@ -148,9 +151,7 @@ def __init__(
148151
self.input_vars = self.build_input_vars()
149152
self.output_vars = {}
150153
self.need_remove_input_vars = set()
151-
self.need_remove_output_vars = set()
152-
self.need_transpose_output_vars = set()
153-
self.need_transpose_input_vars = {}
154+
self.need_add_output_vars = set()
154155

155156
self.shape_propagation()
156157

@@ -176,7 +177,7 @@ def split(
176177
sub_slices = []
177178
for aidx, src_sl, dst_sl, pp_list in tensor.slices:
178179
if pp_list is not None:
179-
src_sl = self.postprocess_transpose(list(src_sl), pp_list)
180+
src_sl = postprocess_transpose(list(src_sl), pp_list)
180181

181182
dst_start = (
182183
dst_sl[axis].start if dst_sl[axis].start is not None else 0
@@ -206,7 +207,7 @@ def split(
206207
inter_begin - start, inter_begin - start + length
207208
)
208209
if pp_list is not None:
209-
sub_src_sl = self.postprocess_transpose(
210+
sub_src_sl = postprocess_transpose(
210211
list(sub_src_sl), pp_list, reverse=True
211212
)
212213
sub_slices.append(
@@ -256,17 +257,19 @@ def concat(self, tensors: list[TensorDesc], axis: int) -> TensorDesc:
256257
curr += t.shape[axis]
257258
return TensorDesc(slices, tuple(shape))
258259

259-
def transpose(self, tensor: TensorDesc, transpose: str) -> TensorDesc:
260+
def transpose(self, tensor: TensorDesc, permutation: str) -> TensorDesc:
260261
slices = []
261-
tensor_shape = transpose_list(tensor.shape, eval(transpose))
262+
tensor_shape = transpose_list(
263+
tensor.shape, ast.literal_eval(permutation)
264+
)
262265
for aidx, src_sl, dst_sl, pp_list in tensor.slices:
263-
trans_dst_sl = transpose_list(dst_sl, eval(transpose))
266+
trans_dst_sl = transpose_list(dst_sl, ast.literal_eval(permutation))
264267
if pp_list is not None:
265268
new_pp_list = pp_list.copy()
266-
new_pp_list.append(transpose)
269+
new_pp_list.append(permutation)
267270
slices.append((aidx, src_sl, trans_dst_sl, new_pp_list))
268271
else:
269-
slices.append((aidx, src_sl, trans_dst_sl, [transpose]))
272+
slices.append((aidx, src_sl, trans_dst_sl, [permutation]))
270273
return TensorDesc(slices, tensor_shape)
271274

272275
def cast(self, tensor: TensorDesc, dtype: str) -> TensorDesc:
@@ -295,7 +298,6 @@ def _get_var_ref(var):
295298
left_vars = stmt.left_vars
296299
right_vars = stmt.right_vars
297300
attrs = stmt.attrs
298-
299301
if len(left_vars) > 1 or len(right_vars) > 1:
300302
if not (len(attrs) == 1 and attrs[0].key == "axis"):
301303
raise ValueError(
@@ -338,47 +340,49 @@ def _get_var_ref(var):
338340
if rvar.name == "_":
339341
self.need_remove_input_vars.add(lvar.name)
340342
elif lvar.name == "_":
341-
self.need_remove_output_vars.add(rvar.name)
343+
self.need_add_output_vars.add(rvar.name)
342344
else:
343-
if attrs:
345+
if len(attrs) > 0:
344346
for attr in attrs:
345347
in_ref = _get_var_ref(lvar)
346-
if attr.key == "transpose":
348+
if attr.key == "permute":
347349
if attr.value == "[]":
348350
ndim = len(in_ref.shape)
349-
transpose = str(
350-
list(range(ndim - 1, -1, -1))
351-
)
351+
perm = str(list(range(ndim - 1, -1, -1)))
352352
else:
353-
transpose = attr.value
354-
result = self.transpose(in_ref, transpose)
353+
perm = attr.value
354+
result = self.transpose(in_ref, perm)
355355
elif attr.key == "dtype":
356356
result = self.cast(in_ref, attr.value)
357+
elif attr.key == "axis":
358+
pass
357359
else:
358360
raise ValueError(
359361
f"Unsupported attribute: {attr}"
360362
)
361363

362-
out_name = rvar.name
363-
intermediate_vars[out_name] = result
364+
intermediate_vars[rvar.name] = result
364365
if (
365-
out_name
366+
rvar.name
366367
in self.context.get_all_dst_state_keys()
367368
):
368-
self.output_vars[out_name] = result
369+
self.output_vars[rvar.name] = result
369370
else:
370-
intermediate_vars[rvar.name] = _get_var_ref(lvar)
371+
in_ref = _get_var_ref(lvar)
372+
intermediate_vars[rvar.name] = in_ref
371373
if rvar.name in self.context.get_all_dst_state_keys():
372-
self.output_vars[rvar.name] = intermediate_vars[
373-
rvar.name
374-
]
374+
self.output_vars[rvar.name] = in_ref
375+
375376
else:
376377
raise SyntaxError(f'Unexpected statement: {stmt}')
377378

378379
for name in self.destination_state_shard_info.keys():
379380
if name not in self.output_vars:
380-
assert name in self.input_vars
381-
self.output_vars[name] = self.input_vars[name]
381+
if name in self.need_add_output_vars:
382+
self.output_vars[name] = None
383+
else:
384+
assert name in self.input_vars
385+
self.output_vars[name] = self.input_vars[name]
382386

383387
def find_source_slices(
384388
self, key: str, local_slice: tuple[slice, ...]
@@ -406,7 +410,7 @@ def slice_intersect(a: slice, b: slice):
406410
else:
407411
# Compute corresponding src_slice for the intersection
408412
if pp_list is not None:
409-
sl_src = self.postprocess_transpose(list(sl_src), pp_list)
413+
sl_src = postprocess_transpose(list(sl_src), pp_list)
410414
src_slice = []
411415
for i in range(ndim):
412416
dst = sl_dst[i]
@@ -424,7 +428,7 @@ def slice_intersect(a: slice, b: slice):
424428
)
425429
src_slice.append(slice(src_inter_start, src_inter_stop, 1))
426430
if pp_list is not None:
427-
src_slice = self.postprocess_transpose(
431+
src_slice = postprocess_transpose(
428432
list(src_slice), pp_list, reverse=True
429433
)
430434
results.append(
@@ -484,6 +488,14 @@ def find_shard_sources(
484488
tgt_global_offset,
485489
)
486490

491+
if source_sharded_weight.key in self.need_remove_input_vars:
492+
mapping_entry = ShardMappingEntry(
493+
target_sharded_weight,
494+
source_sharded_weight,
495+
[],
496+
)
497+
continue
498+
487499
shard_mappings.append(
488500
ShardMappingEntry(
489501
target_sharded_weight,
@@ -493,23 +505,23 @@ def find_shard_sources(
493505
)
494506
return shard_mappings
495507

496-
def postprocess_transpose(
497-
self,
498-
li: list[tuple[slice, ...]] | tuple[tuple[slice, ...]],
499-
postprocess_list: list[str],
500-
reverse: bool = False,
501-
) -> list[tuple[slice, ...]] | tuple[tuple[slice, ...]]:
502-
result = li
503-
if reverse:
504-
for pp in list(reversed(postprocess_list)):
505-
if pp.startswith("["):
506-
reversed_transpose = np.argsort(eval(pp)).tolist()
507-
result = transpose_list(result, reversed_transpose)
508-
else:
509-
for pp in postprocess_list:
510-
if pp.startswith("["):
511-
result = transpose_list(result, eval(pp))
512-
return result
508+
509+
def postprocess_transpose(
510+
li: list[tuple[slice, ...]] | tuple[tuple[slice, ...]],
511+
postprocess_list: list[str],
512+
reverse: bool = False,
513+
) -> list[tuple[slice, ...]] | tuple[tuple[slice, ...]]:
514+
result = li
515+
if reverse:
516+
for pp in list(reversed(postprocess_list)):
517+
if pp.startswith("["):
518+
reversed_transpose = np.argsort(ast.literal_eval(pp)).tolist()
519+
result = transpose_list(result, reversed_transpose)
520+
else:
521+
for pp in postprocess_list:
522+
if pp.startswith("["):
523+
result = transpose_list(result, ast.literal_eval(pp))
524+
return result
513525

514526

515527
def transpose_list(

python/paddle/distributed/flex_checkpoint/aoa/lexer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class Lexer:
4949
('COMMA', r','),
5050
('NUMBER', r'\d+'),
5151
('STRING', r'"[^"]*"|\'[^\']*\''),
52-
('IDENTIFIER', r'[A-Za-z][A-Za-z\.\$\_\*\d\^T]*'),
52+
('IDENTIFIER', r'[A-Za-z_][A-Za-z\.\$\_\*\d\^T]*'),
5353
('SKIP', r'[ \t]+'),
5454
('NEWLINE', r'[\r\n]+'),
5555
('MISMATCH', r'.'),
@@ -71,7 +71,8 @@ def tokenize(self, text):
7171
pos = 0
7272
mo = self.get_token(text, pos)
7373
tokens = []
74-
text += '\n'
74+
if not text.endswith('\n'):
75+
text += '\n'
7576
while mo is not None:
7677
kind = mo.lastgroup
7778
value = mo.group()

0 commit comments

Comments
 (0)