|
24 | 24 | GRPOLossOutput, |
25 | 25 | MCAdvantage, |
26 | 26 | ) |
| 27 | +from torchrl.objectives.llm.sdpo import SDPOLoss, SDPOLossOutput |
27 | 28 | from torchrl.objectives.llm.sft import SFTLoss |
28 | 29 |
|
29 | 30 | _has_transformers = importlib.util.find_spec("transformers") is not None |
@@ -427,6 +428,305 @@ def test_cispo(self, mock_transformer_model): |
427 | 428 | ), f"clip_fraction out of range: {loss_vals.clip_fraction}" |
428 | 429 |
|
429 | 430 |
|
| 431 | +def _mock_data_sdpo(vocab_size: int, device: torch.device | str = "cpu") -> TensorDict: |
| 432 | + """Create mock data for SDPO testing.""" |
| 433 | + from transformers import AutoTokenizer |
| 434 | + |
| 435 | + device = torch.device(device) |
| 436 | + |
| 437 | + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B") |
| 438 | + prompt = History( |
| 439 | + role=["system", "user"], |
| 440 | + content=["You are a useful assistant.", "What is 2+2?"], |
| 441 | + batch_size=(2,), |
| 442 | + device=device, |
| 443 | + ) |
| 444 | + response = History( |
| 445 | + role=["assistant"], |
| 446 | + content=["2 + 2 = 4."], |
| 447 | + batch_size=(1,), |
| 448 | + device=device, |
| 449 | + ) |
| 450 | + full_history = prompt.extend(response, inplace=False) |
| 451 | + history = ChatHistory( |
| 452 | + prompt=prompt, |
| 453 | + response=response, |
| 454 | + full=full_history, |
| 455 | + device=device, |
| 456 | + ) |
| 457 | + batch_size = 1 |
| 458 | + |
| 459 | + # Expand history to match batch size |
| 460 | + history = history.expand((batch_size,)) |
| 461 | + next_history = ChatHistory( |
| 462 | + prompt=full_history, |
| 463 | + device=device, |
| 464 | + ) |
| 465 | + next_history = next_history.expand((batch_size,)) |
| 466 | + |
| 467 | + # Get tokens |
| 468 | + tokens_full = history.to_tokens(tokenizer) |
| 469 | + next_tokens = next_history.to_tokens(tokenizer) |
| 470 | + |
| 471 | + tokens_input_ids = tokens_full.get( |
| 472 | + "full", as_padded_tensor=True, padding_side="left", padding_value=0 |
| 473 | + ) |
| 474 | + seq_len = tokens_input_ids.shape[-1] |
| 475 | + |
| 476 | + # Create tensors |
| 477 | + reward = torch.randn(batch_size, seq_len, 1, device=device) |
| 478 | + done = torch.zeros(batch_size, seq_len, 1, dtype=torch.bool, device=device) |
| 479 | + log_probs = torch.randn_like(tokens_full, dtype=torch.float32, device=device) |
| 480 | + attention_mask = torch.ones(batch_size, seq_len, dtype=torch.bool, device=device) |
| 481 | + |
| 482 | + from tensordict import MetaData |
| 483 | + |
| 484 | + masks = Masks( |
| 485 | + all_attention_mask=attention_mask, |
| 486 | + all_assistant_mask=None, |
| 487 | + padded=MetaData(True), |
| 488 | + device=device, |
| 489 | + ) |
| 490 | + |
| 491 | + # Create teacher context (feedback-augmented prompt for self-teacher) |
| 492 | + # In real usage, this would be constructed by AddFeedbackContext transform |
| 493 | + teacher_context = { |
| 494 | + "history": history, |
| 495 | + "env_feedback": "The solution is correct.", |
| 496 | + } |
| 497 | + |
| 498 | + data = TensorDict( |
| 499 | + { |
| 500 | + "history": history, |
| 501 | + "tokens": tokens_full % vocab_size, |
| 502 | + "masks": masks, |
| 503 | + "teacher_context": teacher_context, |
| 504 | + "next": { |
| 505 | + "history": next_history, |
| 506 | + "tokens": next_tokens % vocab_size, |
| 507 | + "reward": reward, |
| 508 | + "done": done, |
| 509 | + }, |
| 510 | + "log_probs": log_probs, |
| 511 | + }, |
| 512 | + batch_size=(batch_size,), |
| 513 | + ) |
| 514 | + return data |
| 515 | + |
| 516 | + |
| 517 | +class TestSDPO: |
| 518 | + """Test suite for Self-Distillation Policy Optimization (SDPO) loss.""" |
| 519 | + |
| 520 | + @pytest.mark.parametrize( |
| 521 | + "divergence_type", ["kl", "reverse_kl", "js"], ids=["kl", "reverse_kl", "js"] |
| 522 | + ) |
| 523 | + def test_sdpo_basic(self, mock_transformer_model, divergence_type): |
| 524 | + """Test basic SDPO loss computation with different divergence types.""" |
| 525 | + vocab_size = 1024 |
| 526 | + device = torch.device("cpu") |
| 527 | + |
| 528 | + # Create mock model and wrap it |
| 529 | + model = mock_transformer_model(vocab_size=vocab_size, device=device) |
| 530 | + actor_network = TransformersWrapper( |
| 531 | + model, |
| 532 | + generate=False, |
| 533 | + pad_output=True, |
| 534 | + input_mode="history", |
| 535 | + ) |
| 536 | + |
| 537 | + # Create loss module |
| 538 | + loss_fn = SDPOLoss( |
| 539 | + actor_network, |
| 540 | + divergence_type=divergence_type, |
| 541 | + entropy_bonus=True, |
| 542 | + entropy_coeff=0.01, |
| 543 | + ) |
| 544 | + |
| 545 | + # Create fake data |
| 546 | + data = _mock_data_sdpo(vocab_size=vocab_size, device=device) |
| 547 | + |
| 548 | + # Compute loss |
| 549 | + loss_vals = loss_fn(data) |
| 550 | + |
| 551 | + # Assertions: Check output type and structure |
| 552 | + assert isinstance( |
| 553 | + loss_vals, SDPOLossOutput |
| 554 | + ), f"Expected SDPOLossOutput, got {type(loss_vals)}" |
| 555 | + |
| 556 | + # Check that all expected keys are present |
| 557 | + assert hasattr(loss_vals, "loss_objective"), "Missing loss_objective" |
| 558 | + assert hasattr(loss_vals, "divergence"), "Missing divergence" |
| 559 | + assert hasattr(loss_vals, "kl_approx"), "Missing kl_approx" |
| 560 | + assert hasattr(loss_vals, "entropy"), "Missing entropy" |
| 561 | + assert hasattr(loss_vals, "loss_entropy"), "Missing loss_entropy" |
| 562 | + |
| 563 | + # Check tensor shapes (all losses should be scalars after reduction) |
| 564 | + assert ( |
| 565 | + loss_vals.loss_objective.shape == () |
| 566 | + ), f"loss_objective should be scalar, got {loss_vals.loss_objective.shape}" |
| 567 | + assert ( |
| 568 | + loss_vals.divergence.shape == () |
| 569 | + ), f"divergence should be scalar, got {loss_vals.divergence.shape}" |
| 570 | + assert ( |
| 571 | + loss_vals.kl_approx.shape == () |
| 572 | + ), f"kl_approx should be scalar, got {loss_vals.kl_approx.shape}" |
| 573 | + |
| 574 | + # Check that losses are finite |
| 575 | + assert torch.isfinite(loss_vals.loss_objective), "loss_objective is not finite" |
| 576 | + assert torch.isfinite(loss_vals.divergence), "divergence is not finite" |
| 577 | + |
| 578 | + # Divergence should be non-negative |
| 579 | + assert ( |
| 580 | + loss_vals.divergence >= 0 |
| 581 | + ), f"divergence should be non-negative: {loss_vals.divergence}" |
| 582 | + |
| 583 | + @pytest.mark.parametrize("topk", [None, 50, 100], ids=["full", "topk50", "topk100"]) |
| 584 | + def test_sdpo_topk(self, mock_transformer_model, topk): |
| 585 | + """Test SDPO with top-K logit distillation for memory efficiency.""" |
| 586 | + vocab_size = 1024 |
| 587 | + device = torch.device("cpu") |
| 588 | + |
| 589 | + model = mock_transformer_model(vocab_size=vocab_size, device=device) |
| 590 | + actor_network = TransformersWrapper( |
| 591 | + model, |
| 592 | + generate=False, |
| 593 | + pad_output=True, |
| 594 | + input_mode="history", |
| 595 | + ) |
| 596 | + |
| 597 | + # Create loss with top-K |
| 598 | + loss_fn = SDPOLoss( |
| 599 | + actor_network, |
| 600 | + divergence_type="js", |
| 601 | + topk=topk, |
| 602 | + ) |
| 603 | + |
| 604 | + data = _mock_data_sdpo(vocab_size=vocab_size, device=device) |
| 605 | + loss_vals = loss_fn(data) |
| 606 | + |
| 607 | + assert isinstance(loss_vals, SDPOLossOutput) |
| 608 | + assert torch.isfinite(loss_vals.loss_objective) |
| 609 | + assert loss_vals.divergence >= 0 |
| 610 | + |
| 611 | + def test_sdpo_ema_teacher(self, mock_transformer_model): |
| 612 | + """Test SDPO with EMA teacher regularization.""" |
| 613 | + vocab_size = 1024 |
| 614 | + device = torch.device("cpu") |
| 615 | + |
| 616 | + model = mock_transformer_model(vocab_size=vocab_size, device=device) |
| 617 | + actor_network = TransformersWrapper( |
| 618 | + model, |
| 619 | + generate=False, |
| 620 | + pad_output=True, |
| 621 | + input_mode="history", |
| 622 | + ) |
| 623 | + |
| 624 | + # Create loss with EMA teacher |
| 625 | + loss_fn = SDPOLoss( |
| 626 | + actor_network, |
| 627 | + divergence_type="js", |
| 628 | + use_ema_teacher=True, |
| 629 | + ema_decay=0.99, |
| 630 | + ) |
| 631 | + |
| 632 | + # Check that EMA params were initialized |
| 633 | + assert loss_fn._ema_teacher_params is not None |
| 634 | + assert len(loss_fn._ema_teacher_params) > 0 |
| 635 | + |
| 636 | + data = _mock_data_sdpo(vocab_size=vocab_size, device=device) |
| 637 | + loss_vals = loss_fn(data) |
| 638 | + |
| 639 | + assert isinstance(loss_vals, SDPOLossOutput) |
| 640 | + assert torch.isfinite(loss_vals.loss_objective) |
| 641 | + |
| 642 | + # Test EMA update |
| 643 | + loss_fn.update_ema_teacher() |
| 644 | + # Should still work after update |
| 645 | + loss_vals_after = loss_fn(data) |
| 646 | + assert torch.isfinite(loss_vals_after.loss_objective) |
| 647 | + |
| 648 | + def test_sdpo_no_entropy(self, mock_transformer_model): |
| 649 | + """Test SDPO without entropy bonus.""" |
| 650 | + vocab_size = 1024 |
| 651 | + device = torch.device("cpu") |
| 652 | + |
| 653 | + model = mock_transformer_model(vocab_size=vocab_size, device=device) |
| 654 | + actor_network = TransformersWrapper( |
| 655 | + model, |
| 656 | + generate=False, |
| 657 | + pad_output=True, |
| 658 | + input_mode="history", |
| 659 | + ) |
| 660 | + |
| 661 | + loss_fn = SDPOLoss( |
| 662 | + actor_network, |
| 663 | + divergence_type="js", |
| 664 | + entropy_bonus=False, |
| 665 | + ) |
| 666 | + |
| 667 | + data = _mock_data_sdpo(vocab_size=vocab_size, device=device) |
| 668 | + loss_vals = loss_fn(data) |
| 669 | + |
| 670 | + assert isinstance(loss_vals, SDPOLossOutput) |
| 671 | + assert torch.isfinite(loss_vals.loss_objective) |
| 672 | + # Entropy should be None when entropy_bonus is False |
| 673 | + assert loss_vals.entropy is None |
| 674 | + assert loss_vals.loss_entropy is None |
| 675 | + |
| 676 | + |
| 677 | +class TestAddFeedbackContext: |
| 678 | + """Test suite for AddFeedbackContext transform.""" |
| 679 | + |
| 680 | + def test_add_feedback_direct(self): |
| 681 | + """Test adding feedback context in direct mode.""" |
| 682 | + from torchrl.envs.llm.transforms.feedback import AddFeedbackContext |
| 683 | + |
| 684 | + transform = AddFeedbackContext() |
| 685 | + |
| 686 | + td = TensorDict( |
| 687 | + { |
| 688 | + "query": "What is 2+2?", |
| 689 | + ("text", "response"): "The answer is 5.", |
| 690 | + "env_feedback": "Wrong answer. The correct answer is 4.", |
| 691 | + ("next", "reward"): torch.tensor([0.0]), |
| 692 | + ("next", "done"): torch.tensor([True]), |
| 693 | + }, |
| 694 | + batch_size=(), |
| 695 | + ) |
| 696 | + |
| 697 | + td_out = transform(td) |
| 698 | + |
| 699 | + # Check that teacher_context was added |
| 700 | + assert "teacher_context" in td_out.keys() |
| 701 | + teacher_context = td_out.get("teacher_context") |
| 702 | + assert teacher_context is not None |
| 703 | + |
| 704 | + def test_add_feedback_with_success(self): |
| 705 | + """Test adding feedback context with successful rollout.""" |
| 706 | + from torchrl.envs.llm.transforms.feedback import AddFeedbackContext |
| 707 | + |
| 708 | + transform = AddFeedbackContext() |
| 709 | + |
| 710 | + td = TensorDict( |
| 711 | + { |
| 712 | + "query": "What is 2+2?", |
| 713 | + ("text", "response"): "The answer is 5.", |
| 714 | + "env_feedback": "Wrong answer.", |
| 715 | + "_successful_rollout": "The answer is 4.", |
| 716 | + ("next", "reward"): torch.tensor([0.0]), |
| 717 | + ("next", "done"): torch.tensor([True]), |
| 718 | + }, |
| 719 | + batch_size=(), |
| 720 | + ) |
| 721 | + |
| 722 | + td_out = transform(td) |
| 723 | + |
| 724 | + assert "teacher_context" in td_out.keys() |
| 725 | + teacher_context = td_out.get("teacher_context") |
| 726 | + # Should contain both the successful solution and feedback |
| 727 | + assert teacher_context is not None |
| 728 | + |
| 729 | + |
430 | 730 | class TestSFT: |
431 | 731 | @pytest.fixture(scope="class") |
432 | 732 | def data(self): |
|
0 commit comments