You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I want to mimick the behaviour of an online learning algorithm where my data is made of sequences $x_{1:T} = (x_1, ..., x_T)$. More precisely I want to do stochastic gradient descent on some parameter $\theta$ in the form of $\theta_{t+1} = \theta_t - \nabla_\theta L_{\theta_t}(x_{1:t+1})$ where $L_\theta$ is some objective function. In practice I don't have the recursion to express $L_{\theta_t}(x_{1:t+1})$ as a function of $L_{\theta_t}(x_{1:t})$ and $x_{t+1}$, so for now I recompute the entire loss function for the new subsequence $x_{1:t+1}$ and take the gradient by autodifferentation. I'm still interested in the kind of values of $\theta$ I get when there's a gradient step for all subsequences. In practice I have many sequences $(x_{1:T}^i)_{i \leq N}$.
Here's the relevant part of the code, where I work with minibatches $(x_{1:T}^j)_{j \in B}$
under JIT the compile times are unbearably long when $T$ gets larger (even for $T=20$ it's already hard to work with)
without JIT the code runs out of the box but very slowly
I'm used to using lax.scan for fixed-sized inputs. I tried using batch_up_to_timestep = jax.lax.dynamic_slice_in_dim(batch, 0, timestep, axis=1) when batch is in the carry and timestep in the x of the operand for scan, but then I get that you can't index with a tracer object and I understand why. From what I gathered on the forum there's no workaround to JIT that kind of code because there's no way to reduce the operations inside the for loop to a single HLO call. So it's not a syntax issue, it's a lower level limitation of any compiler. Can anyone confirm this or am I still missing some simple trick ? If not, I'm surprised that Jax is that slow whenever you can't write JIT-able code. Is there anything I'm missing that would make this run faster ? Maybe I could manually JIT all operations inside self.loss that involve fixed-sized inputs, but I feel it's not going to be any use for the jax.grad operation which is my main interest.
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
-
Hi all,
I want to mimick the behaviour of an online learning algorithm where my data is made of sequences$x_{1:T} = (x_1, ..., x_T)$ . More precisely I want to do stochastic gradient descent on some parameter $\theta$ in the form of $\theta_{t+1} = \theta_t - \nabla_\theta L_{\theta_t}(x_{1:t+1})$ where $L_\theta$ is some objective function. In practice I don't have the recursion to express $L_{\theta_t}(x_{1:t+1})$ as a function of $L_{\theta_t}(x_{1:t})$ and $x_{t+1}$ , so for now I recompute the entire loss function for the new subsequence $x_{1:t+1}$ and take the gradient by autodifferentation. I'm still interested in the kind of values of $\theta$ I get when there's a gradient step for all subsequences. In practice I have many sequences $(x_{1:T}^i)_{i \leq N}$ .
Here's the relevant part of the code, where I work with minibatches$(x_{1:T}^j)_{j \in B}$
The code works but
I'm used to using
lax.scan
for fixed-sized inputs. I tried usingbatch_up_to_timestep = jax.lax.dynamic_slice_in_dim(batch, 0, timestep, axis=1)
whenbatch
is in thecarry
andtimestep
in thex
of the operand forscan
, but then I get that you can't index with a tracer object and I understand why. From what I gathered on the forum there's no workaround to JIT that kind of code because there's no way to reduce the operations inside thefor
loop to a single HLO call. So it's not a syntax issue, it's a lower level limitation of any compiler. Can anyone confirm this or am I still missing some simple trick ? If not, I'm surprised that Jax is that slow whenever you can't write JIT-able code. Is there anything I'm missing that would make this run faster ? Maybe I could manually JIT all operations insideself.loss
that involve fixed-sized inputs, but I feel it's not going to be any use for thejax.grad
operation which is my main interest.Thanks a lot in advance!
Beta Was this translation helpful? Give feedback.
All reactions