Skip to content

Commit 5d42a39

Browse files
committed
reset HEAD back, wip
1 parent 43bd889 commit 5d42a39

File tree

3 files changed

+32
-0
lines changed

3 files changed

+32
-0
lines changed

backends/qualcomm/quantizer/annotators.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,10 @@ def annotate_ceil(node: Node, quantization_config: QuantizationConfig) -> None:
432432
def annotate_clamp(node: Node, quantization_config: QuantizationConfig) -> None:
433433
annotate_single_in_single_out(node, quantization_config)
434434

435+
@register_annotator([torch.ops.aten.index_select.default])
436+
def annotate_index_select(node: Node, quantization_config: QuantizationConfig) -> None:
437+
import pdb; pdb.set_trace()
438+
annotate_single_in_single_out(node, quantization_config)
435439

436440
@register_annotator([torch.ops.aten.floor.default])
437441
def annotate_floor(node: Node, quantization_config: QuantizationConfig) -> None:

backends/qualcomm/tests/models.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -813,7 +813,23 @@ def __init__(self):
813813
def forward(self, x):
814814
return torch.special.expm1(x)
815815

816+
class Flip(torch.nn.Module):
817+
def __init__(self):
818+
super().__init__()
819+
self.dims = [0,2]
820+
821+
def forward(self, x):
822+
return torch.flip(x, self.dims)
816823

824+
class FlipDecomp(torch.nn.Module):
825+
def __init__(self):
826+
super().__init__()
827+
self.dims = [0,2]
828+
def forward(self, x):
829+
for dim in self.dims:
830+
idx = torch.arange(x.size(dim) - 1, -1, -1, device=x.device)
831+
x = torch.index_select(x, dim, idx)
832+
return x
817833
class Floor(torch.nn.Module):
818834
def __init__(self):
819835
super().__init__()

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2030,6 +2030,18 @@ def test_qnn_backend_expm1(self):
20302030
module = self.get_qdq_module(module, sample_input)
20312031
self.lower_module_and_test_output(module, sample_input)
20322032

2033+
def test_qnn_backend_flip(self):
2034+
sample_input = (torch.randn(3, 4, 5,6),)
2035+
# golden_module = Flip()
2036+
decomp_module = FlipDecomp()
2037+
decomp_module = self.get_qdq_module(decomp_module, sample_input)
2038+
self.lower_module_and_test_output(decomp_module, sample_input)
2039+
# golden_out = golden_module(sample_input)
2040+
# decomp_out = decomp_module(sample_input)
2041+
# torch.testing.assert_close(golden_out, decomp_out)
2042+
2043+
2044+
20332045
def test_qnn_backend_floor(self):
20342046
sample_input = (torch.randn(3, 4),)
20352047
module = Floor() # noqa: F405

0 commit comments

Comments
 (0)