Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion build/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -237,4 +237,4 @@ USER ${USER}
COPY --from=python-installations /home/${USER}/.local /home/${USER}/.local
ENV PYTHONPATH="/home/${USER}/.local/lib/python${PYTHON_VERSION}/site-packages"

CMD [ "python", "/app/accelerate_launch.py" ]
CMD [ "python", "/app/accelerate_launch.py" ]
98 changes: 98 additions & 0 deletions build/nvcr.Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Copyright The FMS HF Tuning Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

## Global Args #################################################################
## If the nvcr container is updated, ensure to check the torch and python
## installation version inside the dockerfile before pushing changes.
ARG NVCR_IMAGE_VERSION=25.02-py3

# This is based on what is inside the NVCR image already
ARG PYTHON_VERSION=3.12

## Base Layer ##################################################################
FROM nvcr.io/nvidia/pytorch:${NVCR_IMAGE_VERSION} AS dev

ARG USER=root
ARG USER_UID=0
ARG WORKDIR=/app
ARG SOURCE_DIR=/app/fms-hf-tuning
ARG SOURCE_BRANCH=main

ARG ENABLE_FMS_ACCELERATION=true
ARG ENABLE_AIM=true
ARG ENABLE_ALORA=true
ARG ENABLE_MLFLOW=true
ARG ENABLE_SCANNER=true
ARG ENABLE_CLEARML=true
ARG ENABLE_TRITON_KERNELS=true
ARG ENABLE_MAMBA_SUPPORT=false

# Ensures to always build mamba_ssm from source
ENV PIP_NO_BINARY=mamba-ssm,mamba_ssm

RUN python -m pip install --upgrade pip

# upgrade torch as the base layer contains only torch 2.7
RUN pip install --upgrade --force-reinstall torch torchaudio torchvision --index-url https://download.pytorch.org/whl/cu128

# Install main package + flash attention
RUN git clone --branch ${SOURCE_BRANCH} --depth 1 https://github.com/foundation-model-stack/fms-hf-tuning.git ${SOURCE_DIR}
RUN cd ${SOURCE_DIR}
RUN pip install --no-cache-dir ${SOURCE_DIR} && \
pip install --no-cache-dir ${SOURCE_DIR}[flash-attn]

# Optional extras
RUN if [[ "${ENABLE_FMS_ACCELERATION}" == "true" ]]; then \
pip install --no-cache-dir ${SOURCE_DIR}[fms-accel] && \
python -m fms_acceleration.cli install fms_acceleration_peft && \
python -m fms_acceleration.cli install fms_acceleration_foak && \
python -m fms_acceleration.cli install fms_acceleration_aadp && \
python -m fms_acceleration.cli install fms_acceleration_moe; \
fi

RUN if [[ "${ENABLE_ALORA}" == "true" ]]; then \
pip install --no-cache-dir ${SOURCE_DIR}[activated-lora]; \
fi
RUN if [[ "${ENABLE_AIM}" == "true" ]]; then \
pip install --no-cache-dir ${SOURCE_DIR}[aim]; \
fi
RUN if [[ "${ENABLE_MLFLOW}" == "true" ]]; then \
pip install --no-cache-dir ${SOURCE_DIR}[mlflow]; \
fi
RUN if [[ "${ENABLE_SCANNER}" == "true" ]]; then \
pip install --no-cache-dir ${SOURCE_DIR}[scanner-dev]; \
fi
RUN if [[ "${ENABLE_CLEARML}" == "true" ]]; then \
pip install --no-cache-dir ${SOURCE_DIR}[clearml]; \
fi
RUN if [[ "${ENABLE_MAMBA_SUPPORT}" == "true" ]]; then \
pip install --no-cache-dir ${SOURCE_DIR}[mamba]; \
fi
RUN if [[ "${ENABLE_TRITON_KERNELS}" == "true" ]]; then \
pip install --no-cache-dir "git+https://github.com/triton-lang/triton.git@main#subdirectory=python/triton_kernels"; \
fi

RUN chmod -R g+rwX $WORKDIR /tmp
RUN mkdir -p /.cache && chmod -R 777 /.cache

# Set Triton environment variables for qLoRA
ENV TRITON_HOME="/tmp/triton_home"
ENV TRITON_DUMP_DIR="/tmp/triton_dump_dir"
ENV TRITON_CACHE_DIR="/tmp/triton_cache_dir"
ENV TRITON_OVERRIDE_DIR="/tmp/triton_override_dir"

WORKDIR $WORKDIR

