@@ -369,6 +369,7 @@ def __init__(
369369 bias = getattr (config , "mlp_bias" , False ),
370370 dtype = config .torch_dtype ,
371371 config = model_config ,
372+ layer_idx = layer_idx ,
372373 )
373374
374375 # self.fusion_config.POST_MLP_FUSION = model_config.mapping.has_tp(
@@ -519,6 +520,7 @@ def __init__(
519520 bias = config .mlp_bias ,
520521 dtype = config .torch_dtype ,
521522 config = model_config ,
523+ layer_idx = layer_idx ,
522524 )
523525 self .input_layernorm = RMSNorm (hidden_size = config .hidden_size ,
524526 eps = config .rms_norm_eps ,
@@ -555,7 +557,7 @@ def forward(
555557 # Fully Connected
556558 hidden_states , residual = self .post_attention_layernorm (
557559 hidden_states , residual )
558- hidden_states = self .mlp (hidden_states )
560+ hidden_states = self .mlp (hidden_states , ** kwargs )
559561 if spec_metadata is not None :
560562 spec_metadata .maybe_capture_hidden_states (self .layer_idx ,
561563 hidden_states , residual )
@@ -689,6 +691,7 @@ def forward(
689691 inputs_embeds : Optional [torch .FloatTensor ] = None ,
690692 pipeline_interface : Optional [PipelineInterface ] = None ,
691693 spec_metadata : Optional [SpecMetadata ] = None ,
694+ lora_params = None ,
692695 ) -> torch .Tensor :
693696 if self .model_config .mapping .is_first_pp_rank ():
694697 if (input_ids is None ) ^ (inputs_embeds is not None ):
@@ -716,6 +719,7 @@ def forward(
716719 attn_metadata = attn_metadata ,
717720 residual = residual ,
718721 spec_metadata = spec_metadata ,
722+ lora_params = lora_params ,
719723 )
720724
721725 if self .model_config .mapping .is_last_pp_rank ():
@@ -732,14 +736,29 @@ def __init__(self, model_config: ModelConfig[LlamaConfig]):
732736 config = self .model_config .pretrained_config
733737 self .padding_idx = config .pad_token_id
734738
739+ vocab_size = config .vocab_size
740+ # TODO smor- hack
741+ if hasattr (model_config ,
742+ 'lora_config' ) and model_config .lora_config is not None :
743+ from tensorrt_llm .lora_manager import HfLoraLoader
744+ lora_loader = HfLoraLoader (model_config .lora_config .lora_dir )
745+ weight = lora_loader .embed_tokens
746+ # TODO smor - need to split tp matrix here
747+ vocab_size = lora_loader .vocab_size
748+
735749 self .embed_tokens = Embedding (
736- config . vocab_size ,
750+ vocab_size ,
737751 config .hidden_size ,
738752 dtype = config .torch_dtype ,
739753 mapping = model_config .mapping ,
740754 tensor_parallel_mode = TensorParallelMode .COLUMN ,
741755 gather_output = True ,
742756 )
757+
758+ if hasattr (model_config ,
759+ 'lora_config' ) and model_config .lora_config is not None :
760+ self .embed_tokens .weight .value = weight .to (self .embed_tokens .dtype )
761+
743762 self .layers = nn .ModuleList ([
744763 LlamaDecoderLayer (
745764 model_config ,
@@ -758,6 +777,7 @@ def forward(
758777 inputs_embeds : Optional [torch .FloatTensor ] = None ,
759778 pipeline_interface : Optional [PipelineInterface ] = None ,
760779 spec_metadata : Optional [SpecMetadata ] = None ,
780+ lora_params = None ,
761781 ) -> torch .Tensor :
762782 if self .model_config .mapping .is_first_pp_rank ():
763783 if (input_ids is None ) ^ (inputs_embeds is not None ):
@@ -783,6 +803,7 @@ def forward(
783803 attn_metadata = attn_metadata ,
784804 residual = residual ,
785805 spec_metadata = spec_metadata ,
806+ lora_params = lora_params ,
786807 )
787808
788809 if self .model_config .mapping .is_last_pp_rank ():
0 commit comments