@@ -80,6 +80,8 @@ class GroupedGemmParamsInput:
8080 b_ptrs : torch .Tensor
8181 b_prime_ptrs : torch .Tensor
8282 sorted_ids : torch .Tensor
83+ output_hidden_sizes : torch .Tensor
84+ output_sizes_offset : torch .Tensor
8385
8486 @property
8587 def slot_offsets (self ):
@@ -242,27 +244,14 @@ def is_moe(self) -> bool:
242244
243245
244246class LoraLayer (torch .nn .Module ):
245- PTR_DTYPE = torch .int64
246- LD_DTYPE = torch .int64
247- SIZES_DTYPE = torch .int32
248247
249248 def __init__ (self , lora_module_types : List [LoraModuleType ],
250249 output_hidden_sizes : List [int ]):
251250 super ().__init__ ()
252251
253252 self .lora_module_types = lora_module_types
254- self .output_hidden_sizes = torch .tensor (output_hidden_sizes ,
255- dtype = self .SIZES_DTYPE )
256- self .output_hidden_sizes_list = output_hidden_sizes
253+ self .output_hidden_sizes = output_hidden_sizes
257254 assert len (lora_module_types ) == len (output_hidden_sizes )
258- self .output_sizes_offset = CudaGraphLoraParams .get_offset_from_counts (
259- self .output_hidden_sizes ).to (
260- dtype = self .PTR_DTYPE ) # [num_layer_modules]
261- if PARAM_PREP :
262- self .output_sizes_offset_device = self .output_sizes_offset .to (
263- device = 'cuda' )
264- self .output_hidden_size_device = self .output_hidden_sizes .to (
265- device = 'cuda' )
266255
267256 def forward (
268257 self ,
@@ -307,7 +296,7 @@ def prepare_grouped_gemm_buffers(self, input: GroupedGemmParamsInput):
307296 # a [bs, hidden]
308297 lda = torch .full (shape_2d ,
309298 input_hidden_size ,
310- dtype = self .LD_DTYPE ,
299+ dtype = CudaGraphLoraParams .LD_DTYPE ,
311300 device = device )
312301
313302 # b [input_hidden_size, lora_rank]
@@ -316,17 +305,17 @@ def prepare_grouped_gemm_buffers(self, input: GroupedGemmParamsInput):
316305 # a_prime / d [num_layer_modules, bs, max_rank]
317306 ldd = torch .full (shape_2d ,
318307 input .max_rank ,
319- dtype = self .LD_DTYPE ,
308+ dtype = CudaGraphLoraParams .LD_DTYPE ,
320309 device = device )
321310
322311 # b_prime [lora_rank, module_output_size]
323312 ldb_prime = input .slot_ranks .unsqueeze (0 ).to (
324- dtype = self .LD_DTYPE ).repeat (shape_2d [0 ], 1 )
313+ dtype = CudaGraphLoraParams .LD_DTYPE ).repeat (shape_2d [0 ], 1 )
325314
326315 # d_prime [bs, sum_of_each_module_output_sizes]
327316 ldd_prime = torch .full (shape_2d ,
328317 sum_out_sizes ,
329- dtype = self .LD_DTYPE ,
318+ dtype = CudaGraphLoraParams .LD_DTYPE ,
330319 device = device )
331320
332321 # reordered a [bs, hidden], each module has the same offset
@@ -335,13 +324,13 @@ def prepare_grouped_gemm_buffers(self, input: GroupedGemmParamsInput):
335324
336325 # d [num_layer_modules, bs, max_rank]
337326 d_offset = (input .slot_offsets .unsqueeze (0 ) + torch .arange (
338- shape_2d [0 ], device = device , dtype = self .PTR_DTYPE ).unsqueeze ( 1 ) *
339- bs ) * input .max_rank
327+ shape_2d [0 ], device = device , dtype = CudaGraphLoraParams .PTR_DTYPE ).
328+ unsqueeze ( 1 ) * bs ) * input .max_rank
340329
341330 # d' [bs, sum_of_each_module_output_sizes]
342331 bs_offset = input .slot_offsets .unsqueeze (0 ) # [1, max_lora_size]
343332 bs_offset = bs_offset * sum_out_sizes
344- out_offset = self . output_sizes_offset_device .unsqueeze (
333+ out_offset = input . output_sizes_offset .unsqueeze (
345334 1 ) # [num_layer_modules, 1]
346335 d_prime_offset = bs_offset + out_offset
347336 '''
@@ -350,12 +339,14 @@ def prepare_grouped_gemm_buffers(self, input: GroupedGemmParamsInput):
350339 '''
351340
352341 # sizes
353- in_sizes = torch .empty (shape_3d , dtype = self .SIZES_DTYPE , device = device )
342+ in_sizes = torch .empty (shape_3d ,
343+ dtype = CudaGraphLoraParams .SIZES_DTYPE ,
344+ device = device )
354345 out_sizes = torch .empty_like (in_sizes )
355346
356347 slot_counts = input .slot_counts .unsqueeze (0 ) # [1, max_lora_size]
357348 ranks = input .slot_ranks .unsqueeze (0 ) # [1, max_lora_size]
358- output_hidden_sizes = self . output_hidden_size_device .unsqueeze (
349+ output_hidden_sizes = input . output_hidden_sizes .unsqueeze (
359350 1 ) # [num_layer_modules, 1]
360351
361352 in_sizes [:, :, 0 ] = slot_counts
@@ -373,7 +364,7 @@ def prepare_grouped_gemm_buffers(self, input: GroupedGemmParamsInput):
373364 # splitk_offsets: [num_layer_modules, max_lora_size]
374365 # splitk offtsets (m * n) for the first grouped gemm with (m, n, k) = (slot_counts, slot_ranks, input_hidden_size)
375366 splitk_offsets = torch .zeros (shape_2d ,
376- dtype = self .LD_DTYPE ,
367+ dtype = CudaGraphLoraParams .LD_DTYPE ,
377368 device = device ) # (layer_problem_count,)
378369
379370 splitk_offsets .view (- 1 )[1 :] = in_sizes .view (- 1 , 3 )[:- 1 , 0 ] # = M
@@ -413,18 +404,24 @@ def _prepare_grouped_gemm_buffers_fused(self,
413404 shape_3d = shape_2d + (3 , )
414405 sum_out_sizes = sum (self .output_hidden_sizes )
415406
416- in_sizes = torch .empty (shape_3d , dtype = self .SIZES_DTYPE , device = device )
407+ in_sizes = torch .empty (shape_3d ,
408+ dtype = CudaGraphLoraParams .SIZES_DTYPE ,
409+ device = device )
417410 out_sizes = torch .empty_like (in_sizes )
418- a_offset = torch .empty (shape_2d , dtype = self .PTR_DTYPE , device = device )
411+ a_offset = torch .empty (shape_2d ,
412+ dtype = CudaGraphLoraParams .PTR_DTYPE ,
413+ device = device )
419414 d_offset = torch .empty_like (a_offset )
420415 d_prime_offset = torch .empty_like (a_offset )
421- lda = torch .empty (shape_2d , dtype = self .LD_DTYPE , device = device )
416+ lda = torch .empty (shape_2d ,
417+ dtype = CudaGraphLoraParams .LD_DTYPE ,
418+ device = device )
422419 ldb = lda
423420 ldd = torch .empty_like (lda )
424421 ldb_prime = torch .empty_like (lda )
425422 ldd_prime = torch .empty_like (lda )
426423 splitk_offsets = torch .empty (shape_2d ,
427- dtype = self .LD_DTYPE ,
424+ dtype = CudaGraphLoraParams .LD_DTYPE ,
428425 device = device ) # (layer_problem_count,)
429426 reordered_input = torch .empty_like (input .x )
430427 torch .ops .trtllm .lora_group_gemm_param_fill_row_reorder_fusion (
@@ -450,8 +447,8 @@ def _prepare_grouped_gemm_buffers_fused(self,
450447 input .slot_counts ,
451448 input .slot_ranks ,
452449 input .slot_offsets ,
453- self . output_hidden_size_device ,
454- self . output_sizes_offset_device ,
450+ input . output_hidden_sizes ,
451+ input . output_sizes_offset ,
455452 input .b_ptrs ,
456453 input .b_prime_ptrs ,
457454 input .x ,
@@ -475,14 +472,16 @@ def _prepare_grouped_gemm_buffers_fused(self,
475472
476473 def _prepare_max_sizes_cpu (self ,
477474 cuda_graph_lora_params : CudaGraphLoraParams ,
475+ layer_key : CudaGraphLoraParams .LoraLayerKey ,
478476 bs : int , input_hidden_size : int ):
477+ layer_params = cuda_graph_lora_params .get_layer_params (layer_key )
479478 shape_2d = (len (self .lora_module_types ),
480479 cuda_graph_lora_params .max_lora_size
481480 ) # [num_layer_modules, max_lora_size]
482481 shape_3d = shape_2d + (3 , )
483482 # dummy max sizes, on CPU
484483 host_max_in_sizes = torch .empty (
485- shape_3d , dtype = self .SIZES_DTYPE
484+ shape_3d , dtype = CudaGraphLoraParams .SIZES_DTYPE
486485 ) # m: batch_size, n: max_lora_rank, k: input_hidden_size
487486 host_max_out_sizes = torch .empty_like (
488487 host_max_in_sizes
@@ -492,7 +491,7 @@ def _prepare_max_sizes_cpu(self,
492491 host_max_in_sizes [:, :, 2 ] = input_hidden_size
493492
494493 host_max_out_sizes [:, :, 0 ] = bs
495- host_max_out_sizes [:, :, 1 ] = self . output_hidden_sizes .unsqueeze (1 )
494+ host_max_out_sizes [:, :, 1 ] = layer_params . h_output_sizes .unsqueeze (1 )
496495 host_max_out_sizes [:, :, 2 ] = cuda_graph_lora_params .max_rank
497496
498497 return host_max_in_sizes , host_max_out_sizes
@@ -546,7 +545,7 @@ def _forward_cuda_graph_mode(
546545 device = x .device )
547546
548547 host_max_in_sizes , host_max_out_sizes = self ._prepare_max_sizes_cpu (
549- cuda_graph_params , batch_size , hidden_size )
548+ cuda_graph_params , layer_key , batch_size , hidden_size )
550549
551550 if RETURN_0_DIRECTLY :
552551 return output_buffer
@@ -569,7 +568,9 @@ def _forward_cuda_graph_mode(
569568 slot_offsets_full = cuda_graph_params .slot_offsets_full ,
570569 b_ptrs = layer_params .d_b_ptrs ,
571570 b_prime_ptrs = layer_params .d_b_prime_ptrs ,
572- sorted_ids = cuda_graph_params .sorted_ids )
571+ sorted_ids = cuda_graph_params .sorted_ids ,
572+ output_hidden_sizes = layer_params .d_output_sizes ,
573+ output_sizes_offset = layer_params .d_output_sizes_offset )
573574 grouped_gemm_params = self ._prepare_grouped_gemm_buffers_fused (
574575 params_fill_input )
575576
@@ -692,17 +693,18 @@ def _forward_cuda_graph_mode(
692693 if PRINT_AND_ASSERT :
693694 assert output_buffer .is_contiguous ()
694695 out_splitted = [
695- output_buffer [:, s :s + le ] for s , le in zip (
696- self .output_sizes_offset , self .output_hidden_sizes )
696+ output_buffer [:, s :s + le ]
697+ for s , le in zip (layer_params .h_output_sizes_offset ,
698+ layer_params .h_output_sizes )
697699 ]
698700 # assert not any(out.is_contiguous() for out in out_splitted)
699701 pyt_strides = torch .tensor ([out .stride (0 ) for out in out_splitted ],
700- dtype = self .LD_DTYPE ,
702+ dtype = CudaGraphLoraParams .LD_DTYPE ,
701703 device = x .device ) # nModules,
702704 assert torch .all (
703705 grouped_gemm_params .ldd_prime == pyt_strides .unsqueeze (1 ))
704706 pyt_addr = torch .tensor ([out .data_ptr () for out in out_splitted ],
705- dtype = self .PTR_DTYPE ,
707+ dtype = CudaGraphLoraParams .PTR_DTYPE ,
706708 device = x .device )
707709 assert torch .all (pyt_addr == grouped_gemm_params .d_prime_offset [:,
708710 0 ])
@@ -908,59 +910,59 @@ def tall(x: torch.Tensor):
908910 # problem_sizes1 = torch.tensor([[24, 32, 16], [24, 32, 16]], dtype=self.SIZES_DTYPE, device=x.device)
909911 problem_sizes1 = torch .tensor (
910912 [[m00 , n00 , k00 ], [0 , 0 , 0 ], [0 , 0 , 0 ], [m02 , n02 , k02 ]],
911- dtype = self .SIZES_DTYPE ,
913+ dtype = CudaGraphLoraParams .SIZES_DTYPE ,
912914 device = x .device )
913915 lda = torch .tensor ([k00 , 17 , 16 , k02 ],
914- dtype = self .LD_DTYPE ,
916+ dtype = CudaGraphLoraParams .LD_DTYPE ,
915917 device = x .device ) + ld_offset
916918 ldb = torch .tensor ([k00 , 16 , 16 , k02 ],
917- dtype = self .LD_DTYPE ,
919+ dtype = CudaGraphLoraParams .LD_DTYPE ,
918920 device = x .device ) + ld_offset
919921 ldd = torch .tensor ([n00 , 32 , 32 , n02 ],
920- dtype = self .LD_DTYPE ,
922+ dtype = CudaGraphLoraParams .LD_DTYPE ,
921923 device = x .device ) + ld_offset
922924
923925 problem_sizes2 = torch .tensor (
924926 [[m10 , n10 , k10 ], [0 , 0 , 0 ], [0 , 0 , 0 ], [m12 , n12 , k12 ]],
925- dtype = self .SIZES_DTYPE ,
927+ dtype = CudaGraphLoraParams .SIZES_DTYPE ,
926928 device = x .device )
927929 ldb1 = torch .tensor ([k10 , 32 , 32 , k12 ],
928- dtype = self .LD_DTYPE ,
930+ dtype = CudaGraphLoraParams .LD_DTYPE ,
929931 device = x .device ) + ld_offset
930932 ldd1 = torch .tensor ([n10 , 48 , 48 , n12 ],
931- dtype = self .LD_DTYPE ,
933+ dtype = CudaGraphLoraParams .LD_DTYPE ,
932934 device = x .device ) + ld_offset
933935
934936 a0_ptr = torch .tensor (
935937 [a0 .data_ptr (),
936938 a0 .data_ptr (),
937939 a0 .data_ptr (),
938940 a02 .data_ptr ()],
939- dtype = self .PTR_DTYPE ,
941+ dtype = CudaGraphLoraParams .PTR_DTYPE ,
940942 device = x .device )
941943 b0_ptr = torch .tensor (
942944 [b0 .data_ptr (), 0 , 0 , b02 .data_ptr ()],
943- dtype = self .PTR_DTYPE ,
945+ dtype = CudaGraphLoraParams .PTR_DTYPE ,
944946 device = x .device )
945947 d0_ptr = torch .tensor ([
946948 d00 .data_ptr (),
947949 d01 .data_ptr (),
948950 d01 .data_ptr (),
949951 d02 .data_ptr ()
950952 ],
951- dtype = self .PTR_DTYPE ,
953+ dtype = CudaGraphLoraParams .PTR_DTYPE ,
952954 device = x .device )
953955 b1_ptr = torch .tensor (
954956 [b1 .data_ptr (), 0 , 0 , b12 .data_ptr ()],
955- dtype = self .PTR_DTYPE ,
957+ dtype = CudaGraphLoraParams .PTR_DTYPE ,
956958 device = x .device )
957959 d1_ptr = torch .tensor ([
958960 d10 .data_ptr (),
959961 d11 .data_ptr (),
960962 d11 .data_ptr (),
961963 d12 .data_ptr ()
962964 ],
963- dtype = self .PTR_DTYPE ,
965+ dtype = CudaGraphLoraParams .PTR_DTYPE ,
964966 device = x .device )
965967
966968 torch .ops .trtllm .lora_grouped_gemm_cuda_graph (
@@ -1081,7 +1083,7 @@ def _forward_legacy_mode(
10811083 lora_ranks ,
10821084 lora_weight_pointers ,
10831085 lora_params ['prompt_lens_cpu' ][:num_seqs ],
1084- self .output_hidden_sizes_list ,
1086+ self .output_hidden_sizes ,
10851087 False , # transA
10861088 True , # transB
10871089 max ([r .max () for r in lora_ranks ]),
@@ -1101,7 +1103,7 @@ def _forward_legacy_mode(
11011103 else :
11021104 lora_output .append (
11031105 torch .zeros (list (x .shape [:- 1 ]) + [
1104- self .output_hidden_sizes_list [
1106+ self .output_hidden_sizes [
11051107 self .lora_module_types .index (module_idx )]
11061108 ],
11071109 dtype = x .dtype ,
0 commit comments