@@ -137,6 +137,7 @@ def main(**kwargs):
137
137
processor = AutoProcessor .from_pretrained (train_config .model_name if train_config .tokenizer_name is None else train_config .tokenizer_name )
138
138
processor .tokenizer .padding_side = 'right'
139
139
else :
140
+ is_vision = False
140
141
model = LlamaForCausalLM .from_pretrained (
141
142
train_config .model_name ,
142
143
quantization_config = bnb_config ,
@@ -188,23 +189,20 @@ def main(**kwargs):
188
189
freeze_transformer_layers (model , train_config .num_freeze_layers )
189
190
190
191
mixed_precision_policy , wrapping_policy = get_policies (fsdp_config , rank )
191
- my_auto_wrapping_policy = fsdp_auto_wrap_policy (model , [CLIPEncoderLayer ])
192
+ my_auto_wrapping_policy = fsdp_auto_wrap_policy (model , [LlamaDecoderLayer ,CLIPEncoderLayer ])
193
+ # if is_vision:
194
+ # my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [LlamaDecoderLayer,CLIPEncoderLayer])
195
+ # else:
196
+ # my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [LlamaDecoderLayer])
192
197
print ("FSDP is enabled" ,my_auto_wrapping_policy )
193
198
device_id = 0
194
199
if is_xpu_available ():
195
200
device_id = torch .xpu .current_device ()
196
201
elif torch .cuda .is_available ():
197
202
device_id = torch .cuda .current_device ()
198
- if train_config .use_peft :
199
- wrapping_policy = my_auto_wrapping_policy
200
- else :
201
- if is_vision :
202
- wrapping_policy = ModuleWrapPolicy ([CLIPEncoderLayer , LlamaDecoderLayer ])
203
- else :
204
- wrapping_policy = ModuleWrapPolicy ([LlamaDecoderLayer ])
205
203
model = FSDP (
206
204
model ,
207
- auto_wrap_policy = wrapping_policy ,
205
+ auto_wrap_policy = my_auto_wrapping_policy if train_config . use_peft else wrapping_policy ,
208
206
cpu_offload = CPUOffload (offload_params = True ) if fsdp_config .fsdp_cpu_offload else None ,
209
207
mixed_precision = mixed_precision_policy if not fsdp_config .pure_bf16 else None ,
210
208
sharding_strategy = fsdp_config .sharding_strategy ,
0 commit comments