Skip to content

Commit 7544c3a

Browse files
michaelroyzenMichael Royzenkashif
authored
Support sequence sampling in Liger Kernel and pass importance_samplin… (#5190)
Co-authored-by: Michael Royzen <michaelroyzen@mac.mynetworksettings.com> Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
1 parent 5cffd59 commit 7544c3a

File tree

2 files changed

+46
-4
lines changed

2 files changed

+46
-4
lines changed

tests/test_grpo_trainer.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2621,6 +2621,47 @@ def test_training_with_liger_grpo_kernel_and_peft(self, model_name):
26212621

26222622
release_memory(model, trainer)
26232623

2624+
@require_liger_kernel
2625+
def test_liger_grpo_kernel_importance_sampling(self):
2626+
model_name = "trl-internal-testing/tiny-LlamaForCausalLM-3.2"
2627+
2628+
training_args = GRPOConfig(
2629+
output_dir=self.tmp_dir,
2630+
per_device_train_batch_size=3,
2631+
num_generations=3,
2632+
use_liger_kernel=True,
2633+
max_completion_length=self.max_length,
2634+
importance_sampling_level="sequence",
2635+
report_to="none",
2636+
logging_strategy="no",
2637+
)
2638+
2639+
model = AutoModelForCausalLM.from_pretrained(model_name, dtype="float32")
2640+
tokenizer = AutoTokenizer.from_pretrained(model_name)
2641+
tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token
2642+
2643+
trainer = GRPOTrainer(
2644+
model=model,
2645+
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
2646+
args=training_args,
2647+
train_dataset=self.train_dataset,
2648+
eval_dataset=self.eval_dataset,
2649+
processing_class=tokenizer,
2650+
)
2651+
from liger_kernel.chunked_loss import LigerFusedLinearGRPOLoss
2652+
2653+
assert isinstance(trainer.liger_grpo_loss, LigerFusedLinearGRPOLoss)
2654+
2655+
previous_trainable_params = {n: param.clone() for n, param in model.named_parameters()}
2656+
2657+
trainer.train()
2658+
2659+
for n, param in previous_trainable_params.items():
2660+
new_param = model.get_parameter(n)
2661+
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
2662+
2663+
release_memory(model, trainer)
2664+
26242665
@pytest.mark.parametrize(
26252666
"model_name",
26262667
[

trl/trainer/grpo_trainer.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -547,10 +547,10 @@ def __init__(
547547
raise NotImplementedError(
548548
"Liger Kernels don't currently support masking token positions based on entropy."
549549
)
550-
if self.use_liger_kernel and not self.importance_sampling_level == "token":
551-
raise NotImplementedError(
552-
"Liger Kernels currently only support token-level importance sampling. Please set"
553-
"`importance_sampling_level` to 'token'."
550+
if self.use_liger_kernel and self.importance_sampling_level not in ("token", "sequence"):
551+
raise ValueError(
552+
f"Unknown importance sampling level: {self.importance_sampling_level}. "
553+
"Possible values are 'token' and 'sequence'."
554554
)
555555

556556
# Datasets
@@ -679,6 +679,7 @@ def cast_outputs_to_original_dtype(module, args, output):
679679
use_ref_model=self.beta != 0.0,
680680
loss_type=self.loss_type,
681681
max_completion_length=self.max_completion_length,
682+
importance_sampling_level=self.importance_sampling_level,
682683
)
683684

684685
# Initialize the metrics

0 commit comments

Comments
 (0)