Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions build/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ ARG PYTHON_VERSION=3.12
ARG WHEEL_VERSION=""
## Enable Aimstack or MLflow if requested via ENABLE_AIM/MLFLOW set to "true"
ARG ENABLE_AIM=false
ARG ENABLE_ALORA=false
ARG ENABLE_MLFLOW=false
ARG ENABLE_FMS_ACCELERATION=true
ARG ENABLE_SCANNER=false
Expand Down Expand Up @@ -127,7 +126,6 @@ ARG USER_UID
ARG ENABLE_FMS_ACCELERATION
ARG ENABLE_AIM
ARG ENABLE_MLFLOW
ARG ENABLE_ALORA
ARG ENABLE_SCANNER
ARG ENABLE_CLEARML

Expand Down Expand Up @@ -179,10 +177,6 @@ RUN if [[ "${ENABLE_AIM}" == "true" ]]; then \
python -m pip install --user "$(head bdist_name)[aim]"; \
fi

RUN if [[ "${ENABLE_ALORA}" == "true" ]]; then \
python -m pip install --user "$(head bdist_name)[activated-lora]"; \
fi

RUN if [[ "${ENABLE_MLFLOW}" == "true" ]]; then \
python -m pip install --user "$(head bdist_name)[mlflow]"; \
fi
Expand Down
4 changes: 0 additions & 4 deletions build/nvcr.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ ARG SOURCE_DIR=${WORKDIR}/fms-hf-tuning

ARG ENABLE_FMS_ACCELERATION=true
ARG ENABLE_AIM=true
ARG ENABLE_ALORA=true
ARG ENABLE_MLFLOW=true
ARG ENABLE_SCANNER=true
ARG ENABLE_CLEARML=true
Expand Down Expand Up @@ -61,9 +60,6 @@ RUN if [[ "${ENABLE_FMS_ACCELERATION}" == "true" ]]; then \
python -m fms_acceleration.cli install fms_acceleration_odm; \
fi

RUN if [[ "${ENABLE_ALORA}" == "true" ]]; then \
pip install --no-cache-dir ${SOURCE_DIR}[activated-lora]; \
fi
RUN if [[ "${ENABLE_AIM}" == "true" ]]; then \
pip install --no-cache-dir ${SOURCE_DIR}[aim]; \
fi
Expand Down
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ dependencies = [
"tokenizers<=0.22",
"tqdm>=4.66.2,<5.0",
"trl>=0.19.1,<0.20.0",
"peft>=0.17.0,<0.18.0",
"peft @ git+https://github.com/huggingface/peft.git@293aea5df6db240856a77f89955d1a89ce38b50d",
"datasets>=4.0.0,<5.0.0",
"simpleeval>=0.9.13,<2.0",
"pillow>=11.0.0,<12.0",
Expand All @@ -52,7 +52,6 @@ fms-accel = ["fms-acceleration>=0.6.2"]
gptq-dev = ["auto_gptq>0.4.2", "optimum>=1.15.0"]
mamba = ["mamba_ssm[causal-conv1d]>=2.0.0,<3.0.0"]
scanner-dev = ["HFResourceScanner>=0.1.0"]
activated-lora = ["alora>=0.3.0"]

[tool.setuptools.packages.find]
exclude = ["tests", "tests.*"]
Expand Down
94 changes: 19 additions & 75 deletions scripts/run_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,19 +138,17 @@ def __exit__(self, exc_type, exc_value, exc_tb):

### Funcs for loading and running models
class TunedCausalLM:
def __init__(self, model, tokenizer, device, use_alora=False):
def __init__(self, model, tokenizer, device):
self.peft_model = model
self.tokenizer = tokenizer
self.device = device
self.use_alora = use_alora

@classmethod
def load(
cls,
checkpoint_path: str,
base_model_name_or_path: str = None,
use_flash_attn: bool = False,
use_alora: bool = False,
) -> "TunedCausalLM":
"""Loads an instance of this model.

Expand Down Expand Up @@ -224,36 +222,14 @@ def load(
tokenizer_and_embedding_resize(
{}, tokenizer=tokenizer, model=base_model
)
if use_alora:
# Third Party
try:
# Third Party
from alora.peft_model_alora import ( # pylint: disable=import-outside-toplevel
aLoRAPeftModelForCausalLM,
)

model = aLoRAPeftModelForCausalLM.from_pretrained(
base_model,
checkpoint_path,
attn_implementation="flash_attention_2"
if use_flash_attn
else None,
torch_dtype=torch.bfloat16 if use_flash_attn else None,
)
except ImportError as exc:
raise ImportError(
"The alora package is required for this operation. "
"Please install it with pip install alora."
) from exc
else:
model = PeftModel.from_pretrained(
base_model,
checkpoint_path,
attn_implementation="flash_attention_2"
if use_flash_attn
else None,
torch_dtype=torch.bfloat16 if use_flash_attn else None,
)
model = PeftModel.from_pretrained(
base_model,
checkpoint_path,
attn_implementation="flash_attention_2"
if use_flash_attn
else None,
torch_dtype=torch.bfloat16 if use_flash_attn else None,
)
except (OSError, ValueError) as e:
print("Failed to initialize checkpoint model!")
raise e
Expand Down Expand Up @@ -283,7 +259,7 @@ def load(
)

model.to(device)
return cls(model, tokenizer, device, use_alora)
return cls(model, tokenizer, device)

def run(
self,
Expand All @@ -307,42 +283,16 @@ def run(
str
Text generation result.
"""
if not self.use_alora:
tok_res = self.tokenizer(text, return_tensors="pt")
input_ids = tok_res.input_ids.to(self.device)
peft_outputs = self.peft_model.generate(
input_ids=input_ids, max_new_tokens=max_new_tokens
)
else: # pass in alora_offsets needed for alora model
# Retrieve invocation string
invocation_string = self.peft_model.peft_config[
self.peft_model.active_adapter
].invocation_string
# Find the invocation string in input
if invocation_string in text:
before, after = text.rsplit(invocation_string, 1)
after = invocation_string + after
else:
raise ValueError(
f"aLoRA invocation string '{invocation_string}' not found in input '{text}'."
)
# Tokenize separately to enforce correct token boundary
before_ids = self.tokenizer(before, return_tensors="pt").input_ids
after_ids = self.tokenizer(invocation_string, return_tensors="pt").input_ids
alora_offsets = [after_ids.shape[1] - 1]
input_ids = torch.cat([before_ids, after_ids], dim=1).to(self.device)

peft_outputs = self.peft_model.generate(
input_ids=input_ids,
max_new_tokens=max_new_tokens,
alora_offsets=alora_offsets,
)
if ret_gen_text_only:
tok_to_decode = peft_outputs[:, input_ids.shape[1] :]
else:
tok_to_decode = peft_outputs
tok_res = self.tokenizer(text, return_tensors="pt")
input_ids = tok_res.input_ids.to(self.device)
peft_outputs = self.peft_model.generate(
input_ids=input_ids, max_new_tokens=max_new_tokens
)
tok_to_decode = (
peft_outputs[:, input_ids.shape[1] :] if ret_gen_text_only else peft_outputs
)
decoded_result = self.tokenizer.batch_decode(
tok_to_decode, skip_special_tokens=False
tok_to_decode, skip_special_tokens=ret_gen_text_only
)[0]
return decoded_result

Expand All @@ -360,11 +310,6 @@ def main():
help="JSON file to write results to",
default="inference_result.json",
)
parser.add_argument(
"--use_alora",
help="Whether to use alora",
default=False,
)
parser.add_argument(
"--base_model_name_or_path",
help="Override for base model to be used for non-merged models \
Expand Down Expand Up @@ -398,7 +343,6 @@ def main():
checkpoint_path=args.model,
base_model_name_or_path=args.base_model_name_or_path,
use_flash_attn=args.use_flash_attn,
use_alora=args.use_alora,
)

