Skip to content

Commit 77dd8af

Browse files
fix:Remove use_cache and update ReadMe. (#531)
* Update readme and fix Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * Assign use_cache for vision model Signed-off-by: Abhishek <maurya.abhishek@ibm.com> --------- Signed-off-by: Abhishek <maurya.abhishek@ibm.com>
1 parent fa070a8 commit 77dd8af

File tree

2 files changed

+15
-8
lines changed

2 files changed

+15
-8
lines changed

README.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -916,12 +916,12 @@ For information on supported dataset formats and how to tune a vision-language m
916916

917917
? May be supported, but not tested
918918

919-
Model Name & Size | Model Architecture | Full Finetuning |
920-
-------------------- | ---------------- | --------------- |
921-
Llama 3.2-11B Vision | MllamaForConditionalGeneration | ✅* |
922-
Llava 1.5-7B | LlavaForConditionalGeneration | ✅* |
923-
Granite 3.1-2B Vision | LlavaNextForConditionalGeneration | ✅* |
924-
Llava Mistral 1.6-7B | LlavaNextForConditionalGeneration | ✅* |
919+
Model Name & Size | Model Architecture | LoRA Tuning | Full Finetuning |
920+
-------------------- | ---------------- | --------------- | --------------- |
921+
Llama 3.2-11B Vision | MllamaForConditionalGeneration | ✅* |* |
922+
Llava 1.5-7B | LlavaForConditionalGeneration | ✅* | 🚫 |
923+
Granite 3.1-2B Vision | LlavaNextForConditionalGeneration | ✅* | 🚫 |
924+
Llava Mistral 1.6-7B | LlavaNextForConditionalGeneration | ✅* | 🚫 |
925925

926926
(*) - Supported with `fms-hf-tuning` v2.8.0 or later.
927927

tuning/sft_trainer.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,9 +237,16 @@ def train(
237237
attn_implementation="flash_attention_2"
238238
if model_args.use_flash_attn
239239
else None,
240-
# avoid warning that use_cache is incompatible with gradient checkpointing
241-
use_cache=(not train_args.gradient_checkpointing),
242240
)
241+
try:
242+
if "use_cache" in model.language_model.config:
243+
# avoid warning that use_cache is incompatible with gradient checkpointing
244+
model.language_model.config.use_cache = (
245+
not train_args.gradient_checkpointing
246+
)
247+
except AttributeError as e:
248+
# When the model doesn't have the use_cache attribute
249+
logger.warning("Couldn't update use_cache for vision model: %s", e)
243250

244251
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path)
245252
tokenizer = processor.tokenizer

0 commit comments

Comments
 (0)