Skip to content

Commit 40db2f6

Browse files
committed
adding the discontinuous mask indices case
1 parent f97d164 commit 40db2f6

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

tests/py/dynamo/conversion/test_index_aten.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,17 @@ class TestIndexConstantConverter(DispatchTestCase):
8181
[torch.tensor([True, False])],
8282
torch.randn(2, 2),
8383
),
84+
(
85+
# covers multi axis and discontinuous indices
86+
"mask_index_multi_axis",
87+
[
88+
None,
89+
torch.tensor([[True, False, False, True]]), # axis 1
90+
None,
91+
torch.tensor([True, False]), # axis 3
92+
],
93+
torch.randn(2, 4, 4, 2),
94+
),
8495
]
8596
)
8697
def test_index_constant(self, _, index, input):

0 commit comments

Comments
 (0)