77require temporary memory allocation.
88"""
99
10- import torch
11- import tempfile
1210import os
11+ import tempfile
1312from pathlib import Path
14- from torch . export import export
15- from executorch . exir import EdgeProgramManager , to_edge_transform_and_lower , EdgeCompileConfig
13+
14+ import torch
1615from executorch .backends .xnnpack .partition .xnnpack_partitioner import XnnpackPartitioner
17- from executorch .runtime import Verification , Runtime , Program , Method
16+ from executorch .exir import (
17+ EdgeCompileConfig ,
18+ EdgeProgramManager ,
19+ to_edge_transform_and_lower ,
20+ )
21+ from executorch .runtime import Method , Program , Runtime , Verification
22+ from torch .export import export
1823
1924
2025class TopKModel (torch .nn .Module ):
2126 """Model that uses torch.topk operation which requires temp memory allocation."""
22-
27+
2328 def __init__ (self , k = 3 ) -> None :
2429 super ().__init__ ()
2530 self .k = k
@@ -32,7 +37,7 @@ def forward(self, x) -> torch.Tensor:
3237
3338class TopKModelWithOut (torch .nn .Module ):
3439 """Model that uses torch.topk with out parameter which also requires temp memory."""
35-
40+
3641 def __init__ (self , k = 3 ) -> None :
3742 super ().__init__ ()
3843 self .k = k
@@ -47,42 +52,44 @@ def forward(self, x) -> torch.Tensor:
4752def test_topk_without_out_parameter ():
4853 """Test torch.topk without out parameter."""
4954 print ("Testing torch.topk without out parameter..." )
50-
55+
5156 model = TopKModel (k = 5 )
5257 example_input = (torch .randn (3 , 100 ),)
53-
58+
5459 # Export and compile the model
5560 with torch .no_grad ():
5661 aten_dialect = export (model , example_input )
57-
62+
5863 backend_dialect = to_edge_transform_and_lower (
5964 aten_dialect ,
6065 compile_config = EdgeCompileConfig (_check_ir_validity = False ),
6166 partitioner = [XnnpackPartitioner ()],
6267 )
63-
68+
6469 executorch_dialect = backend_dialect .to_executorch ()
65-
70+
6671 # Save to temporary file
67- with tempfile .NamedTemporaryFile (suffix = ' .pte' , delete = False ) as f :
72+ with tempfile .NamedTemporaryFile (suffix = " .pte" , delete = False ) as f :
6873 temp_path = f .name
69-
74+
7075 try :
7176 executorch_dialect .save (temp_path )
72-
77+
7378 # Load and execute with ExecuTorch runtime
7479 et_runtime = Runtime .get ()
7580 program = et_runtime .load_program (
7681 Path (temp_path ),
7782 verification = Verification .Minimal ,
7883 )
79-
84+
8085 forward = program .load_method ("forward" )
8186 outputs = forward .execute (example_input )
82-
83- print (f"✓ Successfully executed topk model: { example_input [0 ].shape } -> { outputs [0 ].shape } " )
87+
88+ print (
89+ f"✓ Successfully executed topk model: { example_input [0 ].shape } -> { outputs [0 ].shape } "
90+ )
8491 return True
85-
92+
8693 finally :
8794 # Clean up temporary file
8895 if os .path .exists (temp_path ):
@@ -92,42 +99,44 @@ def test_topk_without_out_parameter():
9299def test_topk_with_out_parameter ():
93100 """Test torch.topk with out parameter (original failing case)."""
94101 print ("Testing torch.topk with out parameter..." )
95-
102+
96103 model = TopKModelWithOut (k = 3 )
97104 example_input = (torch .randn (2 , 256 ),)
98-
105+
99106 # Export and compile the model
100107 with torch .no_grad ():
101108 aten_dialect = export (model , example_input )
102-
109+
103110 backend_dialect = to_edge_transform_and_lower (
104111 aten_dialect ,
105112 compile_config = EdgeCompileConfig (_check_ir_validity = False ),
106113 partitioner = [XnnpackPartitioner ()],
107114 )
108-
115+
109116 executorch_dialect = backend_dialect .to_executorch ()
110-
117+
111118 # Save to temporary file
112- with tempfile .NamedTemporaryFile (suffix = ' .pte' , delete = False ) as f :
119+ with tempfile .NamedTemporaryFile (suffix = " .pte" , delete = False ) as f :
113120 temp_path = f .name
114-
121+
115122 try :
116123 executorch_dialect .save (temp_path )
117-
124+
118125 # Load and execute with ExecuTorch runtime
119126 et_runtime = Runtime .get ()
120127 program = et_runtime .load_program (
121128 Path (temp_path ),
122129 verification = Verification .Minimal ,
123130 )
124-
131+
125132 forward = program .load_method ("forward" )
126133 outputs = forward .execute (example_input )
127-
128- print (f"✓ Successfully executed topk model with out parameter: { example_input [0 ].shape } -> { outputs [0 ].shape } " )
134+
135+ print (
136+ f"✓ Successfully executed topk model with out parameter: { example_input [0 ].shape } -> { outputs [0 ].shape } "
137+ )
129138 return True
130-
139+
131140 finally :
132141 # Clean up temporary file
133142 if os .path .exists (temp_path ):
@@ -137,42 +146,44 @@ def test_topk_with_out_parameter():
137146def test_larger_topk_operation ():
138147 """Test larger topk operation that would require more temporary memory."""
139148 print ("Testing larger topk operation..." )
140-
149+
141150 model = TopKModel (k = 50 )
142151 example_input = (torch .randn (5 , 1000 ),)
143-
152+
144153 # Export and compile the model
145154 with torch .no_grad ():
146155 aten_dialect = export (model , example_input )
147-
156+
148157 backend_dialect = to_edge_transform_and_lower (
149158 aten_dialect ,
150159 compile_config = EdgeCompileConfig (_check_ir_validity = False ),
151160 partitioner = [XnnpackPartitioner ()],
152161 )
153-
162+
154163 executorch_dialect = backend_dialect .to_executorch ()
155-
164+
156165 # Save to temporary file
157- with tempfile .NamedTemporaryFile (suffix = ' .pte' , delete = False ) as f :
166+ with tempfile .NamedTemporaryFile (suffix = " .pte" , delete = False ) as f :
158167 temp_path = f .name
159-
168+
160169 try :
161170 executorch_dialect .save (temp_path )
162-
171+
163172 # Load and execute with ExecuTorch runtime
164173 et_runtime = Runtime .get ()
165174 program = et_runtime .load_program (
166175 Path (temp_path ),
167176 verification = Verification .Minimal ,
168177 )
169-
178+
170179 forward = program .load_method ("forward" )
171180 outputs = forward .execute (example_input )
172-
173- print (f"✓ Successfully executed large topk model: { example_input [0 ].shape } -> { outputs [0 ].shape } " )
181+
182+ print (
183+ f"✓ Successfully executed large topk model: { example_input [0 ].shape } -> { outputs [0 ].shape } "
184+ )
174185 return True
175-
186+
176187 finally :
177188 # Clean up temporary file
178189 if os .path .exists (temp_path ):
@@ -183,16 +194,16 @@ def main():
183194 """Run all tests to verify the temp memory allocation fix."""
184195 print ("Testing temp memory allocation fix for torch.topk operations" )
185196 print ("=" * 60 )
186-
197+
187198 tests = [
188199 test_topk_without_out_parameter ,
189200 test_topk_with_out_parameter ,
190201 test_larger_topk_operation ,
191202 ]
192-
203+
193204 passed = 0
194205 failed = 0
195-
206+
196207 for test in tests :
197208 try :
198209 if test ():
@@ -202,12 +213,14 @@ def main():
202213 except Exception as e :
203214 print (f"✗ Test { test .__name__ } failed with exception: { e } " )
204215 failed += 1
205-
216+
206217 print ("\n " + "=" * 60 )
207218 print (f"Test Results: { passed } passed, { failed } failed" )
208-
219+
209220 if failed == 0 :
210- print ("✓ All tests passed! The temp memory allocation fix is working correctly." )
221+ print (
222+ "✓ All tests passed! The temp memory allocation fix is working correctly."
223+ )
211224 return True
212225 else :
213226 print ("✗ Some tests failed. The fix may not be working correctly." )
@@ -216,4 +229,4 @@ def main():
216229
217230if __name__ == "__main__" :
218231 success = main ()
219- exit (0 if success else 1 )
232+ exit (0 if success else 1 )
0 commit comments