Two-stage grads - implementing DETR #7446
Unanswered
sholtodouglas
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I'm implementing DETR, and a key part of the loss function is dynamic linear sum assignment between the outputs and the labels (effectively sorting the labels to best match the outputs generated, because the labels have no required order). This sorting is non-differentiable, and would be required to be retraced each time - so when using TPUs the best solution I arrived at (which could well be blindly wrong) was to do the forward pass on the TPU cores, put the outputs back on cpu and match by distributing using Ray and 8 cpus, and perform the loss calculation back on the TPU.
Now - this doesn't work because it gives the gradients w.r.t to the outputs, not the original parameters! However, we can't wrap the entire function in a grad call and pmap it - because the matching process won't work on device (nor does it seem to work to wrap it a with no jit call).
At the moment, the best solution would appear to be re-running the forward pass and loss function together using the now sorted labels (and the same dropout seed) - but this whole process is beginning to feel very inefficient, so I thought I'd sense check myself here. Is there a better way to do this? (The original pytorch implementation just takes .backward() from the loss, which traces back to the original forward pass).
Thanks :)
Beta Was this translation helpful? Give feedback.
All reactions