Skip to content

Commit 9929e71

Browse files
committed
add gate/up bundled update works
1 parent d272f9f commit 9929e71

File tree

2 files changed

+144
-529
lines changed

2 files changed

+144
-529
lines changed

tests/unittest/llmapi/test_llm_update_weights.py

Lines changed: 144 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
)
1717
#from torch.distributed.fsdp.api import ShardedStateDictConfig, StateDictType
1818
from transformers import AutoModelForCausalLM, AutoTokenizer
19-
19+
from torch.distributed.tensor import DTensor
2020

2121
def init_distributed():
2222
"""Initialize distributed training"""
@@ -29,7 +29,7 @@ def init_distributed():
2929
if "MASTER_PORT" not in os.environ:
3030
os.environ["MASTER_PORT"] = "29500"
3131

32-
dist.init_process_group(backend="nccl")
32+
dist.init_process_group(backend="cpu:gloo,cuda:nccl")
3333
world_size = dist.get_world_size()
3434
rank = dist.get_rank()
3535
torch.cuda.set_device(rank)
@@ -39,6 +39,18 @@ def exit_distributed():
3939
"""Exit distributed training"""
4040
if dist.is_initialized():
4141
dist.destroy_process_group()
42+
43+
def report_device_id() -> str:
44+
"""Report the UUID of the current CUDA device using NVML.
45+
Returns:
46+
str: UUID of the device in the format "GPU-xxxxx"
47+
"""
48+
from tensorrt_llm._torch.utils import get_device_uuid
49+
# Get current device index from torch
50+
device_idx = torch.cuda.current_device()
51+
# Get device UUID using NVML
52+
return get_device_uuid(device_idx)
53+
4254
class fsdp_interface:
4355
def __init__(self, model_dir):
4456
self.model_dir = model_dir
@@ -96,17 +108,23 @@ def load_fsdp_model(self, model_dir):
96108
return fsdp_model
97109

98110

99-
def report_device_id(self) -> str:
100-
"""Report the UUID of the current CUDA device using NVML.
101111

102-
Returns:
103-
str: UUID of the device in the format "GPU-xxxxx"
104-
"""
105-
from tensorrt_llm._torch.utils import get_device_uuid
106-
# Get current device index from torch
107-
device_idx = torch.cuda.current_device()
108-
# Get device UUID using NVML
109-
return get_device_uuid(device_idx)
112+
def per_tensor_generator(self):
113+
# If the model is not FSDP, then we need to manually move it to the GPU
114+
# For an FSDP model, model.state_dict() will move the params to the GPU
115+
if not isinstance(self.model, FSDP):
116+
self.model = self.manual_load_to_gpu(self.model)
117+
self._held_sharded_state_dict_reference = self.model.state_dict()
118+
else:
119+
# Get sharded state dict instead of full state dict for FSDP1
120+
with FSDP.state_dict_type(
121+
self.model,
122+
state_dict_type=StateDictType.FULL_STATE_DICT,
123+
state_dict_config=FullStateDictConfig()
124+
):
125+
self._held_sharded_state_dict_reference = self.model.state_dict()
126+
for name, param in self._held_sharded_state_dict_reference.items():
127+
yield name, param
110128

111129
@torch.no_grad()
112130
def prepare_weights_for_ipc(self) -> tuple[list[tuple[str, int]], float]:
@@ -182,7 +200,7 @@ def get_weights_ipc_handles(self, keys: list[str]) -> dict[str, Any]:
182200
self._held_streamed_param_reference = converted_params
183201

