Skip to content

Commit 06ad63a

Browse files
feat: alora migration documentation and nit fixes
Signed-off-by: yashasvi <yashasvi@ibm.com>
1 parent 63760db commit 06ad63a

File tree

5 files changed

+8
-23
lines changed

5 files changed

+8
-23
lines changed

docs/tuning-techniques.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -218,13 +218,13 @@ Activated LoRA (aLoRA) is a new low rank adapter architecture that allows for re
218218

219219
**Note** Often (not always) aLoRA requires higher rank (r) than LoRA. r=32 can be a good starting point for challenging tasks.
220220

221-
**Installation** The Activated LoRA requirements are an optional install in pyproject.toml (activated-lora)
221+
**Installation** Native aLoRA support requires PEFT release with `alora_invocation_tokens` support at [PR#2609](https://github.com/huggingface/peft/pull/2609)
222222

223223
Set `peft_method` to `"alora"`.
224224

225225
You *must* pass in an invocation_string argument. This invocation_string *must be present* in both training data inputs and the input at test time. A good solution is to set invocation_string = response_template, this will ensure that every training input will have the invocation_string present. We keep these separate arguments for flexibility. It is most robust if the invocation_string begins and ends with special tokens.
226226

227-
You can additionally pass any arguments from [aLoraConfig](https://github.com/IBM/activated-lora/blob/fms-hf-tuning/alora/config.py#L35), see the LoRA section for examples.
227+
You can additionally pass any arguments from `LoraConfig`, see the LoRA section for examples.
228228

229229
Example command to run, here using the ([Granite Instruct response template](https://huggingface.co/ibm-granite/granite-3.0-8b-instruct/blob/main/tokenizer_config.json#L188)) as the invocation sequence:
230230

@@ -306,7 +306,7 @@ class SaveBestModelCallback(TrainerCallback):
306306
Example inference:
307307
```py
308308
# Load the model
309-
loaded_model = TunedCausalLM.load(ALORA_MODEL, BASE_MODEL_NAME, use_alora=True)
309+
loaded_model = TunedCausalLM.load(ALORA_MODEL, BASE_MODEL_NAME)
310310

311311
# Retrieve the invocation string from the model config
312312
invocation_string = loaded_model.peft_model.peft_config[

tests/test_sft_trainer.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@
9797
load_and_validate_data_config,
9898
)
9999
from tuning.data.data_handlers import DataHandler, DataHandlerType
100-
from tuning.utils.import_utils import is_alora_available, is_fms_accelerate_available
100+
from tuning.utils.import_utils import is_fms_accelerate_available
101101

102102
MODEL_NAME = MAYKEYE_TINY_LLAMA_CACHED
103103

@@ -153,7 +153,6 @@
153153

154154
if hasattr(HFLoraConfig, "alora_invocation_tokens"):
155155
PEFT_ALORA_ARGS = peft_config.LoraConfig(r=8, lora_alpha=32, lora_dropout=0.05)
156-
PEFT_ALORA_ARGS.alora_invocation_tokens = [42]
157156
else:
158157
PEFT_ALORA_ARGS = None
159158

@@ -745,10 +744,6 @@ def test_run_causallm_lora_and_inference(request, target_modules, expected):
745744
assert "Simply put, the theory of relativity states that" in output_inference
746745

747746

748-
@pytest.mark.skipif(
749-
not is_alora_available(),
750-
reason="Only runs if alora is installed",
751-
)
752747
@pytest.mark.parametrize(
753748
"target_modules,expected",
754749
target_modules_val_map,

tuning/sft_trainer.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -676,8 +676,6 @@ def parse_arguments(parser, json_config=None):
676676
invocation_string = additional.invocation_string
677677

678678
if peft_method == peft_config.PEFT_METHOD.ALORA.value:
679-
if invocation_string is None:
680-
raise ValueError("invocation_string is required for aLoRA usage")
681679
tune_config = lora_config
682680
setattr(tune_config, "alora_invocation_string", invocation_string)
683681
elif peft_method == peft_config.PEFT_METHOD.LORA.value:

tuning/utils/config_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,10 @@ def create_tuning_config(peft_method, **kwargs):
6161
"pt",
6262
"None",
6363
], f"peft config {peft_method} not defined in peft.py"
64-
if peft_method in ("alora", "lora"):
64+
if peft_method in (
65+
peft_config.PEFT_METHOD.ALORA.value,
66+
peft_config.PEFT_METHOD.LORA.value,
67+
):
6568
tune_config = peft_config.LoraConfig()
6669
update_config(tune_config, **kwargs)
6770
elif peft_method == "pt":

tuning/utils/import_utils.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,3 @@ def is_fms_accelerate_available(
3232
if not _is_package_available(n):
3333
return False
3434
return True
35-
36-
37-
def is_alora_available() -> bool:
38-
try:
39-
# Third Party
40-
from peft import LoraConfig # pylint: disable=import-outside-toplevel
41-
42-
# Check if LoraConfig has the new Activated LoRA field
43-
return hasattr(LoraConfig, "alora_invocation_tokens")
44-
except ImportError:
45-
return False

0 commit comments

Comments
 (0)