Skip to content

Commit aac50d3

Browse files
committed
fix lintrunner
1 parent f83142b commit aac50d3

File tree

1 file changed

+63
-50
lines changed

1 file changed

+63
-50
lines changed

test/end2end/test_temp_allocator_fix.py

Lines changed: 63 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,24 @@
77
require temporary memory allocation.
88
"""
99

10-
import torch
11-
import tempfile
1210
import os
11+
import tempfile
1312
from pathlib import Path
14-
from torch.export import export
15-
from executorch.exir import EdgeProgramManager, to_edge_transform_and_lower, EdgeCompileConfig
13+
14+
import torch
1615
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
17-
from executorch.runtime import Verification, Runtime, Program, Method
16+
from executorch.exir import (
17+
EdgeCompileConfig,
18+
EdgeProgramManager,
19+
to_edge_transform_and_lower,
20+
)
21+
from executorch.runtime import Method, Program, Runtime, Verification
22+
from torch.export import export
1823

1924

2025
class TopKModel(torch.nn.Module):
2126
"""Model that uses torch.topk operation which requires temp memory allocation."""
22-
27+
2328
def __init__(self, k=3) -> None:
2429
super().__init__()
2530
self.k = k
@@ -32,7 +37,7 @@ def forward(self, x) -> torch.Tensor:
3237

3338
class TopKModelWithOut(torch.nn.Module):
3439
"""Model that uses torch.topk with out parameter which also requires temp memory."""
35-
40+
3641
def __init__(self, k=3) -> None:
3742
super().__init__()
3843
self.k = k
@@ -47,42 +52,44 @@ def forward(self, x) -> torch.Tensor:
4752
def test_topk_without_out_parameter():
4853
"""Test torch.topk without out parameter."""
4954
print("Testing torch.topk without out parameter...")
50-
55+
5156
model = TopKModel(k=5)
5257
example_input = (torch.randn(3, 100),)
53-
58+
5459
# Export and compile the model
5560
with torch.no_grad():
5661
aten_dialect = export(model, example_input)
57-
62+
5863
backend_dialect = to_edge_transform_and_lower(
5964
aten_dialect,
6065
compile_config=EdgeCompileConfig(_check_ir_validity=False),
6166
partitioner=[XnnpackPartitioner()],
6267
)
63-
68+
6469
executorch_dialect = backend_dialect.to_executorch()
65-
70+
6671
# Save to temporary file
67-
with tempfile.NamedTemporaryFile(suffix='.pte', delete=False) as f:
72+
with tempfile.NamedTemporaryFile(suffix=".pte", delete=False) as f:
6873
temp_path = f.name
69-
74+
7075
try:
7176
executorch_dialect.save(temp_path)
72-
77+
7378
# Load and execute with ExecuTorch runtime
7479
et_runtime = Runtime.get()
7580
program = et_runtime.load_program(
7681
Path(temp_path),
7782
verification=Verification.Minimal,
7883
)
79-
84+
8085
forward = program.load_method("forward")
8186
outputs = forward.execute(example_input)
82-
83-
print(f"✓ Successfully executed topk model: {example_input[0].shape} -> {outputs[0].shape}")
87+
88+
print(
89+
f"✓ Successfully executed topk model: {example_input[0].shape} -> {outputs[0].shape}"
90+
)
8491
return True
85-
92+
8693
finally:
8794
# Clean up temporary file
8895
if os.path.exists(temp_path):
@@ -92,42 +99,44 @@ def test_topk_without_out_parameter():
9299
def test_topk_with_out_parameter():
93100
"""Test torch.topk with out parameter (original failing case)."""
94101
print("Testing torch.topk with out parameter...")
95-
102+
96103
model = TopKModelWithOut(k=3)
97104
example_input = (torch.randn(2, 256),)
98-
105+
99106
# Export and compile the model
100107
with torch.no_grad():
101108
aten_dialect = export(model, example_input)
102-
109+
103110
backend_dialect = to_edge_transform_and_lower(
104111
aten_dialect,
105112
compile_config=EdgeCompileConfig(_check_ir_validity=False),
106113
partitioner=[XnnpackPartitioner()],
107114
)
108-
115+
109116
executorch_dialect = backend_dialect.to_executorch()
110-
117+
111118
# Save to temporary file
112-
with tempfile.NamedTemporaryFile(suffix='.pte', delete=False) as f:
119+
with tempfile.NamedTemporaryFile(suffix=".pte", delete=False) as f:
113120
temp_path = f.name
114-
121+
115122
try:
116123
executorch_dialect.save(temp_path)
117-
124+
118125
# Load and execute with ExecuTorch runtime
119126
et_runtime = Runtime.get()
120127
program = et_runtime.load_program(
121128
Path(temp_path),
122129
verification=Verification.Minimal,
123130
)
124-
131+
125132
forward = program.load_method("forward")
126133
outputs = forward.execute(example_input)
127-
128-
print(f"✓ Successfully executed topk model with out parameter: {example_input[0].shape} -> {outputs[0].shape}")
134+
135+
print(
136+
f"✓ Successfully executed topk model with out parameter: {example_input[0].shape} -> {outputs[0].shape}"
137+
)
129138
return True
130-
139+
131140
finally:
132141
# Clean up temporary file
133142
if os.path.exists(temp_path):
@@ -137,42 +146,44 @@ def test_topk_with_out_parameter():
137146
def test_larger_topk_operation():
138147
"""Test larger topk operation that would require more temporary memory."""
139148
print("Testing larger topk operation...")
140-
149+
141150
model = TopKModel(k=50)
142151
example_input = (torch.randn(5, 1000),)
143-
152+
144153
# Export and compile the model
145154
with torch.no_grad():
146155
aten_dialect = export(model, example_input)
147-
156+
148157
backend_dialect = to_edge_transform_and_lower(
149158
aten_dialect,
150159
compile_config=EdgeCompileConfig(_check_ir_validity=False),
151160
partitioner=[XnnpackPartitioner()],
152161
)
153-
162+
154163
executorch_dialect = backend_dialect.to_executorch()
155-
164+
156165
# Save to temporary file
157-
with tempfile.NamedTemporaryFile(suffix='.pte', delete=False) as f:
166+
with tempfile.NamedTemporaryFile(suffix=".pte", delete=False) as f:
158167
temp_path = f.name
159-
168+
160169
try:
161170
executorch_dialect.save(temp_path)
162-
171+
163172
# Load and execute with ExecuTorch runtime
164173
et_runtime = Runtime.get()
165174
program = et_runtime.load_program(
166175
Path(temp_path),
167176
verification=Verification.Minimal,
168177
)
169-
178+
170179
forward = program.load_method("forward")
171180
outputs = forward.execute(example_input)
172-
173-
print(f"✓ Successfully executed large topk model: {example_input[0].shape} -> {outputs[0].shape}")
181+
182+
print(
183+
f"✓ Successfully executed large topk model: {example_input[0].shape} -> {outputs[0].shape}"
184+
)
174185
return True
175-
186+
176187
finally:
177188
# Clean up temporary file
178189
if os.path.exists(temp_path):
@@ -183,16 +194,16 @@ def main():
183194
"""Run all tests to verify the temp memory allocation fix."""
184195
print("Testing temp memory allocation fix for torch.topk operations")
185196
print("=" * 60)
186-
197+
187198
tests = [
188199
test_topk_without_out_parameter,
189200
test_topk_with_out_parameter,
190201
test_larger_topk_operation,
191202
]
192-
203+
193204
passed = 0
194205
failed = 0
195-
206+
196207
for test in tests:
197208
try:
198209
if test():
@@ -202,12 +213,14 @@ def main():
202213
except Exception as e:
203214
print(f"✗ Test {test.__name__} failed with exception: {e}")
204215
failed += 1
205-
216+
206217
print("\n" + "=" * 60)
207218
print(f"Test Results: {passed} passed, {failed} failed")
208-
219+
209220
if failed == 0:
210-
print("✓ All tests passed! The temp memory allocation fix is working correctly.")
221+
print(
222+
"✓ All tests passed! The temp memory allocation fix is working correctly."
223+
)
211224
return True
212225
else:
213226
print("✗ Some tests failed. The fix may not be working correctly.")
@@ -216,4 +229,4 @@ def main():
216229

217230
if __name__ == "__main__":
218231
success = main()
219-
exit(0 if success else 1)
232+
exit(0 if success else 1)

0 commit comments

Comments
 (0)