Skip to content

Commit d178858

Browse files
committed
stream update weights verified
1 parent af0a668 commit d178858

File tree

2 files changed

+98
-25
lines changed

2 files changed

+98
-25
lines changed

tensorrt_llm/_torch/utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,3 +318,15 @@ def get_device_uuid(device_idx: int) -> str:
318318
raise RuntimeError(
319319
f"Failed to get device UUID for device {device_idx} (global index: {global_device_idx}): {e}"
320320
)
321+
322+
def get_free_memory_bytes(device_idx: int) -> float:
323+
"""Get the free memory of a CUDA device in bytes using NVML."""
324+
global_device_idx = device_id_to_physical_device_id(device_idx)
325+
with nvml_context():
326+
try:
327+
handle = pynvml.nvmlDeviceGetHandleByIndex(global_device_idx)
328+
return pynvml.nvmlDeviceGetMemoryInfo(handle).free
329+
except pynvml.NVMLError as e:
330+
raise RuntimeError(
331+
f"Failed to get free memory for device {device_idx} (global index: {global_device_idx}): {e}"
332+
)

tests/unittest/llmapi/test_llm_update_weights.py

Lines changed: 86 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch.distributed as dist
44
import atexit
55
import os
6-
from typing import Any
6+
from typing import Any, Optional
77
from tensorrt_llm import SamplingParams
88
from tensorrt_llm import LLM
99
from tensorrt_llm.llmapi.llm_args import KvCacheConfig
@@ -17,10 +17,6 @@
1717
#from torch.distributed.fsdp.api import ShardedStateDictConfig, StateDictType
1818
from transformers import AutoModelForCausalLM, AutoTokenizer
1919

20-
import contextlib
21-
from typing import Generator
22-
import pynvml
23-
2420

2521
def init_distributed():
2622
"""Initialize distributed training"""
@@ -129,15 +125,29 @@ def prepare_weights_for_ipc(self) -> tuple[list[tuple[str, int]], float]:
129125
self._held_sharded_state_dict_reference = self.model.state_dict()
130126

131127
# Collect info for streaming multiple tensors
132-
state_dict_info = []
128+
### state_dict_info = []
129+
### for name, tensor in self._held_sharded_state_dict_reference.items():
130+
### # dtensor's numel will return complete tensor instead of only local tensor
131+
### size_in_bytes = tensor.element_size() * tensor.numel()
132+
### state_dict_info.append((name, size_in_bytes))
133+
self.refit_param_info = []
133134
for name, tensor in self._held_sharded_state_dict_reference.items():
134135
# dtensor's numel will return complete tensor instead of only local tensor
135136
size_in_bytes = tensor.element_size() * tensor.numel()
136-
state_dict_info.append((name, size_in_bytes))
137+
self.refit_param_info.append((name, size_in_bytes))
137138

139+
from tensorrt_llm._torch.utils import get_free_memory_bytes
138140
#print(f"State dict info: {state_dict_info}")
141+
# Collect current available memory for refit
142+
## Get current device index from torch
143+
device_idx = torch.cuda.current_device()
144+
## Get device free memory using NVML
145+
total_available_bytes = get_free_memory_bytes(device_idx)
146+
## Use 80% of the free memory for safety
147+
memory_ratio = os.getenv("NRL_REFIT_BUFFER_MEMORY_RATIO", "0.8")
148+
total_available_bytes *= float(memory_ratio)
139149

140-
return state_dict_info
150+
return self.refit_param_info, total_available_bytes
141151

142152
@torch.no_grad()
143153
def get_weights_ipc_handles(self, keys: list[str]) -> dict[str, Any]:
@@ -183,6 +193,44 @@ def get_weights_ipc_handles(self, keys: list[str]) -> dict[str, Any]:
183193
print(f"device_uuid: {device_uuid}")
184194
return {device_uuid: all_handles}
185195

