Skip to content

Commit 02fd50c

Browse files
feat: alora migration documentation and nit fixes
Signed-off-by: yashasvi <yashasvi@ibm.com>
1 parent 63760db commit 02fd50c

File tree

6 files changed

+70
-53
lines changed

6 files changed

+70
-53
lines changed

docs/tuning-techniques.md

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -214,17 +214,17 @@ Activated LoRA (aLoRA) is a new low rank adapter architecture that allows for re
214214

215215
[Github](https://github.com/IBM/activated-lora)
216216

217-
**Usage** Usage is very similar to standard LoRA, with the key difference that an invocation_string must be specified so that the model knows when to turn on i.e "activate" the adapter weights. The model will scan any input strings (during training or at test time) for this invocation_string, and activate the adapter weights 1 token after the start of the sequence. If there are multiple instances of the invocation_string in the same input, it will activate at the last such instance.
217+
**Usage** Usage is very similar to standard LoRA, with the key difference that an alora_invocation_string must be specified so that the model knows when to turn on i.e "activate" the adapter weights. The model will scan any input strings (during training or at test time) for this alora_invocation_string, and activate the adapter weights 1 token after the start of the sequence. If there are multiple instances of the alora_invocation_string in the same input, it will activate at the last such instance.
218218

219219
**Note** Often (not always) aLoRA requires higher rank (r) than LoRA. r=32 can be a good starting point for challenging tasks.
220220

221-
**Installation** The Activated LoRA requirements are an optional install in pyproject.toml (activated-lora)
221+
**Installation** ALoRA support is provided via [HF PEFT](https://github.com/huggingface/peft) library later than this [patch](https://github.com/huggingface/peft/pull/2609)
222222

223223
Set `peft_method` to `"alora"`.
224224

225-
You *must* pass in an invocation_string argument. This invocation_string *must be present* in both training data inputs and the input at test time. A good solution is to set invocation_string = response_template, this will ensure that every training input will have the invocation_string present. We keep these separate arguments for flexibility. It is most robust if the invocation_string begins and ends with special tokens.
225+
You *must* pass in an alora_invocation_string argument. This alora_invocation_string *must be present* in both training data inputs and the input at test time. A good solution is to set alora_invocation_string = response_template, this will ensure that every training input will have the alora_invocation_string present. We keep these separate arguments for flexibility. It is most robust if the alora_invocation_string begins and ends with special tokens.
226226

227-
You can additionally pass any arguments from [aLoraConfig](https://github.com/IBM/activated-lora/blob/fms-hf-tuning/alora/config.py#L35), see the LoRA section for examples.
227+
You can additionally pass any arguments from `LoraConfig`, see the LoRA section for examples.
228228

229229
Example command to run, here using the ([Granite Instruct response template](https://huggingface.co/ibm-granite/granite-3.0-8b-instruct/blob/main/tokenizer_config.json#L188)) as the invocation sequence:
230230

@@ -236,9 +236,9 @@ python tuning/sft_trainer.py \
236236
--output_dir $OUTPUT_PATH \
237237
--num_train_epochs 40 \
238238
--per_device_train_batch_size 4 \
239-
---learning_rate 1e-4 \
239+
--learning_rate 1e-4 \
240240
--response_template "<|start_of_role|>assistant<|end_of_role|>" \ #this example uses special tokens in the Granite tokenizer, adjust for other models
241-
--invocation_string "<|start_of_role|>assistant<|end_of_role|>" \
241+
--alora_invocation_string "<|start_of_role|>assistant<|end_of_role|>" \
242242
--dataset_text_field "output" \
243243
--peft_method "alora" \
244244
--r 32 \
@@ -257,7 +257,7 @@ Equally you can pass in a JSON configuration for running tuning. See [build doc]
257257
"per_device_train_batch_size": 4,
258258
"learning_rate": 1e-4,
259259
"response_template": "<|start_of_role|>assistant<|end_of_role|>",
260-
"invocation_string": "<|start_of_role|>assistant<|end_of_role|>",
260+
"alora_invocation_string": "<|start_of_role|>assistant<|end_of_role|>",
261261
"dataset_text_field": "output",
262262
"peft_method": "alora",
263263
"r": 32,
@@ -306,15 +306,15 @@ class SaveBestModelCallback(TrainerCallback):
306306
Example inference:
307307
```py
308308
# Load the model
309-
loaded_model = TunedCausalLM.load(ALORA_MODEL, BASE_MODEL_NAME, use_alora=True)
309+
loaded_model = TunedCausalLM.load(ALORA_MODEL, BASE_MODEL_NAME)
310310

311311
# Retrieve the invocation string from the model config
312-
invocation_string = loaded_model.peft_model.peft_config[
312+
alora_invocation_string = loaded_model.peft_model.peft_config[
313313
loaded_model.peft_model.active_adapter
314-
].invocation_string
314+
].alora_invocation_string
315315

316316
# In this case, we have the invocation string at the end of the input
317-
input_string = "Simply put, the theory of relativity states that \n" + invocation_string
317+
input_string = "Simply put, the theory of relativity states that \n" + alora_invocation_string
318318

319319
# Run inference on the text
320320
output_inference = loaded_model.run(

tests/test_sft_trainer.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@
9797
load_and_validate_data_config,
9898
)
9999
from tuning.data.data_handlers import DataHandler, DataHandlerType
100-
from tuning.utils.import_utils import is_alora_available, is_fms_accelerate_available
100+
from tuning.utils.import_utils import is_fms_accelerate_available
101101

102102
MODEL_NAME = MAYKEYE_TINY_LLAMA_CACHED
103103

@@ -153,7 +153,6 @@
153153

154154
if hasattr(HFLoraConfig, "alora_invocation_tokens"):
155155
PEFT_ALORA_ARGS = peft_config.LoraConfig(r=8, lora_alpha=32, lora_dropout=0.05)
156-
PEFT_ALORA_ARGS.alora_invocation_tokens = [42]
157156
else:
158157
PEFT_ALORA_ARGS = None
159158

@@ -745,10 +744,6 @@ def test_run_causallm_lora_and_inference(request, target_modules, expected):
745744
assert "Simply put, the theory of relativity states that" in output_inference
746745

747746

748-
@pytest.mark.skipif(
749-
not is_alora_available(),
750-
reason="Only runs if alora is installed",
751-
)
752747
@pytest.mark.parametrize(
753748
"target_modules,expected",
754749
target_modules_val_map,

tuning/config/peft_config.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,32 @@ class LoraConfig(HFLoraConfig):
5555
lora_alpha: int = 32
5656
lora_dropout: float = 0.05
5757

58+
# Activated LoRA fields (optional)
59+
alora_invocation_string: Optional[str] = field(
60+
default=None,
61+
metadata={
62+
"help": (
63+
"Human readable invocation string for aLoRA. If set, the training code "
64+
"will tokenize this with the model tokenizer and persist the resulting "
65+
"token ids in `alora_invocation_tokens` so the adapter can be activated "
66+
"at inference time. This field is optional; users may instead set the "
67+
"`alora_invocation_tokens` directly (list of ints)."
68+
)
69+
},
70+
)
71+
72+
alora_invocation_tokens: Optional[List[int]] = field(
73+
default=None,
74+
metadata={
75+
"help": (
76+
"Token ids for the aLoRA invocation sequence. If provided, these will be "
77+
"used directly and will take precedence over `alora_invocation_string`. "
78+
"If not provided but `alora_invocation_string` is, the training flow will "
79+
"tokenize the string and populate this field before training continues."
80+
)
81+
},
82+
)
83+
5884
# HACK: The following list of arguments listed below
5985
# is a fix which reduces the field annotation from
6086
# Optional[List[str], str] type to Optional[List[str]] type

tuning/sft_trainer.py

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -363,18 +363,24 @@ def train(
363363
additional_metrics["model_load_time"] = time.time() - model_load_time
364364

365365
# Convert legacy aLoRA string → token IDs (PEFT-native aLoRA)
366-
if peft_config is not None and hasattr(peft_config, "alora_invocation_string"):
367-
inv_str = getattr(peft_config, "alora_invocation_string")
368-
if not inv_str:
369-
raise ValueError(
370-
"`--invocation_string` is required when using --peft_method alora."
371-
)
372-
alora_tokens = tokenizer.encode(inv_str, add_special_tokens=False)
373-
if not alora_tokens:
374-
raise ValueError(
375-
"`--invocation_string` produced no tokens; check your tokenizer/template."
376-
)
377-
setattr(peft_config, "alora_invocation_tokens", alora_tokens)
366+
if peft_config is not None:
367+
inv_str = getattr(peft_config, "alora_invocation_string", None)
368+
has_string = isinstance(inv_str, str) and inv_str.strip() != ""
369+
has_tokens = hasattr(peft_config, "alora_invocation_tokens") and bool(
370+
getattr(peft_config, "alora_invocation_tokens")
371+
)
372+
373+
if has_string:
374+
alora_tokens = tokenizer.encode(inv_str, add_special_tokens=False)
375+
if not alora_tokens:
376+
raise ValueError(
377+
"`alora_invocation_string` produced no tokens; check your tokenizer/template."
378+
)
379+
setattr(peft_config, "alora_invocation_tokens", alora_tokens)
380+
381+
elif not has_tokens:
382+
# Only raise if neither tokens nor string present
383+
raise ValueError("`alora_invocation_string` is required when using aLoRA.")
378384

379385
peft_config = get_hf_peft_config(
380386
task_type,
@@ -587,13 +593,13 @@ def get_parser():
587593
help='Pass a json string representing K:V pairs to be associated\
588594
to the tuning run in the tracker. e.g. \'{"gpu":"A100-80G"}\'',
589595
)
590-
parser.add_argument(
591-
"--invocation_string",
592-
type=str,
593-
default=None,
594-
help="Pass a invocation string that will be used to activate the aLoRA.\
595-
This needs to be present in each training data row.",
596-
)
596+
# parser.add_argument(
597+
# "--alora_invocation_string",
598+
# type=str,
599+
# default=None,
600+
# help="Pass a invocation string that will be used to activate the aLoRA.\
601+
# This needs to be present in each training data row.",
602+
# )
597603
return parser
598604

599605

@@ -651,7 +657,7 @@ def parse_arguments(parser, json_config=None):
651657
peft_method = json_config.get("peft_method")
652658
exp_metadata = json_config.get("exp_metadata")
653659
quantization_method = json_config.get("quantization_method")
654-
invocation_string = json_config.get("invocation_string")
660+
# alora_invocation_string = json_config.get("alora_invocation_string")
655661
else:
656662
(
657663
model_args,
@@ -673,13 +679,11 @@ def parse_arguments(parser, json_config=None):
673679
peft_method = additional.peft_method
674680
exp_metadata = additional.exp_metadata
675681
quantization_method = additional.quantization_method
676-
invocation_string = additional.invocation_string
682+
# alora_invocation_string = additional.alora_invocation_string
677683

678684
if peft_method == peft_config.PEFT_METHOD.ALORA.value:
679-
if invocation_string is None:
680-
raise ValueError("invocation_string is required for aLoRA usage")
681685
tune_config = lora_config
682-
setattr(tune_config, "alora_invocation_string", invocation_string)
686+
# setattr(tune_config, "alora_invocation_string", alora_invocation_string)
683687
elif peft_method == peft_config.PEFT_METHOD.LORA.value:
684688
tune_config = lora_config
685689
elif peft_method == peft_config.PEFT_METHOD.PT.value:

tuning/utils/config_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,10 @@ def create_tuning_config(peft_method, **kwargs):
6161
"pt",
6262
"None",
6363
], f"peft config {peft_method} not defined in peft.py"
64-
if peft_method in ("alora", "lora"):
64+
if peft_method in (
65+
peft_config.PEFT_METHOD.ALORA.value,
66+
peft_config.PEFT_METHOD.LORA.value,
67+
):
6568
tune_config = peft_config.LoraConfig()
6669
update_config(tune_config, **kwargs)
6770
elif peft_method == "pt":

tuning/utils/import_utils.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,3 @@ def is_fms_accelerate_available(
3232
if not _is_package_available(n):
3333
return False
3434
return True
35-
36-
37-
def is_alora_available() -> bool:
38-
try:
39-
# Third Party
40-
from peft import LoraConfig # pylint: disable=import-outside-toplevel
41-
42-
# Check if LoraConfig has the new Activated LoRA field
43-
return hasattr(LoraConfig, "alora_invocation_tokens")
44-
except ImportError:
45-
return False

0 commit comments

Comments
 (0)