diff --git a/build/Dockerfile b/build/Dockerfile index e893398d5..836373b54 100644 --- a/build/Dockerfile +++ b/build/Dockerfile @@ -21,7 +21,6 @@ ARG PYTHON_VERSION=3.12 ARG WHEEL_VERSION="" ## Enable Aimstack or MLflow if requested via ENABLE_AIM/MLFLOW set to "true" ARG ENABLE_AIM=false -ARG ENABLE_ALORA=false ARG ENABLE_MLFLOW=false ARG ENABLE_FMS_ACCELERATION=true ARG ENABLE_SCANNER=false @@ -127,7 +126,6 @@ ARG USER_UID ARG ENABLE_FMS_ACCELERATION ARG ENABLE_AIM ARG ENABLE_MLFLOW -ARG ENABLE_ALORA ARG ENABLE_SCANNER ARG ENABLE_CLEARML @@ -179,10 +177,6 @@ RUN if [[ "${ENABLE_AIM}" == "true" ]]; then \ python -m pip install --user "$(head bdist_name)[aim]"; \ fi -RUN if [[ "${ENABLE_ALORA}" == "true" ]]; then \ - python -m pip install --user "$(head bdist_name)[activated-lora]"; \ - fi - RUN if [[ "${ENABLE_MLFLOW}" == "true" ]]; then \ python -m pip install --user "$(head bdist_name)[mlflow]"; \ fi diff --git a/build/nvcr.Dockerfile b/build/nvcr.Dockerfile index 7abf4ee98..e277b7f25 100644 --- a/build/nvcr.Dockerfile +++ b/build/nvcr.Dockerfile @@ -30,7 +30,6 @@ ARG SOURCE_DIR=${WORKDIR}/fms-hf-tuning ARG ENABLE_FMS_ACCELERATION=true ARG ENABLE_AIM=true -ARG ENABLE_ALORA=true ARG ENABLE_MLFLOW=true ARG ENABLE_SCANNER=true ARG ENABLE_CLEARML=true @@ -61,9 +60,6 @@ RUN if [[ "${ENABLE_FMS_ACCELERATION}" == "true" ]]; then \ python -m fms_acceleration.cli install fms_acceleration_odm; \ fi -RUN if [[ "${ENABLE_ALORA}" == "true" ]]; then \ - pip install --no-cache-dir ${SOURCE_DIR}[activated-lora]; \ - fi RUN if [[ "${ENABLE_AIM}" == "true" ]]; then \ pip install --no-cache-dir ${SOURCE_DIR}[aim]; \ fi diff --git a/docs/tuning-techniques.md b/docs/tuning-techniques.md index 12f23491c..abbc9ceff 100644 --- a/docs/tuning-techniques.md +++ b/docs/tuning-techniques.md @@ -214,17 +214,17 @@ Activated LoRA (aLoRA) is a new low rank adapter architecture that allows for re [Github](https://github.com/IBM/activated-lora) -**Usage** Usage is very similar to standard LoRA, with the key difference that an invocation_string must be specified so that the model knows when to turn on i.e "activate" the adapter weights. The model will scan any input strings (during training or at test time) for this invocation_string, and activate the adapter weights 1 token after the start of the sequence. If there are multiple instances of the invocation_string in the same input, it will activate at the last such instance. +**Usage** Usage is very similar to standard LoRA, with the key difference that an alora_invocation_string must be specified so that the model knows when to turn on i.e "activate" the adapter weights. The model will scan any input strings (during training or at test time) for this alora_invocation_string, and activate the adapter weights 1 token after the start of the sequence. If there are multiple instances of the alora_invocation_string in the same input, it will activate at the last such instance. **Note** Often (not always) aLoRA requires higher rank (r) than LoRA. r=32 can be a good starting point for challenging tasks. -**Installation** The Activated LoRA requirements are an optional install in pyproject.toml (activated-lora) +**Installation** ALoRA support is provided via [HF PEFT](https://github.com/huggingface/peft) library later than this [patch](https://github.com/huggingface/peft/pull/2609) Set `peft_method` to `"alora"`. -You *must* pass in an invocation_string argument. This invocation_string *must be present* in both training data inputs and the input at test time. A good solution is to set invocation_string = response_template, this will ensure that every training input will have the invocation_string present. We keep these separate arguments for flexibility. It is most robust if the invocation_string begins and ends with special tokens. +You *must* pass in an alora_invocation_string argument. This alora_invocation_string *must be present* in both training data inputs and the input at test time. A good solution is to set alora_invocation_string = response_template, this will ensure that every training input will have the alora_invocation_string present. We keep these separate arguments for flexibility. It is most robust if the alora_invocation_string begins and ends with special tokens. -You can additionally pass any arguments from [aLoraConfig](https://github.com/IBM/activated-lora/blob/fms-hf-tuning/alora/config.py#L35), see the LoRA section for examples. +You can additionally pass any arguments from `LoraConfig`, see the LoRA section for examples. Example command to run, here using the ([Granite Instruct response template](https://huggingface.co/ibm-granite/granite-3.0-8b-instruct/blob/main/tokenizer_config.json#L188)) as the invocation sequence: @@ -236,9 +236,9 @@ python tuning/sft_trainer.py \ --output_dir $OUTPUT_PATH \ --num_train_epochs 40 \ --per_device_train_batch_size 4 \ ----learning_rate 1e-4 \ +--learning_rate 1e-4 \ --response_template "<|start_of_role|>assistant<|end_of_role|>" \ #this example uses special tokens in the Granite tokenizer, adjust for other models ---invocation_string "<|start_of_role|>assistant<|end_of_role|>" \ +--alora_invocation_string "<|start_of_role|>assistant<|end_of_role|>" \ --dataset_text_field "output" \ --peft_method "alora" \ --r 32 \ @@ -257,7 +257,7 @@ Equally you can pass in a JSON configuration for running tuning. See [build doc] "per_device_train_batch_size": 4, "learning_rate": 1e-4, "response_template": "<|start_of_role|>assistant<|end_of_role|>", - "invocation_string": "<|start_of_role|>assistant<|end_of_role|>", + "alora_invocation_string": "<|start_of_role|>assistant<|end_of_role|>", "dataset_text_field": "output", "peft_method": "alora", "r": 32, @@ -306,15 +306,15 @@ class SaveBestModelCallback(TrainerCallback): Example inference: ```py # Load the model -loaded_model = TunedCausalLM.load(ALORA_MODEL, BASE_MODEL_NAME, use_alora=True) +loaded_model = TunedCausalLM.load(ALORA_MODEL, BASE_MODEL_NAME) # Retrieve the invocation string from the model config -invocation_string = loaded_model.peft_model.peft_config[ +alora_invocation_string = loaded_model.peft_model.peft_config[ loaded_model.peft_model.active_adapter -].invocation_string +].alora_invocation_string # In this case, we have the invocation string at the end of the input -input_string = "Simply put, the theory of relativity states that \n" + invocation_string +input_string = "Simply put, the theory of relativity states that \n" + alora_invocation_string # Run inference on the text output_inference = loaded_model.run( diff --git a/pyproject.toml b/pyproject.toml index 704c2184a..10fc72b4e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,7 @@ dependencies = [ "tokenizers<=0.22", "tqdm>=4.66.2,<5.0", "trl>=0.19.1,<0.20.0", -"peft>=0.17.0,<0.18.0", +"peft @ git+https://github.com/huggingface/peft.git@293aea5df6db240856a77f89955d1a89ce38b50d", "datasets>=4.0.0,<5.0.0", "simpleeval>=0.9.13,<2.0", "pillow>=11.0.0,<12.0", @@ -52,7 +52,6 @@ fms-accel = ["fms-acceleration>=0.6.2"] gptq-dev = ["auto_gptq>0.4.2", "optimum>=1.15.0"] mamba = ["mamba_ssm[causal-conv1d]>=2.0.0,<3.0.0"] scanner-dev = ["HFResourceScanner>=0.1.0"] -activated-lora = ["alora>=0.3.0"] [tool.setuptools.packages.find] exclude = ["tests", "tests.*"] diff --git a/scripts/run_inference.py b/scripts/run_inference.py index a4d1d06b7..691107ba4 100644 --- a/scripts/run_inference.py +++ b/scripts/run_inference.py @@ -138,11 +138,10 @@ def __exit__(self, exc_type, exc_value, exc_tb): ### Funcs for loading and running models class TunedCausalLM: - def __init__(self, model, tokenizer, device, use_alora=False): + def __init__(self, model, tokenizer, device): self.peft_model = model self.tokenizer = tokenizer self.device = device - self.use_alora = use_alora @classmethod def load( @@ -150,7 +149,6 @@ def load( checkpoint_path: str, base_model_name_or_path: str = None, use_flash_attn: bool = False, - use_alora: bool = False, ) -> "TunedCausalLM": """Loads an instance of this model. @@ -224,36 +222,14 @@ def load( tokenizer_and_embedding_resize( {}, tokenizer=tokenizer, model=base_model ) - if use_alora: - # Third Party - try: - # Third Party - from alora.peft_model_alora import ( # pylint: disable=import-outside-toplevel - aLoRAPeftModelForCausalLM, - ) - - model = aLoRAPeftModelForCausalLM.from_pretrained( - base_model, - checkpoint_path, - attn_implementation="flash_attention_2" - if use_flash_attn - else None, - torch_dtype=torch.bfloat16 if use_flash_attn else None, - ) - except ImportError as exc: - raise ImportError( - "The alora package is required for this operation. " - "Please install it with pip install alora." - ) from exc - else: - model = PeftModel.from_pretrained( - base_model, - checkpoint_path, - attn_implementation="flash_attention_2" - if use_flash_attn - else None, - torch_dtype=torch.bfloat16 if use_flash_attn else None, - ) + model = PeftModel.from_pretrained( + base_model, + checkpoint_path, + attn_implementation="flash_attention_2" + if use_flash_attn + else None, + torch_dtype=torch.bfloat16 if use_flash_attn else None, + ) except (OSError, ValueError) as e: print("Failed to initialize checkpoint model!") raise e @@ -283,7 +259,7 @@ def load( ) model.to(device) - return cls(model, tokenizer, device, use_alora) + return cls(model, tokenizer, device) def run( self, @@ -307,42 +283,16 @@ def run( str Text generation result. """ - if not self.use_alora: - tok_res = self.tokenizer(text, return_tensors="pt") - input_ids = tok_res.input_ids.to(self.device) - peft_outputs = self.peft_model.generate( - input_ids=input_ids, max_new_tokens=max_new_tokens - ) - else: # pass in alora_offsets needed for alora model - # Retrieve invocation string - invocation_string = self.peft_model.peft_config[ - self.peft_model.active_adapter - ].invocation_string - # Find the invocation string in input - if invocation_string in text: - before, after = text.rsplit(invocation_string, 1) - after = invocation_string + after - else: - raise ValueError( - f"aLoRA invocation string '{invocation_string}' not found in input '{text}'." - ) - # Tokenize separately to enforce correct token boundary - before_ids = self.tokenizer(before, return_tensors="pt").input_ids - after_ids = self.tokenizer(invocation_string, return_tensors="pt").input_ids - alora_offsets = [after_ids.shape[1] - 1] - input_ids = torch.cat([before_ids, after_ids], dim=1).to(self.device) - - peft_outputs = self.peft_model.generate( - input_ids=input_ids, - max_new_tokens=max_new_tokens, - alora_offsets=alora_offsets, - ) - if ret_gen_text_only: - tok_to_decode = peft_outputs[:, input_ids.shape[1] :] - else: - tok_to_decode = peft_outputs + tok_res = self.tokenizer(text, return_tensors="pt") + input_ids = tok_res.input_ids.to(self.device) + peft_outputs = self.peft_model.generate( + input_ids=input_ids, max_new_tokens=max_new_tokens + ) + tok_to_decode = ( + peft_outputs[:, input_ids.shape[1] :] if ret_gen_text_only else peft_outputs + ) decoded_result = self.tokenizer.batch_decode( - tok_to_decode, skip_special_tokens=False + tok_to_decode, skip_special_tokens=ret_gen_text_only )[0] return decoded_result @@ -360,11 +310,6 @@ def main(): help="JSON file to write results to", default="inference_result.json", ) - parser.add_argument( - "--use_alora", - help="Whether to use alora", - default=False, - ) parser.add_argument( "--base_model_name_or_path", help="Override for base model to be used for non-merged models \ @@ -398,7 +343,6 @@ def main(): checkpoint_path=args.model, base_model_name_or_path=args.base_model_name_or_path, use_flash_attn=args.use_flash_attn, - use_alora=args.use_alora, ) # Run inference on the text; if multiple were provided, process them all diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index f9f1d3810..9cf8191aa 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -27,6 +27,7 @@ # Third Party from datasets.exceptions import DatasetGenerationError, DatasetNotFoundError +from peft import LoraConfig as HFLoraConfig from transformers.trainer_callback import TrainerCallback import pytest import torch @@ -96,7 +97,7 @@ load_and_validate_data_config, ) from tuning.data.data_handlers import DataHandler, DataHandlerType -from tuning.utils.import_utils import is_alora_available, is_fms_accelerate_available +from tuning.utils.import_utils import is_fms_accelerate_available MODEL_NAME = MAYKEYE_TINY_LLAMA_CACHED @@ -147,15 +148,13 @@ ) PEFT_LORA_ARGS = peft_config.LoraConfig(r=8, lora_alpha=32, lora_dropout=0.05) -try: # Optional package - # Third Party - from alora.config import aLoraConfig - PEFT_ALORA_ARGS = aLoraConfig( - r=8, lora_alpha=32, lora_dropout=0.05, invocation_string="Label:" - ) -except ImportError: - pass +INVOCATION_STR = "Label:" + +if hasattr(HFLoraConfig, "alora_invocation_tokens"): + PEFT_ALORA_ARGS = peft_config.LoraConfig(r=8, lora_alpha=32, lora_dropout=0.05) +else: + PEFT_ALORA_ARGS = None @pytest.mark.parametrize( @@ -745,22 +744,25 @@ def test_run_causallm_lora_and_inference(request, target_modules, expected): assert "Simply put, the theory of relativity states that" in output_inference -@pytest.mark.skipif( - not is_alora_available(), - reason="Only runs if alora is installed", -) @pytest.mark.parametrize( "target_modules,expected", target_modules_val_map, ids=["default", "custom_target_modules", "all_linear_target_modules"], ) def test_run_causallm_alora_and_inference(request, target_modules, expected): - """Check if we can bootstrap and alora tune causallm models""" + """Check if we can bootstrap and alora tune causallm models via PEFT-native aLoRA.""" with tempfile.TemporaryDirectory() as tempdir: train_args = copy.deepcopy(TRAIN_ALORA_ARGS) train_args.output_dir = tempdir base_alora_args = copy.deepcopy(PEFT_ALORA_ARGS) + tokenizer = transformers.AutoTokenizer.from_pretrained( + MODEL_NAME, use_fast=True, legacy=True + ) + base_alora_args.alora_invocation_tokens = tokenizer.encode( + INVOCATION_STR, add_special_tokens=False + ) + if "default" not in request._pyfuncitem.callspec.id: base_alora_args.target_modules = target_modules @@ -775,16 +777,16 @@ def test_run_causallm_alora_and_inference(request, target_modules, expected): for module in expected: assert module in adapter_config.get("target_modules") - # Load the model - loaded_model = TunedCausalLM.load(checkpoint_path, MODEL_NAME, use_alora=True) - invocation_string = loaded_model.peft_model.peft_config[ - loaded_model.peft_model.active_adapter - ].invocation_string - # Run inference on the text - output_inference = loaded_model.run( - "Simply put, the theory of relativity states that \n" + invocation_string, - max_new_tokens=50, - ) + # aLoRA-specific: saved adapter config must contain the exact tokens + expected_tokens = tokenizer.encode(INVOCATION_STR, add_special_tokens=False) + assert adapter_config.get("alora_invocation_tokens") == expected_tokens + + # Load tuned model (no special aLoRA wrapper/flag needed) + loaded_model = TunedCausalLM.load(checkpoint_path, MODEL_NAME) + + # Inference must include the invocation string so aLoRA activates + prompt = "Simply put, the theory of relativity states that \n" + INVOCATION_STR + output_inference = loaded_model.run(prompt, max_new_tokens=50) assert len(output_inference) > 0 assert "Simply put, the theory of relativity states that \n" in output_inference diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index 330606946..d50c16a58 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -161,19 +161,6 @@ def train( ] = resume_from_checkpoint odm_config = ODMConfig(odm=ODM(**_dataconfig.dataprocessor.odm)) - USE_ALORA = False - try: - # Third Party - from alora.config import aLoraConfig # pylint: disable=import-outside-toplevel - from alora.peft_model_alora import ( # pylint: disable=import-outside-toplevel - aLoRAPeftModelForCausalLM, - ) - - if isinstance(peft_config, aLoraConfig): - USE_ALORA = True - except ImportError: - pass - # Validate parameters if (not isinstance(model_args.model_name_or_path, str)) or ( model_args.model_name_or_path == "" @@ -375,6 +362,20 @@ def train( # Calculate and save additional metrics to track later. additional_metrics["model_load_time"] = time.time() - model_load_time + # Convert legacy aLoRA string → token IDs (PEFT-native aLoRA) + if peft_config is not None and hasattr(peft_config, "alora_invocation_string"): + inv_str = getattr(peft_config, "alora_invocation_string") + if not inv_str: + raise ValueError( + "`--alora_invocation_string` is required when using --peft_method alora." + ) + alora_tokens = tokenizer.encode(inv_str, add_special_tokens=False) + if not alora_tokens: + raise ValueError( + "`--alora_invocation_string` produced no tokens; check your tokenizer/template." + ) + setattr(peft_config, "alora_invocation_tokens", alora_tokens) + peft_config = get_hf_peft_config( task_type, peft_config, @@ -463,21 +464,6 @@ def train( } training_args = SFTConfig(**transformer_kwargs, **additional_args) - # activated LoRA - if USE_ALORA: - response_token_ids = ( - tokenizer( - peft_config.invocation_string, - return_tensors="pt", - add_special_tokens=False, - ) - )["input_ids"] - model = aLoRAPeftModelForCausalLM( - model, peft_config, response_token_ids=response_token_ids - ) - - peft_config = None - if train_args.enable_reduce_loss_sum: TrainerClass = SumLossSFTTrainer else: @@ -602,7 +588,7 @@ def get_parser(): to the tuning run in the tracker. e.g. \'{"gpu":"A100-80G"}\'', ) parser.add_argument( - "--invocation_string", + "--alora_invocation_string", type=str, default=None, help="Pass a invocation string that will be used to activate the aLoRA.\ @@ -665,7 +651,7 @@ def parse_arguments(parser, json_config=None): peft_method = json_config.get("peft_method") exp_metadata = json_config.get("exp_metadata") quantization_method = json_config.get("quantization_method") - invocation_string = json_config.get("invocation_string") + alora_invocation_string = json_config.get("alora_invocation_string") else: ( model_args, @@ -687,25 +673,11 @@ def parse_arguments(parser, json_config=None): peft_method = additional.peft_method exp_metadata = additional.exp_metadata quantization_method = additional.quantization_method - invocation_string = additional.invocation_string + alora_invocation_string = additional.alora_invocation_string if peft_method == peft_config.PEFT_METHOD.ALORA.value: - if invocation_string is None: - raise ValueError("invocation_string is not passed required for aLoRA usage") - try: - # Third Party - from alora.config import ( # pylint: disable=import-outside-toplevel - aLoraConfig, - ) - - tune_config = aLoraConfig( - **vars(lora_config), invocation_string=invocation_string - ) - except ImportError as exc: - raise ImportError( - "The alora package is required for this operation. " - "Please install it with pip install alora." - ) from exc + tune_config = lora_config + setattr(tune_config, "alora_invocation_string", alora_invocation_string) elif peft_method == peft_config.PEFT_METHOD.LORA.value: tune_config = lora_config elif peft_method == peft_config.PEFT_METHOD.PT.value: @@ -863,7 +835,7 @@ def main(): ) sys.exit(INTERNAL_ERROR_EXIT_CODE) - if isinstance(tune_config, LoraConfig): # aLoraConfig subclasses LoraConfig + if isinstance(tune_config, LoraConfig): try: if training_args.save_model_dir: # Write number of added tokens to artifacts diff --git a/tuning/utils/config_utils.py b/tuning/utils/config_utils.py index 45fda027f..806e593dc 100644 --- a/tuning/utils/config_utils.py +++ b/tuning/utils/config_utils.py @@ -41,7 +41,7 @@ def update_config(config, **kwargs): if hasattr(config, param_name): setattr(config, param_name, v) else: - # In case of specialized config we can warm user + # In case of specialized config we can warn user print(f"Warning: {config_name} does not accept parameter: {k}") @@ -50,7 +50,7 @@ def create_tuning_config(peft_method, **kwargs): Args: peft_method: str lora, pt or None - kawrgs: parameters to initialize library configs with + kwargs: parameters to initialize library configs with Return: peft_config.LoraConfig | peft_config.PromptTuningConfig | None """ @@ -61,22 +61,10 @@ def create_tuning_config(peft_method, **kwargs): "pt", "None", ], f"peft config {peft_method} not defined in peft.py" - if peft_method == "alora": - try: - # Third Party - from alora.config import ( # pylint: disable=import-outside-toplevel - aLoraConfig, - ) - - tune_config = aLoraConfig() - update_config(tune_config, **kwargs) - except ImportError as exc: - raise ImportError( - "alora package is required for this operation. " - "Please install it with pip install alora." - ) from exc - - elif peft_method == "lora": + if peft_method in ( + peft_config.PEFT_METHOD.ALORA.value, + peft_config.PEFT_METHOD.LORA.value, + ): tune_config = peft_config.LoraConfig() update_config(tune_config, **kwargs) elif peft_method == "pt": @@ -95,30 +83,23 @@ def get_hf_peft_config(task_type, tuning_config, tokenizer_name_or_path): tokenizer_name_or_path: str Return: HF PEFT config or None """ - USE_ALORA = False - try: - # Third Party - from alora.config import aLoraConfig # pylint: disable=import-outside-toplevel - - if isinstance(tuning_config, aLoraConfig): - USE_ALORA = True - except ImportError: - pass - if USE_ALORA: - alora_config = tuning_config - if alora_config.target_modules == ["all-linear"]: - alora_config.target_modules = "all-linear" - alora_config.task_type = task_type - hf_peft_config = alora_config - elif isinstance(tuning_config, peft_config.LoraConfig): - if getattr(tuning_config, "target_modules") == ["all-linear"]: - setattr(tuning_config, "target_modules", "all-linear") - - if getattr(tuning_config, "task_type") is None: - setattr(tuning_config, "task_type", task_type) - - hf_peft_config = tuning_config - elif isinstance(tuning_config, peft_config.PromptTuningConfig): + if isinstance(tuning_config, peft_config.LoraConfig): + if getattr(tuning_config, "target_modules", None) == ["all-linear"]: + tuning_config.target_modules = "all-linear" + + if getattr(tuning_config, "task_type", None) in (None, ""): + tuning_config.task_type = task_type + + if getattr(tuning_config, "alora_invocation_tokens", None): + if not tuning_config.alora_invocation_tokens: + raise ValueError("alora_invocation_tokens is set but empty.") + tuning_config.task_type = "CAUSAL_LM" + + if hasattr(tuning_config, "alora_invocation_string"): + delattr(tuning_config, "alora_invocation_string") + return tuning_config + + if isinstance(tuning_config, peft_config.PromptTuningConfig): hf_peft_config = HFPromptTuningConfig( task_type=task_type, tokenizer_name_or_path=tokenizer_name_or_path, diff --git a/tuning/utils/import_utils.py b/tuning/utils/import_utils.py index e01e55630..36dd606c6 100644 --- a/tuning/utils/import_utils.py +++ b/tuning/utils/import_utils.py @@ -32,11 +32,3 @@ def is_fms_accelerate_available( if not _is_package_available(n): return False return True - - -def is_alora_available(package_name: str = "alora"): - names = [package_name] - for n in names: - if not _is_package_available(n): - return False - return True