Skip to content

Commit 3d51c99

Browse files
Merge pull request #2493 from AI-Hypercomputer:hengtaoguo-doc
PiperOrigin-RevId: 825585112
2 parents 64d6d9b + bcd2f17 commit 3d51c99

File tree

5 files changed

+353
-0
lines changed

5 files changed

+353
-0
lines changed
189 KB
Loading

docs/guides.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,4 +40,5 @@ guides/checkpointing_solutions/multi_tier_checkpointing.md
4040
guides/jax_ai_libraries_chosen.md
4141
guides/xprof_user_guide.md
4242
guides/megascale_hang_playbook.md
43+
guides/multimodal.md
4344
```

docs/guides/multimodal.md

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
2+
3+
# Multimodal Support on MaxText
4+
5+
This document provides a guide to use the multimodal functionalities in MaxText including:
6+
- **Checkpoint Conversion**: Convert a MaxText-compatible orbax checkpoint from HuggingFace.
7+
- **Multimodal Decode**: Inference with text+images as input.
8+
- **Supervised Fine-Tuning (SFT)**: Apply SFT to the model using a visual-question-answering dataset.
9+
10+
The following table provides a list of models and modalities we currently support:
11+
| Models | Input Modalities | Output Modalities |
12+
| :---- | :---- | :---- |
13+
| - Gemma3-4B/12B/27B<br>- Llama4-Scout/Maverick | Text, images | Text |
14+
15+
## Introduction
16+
17+
Multimodal Large Language Models (LLMs) extend traditional text-only models by incorporating multiple input modalities such as images, audio, and video. For each non-text modality, the architecture typically follows a three-stage pipeline:
18+
- **Data Preprocessing**: We apply modality-specific preprocessing steps to prepare the raw input data (e.g., image resizing and normalization), transforming them into a format which neural networks can understand.
19+
- **Modality-Specific Encoders**: Modality-specific encoders will transform the preprocessed data into high-dimensional representations (e.g., vision transformers for images).
20+
- **Projection and Merge**: Projection layers will map these modality-specific embeddings into the shared embedding space of the language model, usually aligned with the dimension of text embeddings. These projected embeddings are then merged with text token embeddings, allowing the unified model to process and reason over multiple modalities simultaneously within a single coherent framework.
21+
22+
23+
<img src="../_static/multimodal_overview.png" alt="Illustration of multimodal MaxText." width="60%">
24+
*Figure 1: Overview of multimodal dataflow in MaxText.*
25+
26+
## Checkpoint Conversion
27+
28+
Recently we have onboarded a new centralized tool for bidirectional checkpoint conversion between MaxText and HuggingFace (README). This tool is used for the Gemma3 model family. Use this command to convert an unscanned checkpoint from HuggingFace to MaxText, and save it to `MAXTEXT_CKPT_GCS_PATH`:
29+
30+
```shell
31+
export HF_ACCESS_TOKEN=hf_...
32+
export MAXTEXT_CKPT_GCS_PATH=gs://...
33+
python -m MaxText.utils.ckpt_conversion.to_maxtext MaxText/configs/base.yml \
34+
model_name=gemma3-4b \
35+
hf_access_token=$HF_ACCESS_TOKEN \
36+
base_output_directory=$MAXTEXT_CKPT_GCS_PATH \
37+
use_multimodal=true \
38+
scan_layers=false
39+
```
40+
41+
For the Llama4 model family, we are using a separate checkpoint conversion script (of note,we will gradually migrate all checkpoint conversion scripts to the above consolidated tool soon):
42+
43+
```shell
44+
export LOCAL_HF_MODEL_PATH=... # Need to pre-download the safetensors from HuggingFace
45+
export MAXTEXT_CKPT_GCS_PATH=gs://...
46+
python -m MaxText.llama4_ckpt_unscanned \
47+
--model-size=llama4-17b-16e \
48+
--huggingface-checkpoint=True \
49+
--base-model-path=$LOCAL_HF_MODEL_PATH \
50+
--maxtext-model-path=$MAXTEXT_CKPT_GCS_PATH
51+
```
52+
53+
## Multimodal Decode
54+
MaxText supports multimodal decoding, allowing you to input text with multiple images to get a text output. To use this feature, you need three main settings:
55+
- `use_multimodal=True`: Initializes the multimodal preprocessing steps and network components.
56+
- `prompt`: Specifies the position of image placeholder tokens in your input. If you don't manually place them, MaxText will automatically append the required placeholder (e.g., `<start_of_image>` for Gemma3, `<|image|>` for Llama4). The exact placeholder is listed under the `image_placeholder` field in each model's configuration file.
57+
- `image_path`: The path(s) to the image file(s) MaxText will load and process.
58+
59+
Since each model uses a unique native chatting template from its pretraining, we've implemented these specific templates within `multimodal_utils.py` and apply them directly to your prompt.
60+
61+
To run a forward pass and verify the model's output, use the following command:
62+
63+
```shell
64+
# Gemma3 decode
65+
python -m MaxText.decode \
66+
MaxText/configs/base.yml \
67+
model_name=gemma3-4b \
68+
hf_access_token=$HF_ACCESS_TOKEN \
69+
tokenizer_path=assets/tokenizer.gemma3 \
70+
load_parameters_path=$MAXTEXT_CKPT_GCS_PATH/0/items \
71+
per_device_batch_size=1 \
72+
run_name=ht_test \
73+
max_prefill_predict_length=272 \
74+
max_target_length=300 \
75+
steps=1 \
76+
async_checkpointing=false \
77+
scan_layers=false \
78+
use_multimodal=true \
79+
prompt='Describe image <start_of_image>' \
80+
image_path='MaxText/test_assets/test_image.jpg' \
81+
attention='dot_product'
82+
```
83+
84+
The decoding results will look like this:
85+
```
86+
Input `<start_of_turn>user
87+
Describe image <start_of_image><end_of_turn>
88+
<start_of_turn>model
89+
` -> `Here's a description of the image:
90+
91+
**Overall Impression:** The image is a bright, expansive cityscape view of Seattle, Washington, with`
92+
```
93+
94+
To decode with multiple images at once, you can provide multiple image paths like this:
95+
96+
```
97+
python -m MaxText.decode \
98+
MaxText/configs/base.yml \
99+
model_name=gemma3-4b \
100+
... \
101+
image_path=/path/to/image1.jpg,/path/to/image2.jpg \
102+
prompt="Describe each image in a short sentence." # <start_of_image> will be added to prompt if not provided
103+
# or prompt="Describe each image in a short sentence: <start_of_image> and <start_of_image>"
104+
```
105+
106+
For larger models such as Llama4-Scout/Maverick, we suggest to run the decoding on a TPU cluster such as v5p-16.
107+
108+
## Supervised Fine-Tuning
109+
110+
Supervised Fine-Tuning (SFT) of multimodal LLMs in MaxText focuses specifically on post-training optimization rather than pre-training from scratch, which is currently not supported. The SFT process typically involves training on Visual Question Answering (VQA) datasets where the model learns to generate accurate text responses based on both visual and textual inputs. During this fine-tuning phase, we recommend to freeze the pre-trained encoder layers (such as vision transformers) to preserve their learned visual representations, while the projection layers and LLM decoder components remain trainable. This selective training strategy allows the model to adapt the cross-modal alignment and text generation capabilities without disrupting the robust feature extraction abilities of the encoders, ultimately leading to improved performance on multimodal understanding and reasoning tasks while maintaining computational efficiency. This is achieved by setting `freeze_vision_encoder_params=True` in [sft-vision-chartqa.yml](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/configs/sft-vision-chartqa.yml).
111+
112+
Here, we use [ChartQA](https://huggingface.co/datasets/HuggingFaceM4/ChartQA) as an example to demonstrate SFT functionality:
113+
114+
115+
```shell
116+
python -m MaxText.sft_trainer MaxText/configs/sft-vision-chartqa.yml \
117+
run_name=$idx \
118+
model_name=gemma3-4b \
119+
tokenizer_path="google/gemma-3-4b-pt" \
120+
per_device_batch_size=1 \
121+
max_prefill_predict_length=1024 \
122+
max_target_length=2048 \
123+
steps=200 \
124+
scan_layers=false \
125+
async_checkpointing=False \
126+
attention=dot_product \
127+
dataset_type=hf hf_path=parquet hf_access_token=$HF_ACCESS_TOKEN \
128+
hf_train_files=gs://aireenmei-multipod/dataset/hf/chartqa/train-* \
129+
base_output_directory=$BASE_OUTPUT_DIRECTORY \
130+
load_parameters_path=$UNSCANNED_CKPT_PATH \
131+
dtype=bfloat16 weight_dtype=bfloat16 sharding_tolerance=0.05
132+
```
133+
134+
## Other Recommendations
135+
- **Setting appropriate prefill length**: To prevent truncation and ensure your full input (text + image) is processed, the prefill length should be set longer than the total combined length of your text tokens and image tokens. This combined length makes up the final sequence fed to the decoder. We recommend to estimate the combined sequence length from your full input and then add a buffer when setting your `max_prefill_predict_length` for decoding. Token estimation rules:
136+
- For text tokens, a good estimate is $\text{Text Tokens} \approx 1.3 \times \text{Number of Words in Prompt}$.
137+
- For Gemma3, each image is resized to 896*896 and contributes 256 tokens. $\text{Total Tokens} \approx \text{Text Tokens} + \text{Number of Images} * 256$.
138+
- For Llama4 models, each image is dynamically tiled based on its size, with each resulting tile contributing 144 tokens. $\text{Total Tokens} \approx \text{Text Tokens} + \text{Number of Tiles of Image1} * 144 + ... + \text{Number of Tiles of ImageN} * 144$.
139+
Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/AI-Hypercomputer/maxtext/blob/main/src/MaxText/examples/multimodal_gemma3_demo.ipynb)\n",
8+
"\n",
9+
"# Gemma3 Multimodal Inference/Training Demo"
10+
]
11+
},
12+
{
13+
"cell_type": "markdown",
14+
"metadata": {},
15+
"source": [
16+
"## Overview\n",
17+
"\n",
18+
"This notebook demonstrates MaxText's multimodal features, using Gemma3-4B as an example:\n",
19+
"- Convert an orbax checkpoint from HuggingFace.\n",
20+
"- Apply decoding on a single image input.\n",
21+
"- Apply SFT to the converted checkpoint on ChartQA dataset.\n",
22+
"\n",
23+
"Given the relative small size of Gemma3-4B, you can run this colab on a v4-8, v5p-8 or v6e-4 TPU VM. However, we recommend using [XPK](https://github.com/AI-Hypercomputer/maxtext/blob/64d6d9b425e78dde94c37a82bb13ba5606e74b1b/docs/guides/run_maxtext_via_xpk.md) to schedule a training workload on a TPU cluster for better performance."
24+
]
25+
},
26+
{
27+
"cell_type": "markdown",
28+
"metadata": {},
29+
"source": [
30+
"### Get Your Hugging Face Token\n",
31+
"\n",
32+
"To access model checkpoint from the Hugging Face Hub, you need to authenticate with a personal access token.\n",
33+
"\n",
34+
"**Follow these steps to get your token:**\n",
35+
"\n",
36+
"1. **Navigate to the Access Tokens page** in your Hugging Face account settings. You can go there directly by visiting this URL:\n",
37+
" * [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)\n",
38+
"\n",
39+
"2. **Create a new token** by clicking the **\"+ Create new token\"** button.\n",
40+
"\n",
41+
"3. **Give your token a name** and assign it a **`read` role**. The `read` role is sufficient for downloading models.\n",
42+
"\n",
43+
"4. **Copy the generated token**. You will need to paste it in `HF_TOKEN`."
44+
]
45+
},
46+
{
47+
"cell_type": "code",
48+
"execution_count": null,
49+
"metadata": {
50+
"id": "5KPyOE8e9WbO"
51+
},
52+
"outputs": [],
53+
"source": [
54+
"#Install maxtext and dependencies\n",
55+
"# 1. Install uv, a fast Python package installer\n",
56+
"!pip install uv\n",
57+
"\n",
58+
"# 2. Install MaxText and its dependencies\n",
59+
"!uv pip install maxtext --resolution=lowest\n",
60+
"!install_maxtext_github_deps"
61+
]
62+
},
63+
{
64+
"cell_type": "code",
65+
"execution_count": null,
66+
"metadata": {},
67+
"outputs": [],
68+
"source": [
69+
"import os\n",
70+
"import MaxText\n",
71+
"\n",
72+
"# Get the root directory of the MaxText\n",
73+
"MAXTEXT_REPO_ROOT=os.path.dirname(MaxText.__file__)\n",
74+
"\n",
75+
"# Define model name\n",
76+
"MODEL_NAME=\"gemma3-4b\"\n",
77+
"\n",
78+
"# Use either a GCS path or a local path for the model checkpoint\n",
79+
"MODEL_CHECKPOINT_PATH = f\"gs://your-gcs-bucket/{MODEL_NAME}\"\n",
80+
"\n",
81+
"# Replace with your actual Hugging Face token\n",
82+
"HF_TOKEN = \"your_huggingface_token_here\""
83+
]
84+
},
85+
{
86+
"cell_type": "markdown",
87+
"metadata": {},
88+
"source": [
89+
"## Convert Checkpoint from HuggingFace"
90+
]
91+
},
92+
{
93+
"cell_type": "code",
94+
"execution_count": null,
95+
"metadata": {},
96+
"outputs": [],
97+
"source": [
98+
"!python3 -m MaxText.utils.ckpt_conversion.to_maxtext \\\n",
99+
" $MAXTEXT_REPO_ROOT/configs/base.yml \\\n",
100+
" model_name=$MODEL_NAME \\\n",
101+
" hf_access_token=$HF_TOKEN \\\n",
102+
" base_output_directory=$MODEL_CHECKPOINT_PATH \\\n",
103+
" use_multimodal=true \\\n",
104+
" scan_layers=false"
105+
]
106+
},
107+
{
108+
"cell_type": "markdown",
109+
"metadata": {},
110+
"source": [
111+
"## Decode on One Image"
112+
]
113+
},
114+
{
115+
"cell_type": "code",
116+
"execution_count": null,
117+
"metadata": {},
118+
"outputs": [],
119+
"source": [
120+
"!python -m MaxText.decode \\\n",
121+
" $MAXTEXT_REPO_ROOT/configs/base.yml \\\n",
122+
" model_name=$MODEL_NAME \\\n",
123+
" tokenizer_path=assets/tokenizer.gemma3 \\\n",
124+
" load_parameters_path=$MODEL_CHECKPOINT_PATH/0/items \\\n",
125+
" per_device_batch_size=1 \\\n",
126+
" run_name=ht_test max_prefill_predict_length=272 \\\n",
127+
" max_target_length=300 \\\n",
128+
" steps=1 \\\n",
129+
" async_checkpointing=false \\\n",
130+
" scan_layers=false \\\n",
131+
" use_multimodal=true \\\n",
132+
" prompt='Describe image <start_of_image>' \\\n",
133+
" image_path=$MAXTEXT_REPO_ROOT/test_assets/test_image.jpg \\\n",
134+
" attention='dot_product'"
135+
]
136+
},
137+
{
138+
"cell_type": "markdown",
139+
"metadata": {},
140+
"source": [
141+
"## Supervised Finetuning (SFT)"
142+
]
143+
},
144+
{
145+
"cell_type": "markdown",
146+
"metadata": {},
147+
"source": [
148+
"Running the cell below will trigger a 10-step SFT on your TPU VM (v4-8, v5p-8, or v6e-4). However, we recommend using [XPK](https://github.com/AI-Hypercomputer/maxtext/blob/64d6d9b425e78dde94c37a82bb13ba5606e74b1b/docs/guides/run_maxtext_via_xpk.md) to schedule a training workload on a TPU cluster for better performance. After the SFT, the result checkpoint will be saved to `BASE_OUTPUT_DIRECTORY`."
149+
]
150+
},
151+
{
152+
"cell_type": "code",
153+
"execution_count": null,
154+
"metadata": {},
155+
"outputs": [],
156+
"source": [
157+
"# Define SFT output directory\n",
158+
"BASE_OUTPUT_DIRECTORY=f\"gs://your-gcs-bucket/{MODEL_NAME}-sft\"\n",
159+
"PRE_TRAINED_MODEL_TOKENIZER=\"google/gemma-3-4b-it\"\n",
160+
"WORKLOAD_NAME=f\"{MODEL_NAME}-chartqa-sft\"\n",
161+
"STEPS=10\n",
162+
"PER_DEVICE_BATCH_SIZE=1\n",
163+
"\n",
164+
"!python -m MaxText.sft_trainer \\\n",
165+
" $MAXTEXT_REPO_ROOT/configs/sft-vision-chartqa.yml \\\n",
166+
" run_name=$WORKLOAD_NAME \\\n",
167+
" model_name=$MODEL_NAME \\\n",
168+
" tokenizer_path=$PRE_TRAINED_MODEL_TOKENIZER \\\n",
169+
" hf_access_token=$HF_TOKEN \\\n",
170+
" load_parameters_path=$MODEL_CHECKPOINT_PATH/0/items \\\n",
171+
" base_output_directory=$BASE_OUTPUT_DIRECTORY \\\n",
172+
" per_device_batch_size=$PER_DEVICE_BATCH_SIZE \\\n",
173+
" steps=$STEPS \\\n",
174+
" max_prefill_predict_length=1024 \\\n",
175+
" max_target_length=2048 \\\n",
176+
" checkpoint_period=1000 \\\n",
177+
" scan_layers=False \\\n",
178+
" async_checkpointing=True \\\n",
179+
" enable_checkpointing=True \\\n",
180+
" attention=dot_product \\\n",
181+
" max_num_images_per_example=1 \\\n",
182+
" dataset_type=hf profiler=xplane"
183+
]
184+
}
185+
],
186+
"metadata": {
187+
"accelerator": "TPU",
188+
"colab": {
189+
"gpuType": "V5E1",
190+
"provenance": []
191+
},
192+
"kernelspec": {
193+
"display_name": "python3.12",
194+
"language": "python",
195+
"name": "python3"
196+
},
197+
"language_info": {
198+
"codemirror_mode": {
199+
"name": "ipython",
200+
"version": 3
201+
},
202+
"file_extension": ".py",
203+
"mimetype": "text/x-python",
204+
"name": "python",
205+
"nbconvert_exporter": "python",
206+
"pygments_lexer": "ipython3",
207+
"version": "3.12.7"
208+
}
209+
},
210+
"nbformat": 4,
211+
"nbformat_minor": 0
212+
}

src/MaxText/input_pipeline/_hf_data_processing.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def vision_sft_preprocessing_pipeline(
5757
},
5858
remove_columns=image_column, # Drop the original image columns
5959
)
60+
image_column = "images"
6061

6162
dataset = dataset.select_columns(text_columns + [image_column])
6263
if image_column != "images":

0 commit comments

Comments
 (0)