30
30
MllamaForConditionalGeneration
31
31
)
32
32
from transformers .models .llama .modeling_llama import LlamaDecoderLayer
33
- from transformers .models .clip .modeling_clip import CLIPEncoder , CLIPEncoderLayer
33
+ from transformers .models .mllama .modeling_mllama import MllamaSelfAttentionDecoderLayer ,MllamaCrossAttentionDecoderLayer ,MllamaVisionEncoderLayer
34
+
34
35
from llama_recipes .configs import fsdp_config as FSDP_CONFIG
35
36
from llama_recipes .configs import train_config as TRAIN_CONFIG
36
37
from llama_recipes .configs import quantization_config as QUANTIZATION_CONFIG
@@ -129,7 +130,6 @@ def main(**kwargs):
129
130
model = MllamaForConditionalGeneration .from_pretrained (
130
131
train_config .model_name ,
131
132
quantization_config = bnb_config ,
132
- #use_cache=use_cache,
133
133
attn_implementation = "sdpa" if train_config .use_fast_kernels else None ,
134
134
device_map = "auto" if train_config .quantization and not train_config .enable_fsdp else None ,
135
135
torch_dtype = torch .float16 if train_config .use_fp16 else torch .bfloat16 ,
@@ -146,7 +146,7 @@ def main(**kwargs):
146
146
device_map = "auto" if train_config .quantization and not train_config .enable_fsdp else None ,
147
147
torch_dtype = torch .float16 if train_config .use_fp16 else torch .bfloat16 ,
148
148
)
149
-
149
+ print ( model )
150
150
# Load the tokenizer and add special tokens
151
151
tokenizer = AutoTokenizer .from_pretrained (train_config .model_name if train_config .tokenizer_name is None else train_config .tokenizer_name )
152
152
tokenizer .pad_token_id = tokenizer .eos_token_id
@@ -189,11 +189,7 @@ def main(**kwargs):
189
189
freeze_transformer_layers (model , train_config .num_freeze_layers )
190
190
191
191
mixed_precision_policy , wrapping_policy = get_policies (fsdp_config , rank )
192
- my_auto_wrapping_policy = fsdp_auto_wrap_policy (model , [LlamaDecoderLayer ,CLIPEncoderLayer ])
193
- # if is_vision:
194
- # my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [LlamaDecoderLayer,CLIPEncoderLayer])
195
- # else:
196
- # my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [LlamaDecoderLayer])
192
+ my_auto_wrapping_policy = fsdp_auto_wrap_policy (model , [LlamaDecoderLayer ,MllamaSelfAttentionDecoderLayer ,MllamaSelfAttentionDecoderLayer ,MllamaVisionEncoderLayer ])
197
193
print ("FSDP is enabled" ,my_auto_wrapping_policy )
198
194
device_id = 0
199
195
if is_xpu_available ():
@@ -222,7 +218,8 @@ def main(**kwargs):
222
218
model .to ("xpu:0" )
223
219
elif torch .cuda .is_available ():
224
220
model .to ("cuda" )
225
-
221
+ print ("-------------------" )
222
+ print ("FSDP model" , model )
226
223
dataset_config = generate_dataset_config (train_config , kwargs )
227
224
if is_vision :
228
225
dataset_processer = processor
@@ -248,7 +245,10 @@ def main(**kwargs):
248
245
print (f"--> Validation Set Length = { len (dataset_val )} " )
249
246
250
247
if train_config .batching_strategy == "packing" :
251
- dataset_train = ConcatDataset (dataset_train , chunk_size = train_config .context_length )
248
+ if is_vision :
249
+ raise ValueError ("Packing is not supported for vision datasets" )
250
+ else :
251
+ dataset_train = ConcatDataset (dataset_train , chunk_size = train_config .context_length )
252
252
253
253
train_dl_kwargs = get_dataloader_kwargs (train_config , dataset_train , dataset_processer , "train" )
254
254
print ("length of dataset_train" , len (dataset_train ))
@@ -268,7 +268,10 @@ def main(**kwargs):
268
268
eval_dataloader = None
269
269
if train_config .run_validation :
270
270
if train_config .batching_strategy == "packing" :
271
- dataset_val = ConcatDataset (dataset_val , chunk_size = train_config .context_length )
271
+ if is_vision :
272
+ raise ValueError ("Packing is not supported for vision datasets" )
273
+ else :
274
+ dataset_val = ConcatDataset (dataset_val , chunk_size = train_config .context_length )
272
275
273
276
val_dl_kwargs = get_dataloader_kwargs (train_config , dataset_val , dataset_processer , "val" )
274
277
if custom_data_collator :
0 commit comments