1+ #!/usr/bin/env python3
2+ """
3+ Test to verify the fix for temp memory allocation issue in torch.topk operations.
4+
5+ This test specifically checks that the MallocMemoryAllocator fix in pybindings.cpp
6+ resolves the "Memory allocation failed" error when executing operations that
7+ require temporary memory allocation.
8+ """
9+
10+ import torch
11+ import tempfile
12+ import os
13+ from pathlib import Path
14+ from torch .export import export
15+ from executorch .exir import EdgeProgramManager , to_edge_transform_and_lower , EdgeCompileConfig
16+ from executorch .backends .xnnpack .partition .xnnpack_partitioner import XnnpackPartitioner
17+ from executorch .runtime import Verification , Runtime , Program , Method
18+
19+
20+ class TopKModel (torch .nn .Module ):
21+ """Model that uses torch.topk operation which requires temp memory allocation."""
22+
23+ def __init__ (self , k = 3 ) -> None :
24+ super ().__init__ ()
25+ self .k = k
26+
27+ def forward (self , x ) -> torch .Tensor :
28+ # This operation requires temporary memory allocation
29+ top_values , top_indices = torch .topk (x , self .k )
30+ return top_values , top_indices
31+
32+
33+ class TopKModelWithOut (torch .nn .Module ):
34+ """Model that uses torch.topk with out parameter which also requires temp memory."""
35+
36+ def __init__ (self , k = 3 ) -> None :
37+ super ().__init__ ()
38+ self .k = k
39+
40+ def forward (self , x ) -> torch .Tensor :
41+ top_values = torch .ones (x .shape [0 ], self .k , dtype = torch .float32 )
42+ top_indices = torch .ones (x .shape [0 ], self .k , dtype = torch .long )
43+ torch .topk (x .contiguous (), self .k , out = (top_values , top_indices ))
44+ return top_values , top_indices
45+
46+
47+ def test_topk_without_out_parameter ():
48+ """Test torch.topk without out parameter."""
49+ print ("Testing torch.topk without out parameter..." )
50+
51+ model = TopKModel (k = 5 )
52+ example_input = (torch .randn (3 , 100 ),)
53+
54+ # Export and compile the model
55+ with torch .no_grad ():
56+ aten_dialect = export (model , example_input )
57+
58+ backend_dialect = to_edge_transform_and_lower (
59+ aten_dialect ,
60+ compile_config = EdgeCompileConfig (_check_ir_validity = False ),
61+ partitioner = [XnnpackPartitioner ()],
62+ )
63+
64+ executorch_dialect = backend_dialect .to_executorch ()
65+
66+ # Save to temporary file
67+ with tempfile .NamedTemporaryFile (suffix = '.pte' , delete = False ) as f :
68+ temp_path = f .name
69+
70+ try :
71+ executorch_dialect .save (temp_path )
72+
73+ # Load and execute with ExecuTorch runtime
74+ et_runtime = Runtime .get ()
75+ program = et_runtime .load_program (
76+ Path (temp_path ),
77+ verification = Verification .Minimal ,
78+ )
79+
80+ forward = program .load_method ("forward" )
81+ outputs = forward .execute (example_input )
82+
83+ print (f"✓ Successfully executed topk model: { example_input [0 ].shape } -> { outputs [0 ].shape } " )
84+ return True
85+
86+ finally :
87+ # Clean up temporary file
88+ if os .path .exists (temp_path ):
89+ os .unlink (temp_path )
90+
91+
92+ def test_topk_with_out_parameter ():
93+ """Test torch.topk with out parameter (original failing case)."""
94+ print ("Testing torch.topk with out parameter..." )
95+
96+ model = TopKModelWithOut (k = 3 )
97+ example_input = (torch .randn (2 , 256 ),)
98+
99+ # Export and compile the model
100+ with torch .no_grad ():
101+ aten_dialect = export (model , example_input )
102+
103+ backend_dialect = to_edge_transform_and_lower (
104+ aten_dialect ,
105+ compile_config = EdgeCompileConfig (_check_ir_validity = False ),
106+ partitioner = [XnnpackPartitioner ()],
107+ )
108+
109+ executorch_dialect = backend_dialect .to_executorch ()
110+
111+ # Save to temporary file
112+ with tempfile .NamedTemporaryFile (suffix = '.pte' , delete = False ) as f :
113+ temp_path = f .name
114+
115+ try :
116+ executorch_dialect .save (temp_path )
117+
118+ # Load and execute with ExecuTorch runtime
119+ et_runtime = Runtime .get ()
120+ program = et_runtime .load_program (
121+ Path (temp_path ),
122+ verification = Verification .Minimal ,
123+ )
124+
125+ forward = program .load_method ("forward" )
126+ outputs = forward .execute (example_input )
127+
128+ print (f"✓ Successfully executed topk model with out parameter: { example_input [0 ].shape } -> { outputs [0 ].shape } " )
129+ return True
130+
131+ finally :
132+ # Clean up temporary file
133+ if os .path .exists (temp_path ):
134+ os .unlink (temp_path )
135+
136+
137+ def test_larger_topk_operation ():
138+ """Test larger topk operation that would require more temporary memory."""
139+ print ("Testing larger topk operation..." )
140+
141+ model = TopKModel (k = 50 )
142+ example_input = (torch .randn (5 , 1000 ),)
143+
144+ # Export and compile the model
145+ with torch .no_grad ():
146+ aten_dialect = export (model , example_input )
147+
148+ backend_dialect = to_edge_transform_and_lower (
149+ aten_dialect ,
150+ compile_config = EdgeCompileConfig (_check_ir_validity = False ),
151+ partitioner = [XnnpackPartitioner ()],
152+ )
153+
154+ executorch_dialect = backend_dialect .to_executorch ()
155+
156+ # Save to temporary file
157+ with tempfile .NamedTemporaryFile (suffix = '.pte' , delete = False ) as f :
158+ temp_path = f .name
159+
160+ try :
161+ executorch_dialect .save (temp_path )
162+
163+ # Load and execute with ExecuTorch runtime
164+ et_runtime = Runtime .get ()
165+ program = et_runtime .load_program (
166+ Path (temp_path ),
167+ verification = Verification .Minimal ,
168+ )
169+
170+ forward = program .load_method ("forward" )
171+ outputs = forward .execute (example_input )
172+
173+ print (f"✓ Successfully executed large topk model: { example_input [0 ].shape } -> { outputs [0 ].shape } " )
174+ return True
175+
176+ finally :
177+ # Clean up temporary file
178+ if os .path .exists (temp_path ):
179+ os .unlink (temp_path )
180+
181+
182+ def main ():
183+ """Run all tests to verify the temp memory allocation fix."""
184+ print ("Testing temp memory allocation fix for torch.topk operations" )
185+ print ("=" * 60 )
186+
187+ tests = [
188+ test_topk_without_out_parameter ,
189+ test_topk_with_out_parameter ,
190+ test_larger_topk_operation ,
191+ ]
192+
193+ passed = 0
194+ failed = 0
195+
196+ for test in tests :
197+ try :
198+ if test ():
199+ passed += 1
200+ else :
201+ failed += 1
202+ except Exception as e :
203+ print (f"✗ Test { test .__name__ } failed with exception: { e } " )
204+ failed += 1
205+
206+ print ("\n " + "=" * 60 )
207+ print (f"Test Results: { passed } passed, { failed } failed" )
208+
209+ if failed == 0 :
210+ print ("✓ All tests passed! The temp memory allocation fix is working correctly." )
211+ return True
212+ else :
213+ print ("✗ Some tests failed. The fix may not be working correctly." )
214+ return False
215+
216+
217+ if __name__ == "__main__" :
218+ success = main ()
219+ exit (0 if success else 1 )
0 commit comments