1313from collections import defaultdict
1414from typing import Any , Dict , List , Optional , Tuple
1515
16+ import psutil
1617import safetensors
1718import torch
1819import torch ._dynamo .config
1920
2021import tensorrt_llm .bindings .internal .userbuffers as ub
2122from tensorrt_llm ._torch .pyexecutor .sampler import SampleStateTensors
2223from tensorrt_llm ._torch .speculative .mtp import SampleStateTensorsMTP
23- from tensorrt_llm ._utils import (is_trace_enabled , nvtx_range , release_gc ,
24+ from tensorrt_llm ._utils import (is_trace_enabled , local_mpi_rank ,
25+ local_mpi_size , nvtx_range , release_gc ,
2426 torch_dtype_to_str , trace_func )
2527from tensorrt_llm .bindings .executor import GuidedDecodingConfig
2628from tensorrt_llm .logger import logger
@@ -132,6 +134,14 @@ def validate_and_set_kv_cache_quant(model_config: ModelConfig,
132134 model_config .quant_config .kv_cache_quant_algo = mapped_pyt_quant
133135
134136
137+ def _prefetch_one_file (file_name , rank ):
138+ if os .path .exists (file_name ):
139+ logger .info (f"Rank { rank } prefetching { file_name } to memory..." )
140+ with open (file_name , 'rb' ) as f :
141+ f .read ()
142+ logger .info (f"Rank { rank } finished prefetching { file_name } ." )
143+
144+
135145def prefetch_files (file_names : List [str ], mapping : Mapping ):
136146 """
137147 Prefetch safetensors files to memory so that the weight loading will be much faster.
@@ -140,33 +150,35 @@ def prefetch_files(file_names: List[str], mapping: Mapping):
140150 heuristics about when to prefetch and when not to.
141151 """
142152
143- def _prefetch_one_file (file_name , rank ):
144- if os .path .exists (file_name ):
145- logger .info (f"Rank { rank } prefetching { file_name } to memory..." )
146- with open (file_name , 'rb' ) as f :
147- f .read ()
148- logger .info (f"Rank { rank } finished prefetching { file_name } ." )
149-
150153 # Find out the files to prefetch for the current rank.
151- # Each rank loads files with indices rank, rank + world_size, rank + 2*world_size, etc.
152- local_file_names = file_names [mapping .rank ::mapping .world_size ]
153-
154- processes = []
155- for file_name in local_file_names :
156- process = multiprocessing .Process (target = _prefetch_one_file ,
157- args = (file_name , mapping .rank ))
158- process .start ()
159- processes .append (process )
160-
161- for process in processes :
162- process .join ()
154+ # Each rank loads files with indices local_rank, local_rank + local_mpi_size, local_rank + 2*local_mpi_size, etc.
155+ local_file_names = file_names [local_mpi_rank ()::local_mpi_size ()]
156+
157+ max_processes = min (multiprocessing .cpu_count () * 2 , 16 )
158+ with multiprocessing .Pool (processes = max_processes ) as pool :
159+ pool .starmap (
160+ _prefetch_one_file ,
161+ [(file_name , mapping .rank ) for file_name in local_file_names ],
162+ )
163163
164164
165- def load_weights (checkpoint_dir : str , mapping : Mapping ):
165+ def load_weights (
166+ checkpoint_dir : str ,
167+ mapping : Mapping ,
168+ ):
166169 weights = {}
167170 weight_files = glob .glob (f"{ checkpoint_dir } /*.safetensors" )
168171 if weight_files :
169- prefetch_files (weight_files , mapping )
172+ # Prefetch the weight files to CPU memory if the size is less than 90% of the available memory.
173+ # This is a heuristic to avoid prefetching files that are too large and causing file cache thrashing.
174+ prefetch_size = sum (os .path .getsize (file ) for file in weight_files )
175+ enable_prefetch = prefetch_size < psutil .virtual_memory (
176+ ).available * 0.9
177+ if enable_prefetch :
178+ logger .info (
179+ f"Prefetching { prefetch_size / (1024 ** 3 ):.2f} GB checkpoint files."
180+ )
181+ prefetch_files (weight_files , mapping )
170182 for file in weight_files :
171183 logger .info (f"Loading { file } " )
172184 part_weights = safetensors .torch .load_file (file )
@@ -931,16 +943,22 @@ def init_meta_tensor(t: torch.Tensor):
931943 model = AutoModelForCausalLM .from_config (config )
932944
933945 model .to ("cuda" )
946+ rank_model_storage = get_rank_model_storage (model )
934947 logger .info (
935- f"Rank { self .mapping .rank } uses { get_rank_model_storage ( model ) / (1024 ** 3 ):.2f} GB for model weights."
948+ f"Rank { self .mapping .rank } uses { rank_model_storage / (1024 ** 3 ):.2f} GB for model weights."
936949 )
937950
938951 if load_format == LoadFormat .AUTO :
939952 if hasattr (model , 'llm_checkpoint_dir' ):
940- weights = load_weights (model .llm_checkpoint_dir ,
941- self .mapping )
953+ weights = load_weights (
954+ model .llm_checkpoint_dir ,
955+ self .mapping ,
956+ )
942957 else :
943- weights = load_weights (checkpoint_dir , self .mapping )
958+ weights = load_weights (
959+ checkpoint_dir ,
960+ self .mapping ,
961+ )
944962
945963 model .load_weights (weights )
946964
0 commit comments