# this is just a dev image so this is okay.
CMD ["sleep inifinity"]
22 changes: 9 additions & 13 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,19 @@ classifiers=[
"Programming Language :: Python :: 3.12"
]
dependencies = [
"numpy>=1.26.4,<2.0",
"accelerate>=0.20.3,!=0.34,<1.7",
"transformers>=4.53.0,<=4.55.4",
"torch>2.6.0,<=2.8.0",
"numpy>=1.26.4,<2.2.0",
"accelerate>=1.9.0,<2.0.0",
"transformers>=4.55.0,<=4.55.4",
"torch>2.7.0,<2.9.0",
"sentencepiece>=0.1.99,<0.3",
"tokenizers>=0.13.3,<1.0",
"tokenizers<=0.22",
"tqdm>=4.66.2,<5.0",
"trl>=0.13,<0.20",
"peft>=0.15.0,<=0.15.2",
"protobuf>=5.28.0,<6.0.0",
"datasets>=3.5.0,<4.0",
"trl>=0.19.1,<0.20.0",
"peft>=0.17.0,<0.18.0",
"datasets>=4.0.0,<5.0.0",
"simpleeval>=0.9.13,<2.0",
"pillow>=11.0.0,<12.0",
"kernels<=0.9.0",
]

[project.optional-dependencies]
Expand All @@ -54,14 +54,10 @@ mamba = ["mamba_ssm[causal-conv1d]>=2.0.0,<3.0.0"]
scanner-dev = ["HFResourceScanner>=0.1.0"]
activated-lora = ["alora>=0.3.0"]


[tool.setuptools.packages.find]
exclude = ["tests", "tests.*"]
namespaces = false

[tool.setuptools_scm]
version_file = "tuning/_version.py"

[project.urls]
Homepage = "https://github.com/foundation-model-stack/fms-hf-tuning"
Repository = "https://github.com/foundation-model-stack/fms-hf-tuning"
Expand Down
5 changes: 5 additions & 0 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,7 @@ def test_parse_arguments(job_config):
_,
_,
_,
_,
) = sft_trainer.parse_arguments(parser, job_config_copy)
assert str(model_args.torch_dtype) == "torch.bfloat16"
assert data_args.dataset_text_field == "output"
Expand All @@ -432,6 +433,7 @@ def test_parse_arguments_defaults(job_config):
_,
_,
_,
_,
) = sft_trainer.parse_arguments(parser, job_config_defaults)
assert str(model_args.torch_dtype) == "torch.bfloat16"
assert model_args.use_flash_attn is False
Expand All @@ -454,7 +456,9 @@ def test_parse_arguments_peft_method(job_config):
_,
_,
_,
_,
) = sft_trainer.parse_arguments(parser, job_config_pt)

assert isinstance(tune_config, peft_config.PromptTuningConfig)

job_config_lora = copy.deepcopy(job_config)
Expand All @@ -471,6 +475,7 @@ def test_parse_arguments_peft_method(job_config):
_,
_,
_,
_,
) = sft_trainer.parse_arguments(parser, job_config_lora)
assert isinstance(tune_config, peft_config.LoraConfig)
assert not tune_config.target_modules
Expand Down
8 changes: 8 additions & 0 deletions tuning/config/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,14 @@ class ModelArguments:
tokenizer classes."
},
)
flash_attn_implementation: Optional[str] = field(
default="flash_attention_2",
metadata={
"help": "Flash Attention implementation to choose.\
For almost all models don't need to pass or use default i.e. flash_attention_2.\
Requires use_flash_attn=True flag to be enabled."
},
)


@dataclass
Expand Down
26 changes: 26 additions & 0 deletions tuning/config/peft_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,30 @@

# Standard
from dataclasses import dataclass, field
from enum import Enum
from typing import List

# Third Party
from transformers.utils.quantization_config import Mxfp4Config as HfMxfp4Config


class QUANT_METHOD(Enum):
MXFP4 = "mxfp4"


class PEFT_METHOD(Enum):
PT = "pt"
LORA = "lora"
ALORA = "alora"


@dataclass
class Mxfp4Config:
dequantize: bool = True

def to_hf_config(self):
return HfMxfp4Config(dequantize=self.dequantize)


@dataclass
class LoraConfig:
Expand Down Expand Up @@ -55,6 +77,10 @@ class LoraConfig:
"modules except for the output layer."
},
)
target_parameters: List[str] = field(
default=None,
metadata={"help": "The names/regex of the parameters to apply LORA to"},
)
bias = "none"
lora_dropout: float = 0.05

Expand Down
2 changes: 1 addition & 1 deletion tuning/data/data_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def _load_dataset(
load_path = builder if builder else data_path

try:
return datasets.load_dataset(path=load_path, **load_kwargs)
return datasets.load_dataset(load_path, **load_kwargs)
except DatasetNotFoundError as e:
# Reraise with a more context-specific message if needed
raise e
Expand Down
Loading