Conversation
| assert temperature > 0, "Temperature must be positive" | ||
| assert 0 <= alpha <= 1, "Alpha must be in [0, 1]" | ||
| assert 0 <= alpha_soft <= 1, "alpha_soft must be in [0, 1]" | ||
| alpha_hard = 1.0 - alpha_soft |
There was a problem hiding this comment.
Shouldn't alpha_hard + alpha_soft + alpha_sim == 1.0? 🤔
There was a problem hiding this comment.
That's a good question!
So here's my thinking: both soft and hard KL-div loss are on the same scale, one might argue it's sensible for their weights to sum to 1.0 (although, is it strictly necessary?).
Cosine distance term on the other hand, is not comparable to the KL-div loss, it would be a stretch to enforce 3-term weight sum to 1.0. Can you articulate why we would want to enforce sum to 1.0?
There was a problem hiding this comment.
Can you articulate why we would want to enforce sum to 1.0?
Mostly a convention to have weighted sums have float factors that are [0, 1] and sum to one. That way, the weights can be interpreted as relative contributions to the loss, rather than (arbitrary?) absolute weights.
Not a big deal though.
There was a problem hiding this comment.
That way, the weights can be interpreted as relative contributions to the loss
That would be nice, but can we really create that expectation if the scales are different between the KL-div and cosine loss? It may be misleading to actually created that expectation.
I will merge as is for now, future generations can revisit this and hopefully find something better than what I'm doing 🤞
|
|
||
| if alpha_sim > 0: | ||
| assert student_emb is not None and teacher_emb is not None | ||
| hidden_state_sim = F.cosine_embedding_loss( |
|
|
||
| if alpha_sim > 0 and self.d_model < self.teacher_d_model: | ||
| self.d_model_proj: torch.nn.Module = torch.nn.Linear( | ||
| # NOTE: we double the d_model bi-directionally of the teacher model |
There was a problem hiding this comment.
Nit: Perhaps a "why" for the doubling would be nice
There was a problem hiding this comment.
Actually that's probably not need hmm
8da00b2 to
9207270
Compare
In an effort to improve the loss/NT correlation, we explore an alternative to 2-term soft logits + hard labels distillation (Hinton style), and incorporate a 3rd term - cosine distance between the embeddings (roughly DistilBERT style)
We’ve opted to not materialize the embeddings, and instead embed the teacher model inference into the distillation itself
DistilBERT assumes equal embedding dimensions (hereinafter
d_model) between the teacher and student to compute the distance. We then proceed to experiment with a reducedd_modelfor the student model.Code changes:
d_modeland number of blocks a wandb hyperparameters