Skip to content

Commit ec796e4

Browse files
authored
feat: add heuristics for checkpoint files prefetching. (NVIDIA#4765)
Signed-off-by: Yuxian Qiu <[email protected]>
1 parent 7ce1e13 commit ec796e4

File tree

1 file changed

+44
-26
lines changed

1 file changed

+44
-26
lines changed

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 44 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,16 @@
1313
from collections import defaultdict
1414
from typing import Any, Dict, List, Optional, Tuple
1515

16+
import psutil
1617
import safetensors
1718
import torch
1819
import torch._dynamo.config
1920

2021
import tensorrt_llm.bindings.internal.userbuffers as ub
2122
from tensorrt_llm._torch.pyexecutor.sampler import SampleStateTensors
2223
from tensorrt_llm._torch.speculative.mtp import SampleStateTensorsMTP
23-
from tensorrt_llm._utils import (is_trace_enabled, nvtx_range, release_gc,
24+
from tensorrt_llm._utils import (is_trace_enabled, local_mpi_rank,
25+
local_mpi_size, nvtx_range, release_gc,
2426
torch_dtype_to_str, trace_func)
2527
from tensorrt_llm.bindings.executor import GuidedDecodingConfig
2628
from tensorrt_llm.logger import logger
@@ -132,6 +134,14 @@ def validate_and_set_kv_cache_quant(model_config: ModelConfig,
132134
model_config.quant_config.kv_cache_quant_algo = mapped_pyt_quant
133135

134136

137+
def _prefetch_one_file(file_name, rank):
138+
if os.path.exists(file_name):
139+
logger.info(f"Rank {rank} prefetching {file_name} to memory...")
140+
with open(file_name, 'rb') as f:
141+
f.read()
142+
logger.info(f"Rank {rank} finished prefetching {file_name}.")
143+
144+
135145
def prefetch_files(file_names: List[str], mapping: Mapping):
136146
"""
137147
Prefetch safetensors files to memory so that the weight loading will be much faster.
@@ -140,33 +150,35 @@ def prefetch_files(file_names: List[str], mapping: Mapping):
140150
heuristics about when to prefetch and when not to.
141151
"""
142152

143-
def _prefetch_one_file(file_name, rank):
144-
if os.path.exists(file_name):
145-
logger.info(f"Rank {rank} prefetching {file_name} to memory...")
146-
with open(file_name, 'rb') as f:
147-
f.read()
148-
logger.info(f"Rank {rank} finished prefetching {file_name}.")
149-
150153
# Find out the files to prefetch for the current rank.
151-
# Each rank loads files with indices rank, rank + world_size, rank + 2*world_size, etc.
152-
local_file_names = file_names[mapping.rank::mapping.world_size]
153-
154-
processes = []
155-
for file_name in local_file_names:
156-
process = multiprocessing.Process(target=_prefetch_one_file,
157-
args=(file_name, mapping.rank))
158-
process.start()
159-
processes.append(process)
160-
161-
for process in processes:
162-
process.join()
154+
# Each rank loads files with indices local_rank, local_rank + local_mpi_size, local_rank + 2*local_mpi_size, etc.
155+
local_file_names = file_names[local_mpi_rank()::local_mpi_size()]
156+
157+
max_processes = min(multiprocessing.cpu_count() * 2, 16)
158+
with multiprocessing.Pool(processes=max_processes) as pool:
159+
pool.starmap(
160+
_prefetch_one_file,
161+
[(file_name, mapping.rank) for file_name in local_file_names],
162+
)
163163

164164

165-
def load_weights(checkpoint_dir: str, mapping: Mapping):
165+
def load_weights(
166+
checkpoint_dir: str,
167+
mapping: Mapping,
168+
):
166169
weights = {}
167170
weight_files = glob.glob(f"{checkpoint_dir}/*.safetensors")
168171
if weight_files:
169-
prefetch_files(weight_files, mapping)
172+
# Prefetch the weight files to CPU memory if the size is less than 90% of the available memory.
173+
# This is a heuristic to avoid prefetching files that are too large and causing file cache thrashing.
174+
prefetch_size = sum(os.path.getsize(file) for file in weight_files)
175+
enable_prefetch = prefetch_size < psutil.virtual_memory(
176+
).available * 0.9
177+
if enable_prefetch:
178+
logger.info(
179+
f"Prefetching {prefetch_size / (1024**3):.2f}GB checkpoint files."
180+
)
181+
prefetch_files(weight_files, mapping)
170182
for file in weight_files:
171183
logger.info(f"Loading {file}")
172184
part_weights = safetensors.torch.load_file(file)
@@ -931,16 +943,22 @@ def init_meta_tensor(t: torch.Tensor):
931943
model = AutoModelForCausalLM.from_config(config)
932944

933945
model.to("cuda")
946+
rank_model_storage = get_rank_model_storage(model)
934947
logger.info(
935-
f"Rank {self.mapping.rank} uses {get_rank_model_storage(model) / (1024**3):.2f} GB for model weights."
948+
f"Rank {self.mapping.rank} uses {rank_model_storage / (1024**3):.2f} GB for model weights."
936949
)
937950

938951
if load_format == LoadFormat.AUTO:
939952
if hasattr(model, 'llm_checkpoint_dir'):
940-
weights = load_weights(model.llm_checkpoint_dir,
941-
self.mapping)
953+
weights = load_weights(
954+
model.llm_checkpoint_dir,
955+
self.mapping,
956+
)
942957
else:
943-
weights = load_weights(checkpoint_dir, self.mapping)
958+
weights = load_weights(
959+
checkpoint_dir,
960+
self.mapping,
961+
)
944962

945963
model.load_weights(weights)
946964

0 commit comments

Comments
 (0)