13
13
# limitations under the License.
14
14
from __future__ import annotations
15
15
16
+ import ast
16
17
import re
17
18
from collections .abc import Iterable
18
19
from dataclasses import dataclass
@@ -67,7 +68,7 @@ def __init__(
67
68
self .destination_state_shard_info = destination_state_shard_info
68
69
self .optim_state_name = [
69
70
".w_0" ,
70
- ".moment1_0 " ,
71
+ ".moment1_0" ,
71
72
".moment2_0" ,
72
73
".beta1_pow_acc_0" ,
73
74
".beta2_pow_acc_0" ,
@@ -114,11 +115,13 @@ def get_dst_state_shard_num(self, dst_state_key: str) -> int:
114
115
raise KeyError (
115
116
f"dst_state_key '{ dst_state_key } ' not in destination_state_shard_info"
116
117
)
118
+
119
+ new_state_key = dst_state_key
117
120
for state_name in self .optim_state_name :
118
121
if state_name in dst_state_key :
119
122
new_state_key = dst_state_key .replace (state_name , "" )
120
123
break
121
- new_state_key = dst_state_key
124
+
122
125
shard_infos = self .destination_state_shard_info [new_state_key ]
123
126
global_offset_set = set ()
124
127
for shard_info in shard_infos :
@@ -148,9 +151,7 @@ def __init__(
148
151
self .input_vars = self .build_input_vars ()
149
152
self .output_vars = {}
150
153
self .need_remove_input_vars = set ()
151
- self .need_remove_output_vars = set ()
152
- self .need_transpose_output_vars = set ()
153
- self .need_transpose_input_vars = {}
154
+ self .need_add_output_vars = set ()
154
155
155
156
self .shape_propagation ()
156
157
@@ -176,7 +177,7 @@ def split(
176
177
sub_slices = []
177
178
for aidx , src_sl , dst_sl , pp_list in tensor .slices :
178
179
if pp_list is not None :
179
- src_sl = self . postprocess_transpose (list (src_sl ), pp_list )
180
+ src_sl = postprocess_transpose (list (src_sl ), pp_list )
180
181
181
182
dst_start = (
182
183
dst_sl [axis ].start if dst_sl [axis ].start is not None else 0
@@ -206,7 +207,7 @@ def split(
206
207
inter_begin - start , inter_begin - start + length
207
208
)
208
209
if pp_list is not None :
209
- sub_src_sl = self . postprocess_transpose (
210
+ sub_src_sl = postprocess_transpose (
210
211
list (sub_src_sl ), pp_list , reverse = True
211
212
)
212
213
sub_slices .append (
@@ -256,17 +257,19 @@ def concat(self, tensors: list[TensorDesc], axis: int) -> TensorDesc:
256
257
curr += t .shape [axis ]
257
258
return TensorDesc (slices , tuple (shape ))
258
259
259
- def transpose (self , tensor : TensorDesc , transpose : str ) -> TensorDesc :
260
+ def transpose (self , tensor : TensorDesc , permutation : str ) -> TensorDesc :
260
261
slices = []
261
- tensor_shape = transpose_list (tensor .shape , eval (transpose ))
262
+ tensor_shape = transpose_list (
263
+ tensor .shape , ast .literal_eval (permutation )
264
+ )
262
265
for aidx , src_sl , dst_sl , pp_list in tensor .slices :
263
- trans_dst_sl = transpose_list (dst_sl , eval ( transpose ))
266
+ trans_dst_sl = transpose_list (dst_sl , ast . literal_eval ( permutation ))
264
267
if pp_list is not None :
265
268
new_pp_list = pp_list .copy ()
266
- new_pp_list .append (transpose )
269
+ new_pp_list .append (permutation )
267
270
slices .append ((aidx , src_sl , trans_dst_sl , new_pp_list ))
268
271
else :
269
- slices .append ((aidx , src_sl , trans_dst_sl , [transpose ]))
272
+ slices .append ((aidx , src_sl , trans_dst_sl , [permutation ]))
270
273
return TensorDesc (slices , tensor_shape )
271
274
272
275
def cast (self , tensor : TensorDesc , dtype : str ) -> TensorDesc :
@@ -295,7 +298,6 @@ def _get_var_ref(var):
295
298
left_vars = stmt .left_vars
296
299
right_vars = stmt .right_vars
297
300
attrs = stmt .attrs
298
-
299
301
if len (left_vars ) > 1 or len (right_vars ) > 1 :
300
302
if not (len (attrs ) == 1 and attrs [0 ].key == "axis" ):
301
303
raise ValueError (
@@ -338,47 +340,49 @@ def _get_var_ref(var):
338
340
if rvar .name == "_" :
339
341
self .need_remove_input_vars .add (lvar .name )
340
342
elif lvar .name == "_" :
341
- self .need_remove_output_vars .add (rvar .name )
343
+ self .need_add_output_vars .add (rvar .name )
342
344
else :
343
- if attrs :
345
+ if len ( attrs ) > 0 :
344
346
for attr in attrs :
345
347
in_ref = _get_var_ref (lvar )
346
- if attr .key == "transpose " :
348
+ if attr .key == "permute " :
347
349
if attr .value == "[]" :
348
350
ndim = len (in_ref .shape )
349
- transpose = str (
350
- list (range (ndim - 1 , - 1 , - 1 ))
351
- )
351
+ perm = str (list (range (ndim - 1 , - 1 , - 1 )))
352
352
else :
353
- transpose = attr .value
354
- result = self .transpose (in_ref , transpose )
353
+ perm = attr .value
354
+ result = self .transpose (in_ref , perm )
355
355
elif attr .key == "dtype" :
356
356
result = self .cast (in_ref , attr .value )
357
+ elif attr .key == "axis" :
358
+ pass
357
359
else :
358
360
raise ValueError (
359
361
f"Unsupported attribute: { attr } "
360
362
)
361
363
362
- out_name = rvar .name
363
- intermediate_vars [out_name ] = result
364
+ intermediate_vars [rvar .name ] = result
364
365
if (
365
- out_name
366
+ rvar . name
366
367
in self .context .get_all_dst_state_keys ()
367
368
):
368
- self .output_vars [out_name ] = result
369
+ self .output_vars [rvar . name ] = result
369
370
else :
370
- intermediate_vars [rvar .name ] = _get_var_ref (lvar )
371
+ in_ref = _get_var_ref (lvar )
372
+ intermediate_vars [rvar .name ] = in_ref
371
373
if rvar .name in self .context .get_all_dst_state_keys ():
372
- self .output_vars [rvar .name ] = intermediate_vars [
373
- rvar .name
374
- ]
374
+ self .output_vars [rvar .name ] = in_ref
375
+
375
376
else :
376
377
raise SyntaxError (f'Unexpected statement: { stmt } ' )
377
378
378
379
for name in self .destination_state_shard_info .keys ():
379
380
if name not in self .output_vars :
380
- assert name in self .input_vars
381
- self .output_vars [name ] = self .input_vars [name ]
381
+ if name in self .need_add_output_vars :
382
+ self .output_vars [name ] = None
383
+ else :
384
+ assert name in self .input_vars
385
+ self .output_vars [name ] = self .input_vars [name ]
382
386
383
387
def find_source_slices (
384
388
self , key : str , local_slice : tuple [slice , ...]
@@ -406,7 +410,7 @@ def slice_intersect(a: slice, b: slice):
406
410
else :
407
411
# Compute corresponding src_slice for the intersection
408
412
if pp_list is not None :
409
- sl_src = self . postprocess_transpose (list (sl_src ), pp_list )
413
+ sl_src = postprocess_transpose (list (sl_src ), pp_list )
410
414
src_slice = []
411
415
for i in range (ndim ):
412
416
dst = sl_dst [i ]
@@ -424,7 +428,7 @@ def slice_intersect(a: slice, b: slice):
424
428
)
425
429
src_slice .append (slice (src_inter_start , src_inter_stop , 1 ))
426
430
if pp_list is not None :
427
- src_slice = self . postprocess_transpose (
431
+ src_slice = postprocess_transpose (
428
432
list (src_slice ), pp_list , reverse = True
429
433
)
430
434
results .append (
@@ -484,6 +488,14 @@ def find_shard_sources(
484
488
tgt_global_offset ,
485
489
)
486
490
491
+ if source_sharded_weight .key in self .need_remove_input_vars :
492
+ mapping_entry = ShardMappingEntry (
493
+ target_sharded_weight ,
494
+ source_sharded_weight ,
495
+ [],
496
+ )
497
+ continue
498
+
487
499
shard_mappings .append (
488
500
ShardMappingEntry (
489
501
target_sharded_weight ,
@@ -493,23 +505,23 @@ def find_shard_sources(
493
505
)
494
506
return shard_mappings
495
507
496
- def postprocess_transpose (
497
- self ,
498
- li : list [tuple [slice , ...]] | tuple [tuple [slice , ...]],
499
- postprocess_list : list [str ],
500
- reverse : bool = False ,
501
- ) -> list [tuple [slice , ...]] | tuple [tuple [slice , ...]]:
502
- result = li
503
- if reverse :
504
- for pp in list (reversed (postprocess_list )):
505
- if pp .startswith ("[" ):
506
- reversed_transpose = np .argsort (eval (pp )).tolist ()
507
- result = transpose_list (result , reversed_transpose )
508
- else :
509
- for pp in postprocess_list :
510
- if pp .startswith ("[" ):
511
- result = transpose_list (result , eval (pp ))
512
- return result
508
+
509
+ def postprocess_transpose (
510
+ li : list [tuple [slice , ...]] | tuple [tuple [slice , ...]],
511
+ postprocess_list : list [str ],
512
+ reverse : bool = False ,
513
+ ) -> list [tuple [slice , ...]] | tuple [tuple [slice , ...]]:
514
+ result = li
515
+ if reverse :
516
+ for pp in list (reversed (postprocess_list )):
517
+ if pp .startswith ("[" ):
518
+ reversed_transpose = np .argsort (ast . literal_eval (pp )).tolist ()
519
+ result = transpose_list (result , reversed_transpose )
520
+ else :
521
+ for pp in postprocess_list :
522
+ if pp .startswith ("[" ):
523
+ result = transpose_list (result , ast . literal_eval (pp ))
524
+ return result
513
525
514
526
515
527
def transpose_list (
0 commit comments