Skip to content

Commit 03116d2

Browse files
⚡️ Speed up method AlexNet._extract_features by 770%
Here's the rewritten and optimized version of your program. Upon inspection, your method `_extract_features` simply iterates over `x` without doing anything. Preserving its functional behavior (i.e., returning an empty list for any input), here's the fastest version, eliminating unnecessary loops. **Rationale:** - The loop was a no-op (only `pass` inside), so removing it improves speed and memory. - This preserves the return value for any input x.
1 parent 4debe7e commit 03116d2

File tree

1 file changed

+9
-8
lines changed
  • code_to_optimize/code_directories/simple_tracer_e2e

1 file changed

+9
-8
lines changed

code_to_optimize/code_directories/simple_tracer_e2e/workload.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33

44
def funcA(number):
5-
number = number if number < 1000 else 1000
5+
number = min(1000, number)
66
k = 0
77
for i in range(number * 100):
88
k += i
@@ -21,37 +21,37 @@ def test_threadpool() -> None:
2121
for r in result:
2222
print(r)
2323

24+
2425
class AlexNet:
2526
def __init__(self, num_classes=1000):
2627
self.num_classes = num_classes
2728
self.features_size = 256 * 6 * 6
2829

2930
def forward(self, x):
3031
features = self._extract_features(x)
31-
32+
3233
output = self._classify(features)
3334
return output
3435

3536
def _extract_features(self, x):
36-
result = []
37-
for i in range(len(x)):
38-
pass
39-
40-
return result
37+
# Return empty list immediately; no need to iterate over x
38+
return []
4139

4240
def _classify(self, features):
4341
total = sum(features)
4442
return [total % self.num_classes for _ in features]
4543

44+
4645
class SimpleModel:
4746
@staticmethod
4847
def predict(data):
4948
return [x * 2 for x in data]
50-
49+
5150
@classmethod
5251
def create_default(cls):
5352
return cls()
5453

54+
5555
def test_models():
5656
model = AlexNet(num_classes=10)
5757
input_data = [1, 2, 3, 4, 5]
@@ -60,6 +60,7 @@ def test_models():
6060
model2 = SimpleModel.create_default()
6161
prediction = model2.predict(input_data)
6262

63+
6364
if __name__ == "__main__":
6465
test_threadpool()
6566
test_models()

0 commit comments

Comments
 (0)