184202
# Get device UUID for IPC
185-
device_uuid = self.report_device_id()
203+
device_uuid = report_device_id()
186204
# Create handles for the tensors
187205
all_handles = []
188206
for key, p in converted_params.items():
@@ -231,6 +249,25 @@ def prepare_weights_for_ipc_refit(
231249

232250
return grouped_param_keys
233251

252+
class NamedParam:
253+
def __init__(self, name, size, param):
254+
self.name = name
255+
self.size = size
256+
self.param = param
257+
258+
class GateAndUp:
259+
def __init__(self):
260+
self.gate = None
261+
self.up = None
262+
def set_gate(self, gate):
263+
self.gate = gate
264+
def set_up(self, up):
265+
self.up = up
266+
def get_size(self):
267+
return self.gate.size + self.up.size
268+
def is_complete(self):
269+
return self.gate is not None and self.up is not None
270+
234271
class trtllm_interface:
235272
def __init__(self, model_dir, tensor_parallel_size):
236273
self.world_size = dist.get_world_size()
@@ -257,13 +294,104 @@ def load_trtllm_model(self, model_dir, tensor_parallel_size):
257294
else:
258295
return None
259296

297+
def update_weights_from_ipc_handles(self, rank, device_handles):
298+
if rank == 0:
299+
gathered_handles = [None for _ in range(dist.get_world_size())]
300+
else:
301+
gathered_handles = None
302+
dist.gather_object(
303+
obj=device_handles,
304+
object_gather_list=gathered_handles,
305+
dst=0
306+
)
307+
if rank == 0:
308+
all_handles = {k: v for d in gathered_handles for k, v in d.items()}
309+
self.llm.update_weights_from_ipc_handles(all_handles)
310+
311+
def update_weights_from_tensor_generator(self, tensor_generator):
312+
device_uuid = report_device_id()
313+
rank = dist.get_rank()
314+
from torch.multiprocessing.reductions import reduce_tensor
315+
total_available_bytes = 0.7 * (1024**3)
316+
cur_available_bytes = total_available_bytes
317+
converted_params = {}
318+
cur_handles = []
319+
gate_up = {}
320+
for name, param in tensor_generator:
321+
size_in_bytes = param.element_size() * param.numel()
322+
if isinstance(param, DTensor):
323+
param = param.full_tensor()
324+
gate_up_name = None
325+
gate_up_pair = None
326+
if "gate_proj" in name:
327+
gate_up_name = name.replace("gate_proj", "")
328+
if (gate_up_name not in gate_up):
329+
gate_up[gate_up_name] = GateAndUp()
330+
assert gate_up[gate_up_name].gate is None
331+
gate_up[gate_up_name].set_gate(NamedParam(name, size_in_bytes, param))
332+
elif "up_proj" in name:
333+
gate_up_name = name.replace("up_proj", "")
334+
if (gate_up_name not in gate_up):
335+
gate_up[gate_up_name] = GateAndUp()
336+
assert gate_up[gate_up_name].up is None
337+
gate_up[gate_up_name].set_up(NamedParam(name, size_in_bytes, param))
338+
if (gate_up_name is not None):
339+
if gate_up[gate_up_name].is_complete():
340+
gate_up_pair = gate_up.pop(gate_up_name)
341+
size_in_bytes = gate_up_pair.get_size()
342+
else:
343+
continue
344+
345+
if size_in_bytes > cur_available_bytes:
346+
device_handles = {device_uuid: cur_handles}
347+
self.update_weights_from_ipc_handles(rank, device_handles)
348+
cur_available_bytes = total_available_bytes
349+
del converted_params
350+
converted_params = {}
351+
cur_handles = []
352+
353+
assert cur_available_bytes >= size_in_bytes
354+
cur_available_bytes -= size_in_bytes
355+
if (gate_up_pair is not None):
356+
converted_params[gate_up_pair.gate.name] = gate_up_pair.gate.param
357+
converted_params[gate_up_pair.up.name] = gate_up_pair.up.param
358+
handle = reduce_tensor(gate_up_pair.gate.param.detach())
359+
cur_handles.append((gate_up_pair.gate.name, handle))
360+
handle = reduce_tensor(gate_up_pair.up.param.detach())
361+
cur_handles.append((gate_up_pair.up.name, handle))
362+
gate_up_pair = None
363+
else:
364+
converted_params[name] = param
365+
handle = reduce_tensor(param.detach())
366+
cur_handles.append((name, handle))
367+
368+
assert len(gate_up) == 0
369+
370+
if cur_handles:
371+
device_handles = {device_uuid: cur_handles}
372+
self.update_weights_from_ipc_handles(rank, device_handles)
373+
cur_available_bytes = total_available_bytes
374+
del converted_params
375+
converted_params = {}
376+
cur_handles = []
377+
378+
def get_total_available_bytes(pg: dist.ProcessGroup, message: str = "") -> int:
379+
mem_allocated = torch.cuda.memory_allocated()
380+
mem_reserved = torch.cuda.memory_reserved()
381+
mem_free, mem_total = torch.cuda.mem_get_info()
382+
print(f"{message} mem_free: {mem_free:,}, mem_total: {mem_total:,}, mem_allocated: {mem_allocated:,}, mem_reserved: {mem_reserved:,}")
383+
mem_free = torch.tensor(mem_free)
384+
dist.all_reduce(mem_free, op=dist.ReduceOp.MIN, group=pg)
385+
mem_free = mem_free.item()
386+
print(f"{message} gathered_mem_free: {mem_free:,}")
387+
return mem_free * 0.2
388+
260389
def cleanup():
261390
"""Cleanup function to destroy process group"""
262391
if dist.is_initialized():
263392
print(f"Cleaning up process group on rank {dist.get_rank()}")
264393
dist.destroy_process_group()
265394

266-
267395
def main():
268396
parser = argparse.ArgumentParser(
269397
description="LLM models with the PyTorch workflow.")
@@ -306,7 +434,6 @@ def main():
306434
# For FSDP mode, we would need additional logic to integrate withTensorRT-LLM
307435
# This is a placeholder for now
308436
if rank == 0:
309-
310437
outputs = trtllm.llm.generate(prompts, sampling_params)
311438
for i, output in enumerate(outputs):
312439
prompt = output.prompt
@@ -321,33 +448,9 @@ def main():
321448
result = trtllm.llm.wakeup()
322449
print(f"wakeup result: {result}")
323450

324-
dict_info, total_available_bytes = fsdp.prepare_weights_for_ipc()
325-
326-
grouped_param_keys = fsdp.prepare_weights_for_ipc_refit(0.5)
327-
total_num_keys = sum(len(k) for k in grouped_param_keys)
328-
print(
329-
f"[Refit] Split {total_num_keys} keys into {len(grouped_param_keys)} groups"
330-
)
331-
332-
from tensorrt_llm._torch.utils import get_free_memory_bytes
333-
for keys in grouped_param_keys:
334-
handles = fsdp.get_weights_ipc_handles(keys)
335-
#print(f"handles: {handles}")
336-
337-
# Collect handles from all ranks
338-
all_handles = [None for _ in range(world_size)]
339-
dist.all_gather_object(all_handles, handles)
340-
all_handles = {k: v for d in all_handles for k, v in d.items()}
341-
#print(f"all_handles: {all_handles.keys()}")
342-
343-
device_idx = torch.cuda.current_device()
344-
total_available_bytes = get_free_memory_bytes(device_idx)
345-
print(f"total_available_bytes: {total_available_bytes}")
346-
347-
if rank == 0:
348-
result = trtllm.llm.update_weights_from_ipc_handles(all_handles)
349-
print(f"update weights result: {result}")
451+
trtllm.update_weights_from_tensor_generator(fsdp.per_tensor_generator())
350452

453+
# generate the output again
351454
if rank == 0:
352455
outputs = trtllm.llm.generate(prompts, sampling_params)
353456
for i, output in enumerate(outputs):

0 commit comments

Comments
 (0)