Segment_sum + jit + grad #9762
-
I have a jagged array of inputs, which I have called descriptors. The dimension of this array is (i, n_i, k). I would like to keep the memory and computation time low through loops of update. Below is a failed attempt at trying to get around jagged arrays by flattening the first two dimensions to get an array (j, k), then batch predictions, and then sum over indices specified by n_i to get obtain a predictions array of length i. (Note: This does work without jit)
I have a working jit version where I pad my jagged array with zeros with resulting dimension (i, max(n_i), k). The snippet below uses this array. While epochs with this version are twice as fast as the previous version without jit, n_i can vary from 2 up to 1000 and the SGD/Adam nevers converges. I imagine that there should be a way to get around padding the array. Maybe padding with sparsity?
See dummy example below of first scenario with jit commented out.
|
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 4 replies
-
I'm finding it difficult to guess the correct shapes/sizes of input arrays to reproduce the results/errors you're seeing. Could you edit your question to add a minimal reproducible example? |
Beta Was this translation helpful? Give feedback.
I'm finding it difficult to guess the correct shapes/sizes of input arrays to reproduce the results/errors you're seeing. Could you edit your question to add a minimal reproducible example?