Skip to content

Commit df8b94f

Browse files
kgreenewaldKristjan Greenewald Kristjan.H.Greenewald@ibm.comkmehantKristjan Greenewald Kristjan.H.Greenewald@ibm.com
authored
feat: Add ALoRA support (#513)
* Update peft_config.py Signed-off-by: Greenewald <greenewk@umich.edu> Signed-off-by: Kristjan Greenewald <kristjan.h.greenewald@ibm.com> * Update pyproject.toml Signed-off-by: Greenewald <greenewk@umich.edu> Signed-off-by: Kristjan Greenewald <kristjan.h.greenewald@ibm.com> * Update sft_trainer.py Signed-off-by: Greenewald <greenewk@umich.edu> Signed-off-by: Kristjan Greenewald <kristjan.h.greenewald@ibm.com> * Update pyproject.toml Signed-off-by: Greenewald <greenewk@umich.edu> Signed-off-by: Kristjan Greenewald <kristjan.h.greenewald@ibm.com> * Update sft_trainer.py add alora Signed-off-by: Greenewald <greenewk@umich.edu> Signed-off-by: Kristjan Greenewald <kristjan.h.greenewald@ibm.com> * alora support Signed-off-by: Kristjan Greenewald <kristjan.h.greenewald@ibm.com> * Remove error.log Signed-off-by: Kristjan Greenewald <kristjan.h.greenewald@ibm.com> * Update tuning/sft_trainer.py Co-authored-by: Mehant Kammakomati <kmehant@gmail.com> Signed-off-by: Greenewald <greenewk@umich.edu> * Update peft_config.py Getting rid of alora config definition in this repo Signed-off-by: Greenewald <greenewk@umich.edu> * Update sft_trainer.py Signed-off-by: Greenewald <greenewk@umich.edu> * Update pyproject.toml Signed-off-by: Greenewald <greenewk@umich.edu> * Update sft_trainer.py Signed-off-by: Greenewald <greenewk@umich.edu> * Update sft_trainer.py Signed-off-by: Greenewald <greenewk@umich.edu> * Optional alora Signed-off-by: Greenewald <greenewk@umich.edu> * Optional alora package Signed-off-by: Greenewald <greenewk@umich.edu> * Optional alora package Signed-off-by: Greenewald <greenewk@umich.edu> * invocation error message fix Signed-off-by: Greenewald <greenewk@umich.edu> * Update pyproject.toml Signed-off-by: Greenewald <greenewk@umich.edu> * alora inference Signed-off-by: Greenewald <greenewk@umich.edu> * alora test draft Signed-off-by: Greenewald <greenewk@umich.edu> * Update test_sft_trainer.py Signed-off-by: Greenewald <greenewk@umich.edu> * Update sft_trainer.py Signed-off-by: Greenewald <greenewk@umich.edu> * Update config_utils.py Signed-off-by: Greenewald <greenewk@umich.edu> * Update test_sft_trainer.py Signed-off-by: Greenewald <greenewk@umich.edu> * Update sft_trainer.py Signed-off-by: Greenewald <greenewk@umich.edu> * Update config_utils.py Signed-off-by: Greenewald <greenewk@umich.edu> * Update pyproject.toml Signed-off-by: Greenewald <greenewk@umich.edu> * Update sft_trainer.py Signed-off-by: Greenewald <greenewk@umich.edu> * Update test_sft_trainer.py Signed-off-by: Greenewald <greenewk@umich.edu> * Update test_sft_trainer.py Signed-off-by: Greenewald <greenewk@umich.edu> * Update test_sft_trainer.py Signed-off-by: Greenewald <greenewk@umich.edu> * Update test_sft_trainer.py Signed-off-by: Greenewald <greenewk@umich.edu> * Update test_sft_trainer.py Signed-off-by: Greenewald <greenewk@umich.edu> * Update test_sft_trainer.py Signed-off-by: Greenewald <greenewk@umich.edu> * Update config_utils.py Signed-off-by: Greenewald <greenewk@umich.edu> * Update config_utils.py Signed-off-by: Greenewald <greenewk@umich.edu> * Update config_utils.py Signed-off-by: Greenewald <greenewk@umich.edu> * Update test_sft_trainer.py Signed-off-by: Greenewald <greenewk@umich.edu> * Update test_sft_trainer.py Signed-off-by: Greenewald <greenewk@umich.edu> * Update test_sft_trainer.py Signed-off-by: Greenewald <greenewk@umich.edu> * alora saving Signed-off-by: Greenewald <greenewk@umich.edu> * Update test_sft_trainer.py Signed-off-by: Greenewald <greenewk@umich.edu> * Update sft_trainer.py Signed-off-by: Greenewald <greenewk@umich.edu> * Update sft_trainer.py Signed-off-by: Greenewald <greenewk@umich.edu> * Update test_sft_trainer.py Signed-off-by: Greenewald <greenewk@umich.edu> * Update test_sft_trainer.py Signed-off-by: Greenewald <greenewk@umich.edu> * Update test_sft_trainer.py Signed-off-by: Greenewald <greenewk@umich.edu> * Update sft_trainer.py Signed-off-by: Greenewald <greenewk@umich.edu> * Update config_utils.py Signed-off-by: Greenewald <greenewk@umich.edu> * Update test_sft_trainer.py Signed-off-by: Greenewald <greenewk@umich.edu> * Update test_sft_trainer.py Signed-off-by: Greenewald <greenewk@umich.edu> * Update run_inference.py Signed-off-by: Greenewald <greenewk@umich.edu> * pr fixes Signed-off-by: Greenewald <greenewk@umich.edu> * pr fixes Signed-off-by: Greenewald <greenewk@umich.edu> * checking for alora Signed-off-by: Greenewald <greenewk@umich.edu> * run test only if alora installed Signed-off-by: Greenewald <greenewk@umich.edu> * Documentation Signed-off-by: Greenewald <greenewk@umich.edu> * Update README.md Signed-off-by: Greenewald <greenewk@umich.edu> * Update README.md Signed-off-by: Greenewald <greenewk@umich.edu> * Update README.md Signed-off-by: Greenewald <greenewk@umich.edu> * pip install alora Signed-off-by: Greenewald <greenewk@umich.edu> * Update sft_trainer.py Signed-off-by: Greenewald <greenewk@umich.edu> * Update test_sft_trainer.py Signed-off-by: Greenewald <greenewk@umich.edu> * Update config_utils.py Signed-off-by: Greenewald <greenewk@umich.edu> * Update setup_dataprocessor.py Signed-off-by: Greenewald <greenewk@umich.edu> * Update setup_dataprocessor.py Signed-off-by: Greenewald <greenewk@umich.edu> * Update setup_dataprocessor.py Signed-off-by: Greenewald <greenewk@umich.edu> * lint fixes Signed-off-by: Kristjan Greenewald <kristjan.h.greenewald@ibm.com> * lint fixes Signed-off-by: Kristjan Greenewald <kristjan.h.greenewald@ibm.com> * Update sft_trainer.py Signed-off-by: Greenewald <greenewk@umich.edu> * Update config_utils.py Signed-off-by: Greenewald <greenewk@umich.edu> * Update test_sft_trainer.py Signed-off-by: Greenewald <greenewk@umich.edu> * pylint Signed-off-by: Greenewald <greenewk@umich.edu> * pylint Signed-off-by: Greenewald <greenewk@umich.edu> * Update run_inference.py Signed-off-by: Greenewald <greenewk@umich.edu> * fmt fixes Signed-off-by: Kristjan Greenewald <kristjan.h.greenewald@ibm.com> * lint fixes Signed-off-by: Kristjan Greenewald <kristjan.h.greenewald@ibm.com> * Delete mykey.asc Signed-off-by: Greenewald <greenewk@umich.edu> * Delete mypubkey.asc Signed-off-by: Greenewald <greenewk@umich.edu> * requested changes Signed-off-by: Kristjan Greenewald <kristjan.h.greenewald@ibm.com> * restructure inference Signed-off-by: Kristjan Greenewald <kristjan.h.greenewald@ibm.com> * typo Signed-off-by: Kristjan Greenewald <kristjan.h.greenewald@ibm.com> * another typo Signed-off-by: Kristjan Greenewald <kristjan.h.greenewald@ibm.com> --------- Signed-off-by: Greenewald <greenewk@umich.edu> Signed-off-by: Kristjan Greenewald <kristjan.h.greenewald@ibm.com> Co-authored-by: Kristjan Greenewald Kristjan.H.Greenewald@ibm.com <kgreenewald@login2.bluevela.rmf.ibm.com> Co-authored-by: Mehant Kammakomati <kmehant@gmail.com> Co-authored-by: Kristjan Greenewald Kristjan.H.Greenewald@ibm.com <kgreenewald@p2-r09-n2.bluevela.rmf.ibm.com>
1 parent 81177ce commit df8b94f

File tree

8 files changed

+428
-32
lines changed

8 files changed

+428
-32
lines changed

README.md

Lines changed: 129 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
- [Tips on Parameters to Set](#tips-on-parameters-to-set)
1010
- [Tuning Techniques](#tuning-techniques)
1111
- [LoRA Tuning Example](#lora-tuning-example)
12+
- [Activated LoRA Tuning Example](#activated-lora-tuning-example)
1213
- [GPTQ-LoRA with AutoGPTQ Tuning Example](#gptq-lora-with-autogptq-tuning-example)
1314
- [Fine Tuning](#fine-tuning)
1415
- [FMS Acceleration](#fms-acceleration)
@@ -454,7 +455,7 @@ To summarize you can pick either python for single-GPU jobs or use accelerate la
454455

455456
### Tips on Parameters to Set
456457

457-
#### Saving checkpoints while training
458+
#### Saving checkpoints while training (does not apply to Activated LoRA)
458459

459460
By default, [`save_strategy`](tuning/config/configs.py) is set to `"epoch"` in the TrainingArguments. This means that checkpoints will be saved on each epoch. This can also be set to `"steps"` to save on every `"save_steps"` or `"no"` to not save any checkpoints.
460461

@@ -700,6 +701,132 @@ post_process_vLLM_adapters_new_tokens(
700701

701702
_________________________
702703

704+
### Activated LoRA Tuning Example
705+
706+
Activated LoRA (aLoRA) is a new low rank adapter architecture that allows for reusing existing base model KV cache for more efficient inference. This approach is best suited for inference pipelines which rely on the base model for most tasks/generations, but use aLoRA adapter(s) to perform specialized task(s) within the chain. For example, checking or rewriting generated outputs of the base model.
707+
708+
[Paper](https://arxiv.org/abs/2504.12397)
709+
710+
[IBM Research Blogpost](https://research.ibm.com/blog/inference-friendly-aloras)
711+
712+
[Github](https://github.com/IBM/activated-lora)
713+
714+
**Usage** Usage is very similar to standard LoRA, with the key difference that an invocation_string must be specified so that the model knows when to turn on i.e "activate" the adapter weights. The model will scan any input strings (during training or at test time) for this invocation_string, and activate the adapter weights 1 token after the start of the sequence. If there are multiple instances of the invocation_string in the same input, it will activate at the last such instance.
715+
716+
**Note** Often (not always) aLoRA requires higher rank (r) than LoRA. r=32 can be a good starting point for challenging tasks.
717+
718+
**Installation** The Activated LoRA requirements are an optional install in pyproject.toml (activated-lora)
719+
720+
Set `peft_method` to `"alora"`.
721+
722+
You *must* pass in an invocation_string argument. This invocation_string *must be present* in both training data inputs and the input at test time. A good solution is to set invocation_string = response_template, this will ensure that every training input will have the invocation_string present. We keep these separate arguments for flexibility. It is most robust if the invocation_string begins and ends with special tokens.
723+
724+
You can additionally pass any arguments from [aLoraConfig](https://github.com/IBM/activated-lora/blob/fms-hf-tuning/alora/config.py#L35), see the LoRA section for examples.
725+
726+
Example command to run, here using the ([Granite Instruct response template](https://huggingface.co/ibm-granite/granite-3.0-8b-instruct/blob/main/tokenizer_config.json#L188)) as the invocation sequence:
727+
728+
```bash
729+
python tuning/sft_trainer.py \
730+
--model_name_or_path $MODEL_PATH \
731+
--tokenizer_name_or_path $MODEL_PATH \ # This field is optional and if not specified, tokenizer from model_name_or_path will be used
732+
--training_data_path $TRAIN_DATA_PATH \
733+
--output_dir $OUTPUT_PATH \
734+
--num_train_epochs 40 \
735+
--per_device_train_batch_size 4 \
736+
---learning_rate 1e-4 \
737+
--response_template "<|start_of_role|>assistant<|end_of_role|>" \ #this example uses special tokens in the Granite tokenizer, adjust for other models
738+
--invocation_string "<|start_of_role|>assistant<|end_of_role|>" \
739+
--dataset_text_field "output" \
740+
--peft_method "alora" \
741+
--r 32 \
742+
--lora_dropout 0.05 \
743+
--lora_alpha 16 \
744+
--target_modules q_proj k_proj v_proj
745+
```
746+
747+
Equally you can pass in a JSON configuration for running tuning. See [build doc](./build/README.md) for more details. The above can also be passed in as JSON:
748+
```json
749+
{
750+
"model_name_or_path": $MODEL_PATH,
751+
"training_data_path": $TRAIN_DATA_PATH,
752+
"output_dir": $OUTPUT_PATH,
753+
"num_train_epochs": 40.0,
754+
"per_device_train_batch_size": 4,
755+
"learning_rate": 1e-4,
756+
"response_template": "<|start_of_role|>assistant<|end_of_role|>",
757+
"invocation_string": "<|start_of_role|>assistant<|end_of_role|>",
758+
"dataset_text_field": "output",
759+
"peft_method": "alora",
760+
"r": 32,
761+
"lora_dropout": 0.05,
762+
"lora_alpha": 16,
763+
"target_modules": ["q_proj", "k_proj", "v_proj"]
764+
}
765+
```
766+
767+
Notice the `target_modules` are the names of the modules to apply the adapter to.
768+
- If this is specified, only the modules with the specified names will be replaced. When passing a list of strings, either an exact match will be performed or it is checked if the name of the module ends with any of the passed strings. If this is specified as `all-linear`, then all linear/Conv1D modules are chosen, excluding the output layer.
769+
- If this is not specified, modules will be chosen according to the model architecture. If the architecture is not known, an error will be raised — in this case, you should specify the target modules manually. See [HuggingFace docs](https://huggingface.co/docs/peft/en/package_reference/lora#peft.LoraConfig) for more details.
770+
771+
772+
#### How to get list of aLoRA target_modules of a model
773+
See [How to get list of LoRA target_modules of a model](#how-to-get-list-of-lora-target_modules-of-a-model).
774+
775+
#### Recommended target modules per model architecture
776+
As per [aLoRA paper](https://arxiv.org/abs/2504.12397), by using the key, query and value projection matrices, we can achieve good quality with efficient GPU utilization. Hence, while thinking about what aLoRA adapters to specify, we recommend starting with key, query and value matrices.
777+
778+
#### Intermediate checkpoint saving
779+
Note that `sft_trainer.py` will always save the final trained model for you. If you want to save intermediate checkpoints from within the training process, the below applies.
780+
781+
For now, `save_strategy` is not supported (it is always reset to `none`). You can either save the model once training is complete, or pass in a custom callback in `additional_callbacks` directly to `tuning.sft_trainer.train` to perform saving. For example the following (from [alora github](https://github.com/IBM/activated-lora/blob/fms-hf-tuning/train_scripts/finetune_example_callback.py)) saves and updates the best performing model so far, checking whenever eval is called according to `eval_strategy`:
782+
```py
783+
class SaveBestModelCallback(TrainerCallback):
784+
def __init__(self):
785+
self.best_eval_loss = float("inf") # Track best loss
786+
787+
def on_evaluate(self, args, state, control, **kwargs):
788+
"""Save the best model manually during evaluation."""
789+
790+
model = kwargs["model"]
791+
metrics = kwargs["metrics"]
792+
793+
eval_loss = metrics.get("eval_loss")
794+
if eval_loss is not None and eval_loss < self.best_eval_loss:
795+
self.best_eval_loss = eval_loss # Update best loss
796+
797+
# Manually save best model
798+
model.save_pretrained(args.output_dir)
799+
```
800+
#### Inference with aLoRA models
801+
*Important* Inference with aLoRA models requires nsuring that the invocation string is present in the input (usually the end).
802+
803+
Example inference:
804+
```py
805+
# Load the model
806+
loaded_model = TunedCausalLM.load(ALORA_MODEL, BASE_MODEL_NAME, use_alora=True)
807+
808+
# Retrieve the invocation string from the model config
809+
invocation_string = loaded_model.peft_model.peft_config[
810+
loaded_model.peft_model.active_adapter
811+
].invocation_string
812+
813+
# In this case, we have the invocation string at the end of the input
814+
input_string = "Simply put, the theory of relativity states that \n" + invocation_string
815+
816+
# Run inference on the text
817+
output_inference = loaded_model.run(
818+
input_string,
819+
max_new_tokens=50,
820+
)
821+
```
822+
823+
#### Running aLoRA models on VLLM
824+
825+
Coming soon! For now, there is inference support in this package, or see [aLoRA github](https://github.com/IBM/activated-lora/experiments/inference_example.py) for example code demonstrating KV cache reuse from prior base model calls.
826+
827+
__________
828+
829+
703830

704831
### GPTQ-LoRA with AutoGPTQ Tuning Example
705832

@@ -1037,4 +1164,4 @@ Further details on enabling and using the trackers mentioned above can be found
10371164

10381165
## More Examples
10391166

1040-
A good simple example can be found [here](examples/kfto-kueue-sft-trainer.yaml) which launches a Kubernetes-native `PyTorchJob` using the [Kubeflow Training Operator](https://github.com/kubeflow/training-operator/) with [Kueue](https://github.com/kubernetes-sigs/kueue) for the queue management of tuning jobs.
1167+
A good simple example can be found [here](examples/kfto-kueue-sft-trainer.yaml) which launches a Kubernetes-native `PyTorchJob` using the [Kubeflow Training Operator](https://github.com/kubeflow/training-operator/) with [Kueue](https://github.com/kubernetes-sigs/kueue) for the queue management of tuning jobs.

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ dependencies = [
3535
"tokenizers>=0.13.3,<1.0",
3636
"tqdm>=4.66.2,<5.0",
3737
"trl>=0.13,<0.18",
38-
"peft>=0.8.0,<0.14",
38+
"peft>=0.8.0,<=0.14",
3939
"protobuf>=5.28.0,<6.0.0",
4040
"datasets>=3.5.0,<4.0",
4141
"simpleeval>=0.9.13,<2.0",
@@ -51,6 +51,7 @@ fms-accel = ["fms-acceleration>=0.6"]
5151
gptq-dev = ["auto_gptq>0.4.2", "optimum>=1.15.0"]
5252
mamba = ["mamba_ssm[causal-conv1d]>=2.0.0,<3.0.0"]
5353
scanner-dev = ["HFResourceScanner>=0.1.0"]
54+
activated-lora = ["alora>=0.1.0"]
5455

5556

5657
[tool.setuptools.packages.find]

scripts/run_inference.py

Lines changed: 75 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -138,17 +138,19 @@ def __exit__(self, exc_type, exc_value, exc_tb):
138138

139139
### Funcs for loading and running models
140140
class TunedCausalLM:
141-
def __init__(self, model, tokenizer, device):
141+
def __init__(self, model, tokenizer, device, use_alora=False):
142142
self.peft_model = model
143143
self.tokenizer = tokenizer
144144
self.device = device
145+
self.use_alora = use_alora
145146

146147
@classmethod
147148
def load(
148149
cls,
149150
checkpoint_path: str,
150151
base_model_name_or_path: str = None,
151152
use_flash_attn: bool = False,
153+
use_alora: bool = False,
152154
) -> "TunedCausalLM":
153155
"""Loads an instance of this model.
154156
@@ -222,14 +224,36 @@ def load(
222224
tokenizer_and_embedding_resize(
223225
{}, tokenizer=tokenizer, model=base_model
224226
)
225-
model = PeftModel.from_pretrained(
226-
base_model,
227-
checkpoint_path,
228-
attn_implementation="flash_attention_2"
229-
if use_flash_attn
230-
else None,
231-
torch_dtype=torch.bfloat16 if use_flash_attn else None,
232-
)
227+
if use_alora:
228+
# Third Party
229+
try:
230+
# Third Party
231+
from alora.peft_model_alora import ( # pylint: disable=import-outside-toplevel
232+
aLoRAPeftModelForCausalLM,
233+
)
234+
235+
model = aLoRAPeftModelForCausalLM.from_pretrained(
236+
base_model,
237+
checkpoint_path,
238+
attn_implementation="flash_attention_2"
239+
if use_flash_attn
240+
else None,
241+
torch_dtype=torch.bfloat16 if use_flash_attn else None,
242+
)
243+
except ImportError as exc:
244+
raise ImportError(
245+
"The alora package is required for this operation. "
246+
"Please install it with pip install alora."
247+
) from exc
248+
else:
249+
model = PeftModel.from_pretrained(
250+
base_model,
251+
checkpoint_path,
252+
attn_implementation="flash_attention_2"
253+
if use_flash_attn
254+
else None,
255+
torch_dtype=torch.bfloat16 if use_flash_attn else None,
256+
)
233257
except (OSError, ValueError) as e:
234258
print("Failed to initialize checkpoint model!")
235259
raise e
@@ -259,10 +283,14 @@ def load(
259283
)
260284

261285
model.to(device)
262-
return cls(model, tokenizer, device)
286+
return cls(model, tokenizer, device, use_alora)
263287

264288
def run(
265-
self, text: str, *, max_new_tokens: int, ret_gen_text_only: bool = False
289+
self,
290+
text: str,
291+
*,
292+
max_new_tokens: int,
293+
ret_gen_text_only: bool = False,
266294
) -> str:
267295
"""Runs inference on an instance of this model.
268296
@@ -279,12 +307,36 @@ def run(
279307
str
280308
Text generation result.
281309
"""
282-
tok_res = self.tokenizer(text, return_tensors="pt")
283-
input_ids = tok_res.input_ids.to(self.device)
284-
285-
peft_outputs = self.peft_model.generate(
286-
input_ids=input_ids, max_new_tokens=max_new_tokens
287-
)
310+
if not self.use_alora:
311+
tok_res = self.tokenizer(text, return_tensors="pt")
312+
input_ids = tok_res.input_ids.to(self.device)
313+
peft_outputs = self.peft_model.generate(
314+
input_ids=input_ids, max_new_tokens=max_new_tokens
315+
)
316+
else: # pass in alora_offsets needed for alora model
317+
# Retrieve invocation string
318+
invocation_string = self.peft_model.peft_config[
319+
self.peft_model.active_adapter
320+
].invocation_string
321+
# Find the invocation string in input
322+
if invocation_string in text:
323+
before, after = text.rsplit(invocation_string, 1)
324+
after = invocation_string + after
325+
else:
326+
raise ValueError(
327+
f"aLoRA invocation string '{invocation_string}' not found in input '{text}'."
328+
)
329+
# Tokenize separately to enforce correct token boundary
330+
before_ids = self.tokenizer(before, return_tensors="pt").input_ids
331+
after_ids = self.tokenizer(invocation_string, return_tensors="pt").input_ids
332+
alora_offsets = [after_ids.shape[1] - 1]
333+
input_ids = torch.cat([before_ids, after_ids], dim=1).to(self.device)
334+
335+
peft_outputs = self.peft_model.generate(
336+
input_ids=input_ids,
337+
max_new_tokens=max_new_tokens,
338+
alora_offsets=alora_offsets,
339+
)
288340
if ret_gen_text_only:
289341
tok_to_decode = peft_outputs[:, input_ids.shape[1] :]
290342
else:
@@ -308,6 +360,11 @@ def main():
308360
help="JSON file to write results to",
309361
default="inference_result.json",
310362
)
363+
parser.add_argument(
364+
"--use_alora",
365+
help="Whether to use alora",
366+
default=False,
367+
)
311368
parser.add_argument(
312369
"--base_model_name_or_path",
313370
help="Override for base model to be used for non-merged models \
@@ -341,6 +398,7 @@ def main():
341398
checkpoint_path=args.model,
342399
base_model_name_or_path=args.base_model_name_or_path,
343400
use_flash_attn=args.use_flash_attn,
401+
use_alora=args.use_alora,
344402
)
345403

346404
# Run inference on the text; if multiple were provided, process them all

0 commit comments

Comments
 (0)