Skip to content

Commit 7640804

Browse files
authored
fix lora from pretrained. (#7714)
* fix lora from pretrained. * updates. * fix comments. * fix comments.
1 parent d97016e commit 7640804

File tree

3 files changed

+15
-8
lines changed

3 files changed

+15
-8
lines changed

llm/finetune_generation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,7 @@ def neft_post_hook(module, input, output):
417417
merge_weights=False,
418418
tensor_parallel_degree=training_args.tensor_parallel_degree,
419419
dtype=dtype,
420+
do_qat=quant_args.do_qat,
420421
)
421422
model = LoRAModel(model, lora_config)
422423
else:

paddlenlp/peft/lora/lora_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ class LoRAConfig:
7171
"help": "The model multi head dimension.Only for LoRAMergedLinear and ColumnParallelLoRAMergedLinear."
7272
},
7373
)
74+
do_qat: bool = field(default=False, metadata={"help": "Whether the lora model would do quant-aware training"})
7475

7576
@property
7677
def __dict__(self):

paddlenlp/peft/lora/lora_model.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -218,8 +218,11 @@ def _convert_tensor_parallel(self, lora_state_dict):
218218
lora_name_action_mappings[v] = name_action_mappings[k]
219219

220220
for name, action in lora_name_action_mappings.items():
221-
tensor = lora_state_dict.pop(name)
222-
lora_state_dict[name] = action(tensor)
221+
if name in lora_state_dict:
222+
tensor = lora_state_dict.pop(name)
223+
lora_state_dict[name] = action(tensor)
224+
else:
225+
logger.warning(f"{name} not found in lora_state_dict!")
223226
return lora_state_dict
224227

225228
def save_pretrained(self, save_directory: str, merge_tensor_parallel: bool = False, **kwargs):
@@ -326,9 +329,10 @@ def _find_and_replace_module(self, model, module_name, lora_config, enable_lora)
326329
self.add_lora_split_mapping(module_name + ".lora_B", is_column=True)
327330

328331
# for lora qat
329-
self.add_lora_split_mapping(module_name + ".weight_quanter._scale", is_column=True)
330-
self.add_lora_split_mapping(module_name + ".activation_quanter._scale", is_column=False)
331-
self.add_lora_split_mapping(module_name + ".activation_quanter.quanter._scale", is_column=False)
332+
if self.lora_config.do_qat:
333+
self.add_lora_split_mapping(module_name + ".weight_quanter._scale", is_column=True)
334+
self.add_lora_split_mapping(module_name + ".activation_quanter._scale", is_column=False)
335+
self.add_lora_split_mapping(module_name + ".activation_quanter.quanter._scale", is_column=False)
332336
elif isinstance(module, RowParallelLinear):
333337
# recover the original output_features
334338
lora_module = RowParallelLoRALinear(
@@ -345,9 +349,10 @@ def _find_and_replace_module(self, model, module_name, lora_config, enable_lora)
345349
self.add_lora_split_mapping(module_name + ".lora_A", is_column=False)
346350

347351
# for lora qat
348-
self.add_lora_split_mapping(module_name + ".weight_quanter._scale", is_column=False)
349-
self.add_lora_split_mapping(module_name + ".activation_quanter._scale", is_column=False)
350-
self.add_lora_split_mapping(module_name + ".activation_quanter.quanter._scale", is_column=False)
352+
if self.lora_config.do_qat:
353+
self.add_lora_split_mapping(module_name + ".weight_quanter._scale", is_column=False)
354+
self.add_lora_split_mapping(module_name + ".activation_quanter._scale", is_column=False)
355+
self.add_lora_split_mapping(module_name + ".activation_quanter.quanter._scale", is_column=False)
351356
elif QuantizationLinear is not None and isinstance(module, QuantizationLinear):
352357
lora_module = QuantizationLoRALinear(
353358
in_features=module.in_features,

0 commit comments

Comments
 (0)