diff --git a/docs/multimodal.qmd b/docs/multimodal.qmd index d839ce211c..3c91a92885 100644 --- a/docs/multimodal.qmd +++ b/docs/multimodal.qmd @@ -20,6 +20,7 @@ format: - [Qwen2.5-VL](#sec-qwen25-vl) - [SmolVLM2](#sec-smolvlm2) - [LFM2-VL](#sec-lfm2-vl) +- [Intern-VL](#sec-intern-vl) ## Usage @@ -176,6 +177,16 @@ Please uninstall `causal-conv1d` via `pip3 uninstall -y causal-conv1d` base_model: LiquidAI/LFM2-VL-450M ``` +### Intern-VL {#sec-intern-vl} + +::: {.callout-tip} +Please make sure to install `timm` via `pip3 install timm==1.0.19` +::: + +```yaml +base_model: OpenGVLab/InternVL3_5-8B +``` + ## Dataset Format For multi-modal datasets, we adopt an extended `chat_template` format similar to OpenAI's Message format. diff --git a/examples/internvl3_5/README.md b/examples/internvl3_5/README.md new file mode 100644 index 0000000000..4bdc119f7c --- /dev/null +++ b/examples/internvl3_5/README.md @@ -0,0 +1,18 @@ +# Finetune OpenGV's InternVL with Axolotl + +## Getting started + +1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). + +2. Install Cut Cross Entropy following [docs](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy). + +3. Install timm + +2. Run the below + +```bash +# QLoRA SFT linear layers (1xXYGB @ ~AB GiB) +axolotl train examples/internvl3_5/internvl3_5-8b-qlora.yml +``` + +Note: Memory usage taken from `device_mem_reserved(gib)` from logs. diff --git a/examples/internvl3_5/internvl3_5-8b-qlora.yml b/examples/internvl3_5/internvl3_5-8b-qlora.yml new file mode 100644 index 0000000000..34c53d3b14 --- /dev/null +++ b/examples/internvl3_5/internvl3_5-8b-qlora.yml @@ -0,0 +1,63 @@ +base_model: OpenGVLab/InternVL3_5-8B +trust_remote_code: true + +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin + +load_in_4bit: true + +# these 3 lines are needed for now to handle vision chat templates w images +skip_prepare_dataset: true +remove_unused_columns: false +sample_packing: false + +datasets: + - path: HuggingFaceH4/llava-instruct-mix-vsft + type: chat_template + split: train[:1%] + field_messages: messages + +dataset_prepared_path: last_run_prepared +val_set_size: 0.01 +output_dir: ./outputs/out + +adapter: qlora +lora_model_dir: + +sequence_len: 2048 + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 1 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: true +fp16: +tf32: true + +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false +logging_steps: 1 +flash_attention: true +eager_attention: + +warmup_ratio: 0.1 +evals_per_epoch: 1 +saves_per_epoch: 1 +weight_decay: 0.0 + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/src/axolotl/integrations/cut_cross_entropy/README.md b/src/axolotl/integrations/cut_cross_entropy/README.md index a64bdd0548..9d0a0c280c 100644 --- a/src/axolotl/integrations/cut_cross_entropy/README.md +++ b/src/axolotl/integrations/cut_cross_entropy/README.md @@ -47,6 +47,7 @@ plugins: - granitemoe - hunyuan_v1_dense - hunyuan_v1_moe +- internvl_chat - llama - llama4 - llama4_text diff --git a/src/axolotl/monkeypatch/multipack.py b/src/axolotl/monkeypatch/multipack.py index e4f9ca2be6..ff298bd8fa 100644 --- a/src/axolotl/monkeypatch/multipack.py +++ b/src/axolotl/monkeypatch/multipack.py @@ -38,6 +38,7 @@ "smollm3", "gpt_oss", "arcee", + "internvl_chat", ] diff --git a/src/axolotl/processing_strategies.py b/src/axolotl/processing_strategies.py index 4b06eb4c8d..fc2e53009e 100644 --- a/src/axolotl/processing_strategies.py +++ b/src/axolotl/processing_strategies.py @@ -7,6 +7,12 @@ from PIL.Image import Resampling from torch import Tensor, zeros_like from transformers import ProcessorMixin, SmolVLMProcessor, VoxtralProcessor + +try: + from transformers import InternVLProcessor +except ImportError: + InternVLProcessor = None + from transformers.image_utils import load_image from axolotl.utils.dict import remove_none_values @@ -421,6 +427,37 @@ def __init__( ] +class InternVLProcessingStrategy(ProcessingStrategy): + """Processing Strategy class for InternVL""" + + def __init__( + self, + processor: ProcessorMixin, + chat_template: Optional[str] = None, + image_size: int | tuple[int, int] | None = None, + image_resize_algorithm: Resampling | None = None, + ): + super().__init__(processor, chat_template, image_size, image_resize_algorithm) + + if not hasattr(processor, "image_ids"): + raise ValueError("'image_ids' missing from InternVL Processor.") + + self.image_token_ids = processor.image_ids + + def process_labels(self, input_ids): + labels = input_ids.clone() + + labels[labels == self.processor.tokenizer.pad_token_id] = -100 + + for ids in self.image_token_ids: + labels[labels == ids] = -100 + + # Note: Check if need to mask 'video_token' as it gets converted to + # image patches during media processing + + return labels + + def get_processing_strategy( processor: ProcessorMixin, chat_template, @@ -463,6 +500,11 @@ def get_processing_strategy( **processing_kwargs, ) + if InternVLProcessor and isinstance(processor, InternVLProcessor): + return InternVLProcessingStrategy( + **processing_kwargs, + ) + # llama3_2_vision, llama4, llava # mistral_v7_tekken, pixtral, lfm2vl return ProcessingStrategy(