Skip to content

Commit d22f2a6

Browse files
committed
Fix meta tensor loading
Signed-off-by: Jiayu Chang <jiayuc@nvidia.com>
1 parent 1d61b14 commit d22f2a6

File tree

4 files changed

+126
-85
lines changed

4 files changed

+126
-85
lines changed

tensorrt_llm/_torch/peft/lora/cuda_graph_lora_params.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@ class LoraLayerParams:
2626
h_b_ptrs: torch.Tensor # Lora_in weight pointers in host
2727
h_b_prime_ptrs: torch.Tensor # Lora_out weight pointers in host
2828

29+
d_output_sizes: torch.Tensor
30+
d_output_sizes_offset: torch.Tensor
31+
h_output_sizes: torch.Tensor
32+
h_output_sizes_offset: torch.Tensor
33+
2934

3035
class CudaGraphLoraParams:
3136
"""
@@ -37,6 +42,10 @@ class CudaGraphLoraParams:
3742

3843
LoraLayerKey = namedtuple('LoraLayerKey', ['layer_idx', 'module_ids'])
3944

45+
PTR_DTYPE = torch.int64
46+
LD_DTYPE = torch.int64
47+
SIZES_DTYPE = torch.int32
48+
4049
@dataclass
4150
class LoraLayerInfo:
4251
module_num: int = 0
@@ -158,6 +167,14 @@ def _create_layer_params(
158167
# Base model requests are handled separately and don't participate in GEMM operations
159168
shape_2d = (layer_module_num, self.max_lora_size)
160169

170+
output_hidden_sizes = torch.tensor(module_output_sizes,
171+
dtype=self.SIZES_DTYPE)
172+
output_hidden_sizes_device = output_hidden_sizes.to(device='cuda')
173+
174+
output_sizes_offset = self.get_offset_from_counts(
175+
output_hidden_sizes).to(dtype=self.PTR_DTYPE) # [num_layer_modules]
176+
output_sizes_offset_device = output_sizes_offset.to(device='cuda')
177+
161178
return LoraLayerParams(
162179
# Weight pointers - managed by PEFT cache manager
163180
d_b_ptrs=torch.zeros(shape_2d,
@@ -169,7 +186,11 @@ def _create_layer_params(
169186
h_b_ptrs=torch.zeros(shape_2d, dtype=torch.int64, pin_memory=True),
170187
h_b_prime_ptrs=torch.zeros(shape_2d,
171188
dtype=torch.int64,
172-
pin_memory=True))
189+
pin_memory=True),
190+
d_output_sizes=output_hidden_sizes_device,
191+
d_output_sizes_offset=output_sizes_offset_device,
192+
h_output_sizes=output_hidden_sizes,
193+
h_output_sizes_offset=output_sizes_offset)
173194

174195
@staticmethod
175196
def get_sorted_indices(slot_ids: List[int]) -> torch.Tensor:

tensorrt_llm/_torch/peft/lora/layer.py

Lines changed: 54 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ class GroupedGemmParamsInput:
8080
b_ptrs: torch.Tensor
8181
b_prime_ptrs: torch.Tensor
8282
sorted_ids: torch.Tensor
83+
output_hidden_sizes: torch.Tensor
84+
output_sizes_offset: torch.Tensor
8385

8486
@property
8587
def slot_offsets(self):
@@ -242,27 +244,14 @@ def is_moe(self) -> bool:
242244

243245

244246
class LoraLayer(torch.nn.Module):
245-
PTR_DTYPE = torch.int64
246-
LD_DTYPE = torch.int64
247-
SIZES_DTYPE = torch.int32
248247

249248
def __init__(self, lora_module_types: List[LoraModuleType],
250249
output_hidden_sizes: List[int]):
251250
super().__init__()
252251

253252
self.lora_module_types = lora_module_types
254-
self.output_hidden_sizes = torch.tensor(output_hidden_sizes,
255-
dtype=self.SIZES_DTYPE)
256-
self.output_hidden_sizes_list = output_hidden_sizes
253+
self.output_hidden_sizes = output_hidden_sizes
257254
assert len(lora_module_types) == len(output_hidden_sizes)
258-
self.output_sizes_offset = CudaGraphLoraParams.get_offset_from_counts(
259-
self.output_hidden_sizes).to(
260-
dtype=self.PTR_DTYPE) # [num_layer_modules]
261-
if PARAM_PREP:
262-
self.output_sizes_offset_device = self.output_sizes_offset.to(
263-
device='cuda')
264-
self.output_hidden_size_device = self.output_hidden_sizes.to(
265-
device='cuda')
266255

267256
def forward(
268257
self,
@@ -307,7 +296,7 @@ def prepare_grouped_gemm_buffers(self, input: GroupedGemmParamsInput):
307296
# a [bs, hidden]
308297
lda = torch.full(shape_2d,
309298
input_hidden_size,
310-
dtype=self.LD_DTYPE,
299+
dtype=CudaGraphLoraParams.LD_DTYPE,
311300
device=device)
312301

313302
# b [input_hidden_size, lora_rank]
@@ -316,17 +305,17 @@ def prepare_grouped_gemm_buffers(self, input: GroupedGemmParamsInput):
316305
# a_prime / d [num_layer_modules, bs, max_rank]
317306
ldd = torch.full(shape_2d,
318307
input.max_rank,
319-
dtype=self.LD_DTYPE,
308+
dtype=CudaGraphLoraParams.LD_DTYPE,
320309
device=device)
321310

322311
# b_prime [lora_rank, module_output_size]
323312
ldb_prime = input.slot_ranks.unsqueeze(0).to(
324-
dtype=self.LD_DTYPE).repeat(shape_2d[0], 1)
313+
dtype=CudaGraphLoraParams.LD_DTYPE).repeat(shape_2d[0], 1)
325314

326315
# d_prime [bs, sum_of_each_module_output_sizes]
327316
ldd_prime = torch.full(shape_2d,
328317
sum_out_sizes,
329-
dtype=self.LD_DTYPE,
318+
dtype=CudaGraphLoraParams.LD_DTYPE,
330319
device=device)
331320

332321
# reordered a [bs, hidden], each module has the same offset
@@ -335,13 +324,13 @@ def prepare_grouped_gemm_buffers(self, input: GroupedGemmParamsInput):
335324

336325
# d [num_layer_modules, bs, max_rank]
337326
d_offset = (input.slot_offsets.unsqueeze(0) + torch.arange(
338-
shape_2d[0], device=device, dtype=self.PTR_DTYPE).unsqueeze(1) *
339-
bs) * input.max_rank
327+
shape_2d[0], device=device, dtype=CudaGraphLoraParams.PTR_DTYPE).
328+
unsqueeze(1) * bs) * input.max_rank
340329

341330
# d' [bs, sum_of_each_module_output_sizes]
342331
bs_offset = input.slot_offsets.unsqueeze(0) # [1, max_lora_size]
343332
bs_offset = bs_offset * sum_out_sizes
344-
out_offset = self.output_sizes_offset_device.unsqueeze(
333+
out_offset = input.output_sizes_offset.unsqueeze(
345334
1) # [num_layer_modules, 1]
346335
d_prime_offset = bs_offset + out_offset
347336
'''
@@ -350,12 +339,14 @@ def prepare_grouped_gemm_buffers(self, input: GroupedGemmParamsInput):
350339
'''
351340

352341
# sizes
353-
in_sizes = torch.empty(shape_3d, dtype=self.SIZES_DTYPE, device=device)
342+
in_sizes = torch.empty(shape_3d,
343+
dtype=CudaGraphLoraParams.SIZES_DTYPE,
344+
device=device)
354345
out_sizes = torch.empty_like(in_sizes)
355346

356347
slot_counts = input.slot_counts.unsqueeze(0) # [1, max_lora_size]
357348
ranks = input.slot_ranks.unsqueeze(0) # [1, max_lora_size]
358-
output_hidden_sizes = self.output_hidden_size_device.unsqueeze(
349+
output_hidden_sizes = input.output_hidden_sizes.unsqueeze(
359350
1) # [num_layer_modules, 1]
360351

361352
in_sizes[:, :, 0] = slot_counts
@@ -373,7 +364,7 @@ def prepare_grouped_gemm_buffers(self, input: GroupedGemmParamsInput):
373364
# splitk_offsets: [num_layer_modules, max_lora_size]
374365
# splitk offtsets (m * n) for the first grouped gemm with (m, n, k) = (slot_counts, slot_ranks, input_hidden_size)
375366
splitk_offsets = torch.zeros(shape_2d,
376-
dtype=self.LD_DTYPE,
367+
dtype=CudaGraphLoraParams.LD_DTYPE,
377368
device=device) # (layer_problem_count,)
378369

379370
splitk_offsets.view(-1)[1:] = in_sizes.view(-1, 3)[:-1, 0] # = M
@@ -413,18 +404,24 @@ def _prepare_grouped_gemm_buffers_fused(self,
413404
shape_3d = shape_2d + (3, )
414405
sum_out_sizes = sum(self.output_hidden_sizes)
415406

416-
in_sizes = torch.empty(shape_3d, dtype=self.SIZES_DTYPE, device=device)
407+
in_sizes = torch.empty(shape_3d,
408+
dtype=CudaGraphLoraParams.SIZES_DTYPE,
409+
device=device)
417410
out_sizes = torch.empty_like(in_sizes)
418-
a_offset = torch.empty(shape_2d, dtype=self.PTR_DTYPE, device=device)
411+
a_offset = torch.empty(shape_2d,
412+
dtype=CudaGraphLoraParams.PTR_DTYPE,
413+
device=device)
419414
d_offset = torch.empty_like(a_offset)
420415
d_prime_offset = torch.empty_like(a_offset)
421-
lda = torch.empty(shape_2d, dtype=self.LD_DTYPE, device=device)
416+
lda = torch.empty(shape_2d,
417+
dtype=CudaGraphLoraParams.LD_DTYPE,
418+
device=device)
422419
ldb = lda
423420
ldd = torch.empty_like(lda)
424421
ldb_prime = torch.empty_like(lda)
425422
ldd_prime = torch.empty_like(lda)
426423
splitk_offsets = torch.empty(shape_2d,
427-
dtype=self.LD_DTYPE,
424+
dtype=CudaGraphLoraParams.LD_DTYPE,
428425
device=device) # (layer_problem_count,)
429426
reordered_input = torch.empty_like(input.x)
430427
torch.ops.trtllm.lora_group_gemm_param_fill_row_reorder_fusion(
@@ -450,8 +447,8 @@ def _prepare_grouped_gemm_buffers_fused(self,
450447
input.slot_counts,
451448
input.slot_ranks,
452449
input.slot_offsets,
453-
self.output_hidden_size_device,
454-
self.output_sizes_offset_device,
450+
input.output_hidden_sizes,
451+
input.output_sizes_offset,
455452
input.b_ptrs,
456453
input.b_prime_ptrs,
457454
input.x,
@@ -475,14 +472,16 @@ def _prepare_grouped_gemm_buffers_fused(self,
475472

476473
def _prepare_max_sizes_cpu(self,
477474
cuda_graph_lora_params: CudaGraphLoraParams,
475+
layer_key: CudaGraphLoraParams.LoraLayerKey,
478476
bs: int, input_hidden_size: int):
477+
layer_params = cuda_graph_lora_params.get_layer_params(layer_key)
479478
shape_2d = (len(self.lora_module_types),
480479
cuda_graph_lora_params.max_lora_size
481480
) # [num_layer_modules, max_lora_size]
482481
shape_3d = shape_2d + (3, )
483482
# dummy max sizes, on CPU
484483
host_max_in_sizes = torch.empty(
485-
shape_3d, dtype=self.SIZES_DTYPE
484+
shape_3d, dtype=CudaGraphLoraParams.SIZES_DTYPE
486485
) # m: batch_size, n: max_lora_rank, k: input_hidden_size
487486
host_max_out_sizes = torch.empty_like(
488487
host_max_in_sizes
@@ -492,7 +491,7 @@ def _prepare_max_sizes_cpu(self,
492491
host_max_in_sizes[:, :, 2] = input_hidden_size
493492

494493
host_max_out_sizes[:, :, 0] = bs
495-
host_max_out_sizes[:, :, 1] = self.output_hidden_sizes.unsqueeze(1)
494+
host_max_out_sizes[:, :, 1] = layer_params.h_output_sizes.unsqueeze(1)
496495
host_max_out_sizes[:, :, 2] = cuda_graph_lora_params.max_rank
497496

498497
return host_max_in_sizes, host_max_out_sizes
@@ -546,7 +545,7 @@ def _forward_cuda_graph_mode(
546545
device=x.device)
547546

548547
host_max_in_sizes, host_max_out_sizes = self._prepare_max_sizes_cpu(
549-
cuda_graph_params, batch_size, hidden_size)
548+
cuda_graph_params, layer_key, batch_size, hidden_size)
550549

551550
if RETURN_0_DIRECTLY:
552551
return output_buffer
@@ -569,7 +568,9 @@ def _forward_cuda_graph_mode(
569568
slot_offsets_full=cuda_graph_params.slot_offsets_full,
570569
b_ptrs=layer_params.d_b_ptrs,
571570
b_prime_ptrs=layer_params.d_b_prime_ptrs,
572-
sorted_ids=cuda_graph_params.sorted_ids)
571+
sorted_ids=cuda_graph_params.sorted_ids,
572+
output_hidden_sizes=layer_params.d_output_sizes,
573+
output_sizes_offset=layer_params.d_output_sizes_offset)
573574
grouped_gemm_params = self._prepare_grouped_gemm_buffers_fused(
574575
params_fill_input)
575576

@@ -692,17 +693,18 @@ def _forward_cuda_graph_mode(
692693
if PRINT_AND_ASSERT:
693694
assert output_buffer.is_contiguous()
694695
out_splitted = [
695-
output_buffer[:, s:s + le] for s, le in zip(
696-
self.output_sizes_offset, self.output_hidden_sizes)
696+
output_buffer[:, s:s + le]
697+
for s, le in zip(layer_params.h_output_sizes_offset,
698+
layer_params.h_output_sizes)
697699
]
698700
# assert not any(out.is_contiguous() for out in out_splitted)
699701
pyt_strides = torch.tensor([out.stride(0) for out in out_splitted],
700-
dtype=self.LD_DTYPE,
702+
dtype=CudaGraphLoraParams.LD_DTYPE,
701703
device=x.device) # nModules,
702704
assert torch.all(
703705
grouped_gemm_params.ldd_prime == pyt_strides.unsqueeze(1))
704706
pyt_addr = torch.tensor([out.data_ptr() for out in out_splitted],
705-
dtype=self.PTR_DTYPE,
707+
dtype=CudaGraphLoraParams.PTR_DTYPE,
706708
device=x.device)
707709
assert torch.all(pyt_addr == grouped_gemm_params.d_prime_offset[:,
708710
0])
@@ -908,59 +910,59 @@ def tall(x: torch.Tensor):
908910
# problem_sizes1 = torch.tensor([[24, 32, 16], [24, 32, 16]], dtype=self.SIZES_DTYPE, device=x.device)
909911
problem_sizes1 = torch.tensor(
910912
[[m00, n00, k00], [0, 0, 0], [0, 0, 0], [m02, n02, k02]],
911-
dtype=self.SIZES_DTYPE,
913+
dtype=CudaGraphLoraParams.SIZES_DTYPE,
912914
device=x.device)
913915
lda = torch.tensor([k00, 17, 16, k02],
914-
dtype=self.LD_DTYPE,
916+
dtype=CudaGraphLoraParams.LD_DTYPE,
915917
device=x.device) + ld_offset
916918
ldb = torch.tensor([k00, 16, 16, k02],
917-
dtype=self.LD_DTYPE,
919+
dtype=CudaGraphLoraParams.LD_DTYPE,
918920
device=x.device) + ld_offset
919921
ldd = torch.tensor([n00, 32, 32, n02],
920-
dtype=self.LD_DTYPE,
922+
dtype=CudaGraphLoraParams.LD_DTYPE,
921923
device=x.device) + ld_offset
922924

923925
problem_sizes2 = torch.tensor(
924926
[[m10, n10, k10], [0, 0, 0], [0, 0, 0], [m12, n12, k12]],
925-
dtype=self.SIZES_DTYPE,
927+
dtype=CudaGraphLoraParams.SIZES_DTYPE,
926928
device=x.device)
927929
ldb1 = torch.tensor([k10, 32, 32, k12],
928-
dtype=self.LD_DTYPE,
930+
dtype=CudaGraphLoraParams.LD_DTYPE,
929931
device=x.device) + ld_offset
930932
ldd1 = torch.tensor([n10, 48, 48, n12],
931-
dtype=self.LD_DTYPE,
933+
dtype=CudaGraphLoraParams.LD_DTYPE,
932934
device=x.device) + ld_offset
933935

934936
a0_ptr = torch.tensor(
935937
[a0.data_ptr(),
936938
a0.data_ptr(),
937939
a0.data_ptr(),
938940
a02.data_ptr()],
939-
dtype=self.PTR_DTYPE,
941+
dtype=CudaGraphLoraParams.PTR_DTYPE,
940942
device=x.device)
941943
b0_ptr = torch.tensor(
942944
[b0.data_ptr(), 0, 0, b02.data_ptr()],
943-
dtype=self.PTR_DTYPE,
945+
dtype=CudaGraphLoraParams.PTR_DTYPE,
944946
device=x.device)
945947
d0_ptr = torch.tensor([
946948
d00.data_ptr(),
947949
d01.data_ptr(),
948950
d01.data_ptr(),
949951
d02.data_ptr()
950952
],
951-
dtype=self.PTR_DTYPE,
953+
dtype=CudaGraphLoraParams.PTR_DTYPE,
952954
device=x.device)
953955
b1_ptr = torch.tensor(
954956
[b1.data_ptr(), 0, 0, b12.data_ptr()],
955-
dtype=self.PTR_DTYPE,
957+
dtype=CudaGraphLoraParams.PTR_DTYPE,
956958
device=x.device)
957959
d1_ptr = torch.tensor([
958960
d10.data_ptr(),
959961
d11.data_ptr(),
960962
d11.data_ptr(),
961963
d12.data_ptr()
962964
],
963-
dtype=self.PTR_DTYPE,
965+
dtype=CudaGraphLoraParams.PTR_DTYPE,
964966
device=x.device)
965967

966968
torch.ops.trtllm.lora_grouped_gemm_cuda_graph(
@@ -1081,7 +1083,7 @@ def _forward_legacy_mode(
10811083
lora_ranks,
10821084
lora_weight_pointers,
10831085
lora_params['prompt_lens_cpu'][:num_seqs],
1084-
self.output_hidden_sizes_list,
1086+
self.output_hidden_sizes,
10851087
False, # transA
10861088
True, # transB
10871089
max([r.max() for r in lora_ranks]),
@@ -1101,7 +1103,7 @@ def _forward_legacy_mode(
11011103
else:
11021104
lora_output.append(
11031105
torch.zeros(list(x.shape[:-1]) + [
1104-
self.output_hidden_sizes_list[
1106+
self.output_hidden_sizes[
11051107
self.lora_module_types.index(module_idx)]
11061108
],
11071109
dtype=x.dtype,

0 commit comments

Comments
 (0)