# Run inference on the text; if multiple were provided, process them all
Expand Down
45 changes: 26 additions & 19 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

# Third Party
from datasets.exceptions import DatasetGenerationError, DatasetNotFoundError
from peft import LoraConfig as HFLoraConfig
from transformers.trainer_callback import TrainerCallback
import pytest
import torch
Expand Down Expand Up @@ -147,15 +148,14 @@
)

PEFT_LORA_ARGS = peft_config.LoraConfig(r=8, lora_alpha=32, lora_dropout=0.05)
try: # Optional package
# Third Party
from alora.config import aLoraConfig

PEFT_ALORA_ARGS = aLoraConfig(
r=8, lora_alpha=32, lora_dropout=0.05, invocation_string="Label:"
)
except ImportError:
pass
INVOCATION_STR = "Label:"

if hasattr(HFLoraConfig, "alora_invocation_tokens"):
PEFT_ALORA_ARGS = peft_config.LoraConfig(r=8, lora_alpha=32, lora_dropout=0.05)
PEFT_ALORA_ARGS.alora_invocation_tokens = [42]
else:
PEFT_ALORA_ARGS = None


@pytest.mark.parametrize(
Expand Down Expand Up @@ -755,12 +755,19 @@ def test_run_causallm_lora_and_inference(request, target_modules, expected):
ids=["default", "custom_target_modules", "all_linear_target_modules"],
)
def test_run_causallm_alora_and_inference(request, target_modules, expected):
"""Check if we can bootstrap and alora tune causallm models"""
"""Check if we can bootstrap and alora tune causallm models via PEFT-native aLoRA."""
with tempfile.TemporaryDirectory() as tempdir:
train_args = copy.deepcopy(TRAIN_ALORA_ARGS)
train_args.output_dir = tempdir
base_alora_args = copy.deepcopy(PEFT_ALORA_ARGS)

