Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -140,4 +140,5 @@ autogen/
#fp8
ops/csrc/fp8/deep_gemm/include/cutlass
ops/csrc/fp8/deep_gemm/include/cute
.ccls-cache
.ccls-cache
llm/log
1 change: 1 addition & 0 deletions docs/zh/llm/benchmark/rl/README.md
37 changes: 37 additions & 0 deletions llm/config/llama/dislora_argument.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
{
"model_name_or_path": "Qwen/Qwen2.5-0.5B-Instruct",
"dataset_name_or_path": "/home/bjh/bjh/Dislora/cs_5_lite",
"output_dir": "./checkpoints/dislora_ckpts_3",
"dislora": true,
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 5,
"num_train_epochs": 1,
"learning_rate": 2e-05,
"lr_scheduler_type": "linear",
"warmup_steps": 30,
"logging_steps": 1,
"evaluation_strategy": "no",
"save_strategy": "steps",
"save_steps": 500,
"src_length": 256,
"max_length": 512,
"bf16": true,
"do_train": true,
"do_eval": false,
"disable_tqdm": false,
"load_best_model_at_end": false,
"eval_with_do_generation": false,
"recompute": false,
"save_total_limit": 5,
"fp16_opt_level": "O2",
"sharding": "stage3",
"zero_padding": false,
"use_flash_attention": false,
"unified_checkpoint": false,
"dislora_rank": 8,
"dislora_dropout": 0.05,
"target_modules": [".*q_proj.*", ".*v_proj.*", ".*k_proj.*", ".*o_proj.*"],
"s_tsd": 8,
"ortho_lambda": 1.0,
"prefer_small_sigma": true
}
36 changes: 36 additions & 0 deletions llm/config/qwen/dislora_argument.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
{
"model_name_or_path": "Qwen/Qwen2.5-7B-Instruct",
"dataset_name_or_path": "/home/bjh/bjh/Dislora/cs_5_lite",
"output_dir": "./checkpoints/dislora_ckpts",
"dislora": true,
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 1,
"num_train_epochs": 1,
"learning_rate": 2e-05,
"lr_scheduler_type": "linear",
"warmup_steps": 30,
"logging_steps": 1,
"evaluation_strategy": "no",
"save_strategy": "steps",
"save_steps": 500,
"src_length": 256,
"max_length": 512,
"bf16": true,
"do_train": true,
"do_eval": false,
"disable_tqdm": false,
"load_best_model_at_end": false,
"eval_with_do_generation": false,
"recompute": false,
"save_total_limit": 5,
"fp16_opt_level": "O2",
"sharding": "stage3",
"zero_padding": false,
"use_flash_attention": false,
"unified_checkpoint": false,
"dislora_rank": 8,
"dislora_dropout": 0.05,
"s_tsd": 8,
"ortho_lambda": 1.0,
"prefer_small_sigma": true
}
89 changes: 87 additions & 2 deletions llm/run_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
)
from paddlenlp.metrics import BLEU, Rouge1, Rouge2, RougeL
from paddlenlp.peft import (
DisLoRAConfig,
DisLoRAModel,
LoKrConfig,
LoKrModel,
LoRAConfig,
Expand Down Expand Up @@ -311,6 +313,15 @@ def neft_post_hook(module, input, output):
tokenizer.pad_token_id = tokenizer.eos_token_id

train_ds, dev_ds, test_ds = create_dataset(data_args, training_args)

train_dataset_size = None
if train_ds is not None and model_args.dislora:
train_dataset_size = get_dataset_size(train_ds)
if train_dataset_size is not None:
logger.info(f"Original training dataset size: {train_dataset_size}")
else:
logger.warning("Unable to determine training dataset size for dynamic dash_flag calculation")

# TODO(ZHUI & sijunhe): Temporary implementation. Generalize this logic and move to Trainer later.
if training_args.resume_from_checkpoint is not None and data_args.lazy:
logger.info(
Expand Down Expand Up @@ -377,7 +388,9 @@ def neft_post_hook(module, input, output):
if eval_zero_padding and test_ds is not None:
test_ds = intoken_dataset(test_ds, tokenizer=tokenizer, max_length=data_args.max_length)

model = create_peft_model(model_args, reft_args, training_args, dtype, model_config, model, reft_layers)
model = create_peft_model(
model_args, reft_args, training_args, dtype, model_config, model, reft_layers, train_dataset_size
)

def compute_metrics_do_generation(eval_preds):
rouge1 = Rouge1()
Expand Down Expand Up @@ -441,6 +454,10 @@ def compute_metrics_do_generation(eval_preds):
return_attention_mask=not model_args.flash_mask,
pad_to_multiple_of=data_args.pad_to_multiple_of,
)

if model_args.dislora and hasattr(model_args, "ortho_lambda"):
training_args.dislora_ortho_lambda = model_args.ortho_lambda

trainer = SFTTrainer(
model=model,
args=training_args,
Expand Down Expand Up @@ -531,7 +548,9 @@ def save_to_aistudio(model_args, training_args, trainer):
)


def create_peft_model(model_args, reft_args, training_args, dtype, model_config, model, reft_layers):
def create_peft_model(
model_args, reft_args, training_args, dtype, model_config, model, reft_layers, train_dataset_size
):
if model_args.prefix_tuning:
if training_args.pipeline_parallel_degree > 1:
raise NotImplementedError("Prefix tuning is not implemented for pipeline parallelism.")
Expand Down Expand Up @@ -606,6 +625,53 @@ def create_peft_model(model_args, reft_args, training_args, dtype, model_config,
else:
model = LoKrModel.from_pretrained(model=model, lokr_path=model_args.lokr_path)

if model_args.dislora:
# Calculate dynamic dash_flag based on training configuration
if train_dataset_size is not None and training_args.do_train:
# Calculate warmup steps: len(train_data) * num_epochs // (batch_size * gradient_accumulation_steps * 3)
effective_batch_size = (
training_args.per_device_train_batch_size
* training_args.gradient_accumulation_steps
* training_args.dataset_world_size # Consider data parallel
)
calculated_dash_flag = (train_dataset_size * training_args.num_train_epochs) // (effective_batch_size * 3)

# Use calculated value if it's reasonable, otherwise fall back to model_args
if calculated_dash_flag > 0:
dash_flag = calculated_dash_flag
logger.info(
f"Calculated dynamic dash_flag: {dash_flag} based on dataset size: {train_dataset_size}, "
f"epochs: {training_args.num_train_epochs}, effective batch size: {effective_batch_size}"
)
else:
dash_flag = model_args.dash_flag
logger.warning(
f"Calculated dash_flag was {calculated_dash_flag}, using model_args.dash_flag: {dash_flag}"
)
else:
dash_flag = getattr(model_args, "dash_flag", 50)
if train_dataset_size is None:
logger.info(
f"Unable to calculate dynamic dash_flag (dataset size unknown), using configured dash_flag: {dash_flag}"
)
else:
logger.info(f"Not in training mode, using configured dash_flag: {dash_flag}")
if model_args.dislora_path is None:
dislora_config = DisLoRAConfig(
target_modules=model_args.target_modules
if model_args.target_modules
else get_lora_target_modules(model),
r=model_args.dislora_rank,
dislora_alpha=1.5 * model_args.dislora_rank,
dislora_dropout=model_args.dislora_dropout,
dtype=dtype,
base_model_name_or_path=model_args.model_name_or_path,
s_tsd=model_args.s_tsd,
dash_flag=dash_flag, # Use calculated dash_flag
ortho_lambda=model_args.ortho_lambda,
)
model = DisLoRAModel(model, dislora_config)

if model_args.reft:
intervention_dtype = dtype
intervention_params = {
Expand Down Expand Up @@ -745,5 +811,24 @@ def create_dataset(data_args, training_args):
return train_ds, dev_ds, test_ds


def get_dataset_size(dataset):
"""Get the size of a dataset, handling both lazy and regular datasets"""
if dataset is None:
return None

try:
if hasattr(dataset, "__len__"):
return len(dataset)
elif hasattr(dataset, "_length"):
return dataset._length
else:
# For lazy datasets, we might need to iterate once to count
logger.warning("Unable to determine dataset size directly for lazy loading dataset")
return None
except Exception as e:
logger.warning(f"Error getting dataset size: {e}")
return None


if __name__ == "__main__":
main()
Loading
Loading