Skip to content

Commit 5fff8f0

Browse files
authored
Add running E2E LoRA flow (NVIDIA#3648)
* add passing E2E LoRA flow Signed-off-by: Shahar Mor <[email protected]> * add experimental feature Signed-off-by: Shahar Mor <[email protected]> * fix llma_args definition Signed-off-by: Shahar Mor <[email protected]> * decreased manually size of max loras to address OOM Signed-off-by: Shahar Mor <[email protected]> --------- Signed-off-by: Shahar Mor <[email protected]>
1 parent c4d86b2 commit 5fff8f0

File tree

18 files changed

+359
-79
lines changed

18 files changed

+359
-79
lines changed

tensorrt_llm/_torch/models/modeling_llama.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,7 @@ def __init__(
369369
bias=getattr(config, "mlp_bias", False),
370370
dtype=config.torch_dtype,
371371
config=model_config,
372+
layer_idx=layer_idx,
372373
)
373374

374375
# self.fusion_config.POST_MLP_FUSION = model_config.mapping.has_tp(
@@ -519,6 +520,7 @@ def __init__(
519520
bias=config.mlp_bias,
520521
dtype=config.torch_dtype,
521522
config=model_config,
523+
layer_idx=layer_idx,
522524
)
523525
self.input_layernorm = RMSNorm(hidden_size=config.hidden_size,
524526
eps=config.rms_norm_eps,
@@ -555,7 +557,7 @@ def forward(
555557
# Fully Connected
556558
hidden_states, residual = self.post_attention_layernorm(
557559
hidden_states, residual)
558-
hidden_states = self.mlp(hidden_states)
560+
hidden_states = self.mlp(hidden_states, **kwargs)
559561
if spec_metadata is not None:
560562
spec_metadata.maybe_capture_hidden_states(self.layer_idx,
561563
hidden_states, residual)
@@ -689,6 +691,7 @@ def forward(
689691
inputs_embeds: Optional[torch.FloatTensor] = None,
690692
pipeline_interface: Optional[PipelineInterface] = None,
691693
spec_metadata: Optional[SpecMetadata] = None,
694+
lora_params=None,
692695
) -> torch.Tensor:
693696
if self.model_config.mapping.is_first_pp_rank():
694697
if (input_ids is None) ^ (inputs_embeds is not None):
@@ -716,6 +719,7 @@ def forward(
716719
attn_metadata=attn_metadata,
717720
residual=residual,
718721
spec_metadata=spec_metadata,
722+
lora_params=lora_params,
719723
)
720724

721725
if self.model_config.mapping.is_last_pp_rank():
@@ -732,14 +736,29 @@ def __init__(self, model_config: ModelConfig[LlamaConfig]):
732736
config = self.model_config.pretrained_config
733737
self.padding_idx = config.pad_token_id
734738

739+
vocab_size = config.vocab_size
740+
# TODO smor- hack
741+
if hasattr(model_config,
742+
'lora_config') and model_config.lora_config is not None:
743+
from tensorrt_llm.lora_manager import HfLoraLoader
744+
lora_loader = HfLoraLoader(model_config.lora_config.lora_dir)
745+
weight = lora_loader.embed_tokens
746+
# TODO smor - need to split tp matrix here
747+
vocab_size = lora_loader.vocab_size
748+
735749
self.embed_tokens = Embedding(
736-
config.vocab_size,
750+
vocab_size,
737751
config.hidden_size,
738752
dtype=config.torch_dtype,
739753
mapping=model_config.mapping,
740754
tensor_parallel_mode=TensorParallelMode.COLUMN,
741755
gather_output=True,
742756
)
757+
758+
if hasattr(model_config,
759+
'lora_config') and model_config.lora_config is not None:
760+
self.embed_tokens.weight.value = weight.to(self.embed_tokens.dtype)
761+
743762
self.layers = nn.ModuleList([
744763
LlamaDecoderLayer(
745764
model_config,
@@ -758,6 +777,7 @@ def forward(
758777
inputs_embeds: Optional[torch.FloatTensor] = None,
759778
pipeline_interface: Optional[PipelineInterface] = None,
760779
spec_metadata: Optional[SpecMetadata] = None,
780+
lora_params=None,
761781
) -> torch.Tensor:
762782
if self.model_config.mapping.is_first_pp_rank():
763783
if (input_ids is None) ^ (inputs_embeds is not None):
@@ -783,6 +803,7 @@ def forward(
783803
attn_metadata=attn_metadata,
784804
residual=residual,
785805
spec_metadata=spec_metadata,
806+
lora_params=lora_params,
786807
)
787808

788809
if self.model_config.mapping.is_last_pp_rank():

tensorrt_llm/_torch/models/modeling_utils.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,7 @@ def forward(
240240
input_ids: torch.LongTensor = None,
241241
position_ids: Optional[torch.LongTensor] = None,
242242
inputs_embeds: Optional[torch.FloatTensor] = None,
243+
lora_params: Optional = None, # TODO smor add type hint
243244
**kwargs,
244245
) -> torch.Tensor:
245246
if (input_ids is None) ^ (inputs_embeds is not None):
@@ -257,6 +258,7 @@ def forward(
257258
position_ids=position_ids,
258259
hidden_states=hidden_states,
259260
attn_metadata=attn_metadata,
261+
lora_params=lora_params,
260262
)
261263

262264
hidden_states = self.norm(hidden_states)
@@ -355,6 +357,15 @@ def __init__(self, model: TModel, *, config: ModelConfig[TConfig],
355357
else:
356358
# TODO(zhenhuanc): Currently lm_head Linear will not accept QuantConfig
357359
# will considering per layer QuantConfig in the future.
360+
361+
# TODO smor- hack
362+
if hasattr(config,
363+
'lora_config') and config.lora_config is not None:
364+
from tensorrt_llm.lora_manager import HfLoraLoader
365+
lora_loader = HfLoraLoader(config.lora_config.lora_dir)
366+
weight = lora_loader.lm_head
367+
vocab_size = lora_loader.vocab_size
368+
358369
self.lm_head = LMHead(
359370
vocab_size,
360371
hidden_size,
@@ -364,6 +375,12 @@ def __init__(self, model: TModel, *, config: ModelConfig[TConfig],
364375
gather_output=True,
365376
)
366377

378+
if hasattr(config,
379+
'lora_config') and config.lora_config is not None:
380+
# TODO smor- figure out if it sticks
381+
self.lm_head.weight.value = weight.to(
382+
self.lm_head.dtype).cuda()
383+
367384
# use embedding weights in lm_head if tie word embedding is enabled
368385
if config.pretrained_config.tie_word_embeddings and not isinstance(
369386
self.model.embed_tokens, MissingLayer):
@@ -450,6 +467,7 @@ def forward(
450467
pipeline_interface: Optional[PipelineInterface] = None,
451468
return_context_logits: bool = False,
452469
spec_metadata: Optional[SpecMetadata] = None,
470+
lora_params: Optional = None, # TODO smor add type hint
453471
**kwargs,
454472
) -> torch.Tensor:
455473
if self._supports_pp and self.pp_size > 1:
@@ -466,12 +484,14 @@ def forward(
466484
if self.pp_rank < self.pp_size - 1:
467485
return output
468486
else:
487+
469488
output = self.model(
470489
input_ids=input_ids,
471490
attn_metadata=attn_metadata,
472491
position_ids=position_ids,
473492
inputs_embeds=inputs_embeds,
474493
spec_metadata=spec_metadata,
494+
lora_params=lora_params,
475495
)
476496

477497
return self.logits_processor.forward(
@@ -506,6 +526,13 @@ def filter_weights(prefix, weights: Dict):
506526
"lm_head"):
507527
continue
508528

529+
# Skip loading weights for embedding and lm_head if LoRA is enabled
530+
if hasattr(self.model_config, 'lora_config'
531+
) and self.model_config.lora_config is not None and (
532+
name == "model.embed_tokens"
533+
or name == "lm_head"):
534+
continue
535+
509536
# Skip if parameter belongs to a missing layer
510537
if missing_layer_parameter(name, self):
511538
continue

tensorrt_llm/_torch/modules/attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def forward(
171171
) -> torch.Tensor:
172172
qkv = self.qkv_proj(hidden_states)
173173

174-
if lora_params is not None:
174+
if bool(lora_params):
175175
qkv_lora = self.splitted_qkv_lora(hidden_states, lora_params,
176176
self.layer_idx)
177177
if qkv_lora is not None:
@@ -204,7 +204,7 @@ def forward(
204204

205205
attn_output = self.o_proj(attn_output,
206206
all_reduce_params=all_reduce_params)
207-
if lora_params is not None:
207+
if bool(lora_params):
208208
attn_lora_output = self.o_lora(attn_output, lora_params,
209209
self.layer_idx)
210210
if attn_lora_output is not None:

tensorrt_llm/_torch/modules/gated_mlp.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,27 @@ def forward(
114114

115115
if self.activation == F.silu:
116116
h1 = self.gate_up_proj(x)
117+
if bool(lora_params):
118+
assert self.layer_idx is not None, "layer_idx is required for lora"
119+
h1_lora = self.splitted_gate_up_lora(x, lora_params,
120+
self.layer_idx)
121+
if h1_lora is not None:
122+
h1 = h1 + h1_lora
123+
124+
h1_lora = self.fused_gate_up_lora(x, lora_params,
125+
self.layer_idx)
126+
127+
if h1_lora is not None:
128+
h1 = h1 + h1_lora
129+
117130
h2 = swiglu(h1)
118131
output = self.down_proj(h2,
119132
all_reduce_params=final_all_reduce_params)
133+
if bool(lora_params):
134+
output_lora = self.down_lora(h2, lora_params, self.layer_idx)
135+
if output_lora is not None:
136+
output = output + output_lora
137+
120138
return output
121139
else:
122140
raise NotImplementedError(

tensorrt_llm/_torch/peft/lora/layer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def forward(self, x, lora_params: Dict,
126126
True, # transB
127127
max([r.max() for r in lora_ranks]),
128128
0,
129-
lora_params["remove_input_padding"],
129+
True, # TODO smor- should be lora_params["remove_input_padding"], support in loraOp as well
130130
)
131131
if isinstance(lora_outputs, torch.Tensor):
132132
return lora_outputs

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -381,8 +381,8 @@ def create_py_executor_instance(dist,
381381
len(lora_config.lora_target_modules + lora_config.missing_qkv_modules)
382382

383383
# TODO smor- need to figure out how to set these values
384-
max_loras = 4
385-
max_cpu_loras = 4
384+
max_loras = 2
385+
max_cpu_loras = 2
386386
executor_config.peft_cache_config = tllm.executor.PeftCacheConfig(
387387
num_device_module_layer=max_lora_rank * num_lora_modules *
388388
max_loras,
@@ -394,6 +394,9 @@ def create_py_executor_instance(dist,
394394
peft_cache_config=executor_config.peft_cache_config,
395395
model_config=model_binding_config)
396396
resources["peft_cache_manager"] = peft_cache_manager
397+
model_engine.set_lora_model_config(
398+
lora_config.lora_target_modules,
399+
lora_config.trtllm_modules_to_hf_modules)
397400

398401
resource_manager = ResourceManager(resources)
399402

tensorrt_llm/_torch/pyexecutor/llm_request.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def __init__(self, *args, client_id=None, **kwargs):
5050
self.py_draft_tokens = self.draft_tokens
5151
self.py_last_draft_tokens = None
5252
self.py_decoding_iter = 0
53+
self.py_lora_task_layer_module_configs = None
5354

5455

5556
def convert_wordlist(word_list) -> List[List[int]]:
@@ -121,13 +122,16 @@ def executor_request_to_llm_request(req_id: int,
121122
is None else executor_request.prompt_tuning_config.embedding_table,
122123
prompt_vocab_size=None if executor_request.prompt_tuning_config is None
123124
else executor_request.prompt_tuning_config.embedding_table.shape[0],
125+
lora_task_id=executor_request.lora_config.task_id
126+
if executor_request.lora_config is not None else None,
127+
lora_weights=executor_request.lora_config.weights
128+
if executor_request.lora_config is not None else None,
129+
lora_config=executor_request.lora_config.config
130+
if executor_request.lora_config is not None else None,
124131
mrope_rotary_cos_sin=None if executor_request.mrope_config is None else
125132
executor_request.mrope_config.mrope_rotary_cos_sin,
126133
mrope_position_deltas=None if executor_request.mrope_config is None else
127134
executor_request.mrope_config.mrope_position_deltas,
128-
lora_task_id=None,
129-
lora_weights=None,
130-
lora_config=None,
131135
lookahead_config=None,
132136
return_log_probs=False,
133137
return_context_logits=executor_request.output_config.

0 commit comments

Comments
 (0)