Skip to content

Commit a327b73

Browse files
committed
Add unit test to validate the size of the Spill-Fill buffer.
1 parent 7010a11 commit a327b73

File tree

3 files changed

+54
-3
lines changed

3 files changed

+54
-3
lines changed

backends/qualcomm/tests/models.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -596,6 +596,18 @@ def forward(self, input_pos, k_val):
596596
return k_out
597597

598598

599+
class LargeTensorLinear(torch.nn.Module):
600+
def __init__(self):
601+
super().__init__()
602+
hidden_dim = 4096
603+
self.linear1 = torch.nn.Linear(512, hidden_dim)
604+
self.linear2 = torch.nn.Linear(hidden_dim, 512)
605+
606+
def forward(self, x):
607+
x1 = self.linear1(x) + self.linear1(x)
608+
return self.linear2(x1)
609+
610+
599611
class LayerNorm(torch.nn.Module):
600612
def __init__(self):
601613
super().__init__()

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1581,6 +1581,24 @@ def test_qnn_backend_skip_node_op(self):
15811581
skip_node_op_set={"aten.add.Tensor"},
15821582
)
15831583

1584+
def test_qnn_backend_spill_fill_buffer_size(self):
1585+
module = LargeTensorLinear() # noqa: F405
1586+
sample_input = (torch.randn(1, 256, 512),)
1587+
edge_prog = capture_program(module, sample_input)
1588+
1589+
backend_options = generate_htp_compiler_spec(
1590+
use_fp16=True,
1591+
use_multi_contexts=True,
1592+
)
1593+
compiler_specs = generate_qnn_executorch_compiler_spec(
1594+
soc_model=self.chipset_table[TestQNN.model],
1595+
backend_options=backend_options,
1596+
)
1597+
partitioner = QnnPartitioner(compiler_specs)
1598+
edge_prog.exported_program = to_backend(edge_prog.exported_program, partitioner)
1599+
max_sf_size = update_spill_fill_size(edge_prog.exported_program)
1600+
self.assertNotEqual(0, max_sf_size)
1601+
15841602
def test_qnn_backend_multi_contexts(self):
15851603
module = SimpleModel() # noqa: F405
15861604
sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28))
@@ -2007,6 +2025,25 @@ def calibrator(gm):
20072025
).to_executorch()
20082026
self.verify_output(module, sample_input, exec_prog)
20092027

2028+
def test_qnn_backend_spill_fill_buffer_size(self):
2029+
module = LargeTensorLinear() # noqa: F405
2030+
sample_input = (torch.randn(1, 256, 512),)
2031+
module = self.get_qdq_module(module, sample_input)
2032+
edge_prog = capture_program(module, sample_input)
2033+
2034+
backend_options = generate_htp_compiler_spec(
2035+
use_fp16=False,
2036+
use_multi_contexts=True,
2037+
)
2038+
compiler_specs = generate_qnn_executorch_compiler_spec(
2039+
soc_model=self.chipset_table[TestQNN.model],
2040+
backend_options=backend_options,
2041+
)
2042+
partitioner = QnnPartitioner(compiler_specs)
2043+
edge_prog.exported_program = to_backend(edge_prog.exported_program, partitioner)
2044+
max_sf_size = update_spill_fill_size(edge_prog.exported_program)
2045+
self.assertNotEqual(0, max_sf_size)
2046+
20102047
def test_qnn_backend_graph_level_mixed_precision(self):
20112048
module = SimpleModel() # noqa: F405
20122049
sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28))

backends/qualcomm/utils/utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -268,15 +268,17 @@ def set_spec(module, options):
268268
options.backend_options.htp_options.max_sf_buf_size = max_sf_buf_size
269269
set_spec(module, options)
270270

271+
max_sf_size, modules_map = 0, {}
271272
if isinstance(exported_program, list):
272-
max_sf_size, modules_map = 0, {}
273273
for prog in exported_program:
274274
max_sf_buf_size, module_map = get_program_info(prog)
275275
max_sf_size = max(max_sf_size, max_sf_buf_size)
276276
modules_map.update(module_map)
277-
update_program(max_sf_size, modules_map)
278277
else:
279-
update_program(*get_program_info(exported_program))
278+
max_sf_size, module_map = get_program_info(exported_program)
279+
update_program(max_sf_size, module_map)
280+
281+
return max_sf_size
280282

281283

282284
def get_decomp_table() -> Dict[torch._ops.OperatorBase, Callable]:

0 commit comments

Comments
 (0)