Skip to content

Commit 60a0951

Browse files
authored
[Bugfix] Fix BNB name match (vllm-project#24735)
Signed-off-by: Jee Jee Li <[email protected]>
1 parent 64d90c3 commit 60a0951

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

vllm/model_executor/model_loader/bitsandbytes_loader.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,8 @@ def _unquantized_generator(self, hf_weights_files, use_safetensors,
326326

327327
global_tp_size = get_tensor_model_parallel_world_size()
328328
global_tp_rank = get_tensor_model_parallel_rank()
329-
329+
check_match = (lambda weight_name, module_name: weight_name.
330+
removesuffix(".weight") == module_name)
330331
for (
331332
org_weight_name,
332333
mapped_weight_name,
@@ -347,12 +348,12 @@ def _unquantized_generator(self, hf_weights_files, use_safetensors,
347348
) and mapped_weight_name.endswith(".weight"):
348349
# Without sharding
349350
if any(
350-
mapped_weight_name.startswith(module)
351+
check_match(mapped_weight_name, module)
351352
for module in self.unsharded_weights_modules):
352353
weight_sub_tensor = weight_tensor
353354
# Shard by column
354355
elif any(
355-
mapped_weight_name.startswith(module)
356+
check_match(mapped_weight_name, module)
356357
for module in self.column_sharded_weights_modules):
357358
total_size = weight_tensor.size(-1)
358359
start_index = total_size // tp_size * tp_rank
@@ -362,14 +363,14 @@ def _unquantized_generator(self, hf_weights_files, use_safetensors,
362363
# Weights have fused on disk. In this case, we assume that the
363364
# weight and module use same name.
364365
elif any(
365-
mapped_weight_name.startswith(module)
366+
check_match(mapped_weight_name, module)
366367
for module in self.maybe_fused_weights_modules):
367368
# special case for fused weights
368369
# get the size of each shard weight tensor
369370
total_shard_sizes = next(
370371
(sizes for module, sizes in
371372
self.maybe_fused_weights_modules.items()
372-
if mapped_weight_name.startswith(module)))
373+
if check_match(mapped_weight_name, module)))
373374
total_size = weight_tensor.size(0)
374375
assert total_size == sum(total_shard_sizes)
375376
# get the start/end index of each shard weight tensor

0 commit comments

Comments
 (0)