Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions docs/multimodal.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ format:
- [Qwen2.5-VL](#sec-qwen25-vl)
- [SmolVLM2](#sec-smolvlm2)
- [LFM2-VL](#sec-lfm2-vl)
- [Intern-VL](#sec-intern-vl)

## Usage

Expand Down Expand Up @@ -176,6 +177,16 @@ Please uninstall `causal-conv1d` via `pip3 uninstall -y causal-conv1d`
base_model: LiquidAI/LFM2-VL-450M
```

### Intern-VL {#sec-intern-vl}

::: {.callout-tip}
Please make sure to install `timm` via `pip3 install timm==1.0.19`
:::

```yaml
base_model: OpenGVLab/InternVL3_5-8B
```

## Dataset Format

For multi-modal datasets, we adopt an extended `chat_template` format similar to OpenAI's Message format.
Expand Down
18 changes: 18 additions & 0 deletions examples/internvl3_5/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Finetune OpenGV's InternVL with Axolotl

## Getting started

1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).

2. Install Cut Cross Entropy following [docs](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy).

3. Install timm

2. Run the below

```bash
# QLoRA SFT linear layers (1xXYGB @ ~AB GiB)
axolotl train examples/internvl3_5/internvl3_5-8b-qlora.yml
```

Note: Memory usage taken from `device_mem_reserved(gib)` from logs.
63 changes: 63 additions & 0 deletions examples/internvl3_5/internvl3_5-8b-qlora.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
base_model: OpenGVLab/InternVL3_5-8B
trust_remote_code: true

plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin

load_in_4bit: true

# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false

datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
field_messages: messages

dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./outputs/out

adapter: qlora
lora_model_dir:

sequence_len: 2048

lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'

wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002

bf16: true
fp16:
tf32: true

gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
logging_steps: 1
flash_attention: true
eager_attention:

warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0

# save_first_step: true # uncomment this to validate checkpoint saving works with your config
1 change: 1 addition & 0 deletions src/axolotl/integrations/cut_cross_entropy/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ plugins:
- granitemoe
- hunyuan_v1_dense
- hunyuan_v1_moe
- internvl_chat
- llama
- llama4
- llama4_text
Expand Down
1 change: 1 addition & 0 deletions src/axolotl/monkeypatch/multipack.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
"smollm3",
"gpt_oss",
"arcee",
"internvl_chat",
]


Expand Down
42 changes: 42 additions & 0 deletions src/axolotl/processing_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@
from PIL.Image import Resampling
from torch import Tensor, zeros_like
from transformers import ProcessorMixin, SmolVLMProcessor, VoxtralProcessor

try:
from transformers import InternVLProcessor
except ImportError:
InternVLProcessor = None

from transformers.image_utils import load_image

from axolotl.utils.dict import remove_none_values
Expand Down Expand Up @@ -421,6 +427,37 @@ def __init__(
]


class InternVLProcessingStrategy(ProcessingStrategy):
"""Processing Strategy class for InternVL"""

def __init__(
self,
processor: ProcessorMixin,
chat_template: Optional[str] = None,
image_size: int | tuple[int, int] | None = None,
image_resize_algorithm: Resampling | None = None,
):
super().__init__(processor, chat_template, image_size, image_resize_algorithm)

if not hasattr(processor, "image_ids"):
raise ValueError("'image_ids' missing from InternVL Processor.")

self.image_token_ids = processor.image_ids

def process_labels(self, input_ids):
labels = input_ids.clone()

labels[labels == self.processor.tokenizer.pad_token_id] = -100

for ids in self.image_token_ids:
labels[labels == ids] = -100

# Note: Check if need to mask 'video_token' as it gets converted to
# image patches during media processing

return labels


def get_processing_strategy(
processor: ProcessorMixin,
chat_template,
Expand Down Expand Up @@ -463,6 +500,11 @@ def get_processing_strategy(
**processing_kwargs,
)

if InternVLProcessor and isinstance(processor, InternVLProcessor):
return InternVLProcessingStrategy(
**processing_kwargs,
)

# llama3_2_vision, llama4, llava
# mistral_v7_tekken, pixtral, lfm2vl
return ProcessingStrategy(
Expand Down