Skip to content

Commit 1627bcb

Browse files
committed
fix
1 parent 78063df commit 1627bcb

File tree

2 files changed

+221
-2
lines changed

2 files changed

+221
-2
lines changed

extension/pybindings/pybindings.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ class Module final {
358358

359359
MallocMemoryAllocator runtime_allocator_;
360360

361-
MemoryAllocator temp_allocator_{MemoryAllocator(0, nullptr)};
361+
MallocMemoryAllocator temp_allocator_{};
362362

363363
std::vector<std::vector<uint8_t>> non_const_buffers_;
364364

@@ -1061,7 +1061,7 @@ class ProgramMemory {
10611061

10621062
MallocMemoryAllocator runtime_allocator_;
10631063

1064-
MemoryAllocator temp_allocator_{MemoryAllocator(0, nullptr)};
1064+
MallocMemoryAllocator temp_allocator_{};
10651065

10661066
std::vector<std::vector<uint8_t>> non_const_buffers_;
10671067

Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Test to verify the fix for temp memory allocation issue in torch.topk operations.
4+
5+
This test specifically checks that the MallocMemoryAllocator fix in pybindings.cpp
6+
resolves the "Memory allocation failed" error when executing operations that
7+
require temporary memory allocation.
8+
"""
9+
10+
import torch
11+
import tempfile
12+
import os
13+
from pathlib import Path
14+
from torch.export import export
15+
from executorch.exir import EdgeProgramManager, to_edge_transform_and_lower, EdgeCompileConfig
16+
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
17+
from executorch.runtime import Verification, Runtime, Program, Method
18+
19+
20+
class TopKModel(torch.nn.Module):
21+
"""Model that uses torch.topk operation which requires temp memory allocation."""
22+
23+
def __init__(self, k=3) -> None:
24+
super().__init__()
25+
self.k = k
26+
27+
def forward(self, x) -> torch.Tensor:
28+
# This operation requires temporary memory allocation
29+
top_values, top_indices = torch.topk(x, self.k)
30+
return top_values, top_indices
31+
32+
33+
class TopKModelWithOut(torch.nn.Module):
34+
"""Model that uses torch.topk with out parameter which also requires temp memory."""
35+
36+
def __init__(self, k=3) -> None:
37+
super().__init__()
38+
self.k = k
39+
40+
def forward(self, x) -> torch.Tensor:
41+
top_values = torch.ones(x.shape[0], self.k, dtype=torch.float32)
42+
top_indices = torch.ones(x.shape[0], self.k, dtype=torch.long)
43+
torch.topk(x.contiguous(), self.k, out=(top_values, top_indices))
44+
return top_values, top_indices
45+
46+
47+
def test_topk_without_out_parameter():
48+
"""Test torch.topk without out parameter."""
49+
print("Testing torch.topk without out parameter...")
50+
51+
model = TopKModel(k=5)
52+
example_input = (torch.randn(3, 100),)
53+
54+
# Export and compile the model
55+
with torch.no_grad():
56+
aten_dialect = export(model, example_input)
57+
58+
backend_dialect = to_edge_transform_and_lower(
59+
aten_dialect,
60+
compile_config=EdgeCompileConfig(_check_ir_validity=False),
61+
partitioner=[XnnpackPartitioner()],
62+
)
63+
64+
executorch_dialect = backend_dialect.to_executorch()
65+
66+
# Save to temporary file
67+
with tempfile.NamedTemporaryFile(suffix='.pte', delete=False) as f:
68+
temp_path = f.name
69+
70+
try:
71+
executorch_dialect.save(temp_path)
72+
73+
# Load and execute with ExecuTorch runtime
74+
et_runtime = Runtime.get()
75+
program = et_runtime.load_program(
76+
Path(temp_path),
77+
verification=Verification.Minimal,
78+
)
79+
80+
forward = program.load_method("forward")
81+
outputs = forward.execute(example_input)
82+
83+
print(f"✓ Successfully executed topk model: {example_input[0].shape} -> {outputs[0].shape}")
84+
return True
85+
86+
finally:
87+
# Clean up temporary file
88+
if os.path.exists(temp_path):
89+
os.unlink(temp_path)
90+
91+
92+
def test_topk_with_out_parameter():
93+
"""Test torch.topk with out parameter (original failing case)."""
94+
print("Testing torch.topk with out parameter...")
95+
96+
model = TopKModelWithOut(k=3)
97+
example_input = (torch.randn(2, 256),)
98+
99+
# Export and compile the model
100+
with torch.no_grad():
101+
aten_dialect = export(model, example_input)
102+
103+
backend_dialect = to_edge_transform_and_lower(
104+
aten_dialect,
105+
compile_config=EdgeCompileConfig(_check_ir_validity=False),
106+
partitioner=[XnnpackPartitioner()],
107+
)
108+
109+
executorch_dialect = backend_dialect.to_executorch()
110+
111+
# Save to temporary file
112+
with tempfile.NamedTemporaryFile(suffix='.pte', delete=False) as f:
113+
temp_path = f.name
114+
115+
try:
116+
executorch_dialect.save(temp_path)
117+
118+
# Load and execute with ExecuTorch runtime
119+
et_runtime = Runtime.get()
120+
program = et_runtime.load_program(
121+
Path(temp_path),
122+
verification=Verification.Minimal,
123+
)
124+
125+
forward = program.load_method("forward")
126+
outputs = forward.execute(example_input)
127+
128+
print(f"✓ Successfully executed topk model with out parameter: {example_input[0].shape} -> {outputs[0].shape}")
129+
return True
130+
131+
finally:
132+
# Clean up temporary file
133+
if os.path.exists(temp_path):
134+
os.unlink(temp_path)
135+
136+
137+
def test_larger_topk_operation():
138+
"""Test larger topk operation that would require more temporary memory."""
139+
print("Testing larger topk operation...")
140+
141+
model = TopKModel(k=50)
142+
example_input = (torch.randn(5, 1000),)
143+
144+
# Export and compile the model
145+
with torch.no_grad():
146+
aten_dialect = export(model, example_input)
147+
148+
backend_dialect = to_edge_transform_and_lower(
149+
aten_dialect,
150+
compile_config=EdgeCompileConfig(_check_ir_validity=False),
151+
partitioner=[XnnpackPartitioner()],
152+
)
153+
154+
executorch_dialect = backend_dialect.to_executorch()
155+
156+
# Save to temporary file
157+
with tempfile.NamedTemporaryFile(suffix='.pte', delete=False) as f:
158+
temp_path = f.name
159+
160+
try:
161+
executorch_dialect.save(temp_path)
162+
163+
# Load and execute with ExecuTorch runtime
164+
et_runtime = Runtime.get()
165+
program = et_runtime.load_program(
166+
Path(temp_path),
167+
verification=Verification.Minimal,
168+
)
169+
170+
forward = program.load_method("forward")
171+
outputs = forward.execute(example_input)
172+
173+
print(f"✓ Successfully executed large topk model: {example_input[0].shape} -> {outputs[0].shape}")
174+
return True
175+
176+
finally:
177+
# Clean up temporary file
178+
if os.path.exists(temp_path):
179+
os.unlink(temp_path)
180+
181+
182+
def main():
183+
"""Run all tests to verify the temp memory allocation fix."""
184+
print("Testing temp memory allocation fix for torch.topk operations")
185+
print("=" * 60)
186+
187+
tests = [
188+
test_topk_without_out_parameter,
189+
test_topk_with_out_parameter,
190+
test_larger_topk_operation,
191+
]
192+
193+
passed = 0
194+
failed = 0
195+
196+
for test in tests:
197+
try:
198+
if test():
199+
passed += 1
200+
else:
201+
failed += 1
202+
except Exception as e:
203+
print(f"✗ Test {test.__name__} failed with exception: {e}")
204+
failed += 1
205+
206+
print("\n" + "=" * 60)
207+
print(f"Test Results: {passed} passed, {failed} failed")
208+
209+
if failed == 0:
210+
print("✓ All tests passed! The temp memory allocation fix is working correctly.")
211+
return True
212+
else:
213+
print("✗ Some tests failed. The fix may not be working correctly.")
214+
return False
215+
216+
217+
if __name__ == "__main__":
218+
success = main()
219+
exit(0 if success else 1)

0 commit comments

Comments
 (0)