@@ -1581,6 +1581,24 @@ def test_qnn_backend_skip_node_op(self):
15811581 skip_node_op_set = {"aten.add.Tensor" },
15821582 )
15831583
1584+ def test_qnn_backend_spill_fill_buffer_size (self ):
1585+ module = LargeTensorLinear () # noqa: F405
1586+ sample_input = (torch .randn (1 , 256 , 512 ),)
1587+ edge_prog = capture_program (module , sample_input )
1588+
1589+ backend_options = generate_htp_compiler_spec (
1590+ use_fp16 = True ,
1591+ use_multi_contexts = True ,
1592+ )
1593+ compiler_specs = generate_qnn_executorch_compiler_spec (
1594+ soc_model = self .chipset_table [TestQNN .model ],
1595+ backend_options = backend_options ,
1596+ )
1597+ partitioner = QnnPartitioner (compiler_specs )
1598+ edge_prog .exported_program = to_backend (edge_prog .exported_program , partitioner )
1599+ max_sf_size = update_spill_fill_size (edge_prog .exported_program )
1600+ self .assertNotEqual (0 , max_sf_size )
1601+
15841602 def test_qnn_backend_multi_contexts (self ):
15851603 module = SimpleModel () # noqa: F405
15861604 sample_input = (torch .ones (1 , 32 , 28 , 28 ), torch .ones (1 , 32 , 28 , 28 ))
@@ -2007,6 +2025,25 @@ def calibrator(gm):
20072025 ).to_executorch ()
20082026 self .verify_output (module , sample_input , exec_prog )
20092027
2028+ def test_qnn_backend_spill_fill_buffer_size (self ):
2029+ module = LargeTensorLinear () # noqa: F405
2030+ sample_input = (torch .randn (1 , 256 , 512 ),)
2031+ module = self .get_qdq_module (module , sample_input )
2032+ edge_prog = capture_program (module , sample_input )
2033+
2034+ backend_options = generate_htp_compiler_spec (
2035+ use_fp16 = False ,
2036+ use_multi_contexts = True ,
2037+ )
2038+ compiler_specs = generate_qnn_executorch_compiler_spec (
2039+ soc_model = self .chipset_table [TestQNN .model ],
2040+ backend_options = backend_options ,
2041+ )
2042+ partitioner = QnnPartitioner (compiler_specs )
2043+ edge_prog .exported_program = to_backend (edge_prog .exported_program , partitioner )
2044+ max_sf_size = update_spill_fill_size (edge_prog .exported_program )
2045+ self .assertNotEqual (0 , max_sf_size )
2046+
20102047 def test_qnn_backend_graph_level_mixed_precision (self ):
20112048 module = SimpleModel () # noqa: F405
20122049 sample_input = (torch .ones (1 , 32 , 28 , 28 ), torch .ones (1 , 32 , 28 , 28 ))
0 commit comments