Skip to content

Commit 9c7a5b4

Browse files
authored
fix AutoModel and bump transformers version to 4.45 (meta-llama#686)
1 parent c7c229d commit 9c7a5b4

File tree

3 files changed

+5
-3
lines changed

3 files changed

+5
-3
lines changed

.watchman-cookie-devgpu003.cco3.facebook.com-3137776-2746

Whitespace-only changes.

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ black[jupyter]
88
datasets
99
fire
1010
peft
11-
transformers>=4.43.1
11+
transformers>=4.45.1
1212
sentencepiece
1313
py7zr
1414
scipy

src/llama_recipes/finetuning.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
AutoTokenizer,
2222
BitsAndBytesConfig,
2323
AutoProcessor,
24+
LlamaForCausalLM,
2425
MllamaForConditionalGeneration,
25-
AutoModel,
2626
)
2727
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
2828
from transformers.models.mllama.modeling_mllama import MllamaSelfAttentionDecoderLayer,MllamaCrossAttentionDecoderLayer,MllamaVisionEncoderLayer
@@ -132,9 +132,11 @@ def main(**kwargs):
132132
)
133133
processor = AutoProcessor.from_pretrained(train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name)
134134
processor.tokenizer.padding_side='right'
135+
model.supports_gradient_checkpointing = True
136+
model.language_model.supports_gradient_checkpointing = True
135137
elif config.model_type == "llama":
136138
is_vision = False
137-
model = AutoModel.from_pretrained(
139+
model = LlamaForCausalLM.from_pretrained(
138140
train_config.model_name,
139141
quantization_config=bnb_config,
140142
use_cache=use_cache,

0 commit comments

Comments
 (0)