@@ -326,7 +326,8 @@ def _unquantized_generator(self, hf_weights_files, use_safetensors,
326
326
327
327
global_tp_size = get_tensor_model_parallel_world_size ()
328
328
global_tp_rank = get_tensor_model_parallel_rank ()
329
-
329
+ check_match = (lambda weight_name , module_name : weight_name .
330
+ removesuffix (".weight" ) == module_name )
330
331
for (
331
332
org_weight_name ,
332
333
mapped_weight_name ,
@@ -347,12 +348,12 @@ def _unquantized_generator(self, hf_weights_files, use_safetensors,
347
348
) and mapped_weight_name .endswith (".weight" ):
348
349
# Without sharding
349
350
if any (
350
- mapped_weight_name . startswith ( module )
351
+ check_match ( mapped_weight_name , module )
351
352
for module in self .unsharded_weights_modules ):
352
353
weight_sub_tensor = weight_tensor
353
354
# Shard by column
354
355
elif any (
355
- mapped_weight_name . startswith ( module )
356
+ check_match ( mapped_weight_name , module )
356
357
for module in self .column_sharded_weights_modules ):
357
358
total_size = weight_tensor .size (- 1 )
358
359
start_index = total_size // tp_size * tp_rank
@@ -362,14 +363,14 @@ def _unquantized_generator(self, hf_weights_files, use_safetensors,
362
363
# Weights have fused on disk. In this case, we assume that the
363
364
# weight and module use same name.
364
365
elif any (
365
- mapped_weight_name . startswith ( module )
366
+ check_match ( mapped_weight_name , module )
366
367
for module in self .maybe_fused_weights_modules ):
367
368
# special case for fused weights
368
369
# get the size of each shard weight tensor
369
370
total_shard_sizes = next (
370
371
(sizes for module , sizes in
371
372
self .maybe_fused_weights_modules .items ()
372
- if mapped_weight_name . startswith ( module )))
373
+ if check_match ( mapped_weight_name , module )))
373
374
total_size = weight_tensor .size (0 )
374
375
assert total_size == sum (total_shard_sizes )
375
376
# get the start/end index of each shard weight tensor
0 commit comments