tokenizer = transformers.AutoTokenizer.from_pretrained(
MODEL_NAME, use_fast=True, legacy=True
)
base_alora_args.alora_invocation_tokens = tokenizer.encode(
INVOCATION_STR, add_special_tokens=False
)

if "default" not in request._pyfuncitem.callspec.id:
base_alora_args.target_modules = target_modules

Expand All @@ -775,16 +782,16 @@ def test_run_causallm_alora_and_inference(request, target_modules, expected):
for module in expected:
assert module in adapter_config.get("target_modules")

# Load the model
loaded_model = TunedCausalLM.load(checkpoint_path, MODEL_NAME, use_alora=True)
invocation_string = loaded_model.peft_model.peft_config[
loaded_model.peft_model.active_adapter
].invocation_string
# Run inference on the text
output_inference = loaded_model.run(
"Simply put, the theory of relativity states that \n" + invocation_string,
max_new_tokens=50,
)
# aLoRA-specific: saved adapter config must contain the exact tokens
expected_tokens = tokenizer.encode(INVOCATION_STR, add_special_tokens=False)
assert adapter_config.get("alora_invocation_tokens") == expected_tokens

# Load tuned model (no special aLoRA wrapper/flag needed)
loaded_model = TunedCausalLM.load(checkpoint_path, MODEL_NAME)

# Inference must include the invocation string so aLoRA activates
prompt = "Simply put, the theory of relativity states that \n" + INVOCATION_STR
output_inference = loaded_model.run(prompt, max_new_tokens=50)

assert len(output_inference) > 0
assert "Simply put, the theory of relativity states that \n" in output_inference
Expand Down
62 changes: 18 additions & 44 deletions tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,19 +137,6 @@ def train(
if _dataconfig.dataprocessor.type == "odm":
odm_config = ODMConfig(odm=ODM(**_dataconfig.dataprocessor.odm))

USE_ALORA = False
try:
# Third Party
from alora.config import aLoraConfig # pylint: disable=import-outside-toplevel
from alora.peft_model_alora import ( # pylint: disable=import-outside-toplevel
aLoRAPeftModelForCausalLM,
)

if isinstance(peft_config, aLoraConfig):
USE_ALORA = True
except ImportError:
pass

# Validate parameters
if (not isinstance(model_args.model_name_or_path, str)) or (
model_args.model_name_or_path == ""
Expand Down Expand Up @@ -351,6 +338,20 @@ def train(
# Calculate and save additional metrics to track later.
additional_metrics["model_load_time"] = time.time() - model_load_time

# Convert legacy aLoRA string → token IDs (PEFT-native aLoRA)
if peft_config is not None and hasattr(peft_config, "alora_invocation_string"):
inv_str = getattr(peft_config, "alora_invocation_string")
if not inv_str:
raise ValueError(
"`--invocation_string` is required when using --peft_method alora."
)
alora_tokens = tokenizer.encode(inv_str, add_special_tokens=False)
if not alora_tokens:
raise ValueError(
"`--invocation_string` produced no tokens; check your tokenizer/template."
)
setattr(peft_config, "alora_invocation_tokens", alora_tokens)

peft_config = get_hf_peft_config(
task_type,
peft_config,
Expand Down Expand Up @@ -439,21 +440,6 @@ def train(
}
training_args = SFTConfig(**transformer_kwargs, **additional_args)

# activated LoRA
if USE_ALORA:
response_token_ids = (
tokenizer(
peft_config.invocation_string,
return_tensors="pt",
add_special_tokens=False,
)
)["input_ids"]
model = aLoRAPeftModelForCausalLM(
model, peft_config, response_token_ids=response_token_ids
)

peft_config = None

if train_args.enable_reduce_loss_sum:
TrainerClass = SumLossSFTTrainer
else:
Expand Down Expand Up @@ -684,21 +670,9 @@ def parse_arguments(parser, json_config=None):

if peft_method == peft_config.PEFT_METHOD.ALORA.value:
if invocation_string is None:
raise ValueError("invocation_string is not passed required for aLoRA usage")
try:
# Third Party
from alora.config import ( # pylint: disable=import-outside-toplevel
aLoraConfig,
)

tune_config = aLoraConfig(
**vars(lora_config), invocation_string=invocation_string
)
except ImportError as exc:
raise ImportError(
"The alora package is required for this operation. "
"Please install it with pip install alora."
) from exc
raise ValueError("invocation_string is required for aLoRA usage")
tune_config = lora_config
setattr(tune_config, "alora_invocation_string", invocation_string)
elif peft_method == peft_config.PEFT_METHOD.LORA.value:
tune_config = lora_config
elif peft_method == peft_config.PEFT_METHOD.PT.value:
Expand Down Expand Up @@ -859,7 +833,7 @@ def main():
)
sys.exit(INTERNAL_ERROR_EXIT_CODE)

if isinstance(tune_config, LoraConfig): # aLoraConfig subclasses LoraConfig
if isinstance(tune_config, LoraConfig):
try:
if training_args.save_model_dir:
# Write number of added tokens to artifacts
Expand Down
Loading