Skip to content

Commit ee204cc

Browse files
committed
working now
1 parent b566582 commit ee204cc

File tree

3 files changed

+41
-18
lines changed

3 files changed

+41
-18
lines changed

src/llama_recipes/finetuning.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,11 @@
1414
FullyShardedDataParallel as FSDP,
1515
ShardingStrategy
1616
)
17-
17+
from torch.distributed.fsdp.wrap import (
18+
always_wrap_policy,
19+
ModuleWrapPolicy,
20+
transformer_auto_wrap_policy,
21+
)
1822
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
1923
from torch.optim.lr_scheduler import StepLR
2024
from transformers import (
@@ -29,7 +33,7 @@
2933

3034
)
3135
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
32-
36+
from transformers.models.clip.modeling_clip import CLIPEncoder, CLIPEncoderLayer
3337
from llama_recipes.configs import fsdp_config as FSDP_CONFIG
3438
from llama_recipes.configs import train_config as TRAIN_CONFIG
3539
from llama_recipes.configs import quantization_config as QUANTIZATION_CONFIG
@@ -121,11 +125,11 @@ def main(**kwargs):
121125
bnb_config = quant_config.create_bnb_config(train_config.quantization)
122126

123127
# Load the pre-trained model and setup its configuration
124-
#use_cache = False if train_config.enable_fsdp else None
128+
use_cache = False if train_config.enable_fsdp else None
125129
model = LlavaNextForConditionalGeneration.from_pretrained(
126130
train_config.model_name,
127131
quantization_config=bnb_config,
128-
# use_cache=use_cache,
132+
#use_cache=use_cache,
129133
attn_implementation="sdpa" if train_config.use_fast_kernels else None,
130134
device_map="auto" if train_config.quantization and not train_config.enable_fsdp else None,
131135
torch_dtype=torch.float16 if train_config.use_fp16 else torch.bfloat16,
@@ -172,16 +176,25 @@ def main(**kwargs):
172176
freeze_transformer_layers(model, train_config.num_freeze_layers)
173177

174178
mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
175-
my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer)
176-
179+
my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [CLIPEncoderLayer])
180+
print("FSDP is enabled",my_auto_wrapping_policy)
177181
device_id = 0
178182
if is_xpu_available():
179183
device_id = torch.xpu.current_device()
180184
elif torch.cuda.is_available():
181185
device_id = torch.cuda.current_device()
186+
# print(dir(model))
187+
# for layer in model.named_children():
188+
# print(f"Layer: {layer}")
189+
190+
# layernorm = model.CLIPVisionTransformer.CLIPEncoder.LayerNorm
191+
# for name, param in layernorm.named_parameters():
192+
# print(f"Parameter: {name}, Shape: {param.shape}, Dtype: {param.dtype}")
193+
# exit()
182194
model = FSDP(
183195
model,
184-
auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy,
196+
auto_wrap_policy= ModuleWrapPolicy([CLIPEncoderLayer, LlamaDecoderLayer]),
197+
#auto_wrap_policy= my_auto_wrapping_policy, #if train_config.use_peft else wrapping_policy,
185198
cpu_offload=CPUOffload(offload_params=True) if fsdp_config.fsdp_cpu_offload else None,
186199
mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None,
187200
sharding_strategy=fsdp_config.sharding_strategy,
@@ -192,6 +205,7 @@ def main(**kwargs):
192205
param_init_fn=(lambda module: module.to_empty(device=torch.device("cuda"), recurse=False))
193206
if train_config.low_cpu_fsdp and rank != 0 else None,
194207
)
208+
#print(model)
195209
if fsdp_config.fsdp_activation_checkpointing:
196210
model.enable_input_require_grads()
197211
model.gradient_checkpointing_enable()
@@ -205,6 +219,11 @@ def main(**kwargs):
205219
dataset_config = generate_dataset_config(train_config, kwargs)
206220

207221
# Load and preprocess the dataset for training and validation
222+
# dataset_train = get_preprocessed_dataset(
223+
# processor,
224+
# dataset_config,
225+
# split="train",
226+
# )
208227
dataset_train = get_preprocessed_dataset(
209228
processor,
210229
dataset_config,
@@ -272,6 +291,7 @@ def main(**kwargs):
272291
)
273292
scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)
274293
# Start the training process
294+
275295
results = train(
276296
model,
277297
train_dataloader,

src/llama_recipes/utils/fsdp_utils.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from torch.distributed._tensor.device_mesh import init_device_mesh
44
import os
55

6-
def fsdp_auto_wrap_policy(model, transformer_layer_name):
6+
def fsdp_auto_wrap_policy(model, transformer_layer_names):
77
import functools
88

99
from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy
@@ -16,16 +16,19 @@ def lambda_policy_fn(module):
1616
):
1717
return True
1818
return False
19-
19+
transformer_wrap_policies = []
2020
lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn)
21-
transformer_wrap_policy = functools.partial(
22-
transformer_auto_wrap_policy,
23-
transformer_layer_cls=(
24-
transformer_layer_name,
25-
),
26-
)
27-
28-
auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, transformer_wrap_policy])
21+
for transformer_layer_name in transformer_layer_names:
22+
23+
transformer_wrap_policy = functools.partial(
24+
transformer_auto_wrap_policy,
25+
transformer_layer_cls=(
26+
transformer_layer_name,
27+
),
28+
)
29+
transformer_wrap_policies.append(transformer_wrap_policy)
30+
policies = transformer_wrap_policies
31+
auto_wrap_policy = functools.partial(_or_policy, policies=policies)
2932
return auto_wrap_policy
3033

3134

src/llama_recipes/utils/train_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer, wandb
358358
# Ensure no gradients are computed for this scope to save memory
359359
with torch.no_grad():
360360
# Forward pass and compute loss
361-
outputs = model(**batch)
361+
outputs = model(**batch,use_cache=False)
362362
loss = outputs.loss
363363
if train_config.save_metrics:
364364
val_step_loss.append(loss.detach().float().item())

0 commit comments

Comments
 (0)