Skip to content

Distill via embeddings#42

Merged
ravwojdyla merged 4 commits intomainfrom
rav-distill-via-emb
Aug 1, 2025
Merged

Distill via embeddings#42
ravwojdyla merged 4 commits intomainfrom
rav-distill-via-emb

Conversation

@ravwojdyla
Copy link
Contributor

@ravwojdyla ravwojdyla commented Jul 29, 2025

  • In an effort to improve the loss/NT correlation, we explore an alternative to 2-term soft logits + hard labels distillation (Hinton style), and incorporate a 3rd term - cosine distance between the embeddings (roughly DistilBERT style)

    • TLDR - the loss/NT correlation is better with this style of distillation for the Caduceus model. We believe with sufficient training time and student size, you can achieve useful levels of performance on the NT downstream tasks
      • Note in our experiments we limit ourselves to at most 24h training runs (Colab limitation)
        • https://wandb.ai/tabtablabs/caduceus_distill/runs/w6balygs
          • Experiment with half the number of blocks (height), and same width as the teacher - roughly 50% of the size. After ~17hrs, we get roughly 88% of the NT scores (ROCAUC and F1)
            • Correlation is better than without the 3rd term, mean (SD): -0.66 (0.11)
  • We’ve opted to not materialize the embeddings, and instead embed the teacher model inference into the distillation itself

    • There’s a couple of reasons why we have opted for this, all of them boil down to reducing (immediate) cost
  • DistilBERT assumes equal embedding dimensions (hereinafter d_model) between the teacher and student to compute the distance. We then proceed to experiment with a reduced d_model for the student model.

    • We’ve tried model surgery, this largely failed due to mamba SSM expectations on residual dimensions. With sufficient finesse this approach may work, but suffice to say it’s non-obvious.
    • In the end adding a projection layer (before logits) and monkey patching the head works, albeit rather hacky.

Code changes:

  • embed teacher inference in the distillation
  • support resuming experiment from a checkpoint
    • this adjusts the data loader, as well as resumes wandb logging
  • log d_model and number of blocks a wandb hyperparameters
  • fix bug with validation schedule over multiple epochs

@ravwojdyla ravwojdyla requested a review from yonromai July 29, 2025 13:05
assert temperature > 0, "Temperature must be positive"
assert 0 <= alpha <= 1, "Alpha must be in [0, 1]"
assert 0 <= alpha_soft <= 1, "alpha_soft must be in [0, 1]"
alpha_hard = 1.0 - alpha_soft
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't alpha_hard + alpha_soft + alpha_sim == 1.0? 🤔

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good question!

So here's my thinking: both soft and hard KL-div loss are on the same scale, one might argue it's sensible for their weights to sum to 1.0 (although, is it strictly necessary?).

Cosine distance term on the other hand, is not comparable to the KL-div loss, it would be a stretch to enforce 3-term weight sum to 1.0. Can you articulate why we would want to enforce sum to 1.0?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you articulate why we would want to enforce sum to 1.0?

Mostly a convention to have weighted sums have float factors that are [0, 1] and sum to one. That way, the weights can be interpreted as relative contributions to the loss, rather than (arbitrary?) absolute weights.

Not a big deal though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That way, the weights can be interpreted as relative contributions to the loss

That would be nice, but can we really create that expectation if the scales are different between the KL-div and cosine loss? It may be misleading to actually created that expectation.

I will merge as is for now, future generations can revisit this and hopefully find something better than what I'm doing 🤞


if alpha_sim > 0:
assert student_emb is not None and teacher_emb is not None
hidden_state_sim = F.cosine_embedding_loss(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

neat!


if alpha_sim > 0 and self.d_model < self.teacher_d_model:
self.d_model_proj: torch.nn.Module = torch.nn.Linear(
# NOTE: we double the d_model bi-directionally of the teacher model
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Perhaps a "why" for the doubling would be nice

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually that's probably not need hmm

Copy link
Contributor

@yonromai yonromai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔥🔥🔥

@ravwojdyla ravwojdyla force-pushed the rav-distill-via-emb branch from 8da00b2 to 9207270 Compare July 29, 2025 15:53
@ravwojdyla ravwojdyla merged commit 382c534 into main Aug 1, 2025
2 checks passed
@ravwojdyla ravwojdyla deleted the rav-distill-via-emb branch August 1, 2025 11:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants