14
14
FullyShardedDataParallel as FSDP ,
15
15
ShardingStrategy
16
16
)
17
-
17
+ from torch .distributed .fsdp .wrap import (
18
+ always_wrap_policy ,
19
+ ModuleWrapPolicy ,
20
+ transformer_auto_wrap_policy ,
21
+ )
18
22
from torch .distributed .fsdp .fully_sharded_data_parallel import CPUOffload
19
23
from torch .optim .lr_scheduler import StepLR
20
24
from transformers import (
29
33
30
34
)
31
35
from transformers .models .llama .modeling_llama import LlamaDecoderLayer
32
-
36
+ from transformers . models . clip . modeling_clip import CLIPEncoder , CLIPEncoderLayer
33
37
from llama_recipes .configs import fsdp_config as FSDP_CONFIG
34
38
from llama_recipes .configs import train_config as TRAIN_CONFIG
35
39
from llama_recipes .configs import quantization_config as QUANTIZATION_CONFIG
@@ -121,11 +125,11 @@ def main(**kwargs):
121
125
bnb_config = quant_config .create_bnb_config (train_config .quantization )
122
126
123
127
# Load the pre-trained model and setup its configuration
124
- # use_cache = False if train_config.enable_fsdp else None
128
+ use_cache = False if train_config .enable_fsdp else None
125
129
model = LlavaNextForConditionalGeneration .from_pretrained (
126
130
train_config .model_name ,
127
131
quantization_config = bnb_config ,
128
- # use_cache=use_cache,
132
+ # use_cache=use_cache,
129
133
attn_implementation = "sdpa" if train_config .use_fast_kernels else None ,
130
134
device_map = "auto" if train_config .quantization and not train_config .enable_fsdp else None ,
131
135
torch_dtype = torch .float16 if train_config .use_fp16 else torch .bfloat16 ,
@@ -172,16 +176,25 @@ def main(**kwargs):
172
176
freeze_transformer_layers (model , train_config .num_freeze_layers )
173
177
174
178
mixed_precision_policy , wrapping_policy = get_policies (fsdp_config , rank )
175
- my_auto_wrapping_policy = fsdp_auto_wrap_policy (model , LlamaDecoderLayer )
176
-
179
+ my_auto_wrapping_policy = fsdp_auto_wrap_policy (model , [ CLIPEncoderLayer ] )
180
+ print ( "FSDP is enabled" , my_auto_wrapping_policy )
177
181
device_id = 0
178
182
if is_xpu_available ():
179
183
device_id = torch .xpu .current_device ()
180
184
elif torch .cuda .is_available ():
181
185
device_id = torch .cuda .current_device ()
186
+ # print(dir(model))
187
+ # for layer in model.named_children():
188
+ # print(f"Layer: {layer}")
189
+
190
+ # layernorm = model.CLIPVisionTransformer.CLIPEncoder.LayerNorm
191
+ # for name, param in layernorm.named_parameters():
192
+ # print(f"Parameter: {name}, Shape: {param.shape}, Dtype: {param.dtype}")
193
+ # exit()
182
194
model = FSDP (
183
195
model ,
184
- auto_wrap_policy = my_auto_wrapping_policy if train_config .use_peft else wrapping_policy ,
196
+ auto_wrap_policy = ModuleWrapPolicy ([CLIPEncoderLayer , LlamaDecoderLayer ]),
197
+ #auto_wrap_policy= my_auto_wrapping_policy, #if train_config.use_peft else wrapping_policy,
185
198
cpu_offload = CPUOffload (offload_params = True ) if fsdp_config .fsdp_cpu_offload else None ,
186
199
mixed_precision = mixed_precision_policy if not fsdp_config .pure_bf16 else None ,
187
200
sharding_strategy = fsdp_config .sharding_strategy ,
@@ -192,6 +205,7 @@ def main(**kwargs):
192
205
param_init_fn = (lambda module : module .to_empty (device = torch .device ("cuda" ), recurse = False ))
193
206
if train_config .low_cpu_fsdp and rank != 0 else None ,
194
207
)
208
+ #print(model)
195
209
if fsdp_config .fsdp_activation_checkpointing :
196
210
model .enable_input_require_grads ()
197
211
model .gradient_checkpointing_enable ()
@@ -205,6 +219,11 @@ def main(**kwargs):
205
219
dataset_config = generate_dataset_config (train_config , kwargs )
206
220
207
221
# Load and preprocess the dataset for training and validation
222
+ # dataset_train = get_preprocessed_dataset(
223
+ # processor,
224
+ # dataset_config,
225
+ # split="train",
226
+ # )
208
227
dataset_train = get_preprocessed_dataset (
209
228
processor ,
210
229
dataset_config ,
@@ -272,6 +291,7 @@ def main(**kwargs):
272
291
)
273
292
scheduler = StepLR (optimizer , step_size = 1 , gamma = train_config .gamma )
274
293
# Start the training process
294
+
275
295
results = train (
276
296
model ,
277
297
train_dataloader ,
0 commit comments