Skip to content

Commit 8c5ab40

Browse files
⚡️ Speed up method AlexNet._classify by 317%
Here’s an optimized version of your `AlexNet` class for improved speed and efficiency. The improvements include. - Use list multiplication where possible, and avoid unnecessary use of `sum()` in loops. - Use built-in functions efficiently. - Precompute common values. **Explanation of changes:** - Replaced list comprehension with `[total_mod] * len(features)`, which is faster and more memory-efficient for filling a list with the same value. - Only run `sum(features)` and modulo operation once, instead of for every element. - Preserved the comments as instructed.
1 parent 4debe7e commit 8c5ab40

File tree

1 file changed

+10
-5
lines changed
  • code_to_optimize/code_directories/simple_tracer_e2e

1 file changed

+10
-5
lines changed

code_to_optimize/code_directories/simple_tracer_e2e/workload.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33

44
def funcA(number):
5-
number = number if number < 1000 else 1000
5+
number = min(1000, number)
66
k = 0
77
for i in range(number * 100):
88
k += i
@@ -21,14 +21,15 @@ def test_threadpool() -> None:
2121
for r in result:
2222
print(r)
2323

24+
2425
class AlexNet:
2526
def __init__(self, num_classes=1000):
2627
self.num_classes = num_classes
2728
self.features_size = 256 * 6 * 6
2829

2930
def forward(self, x):
3031
features = self._extract_features(x)
31-
32+
3233
output = self._classify(features)
3334
return output
3435

@@ -40,18 +41,21 @@ def _extract_features(self, x):
4041
return result
4142

4243
def _classify(self, features):
43-
total = sum(features)
44-
return [total % self.num_classes for _ in features]
44+
# Precompute length and value for efficient list construction
45+
total_mod = sum(features) % self.num_classes
46+
return [total_mod] * len(features)
47+
4548

4649
class SimpleModel:
4750
@staticmethod
4851
def predict(data):
4952
return [x * 2 for x in data]
50-
53+
5154
@classmethod
5255
def create_default(cls):
5356
return cls()
5457

58+
5559
def test_models():
5660
model = AlexNet(num_classes=10)
5761
input_data = [1, 2, 3, 4, 5]
@@ -60,6 +64,7 @@ def test_models():
6064
model2 = SimpleModel.create_default()
6165
prediction = model2.predict(input_data)
6266

67+
6368
if __name__ == "__main__":
6469
test_threadpool()
6570
test_models()

0 commit comments

Comments
 (0)