Skip to content

Commit 6d56713

Browse files
authored
Fix temp memory allocation issue in torch.topk operations (#12810)
Summary: Fixes issue #8700 changed the temp_allocator_ from a MemoryAllocator with null buffer to a MallocMemoryAllocator that can dynamically allocate memory as needed. Test plan: To verify the fix, a new end-to-end test suite has been added (`test/end2end/test_temp_allocator_fix.py`). This suite includes tests for `torch.topk` both with and without the `out` parameter, as well as a test with a larger input to ensure the allocator can handle more significant memory requirements. These tests now pass with the implemented fix. I will now write the PR description.
1 parent af420f5 commit 6d56713

File tree

2 files changed

+230
-2
lines changed

2 files changed

+230
-2
lines changed

extension/pybindings/pybindings.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ class Module final {
358358

359359
MallocMemoryAllocator runtime_allocator_;
360360

361-
MemoryAllocator temp_allocator_{MemoryAllocator(0, nullptr)};
361+
MallocMemoryAllocator temp_allocator_{};
362362

363363
std::vector<std::vector<uint8_t>> non_const_buffers_;
364364

@@ -1061,7 +1061,7 @@ class ProgramMemory {
10611061

10621062
MallocMemoryAllocator runtime_allocator_;
10631063

1064-
MemoryAllocator temp_allocator_{MemoryAllocator(0, nullptr)};
1064+
MallocMemoryAllocator temp_allocator_{};
10651065

10661066
std::vector<std::vector<uint8_t>> non_const_buffers_;
10671067

Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Test to verify the fix for temp memory allocation issue in torch.topk operations.
4+
5+
This test specifically checks that the MallocMemoryAllocator fix in pybindings.cpp
6+
resolves the "Memory allocation failed" error when executing operations that
7+
require temporary memory allocation.
8+
"""
9+
10+
import os
11+
import tempfile
12+
from pathlib import Path
13+
14+
import torch
15+
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
16+
from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower
17+
from executorch.runtime import Runtime, Verification
18+
from torch.export import export
19+
20+
21+
class TopKModel(torch.nn.Module):
22+
"""Model that uses torch.topk operation which requires temp memory allocation."""
23+
24+
def __init__(self, k=3) -> None:
25+
super().__init__()
26+
self.k = k
27+
28+
def forward(self, x) -> "tuple[torch.Tensor, torch.Tensor]":
29+
# This operation requires temporary memory allocation
30+
top_values, top_indices = torch.topk(x, self.k)
31+
return top_values, top_indices
32+
33+
34+
class TopKModelWithOut(torch.nn.Module):
35+
"""Model that uses torch.topk with out parameter which also requires temp memory."""
36+
37+
def __init__(self, k=3) -> None:
38+
super().__init__()
39+
self.k = k
40+
41+
def forward(self, x) -> "tuple[torch.Tensor, torch.Tensor]":
42+
top_values = torch.ones(x.shape[0], self.k, dtype=torch.float32)
43+
top_indices = torch.ones(x.shape[0], self.k, dtype=torch.long)
44+
torch.topk(x.contiguous(), self.k, out=(top_values, top_indices))
45+
return top_values, top_indices
46+
47+
48+
def test_topk_without_out_parameter():
49+
"""Test torch.topk without out parameter."""
50+
print("Testing torch.topk without out parameter...")
51+
52+
model = TopKModel(k=5)
53+
example_input = (torch.randn(3, 100),)
54+
55+
# Export and compile the model
56+
with torch.no_grad():
57+
aten_dialect = export(model, example_input)
58+
59+
backend_dialect = to_edge_transform_and_lower(
60+
aten_dialect,
61+
compile_config=EdgeCompileConfig(_check_ir_validity=False),
62+
partitioner=[XnnpackPartitioner()],
63+
)
64+
65+
executorch_dialect = backend_dialect.to_executorch()
66+
67+
# Save to temporary file
68+
with tempfile.NamedTemporaryFile(suffix=".pte", delete=False) as f:
69+
temp_path = f.name
70+
71+
try:
72+
executorch_dialect.save(temp_path)
73+
74+
# Load and execute with ExecuTorch runtime
75+
et_runtime = Runtime.get()
76+
program = et_runtime.load_program(
77+
Path(temp_path),
78+
verification=Verification.Minimal,
79+
)
80+
81+
forward = program.load_method("forward")
82+
outputs = forward.execute(example_input)
83+
84+
print(
85+
f"✓ Successfully executed topk model: {example_input[0].shape} -> {outputs[0].shape}"
86+
)
87+
return True
88+
89+
finally:
90+
# Clean up temporary file
91+
if os.path.exists(temp_path):
92+
os.unlink(temp_path)
93+
94+
95+
def test_topk_with_out_parameter():
96+
"""Test torch.topk with out parameter (original failing case)."""
97+
print("Testing torch.topk with out parameter...")
98+
99+
model = TopKModelWithOut(k=3)
100+
example_input = (torch.randn(2, 256),)
101+
102+
# Export and compile the model
103+
with torch.no_grad():
104+
aten_dialect = export(model, example_input)
105+
106+
backend_dialect = to_edge_transform_and_lower(
107+
aten_dialect,
108+
compile_config=EdgeCompileConfig(_check_ir_validity=False),
109+
partitioner=[XnnpackPartitioner()],
110+
)
111+
112+
executorch_dialect = backend_dialect.to_executorch()
113+
114+
# Save to temporary file
115+
with tempfile.NamedTemporaryFile(suffix=".pte", delete=False) as f:
116+
temp_path = f.name
117+
118+
try:
119+
executorch_dialect.save(temp_path)
120+
121+
# Load and execute with ExecuTorch runtime
122+
et_runtime = Runtime.get()
123+
program = et_runtime.load_program(
124+
Path(temp_path),
125+
verification=Verification.Minimal,
126+
)
127+
128+
forward = program.load_method("forward")
129+
outputs = forward.execute(example_input)
130+
131+
print(
132+
f"✓ Successfully executed topk model with out parameter: {example_input[0].shape} -> {outputs[0].shape}"
133+
)
134+
return True
135+
136+
finally:
137+
# Clean up temporary file
138+
if os.path.exists(temp_path):
139+
os.unlink(temp_path)
140+
141+
142+
def test_larger_topk_operation():
143+
"""Test larger topk operation that would require more temporary memory."""
144+
print("Testing larger topk operation...")
145+
146+
model = TopKModel(k=50)
147+
example_input = (torch.randn(5, 1000),)
148+
149+
# Export and compile the model
150+
with torch.no_grad():
151+
aten_dialect = export(model, example_input)
152+
153+
backend_dialect = to_edge_transform_and_lower(
154+
aten_dialect,
155+
compile_config=EdgeCompileConfig(_check_ir_validity=False),
156+
partitioner=[XnnpackPartitioner()],
157+
)
158+
159+
executorch_dialect = backend_dialect.to_executorch()
160+
161+
# Save to temporary file
162+
with tempfile.NamedTemporaryFile(suffix=".pte", delete=False) as f:
163+
temp_path = f.name
164+
165+
try:
166+
executorch_dialect.save(temp_path)
167+
168+
# Load and execute with ExecuTorch runtime
169+
et_runtime = Runtime.get()
170+
program = et_runtime.load_program(
171+
Path(temp_path),
172+
verification=Verification.Minimal,
173+
)
174+
175+
forward = program.load_method("forward")
176+
outputs = forward.execute(example_input)
177+
178+
print(
179+
f"✓ Successfully executed large topk model: {example_input[0].shape} -> {outputs[0].shape}"
180+
)
181+
return True
182+
183+
finally:
184+
# Clean up temporary file
185+
if os.path.exists(temp_path):
186+
os.unlink(temp_path)
187+
188+
189+
def main():
190+
"""Run all tests to verify the temp memory allocation fix."""
191+
print("Testing temp memory allocation fix for torch.topk operations")
192+
print("=" * 60)
193+
194+
tests = [
195+
test_topk_without_out_parameter,
196+
test_topk_with_out_parameter,
197+
test_larger_topk_operation,
198+
]
199+
200+
passed = 0
201+
failed = 0
202+
203+
for test in tests:
204+
try:
205+
if test():
206+
passed += 1
207+
else:
208+
failed += 1
209+
except Exception as e:
210+
print(f"✗ Test {test.__name__} failed with exception: {e}")
211+
failed += 1
212+
213+
print("\n" + "=" * 60)
214+
print(f"Test Results: {passed} passed, {failed} failed")
215+
216+
if failed == 0:
217+
print(
218+
"✓ All tests passed! The temp memory allocation fix is working correctly."
219+
)
220+
return True
221+
else:
222+
print("✗ Some tests failed. The fix may not be working correctly.")
223+
return False
224+
225+
226+
if __name__ == "__main__":
227+
success = main()
228+
exit(0 if success else 1)

0 commit comments

Comments
 (0)