Any suggestions on CTC loss implementations? #7380
Unanswered
PeterZhizhin
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
I was looking for a CTC loss implementation in Jax and found the following issue: #5411
I also found the following discussion under Flax discussions: google/flax#843
So, there wasn't a public Jax implementation, so I decided to implement it on my own. I found a PyTorch implementation that I found to be reasonable enough to start with: https://github.com/vadimkantorov/ctc
Here is how I adapted this implementation for Jax: https://gist.github.com/PeterZhizhin/572782917d8a9c5d73ff4187c87d1c5e
The implementation produces results identical to the built-in PyTorch
ctc_loss
function in both forward and backward passes. However, the implementation is very slow compared to a built-in function.It takes 80-90ms for a forward and a backward pass even if I wrap the main loop with
jax.lax.scan
and apply@jax.jit
to my function:A built-in PyTorch function takes only 1.5ms for a forward and backward passes:
If we take a look at how PyTorch implements the loss, we will see that it uses a cuDNN function: https://github.com/pytorch/pytorch/blob/9a08e87d8bbb9caa31381b838b10e2b6dcc5e4ec/aten/src/ATen/native/cudnn/LossCTC.cpp#L123-L136
I don't see how I could do the same in Jax. Do you have any suggestions? Is it possible to implement a C++ kernel of some sort for Jax to leverage a built-in cuDNN function? I don't see how it could be done. Could you please help me?
If you have any suggestions on how CTC could be implemented such that it's faster than I did it, I would really appreciate it.
Beta Was this translation helpful? Give feedback.
All reactions