Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion build/nvcr.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ RUN python -m pip install --upgrade pip
RUN pip install --upgrade --force-reinstall torch torchaudio torchvision --index-url https://download.pytorch.org/whl/cu128

# Install main package + flash attention
RUN COPY . ${SOURCE_DIR}
COPY . ${SOURCE_DIR}
RUN cd ${SOURCE_DIR}
RUN pip install --no-cache-dir ${SOURCE_DIR} && \
pip install --no-cache-dir ${SOURCE_DIR}[flash-attn]
Expand Down
21 changes: 13 additions & 8 deletions tuning/data/setup_dataprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,9 @@ def process_dataconfig_file(


# Data Format 1: Pretokenized Data
def _get_pretokenized_dataset_handlers(data_args, is_eval_tokenized):
def _get_pretokenized_dataset_handlers(
data_args: DataArguments, is_eval_present, is_eval_tokenized
):

# if the provided train dataset is pretokenized
# however user provides formatting flags, error out
Expand All @@ -168,6 +170,7 @@ def _get_pretokenized_dataset_handlers(data_args, is_eval_tokenized):
or data_args.data_formatter_template
or data_args.dataset_text_field
or data_args.instruction_template
or data_args.dataset_conversation_field
):
raise ValueError(
"fields response_template, data_formatter_template,"
Expand All @@ -177,7 +180,7 @@ def _get_pretokenized_dataset_handlers(data_args, is_eval_tokenized):

# if the train dataset is pretokenized
# ensure validation dataset is pretokenized otherwise error out
if is_eval_tokenized:
if is_eval_present and not is_eval_tokenized:
raise ValueError(
"validation data should be pretokenized to be used \
along with pretokenized train data"
Expand All @@ -189,7 +192,9 @@ def _get_pretokenized_dataset_handlers(data_args, is_eval_tokenized):

### Data format 2
# pylint: disable=unused-argument
def _get_dataset_formatting_handlers(data_args, packing, is_padding_free=False):
def _get_dataset_formatting_handlers(
data_args: DataArguments, packing, is_padding_free=False
):

if data_args.response_template is None:
if packing is False:
Expand Down Expand Up @@ -253,7 +258,7 @@ def _get_chat_dataset_handlers(data_args, tokenizer_kwargs):
fn_kwargs["formatted_text_column_name"] = data_args.dataset_text_field
fn_kwargs["tokenizer_kwargs"] = tokenizer_kwargs
if data_args.dataset_conversation_field is not None:
fn_kwargs["conversation_column"] = data_args.dataset_conversation_field
fn_kwargs["conversation_column_name"] = data_args.dataset_conversation_field

kwargs = {"fn_kwargs": fn_kwargs, "batched": False, "remove_columns": "all"}

Expand Down Expand Up @@ -284,14 +289,14 @@ def _get_default_dataset_handlers(data_args, tokenizer_kwargs):


### Vsion Data Format
def _get_vision_dataset_handlers(data_args, processor_kwargs):
def _get_vision_dataset_handlers(data_args: DataArguments, processor_kwargs):

handlers = []

# First data handler configuration
handler_fn_kwargs1 = {
"dataset_text_field": data_args.dataset_text_field,
"conversation_column": data_args.dataset_text_field,
"formatted_text_column_name": data_args.dataset_text_field,
"conversation_column_name": data_args.dataset_conversation_field,
}
handler_kwargs1 = {
"fn_kwargs": handler_fn_kwargs1,
Expand Down Expand Up @@ -403,7 +408,7 @@ def _process_raw_data_args(
if is_traindata_tokenized:
# Data Format 1: Pretokenized Data
handlers, dataset_text_field = _get_pretokenized_dataset_handlers(
data_args, (is_eval_dataset_present and not is_evaldata_tokenized)
data_args, is_eval_dataset_present, is_evaldata_tokenized
)
elif processor and data_args.dataset_text_field and data_args.dataset_image_field:

Expand Down
45 changes: 18 additions & 27 deletions tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,16 @@ def train(

model_load_time = time.time()
try:
model_kwargs = dict( # pylint: disable=use-dict-literal
cache_dir=train_args.cache_dir,
torch_dtype=get_torch_dtype(model_args.torch_dtype),
attn_implementation=model_args.flash_attn_implementation
if model_args.use_flash_attn
else None,
)
if quantization_config is not None:
model_kwargs["quantization_config"] = quantization_config.to_hf_config()

logger.info("Loading the model {} now".format(model_args.model_name_or_path))
try:
logger.info(
Expand All @@ -263,18 +273,8 @@ def train(
)
)
# try to load model as a vision model
model_loader = AutoModelForVision2Seq.from_pretrained

model = model_loader(
model_args.model_name_or_path,
cache_dir=train_args.cache_dir,
torch_dtype=get_torch_dtype(model_args.torch_dtype),
quantization_config=quantization_config.to_hf_config()
if quantization_config
else None,
attn_implementation=model_args.flash_attn_implementation
if model_args.use_flash_attn
else None,
model = AutoModelForVision2Seq.from_pretrained(
model_args.model_name_or_path, **model_kwargs
)
try:
if "use_cache" in model.language_model.config:
Expand All @@ -290,10 +290,10 @@ def train(
logger.info("Loaded vision model as {} ".format(model))
logger.info("Loaded vision model processor {} ".format(processor))
logger.info("Loaded model tokenizer {} ".format(tokenizer))
except ValueError:
except Exception as e: # pylint: disable=broad-except
logger.info(
"Couldn't load model {} as a vision model".format(
model_args.model_name_or_path
"Couldn't load model {} as a vision model due to {} ".format(
model_args.model_name_or_path, e
)
)
model = None
Expand All @@ -314,16 +314,7 @@ def train(
model_loader = AutoModelForCausalLM.from_pretrained

model = model_loader(
model_args.model_name_or_path,
cache_dir=train_args.cache_dir,
torch_dtype=get_torch_dtype(model_args.torch_dtype),
quantization_config=quantization_config.to_hf_config()
if quantization_config
else None,
attn_implementation=model_args.flash_attn_implementation
if model_args.use_flash_attn
else None,
use_cache=False,
model_args.model_name_or_path, use_cache=False, **model_kwargs
)

# TODO: Move these to a config as well
Expand Down Expand Up @@ -757,12 +748,12 @@ def main():
"Tune Config": tune_config,
"Quantization Config": quantization_config,
"QLoRA Config": quantized_lora_config,
"Tracker Config": tracker_configs,
"AADP (fms-acceleration) Config": attention_and_distributed_packing_config,
"Fused Ops Kernels Config": fusedops_kernels_config,
"Fast MoE Config": fast_moe_config,
"Trainer Controller Config": trainer_controller_args,
"Tracker Config": tracker_configs,
"Extra Metadata": exp_metadata,
"Trainer Controller Config": trainer_controller_args,
}
)
logger.info(args_dump)
Expand Down