@@ -299,6 +299,32 @@ def naive_fuse_split_tp(
299
299
300
300
"""
301
301
axis = - 1 if is_column else 0
302
+ if "PySafeSlice" in str (type (weight )):
303
+ size = weight .get_shape ()[axis ]
304
+ block_size = size // (fuse_tensor_parts * tensor_parallel_degree )
305
+
306
+ splited = []
307
+ if tensor_parallel_rank is None :
308
+ begin , end , step = 0 , fuse_tensor_parts * tensor_parallel_degree , 1
309
+ else :
310
+ begin , end , step = tensor_parallel_rank , fuse_tensor_parts * tensor_parallel_degree , tensor_parallel_degree
311
+ for rank in range (begin , end , step ):
312
+ start = rank * block_size
313
+ stop = (rank + 1 ) * block_size
314
+ if axis == 0 or len (weight .get_shape ()) == 1 :
315
+ tensor = weight [start :stop ]
316
+ else :
317
+ tensor = weight [:, start :stop ]
318
+ splited .append (tensor )
319
+
320
+ if tensor_parallel_rank is None :
321
+ ret = []
322
+ for tensor_parallel_rank in range (tensor_parallel_degree ):
323
+ ret .append (np .concatenate (splited [tensor_parallel_rank ::tensor_parallel_degree ], axis = axis ))
324
+ return ret
325
+
326
+ return np .concatenate (splited , axis = axis )
327
+
302
328
splited = np .split (weight , fuse_tensor_parts * tensor_parallel_degree , axis = axis )
303
329
304
330
if tensor_parallel_rank is None :
0 commit comments