196+
@torch.no_grad()
197+
def prepare_weights_for_ipc_refit(
198+
self, _refit_buffer_size_gb: Optional[int] = None
199+
) -> list[list[str]]:
200+
"""Prepare the weights for IPC.
201+
202+
Returns:
203+
list: A list containing the keys of the parameters, which is grouped by size.
204+
"""
205+
# Get the state_dict_info and available memory from all workers
206+
state_dict_info = self.refit_param_info
207+
208+
if _refit_buffer_size_gb is not None:
209+
total_available_bytes = _refit_buffer_size_gb * (1024**3)
210+
else:
211+
# Get the minimum available memory from all workers
212+
total_available_bytes = min(result[1] for result in state_dict_info)
213+
214+
# Group tensors by size
215+
cur_available_bytes = total_available_bytes
216+
grouped_param_keys: list[list[str]] = []
217+
keys: list[str] = []
218+
219+
for key, size_in_bytes in state_dict_info:
220+
if size_in_bytes > cur_available_bytes:
221+
if keys:
222+
grouped_param_keys.append(keys)
223+
keys = []
224+
cur_available_bytes = total_available_bytes
225+
226+
keys.append(key)
227+
cur_available_bytes -= size_in_bytes
228+
229+
if keys:
230+
grouped_param_keys.append(keys)
231+
232+
return grouped_param_keys
233+
186234
class trtllm_interface:
187235
def __init__(self, model_dir, tensor_parallel_size):
188236
self.world_size = dist.get_world_size()
@@ -202,6 +250,7 @@ def load_trtllm_model(self, model_dir, tensor_parallel_size):
202250
#load_format='auto'
203251
load_format='dummy',
204252
kv_cache_config=KvCacheConfig(
253+
free_gpu_memory_fraction=0.85,
205254
enable_block_reuse=False
206255
)
207256
)
@@ -251,23 +300,9 @@ def main():
251300
fsdp = fsdp_interface(args.model_dir)
252301
trtllm = trtllm_interface(args.model_dir, args.tensor_parallel_size)
253302

254-
grouped_param_keys = [key for key,size in fsdp.prepare_weights_for_ipc()]
255-
handles = fsdp.get_weights_ipc_handles(grouped_param_keys)
256-
#print(f"handles: {handles}")
257-
258-
# Collect handles from all ranks
259-
all_handles = [None for _ in range(world_size)]
260-
dist.all_gather_object(all_handles, handles)
261-
all_handles = {k: v for d in all_handles for k, v in d.items()}
262-
print(f"all_handles: {all_handles.keys()}")
263-
264303
if rank == 0:
265304
print(f"Collected handles from all {world_size} ranks:")
266305

267-
# Now all_handles contains the handles from each rank
268-
# all_handles[0] = handles from rank 0
269-
# all_handles[1] = handles from rank 1, etc.
270-
271306
# For FSDP mode, we would need additional logic to integrate withTensorRT-LLM
272307
# This is a placeholder for now
273308
if rank == 0:
@@ -286,9 +321,34 @@ def main():
286321
result = trtllm.llm.wakeup()
287322
print(f"wakeup result: {result}")
288323

289-
result = trtllm.llm.update_weights_from_ipc_handles(all_handles)
290-
print(f"update weights result: {result}")
324+
dict_info, total_available_bytes = fsdp.prepare_weights_for_ipc()
325+
326+
grouped_param_keys = fsdp.prepare_weights_for_ipc_refit(0.5)
327+
total_num_keys = sum(len(k) for k in grouped_param_keys)
328+
print(
329+
f"[Refit] Split {total_num_keys} keys into {len(grouped_param_keys)} groups"
330+
)
291331

332+
from tensorrt_llm._torch.utils import get_free_memory_bytes
333+
for keys in grouped_param_keys:
334+
handles = fsdp.get_weights_ipc_handles(keys)
335+
#print(f"handles: {handles}")
336+
337+
# Collect handles from all ranks
338+
all_handles = [None for _ in range(world_size)]
339+
dist.all_gather_object(all_handles, handles)
340+
all_handles = {k: v for d in all_handles for k, v in d.items()}
341+
#print(f"all_handles: {all_handles.keys()}")
342+
343+
device_idx = torch.cuda.current_device()
344+
total_available_bytes = get_free_memory_bytes(device_idx)
345+
print(f"total_available_bytes: {total_available_bytes}")
346+
347+
if rank == 0:
348+
result = trtllm.llm.update_weights_from_ipc_handles(all_handles)
349+
print(f"update weights result: {result}")
350+
351+
if rank == 0:
292352
outputs = trtllm.llm.generate(prompts, sampling_params)
293353
for i, output in enumerate(outputs):
294354
prompt = output.prompt
@@ -299,4 +359,5 @@ def main():
299359
if __name__ == '__main__':
300360
main()
301361

302-
# torchrun --nproc_per_node=2 generate.py --model_dir /model/Qwen2.5-0.5B-Instruct --tensor_parallel_size 2
362+
# torchrun --nproc_per_node=2 tests/unittest/llmapi/test_llm_update_weights.py --model_dir /model/Qwen2.5-0.5B-Instruct --tensor_parallel_size 2
363+
# torchrun --nproc_per_node=2 tests/unittest/llmapi/test_llm_update_weights.py --model_dir /model/Qwen2.5-3B-Instruct/ --tensor_parallel_size 2

0 commit comments

Comments
 (0)