26
26
BitsAndBytesConfig ,
27
27
LlamaForCausalLM ,
28
28
LlamaConfig ,
29
- AutoConfig ,
30
- AutoModel ,
31
- LlavaNextForConditionalGeneration ,
32
- LlavaNextProcessor
33
-
29
+ AutoProcessor ,
30
+ MllamaForConditionalGeneration
34
31
)
35
32
from transformers .models .llama .modeling_llama import LlamaDecoderLayer
36
33
from transformers .models .clip .modeling_clip import CLIPEncoder , CLIPEncoderLayer
@@ -126,20 +123,32 @@ def main(**kwargs):
126
123
127
124
# Load the pre-trained model and setup its configuration
128
125
use_cache = False if train_config .enable_fsdp else None
129
- model = LlavaNextForConditionalGeneration .from_pretrained (
126
+ if "11B" in train_config .model_name or "90B" in train_config .model_name :
127
+ is_vision = True
128
+ model = MllamaForConditionalGeneration .from_pretrained (
130
129
train_config .model_name ,
131
130
quantization_config = bnb_config ,
132
131
#use_cache=use_cache,
133
132
attn_implementation = "sdpa" if train_config .use_fast_kernels else None ,
134
133
device_map = "auto" if train_config .quantization and not train_config .enable_fsdp else None ,
135
134
torch_dtype = torch .float16 if train_config .use_fp16 else torch .bfloat16 ,
136
135
)
136
+ processor = AutoProcessor .from_pretrained (train_config .model_name if train_config .tokenizer_name is None else train_config .tokenizer_name )
137
+ processor .tokenizer .padding_side = 'right'
138
+ else :
139
+ model = LlamaForCausalLM .from_pretrained (
140
+ train_config .model_name ,
141
+ quantization_config = bnb_config ,
142
+ use_cache = use_cache ,
143
+ attn_implementation = "sdpa" if train_config .use_fast_kernels else None ,
144
+ device_map = "auto" if train_config .quantization and not train_config .enable_fsdp else None ,
145
+ torch_dtype = torch .float16 if train_config .use_fp16 else torch .bfloat16 ,
146
+ )
137
147
138
148
# Load the tokenizer and add special tokens
139
149
tokenizer = AutoTokenizer .from_pretrained (train_config .model_name if train_config .tokenizer_name is None else train_config .tokenizer_name )
140
150
tokenizer .pad_token_id = tokenizer .eos_token_id
141
- processor = LlavaNextProcessor .from_pretrained (train_config .model_name if train_config .tokenizer_name is None else train_config .tokenizer_name )
142
- processor .tokenizer .padding_side = 'right'
151
+
143
152
# If there is a mismatch between tokenizer vocab size and embedding matrix,
144
153
# throw a warning and then expand the embedding matrix
145
154
if len (tokenizer ) > model .get_input_embeddings ().weight .shape [0 ]:
@@ -183,18 +192,16 @@ def main(**kwargs):
183
192
device_id = torch .xpu .current_device ()
184
193
elif torch .cuda .is_available ():
185
194
device_id = torch .cuda .current_device ()
186
- # print(dir(model))
187
- # for layer in model.named_children():
188
- # print(f"Layer: {layer}")
189
-
190
- # layernorm = model.CLIPVisionTransformer.CLIPEncoder.LayerNorm
191
- # for name, param in layernorm.named_parameters():
192
- # print(f"Parameter: {name}, Shape: {param.shape}, Dtype: {param.dtype}")
193
- # exit()
195
+ if train_config .use_peft :
196
+ wrapping_policy = my_auto_wrapping_policy
197
+ else :
198
+ if is_vision :
199
+ wrapping_policy = ModuleWrapPolicy ([CLIPEncoderLayer , LlamaDecoderLayer ])
200
+ else :
201
+ wrapping_policy = ModuleWrapPolicy ([LlamaDecoderLayer ])
194
202
model = FSDP (
195
203
model ,
196
- auto_wrap_policy = ModuleWrapPolicy ([CLIPEncoderLayer , LlamaDecoderLayer ]),
197
- #auto_wrap_policy= my_auto_wrapping_policy, #if train_config.use_peft else wrapping_policy,
204
+ auto_wrap_policy = wrapping_policy ,
198
205
cpu_offload = CPUOffload (offload_params = True ) if fsdp_config .fsdp_cpu_offload else None ,
199
206
mixed_precision = mixed_precision_policy if not fsdp_config .pure_bf16 else None ,
200
207
sharding_strategy = fsdp_config .sharding_strategy ,
@@ -205,10 +212,9 @@ def main(**kwargs):
205
212
param_init_fn = (lambda module : module .to_empty (device = torch .device ("cuda" ), recurse = False ))
206
213
if train_config .low_cpu_fsdp and rank != 0 else None ,
207
214
)
208
- #print(model)
209
215
if fsdp_config .fsdp_activation_checkpointing :
210
216
model .enable_input_require_grads ()
211
- model .gradient_checkpointing_enable ()
217
+ # model.gradient_checkpointing_enable()
212
218
apply_fsdp_checkpointing (model )
213
219
elif not train_config .quantization and not train_config .enable_fsdp :
214
220
if is_xpu_available ():
@@ -217,23 +223,23 @@ def main(**kwargs):
217
223
model .to ("cuda" )
218
224
219
225
dataset_config = generate_dataset_config (train_config , kwargs )
226
+ if is_vision :
227
+ dataset_processer = processor
228
+ else :
229
+ dataset_processer = tokenizer
230
+
231
+ # Load and preprocess the dataset for training and validation
220
232
221
- # Load and preprocess the dataset for training and validation
222
- # dataset_train = get_preprocessed_dataset(
223
- # processor,
224
- # dataset_config,
225
- # split="train",
226
- # )
227
233
dataset_train = get_preprocessed_dataset (
228
- processor ,
234
+ dataset_processer ,
229
235
dataset_config ,
230
236
split = "train" ,
231
237
)
232
238
if not train_config .enable_fsdp or rank == 0 :
233
239
print (f"--> Training Set Length = { len (dataset_train )} " )
234
240
235
241
dataset_val = get_preprocessed_dataset (
236
- processor ,
242
+ dataset_processer ,
237
243
dataset_config ,
238
244
split = "test" ,
239
245
)
0 commit comments