-
Notifications
You must be signed in to change notification settings - Fork 148
Open
Description
Description
We had a subtle bug in the Numba implementation of Scan found out in #811
Scan is supposed to zero-out unwritten buffers when n_steps doesn't cover it all. This adds yet another nugget of complexity to Scan, as if it weren't already a mess.
Things to investigate:
- Do we need to zero out? Can we slice away like while Scans do?
- If not, can this be handled outside the Scan by the code that creates such a Scan?
Lines 2179 to 2203 in bc0d670
| # This would normally happen only when doing truncated | |
| # backpropagation through time. In such a scenario Scan is | |
| # expected to return 0 for all entries for which the gradient is | |
| # not actually computed | |
| elif store_steps[idx] > i - self.mintaps[idx]: | |
| output_storage[idx][0][i - self.mintaps[idx] :] = 0 | |
| # This is a fix for a bug introduced by while. If you say | |
| # you want to loop up to a condition, you expect the output | |
| # to have that length ( and not the maximal length possible) | |
| # | |
| # Without this the behaviour of a scan op is not consistent | |
| # if optimization gets applied compared to when optimization | |
| # do not get applied | |
| if i < n_steps: | |
| # The reason I don't use out[idx][0][:i] is because for | |
| # certain outputs (those with multiple taps), | |
| # outs[idx][0] has more than n_steps entries, with the | |
| # initial state at the beginning. When indexing in it I | |
| # usually have to do something like | |
| # outs[idx][0][i+offset]. To do something similar here, | |
| # I would have first to compute the maximal tap for | |
| # every output and then do outs[0][:i+maximal_tap], | |
| # which implies I think more computations then this | |
| # little trick that I used | |
| output_storage[idx][0] = output_storage[idx][0][: -(n_steps - i)] |