You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Fixes#859
Start reading at `NOTE: [advanced indexing (index.Tensor) batch rule]`
in the code for details. This PR rewrites the index.Tensor and index_put
batching rules.
The TL;DR is:
- advanced indexing has different behavior depending on if the "advanced
indices are adjacent":
https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing
- we have to take this into account in our batching rules, because
index.Tensor and index_put handle these internally.
Test Plan
- I added new test cases for getitem and aten.ops.index_put via OpInfo
testing.
Future
- primtorch should have a sane decomposition that we can use
- We haven't fixed the index_put_ batching rule yet. TODO later...
- Upstream our test cases (see next section) into pytorch/pytorch
0 commit comments