Skip to content

Commit 2730bca

Browse files
committed
fix readme and fsdp logic
1 parent 3985d07 commit 2730bca

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

recipes/quickstart/finetuning/finetune_vision_model.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
## Fine-Tuning Meta Llama Multi Modal Models recipe
22
This recipe steps you through how to finetune a Llama 3.2 vision model on the OCR VQA task using the [OCRVQA](https://huggingface.co/datasets/HuggingFaceM4/the_cauldron/viewer/ocrvqa?row=0) dataset.
33

4-
**Disclaimer** As our vision models already have a very good OCR ability, here we just use the OCRVQA dataset to demonstrate the steps needed for fine-tuning our vision models.
4+
**Disclaimer**: As our vision models already have a very good OCR ability, here we just use the OCRVQA dataset to demonstrate the steps needed for fine-tuning our vision models.
55

66
### Fine-tuning steps
77

src/llama_recipes/finetuning.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,12 @@ def main(**kwargs):
187187
freeze_transformer_layers(model, train_config.num_freeze_layers)
188188

189189
mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
190-
my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [LlamaDecoderLayer,MllamaSelfAttentionDecoderLayer,MllamaSelfAttentionDecoderLayer,MllamaVisionEncoderLayer])
190+
# Create the FSDP wrapper for MllamaSelfAttentionDecoderLayer,MllamaSelfAttentionDecoderLayer,MllamaVisionEncoderLayer in vision models
191+
if is_vision:
192+
my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [MllamaSelfAttentionDecoderLayer,MllamaSelfAttentionDecoderLayer,MllamaVisionEncoderLayer])
193+
else:
194+
# Create the FSDP wrapper for LlamaDecoderLayer in text models
195+
my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [LlamaDecoderLayer])
191196
device_id = 0
192197
if is_xpu_available():
193198
device_id = torch.xpu.current_device()

0 commit comments

Comments
 (0)