Hi! really appreciate your nice work and sharing the code. π
Here is a question that just confuses me a little,
dist = one_hot_labels[:, 1:].float() * log_probs[:, 1:]
example_loss_except_other, _ = dist.min(dim=-1)
per_example_loss = - example_loss_except_other.mean()
why does the min of dist equal the loss without other class here?
thanks for your great work again, looking forward to your reply π