Skip to content

Commit eb12598

Browse files
authored
Fix index.Tensor, index_put batching rules (#862)
Fixes #859 Start reading at `NOTE: [advanced indexing (index.Tensor) batch rule]` in the code for details. This PR rewrites the index.Tensor and index_put batching rules. The TL;DR is: - advanced indexing has different behavior depending on if the "advanced indices are adjacent": https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing - we have to take this into account in our batching rules, because index.Tensor and index_put handle these internally. Test Plan - I added new test cases for getitem and aten.ops.index_put via OpInfo testing. Future - primtorch should have a sane decomposition that we can use - We haven't fixed the index_put_ batching rule yet. TODO later... - Upstream our test cases (see next section) into pytorch/pytorch
1 parent 26d4cfc commit eb12598

File tree

4 files changed

+383
-71
lines changed

4 files changed

+383
-71
lines changed

0 commit comments

Comments
 (0)