Skip to content

Commit 40e3574

Browse files
bbartelsDarkLight1337
authored andcommitted
Adds parallel model weight loading for runai_streamer (vllm-project#21330)
Signed-off-by: bbartels <[email protected]> Co-authored-by: Cyrus Leung <[email protected]> Signed-off-by: avigny <[email protected]>
1 parent 134135f commit 40e3574

File tree

2 files changed

+16
-9
lines changed

2 files changed

+16
-9
lines changed

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -659,7 +659,8 @@ def _read_requirements(filename: str) -> list[str]:
659659
"bench": ["pandas", "datasets"],
660660
"tensorizer": ["tensorizer==2.10.1"],
661661
"fastsafetensors": ["fastsafetensors >= 0.1.10"],
662-
"runai": ["runai-model-streamer", "runai-model-streamer-s3", "boto3"],
662+
"runai":
663+
["runai-model-streamer >= 0.13.3", "runai-model-streamer-s3", "boto3"],
663664
"audio": ["librosa", "soundfile",
664665
"mistral_common[audio]"], # Required for audio processing
665666
"video": [] # Kept for backwards compatibility

vllm/model_executor/model_loader/weight_utils.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -482,14 +482,20 @@ def runai_safetensors_weights_iterator(
482482
) -> Generator[tuple[str, torch.Tensor], None, None]:
483483
"""Iterate over the weights in the model safetensor files."""
484484
with SafetensorsStreamer() as streamer:
485-
for st_file in tqdm(
486-
hf_weights_files,
487-
desc="Loading safetensors using Runai Model Streamer",
488-
disable=not enable_tqdm(use_tqdm_on_load),
489-
bar_format=_BAR_FORMAT,
490-
):
491-
streamer.stream_file(st_file)
492-
yield from streamer.get_tensors()
485+
streamer.stream_files(hf_weights_files)
486+
total_tensors = sum(
487+
len(tensors_meta)
488+
for tensors_meta in streamer.files_to_tensors_metadata.values())
489+
490+
tensor_iter = tqdm(
491+
streamer.get_tensors(),
492+
total=total_tensors,
493+
desc="Loading safetensors using Runai Model Streamer",
494+
bar_format=_BAR_FORMAT,
495+
disable=not enable_tqdm(use_tqdm_on_load),
496+
)
497+
498+
yield from tensor_iter
493499

494500

495501
def fastsafetensors_weights_iterator(

0 commit comments

Comments
 (0)