Skip to content

Commit ade8ab5

Browse files
⚡️ Speed up function sigmoid_stable by 208%
Here is an optimized version of your program. The original code inadvertently computes `np.exp(x)` twice when `x < 0`, incurring redundant computation. By computing it once and caching the result, you eliminate the duplicate work. This reduces runtime, especially on large arrays. This form reduces repeated computation and will run faster, especially for large NumPy arrays. **All comments are preserved unless their code portion changed.**
1 parent 102b4b6 commit ade8ab5

File tree

1 file changed

+24
-0
lines changed

1 file changed

+24
-0
lines changed

codeflash/process/infer.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import numpy as np
2+
3+
4+
def sigmoid_stable(x):
5+
# Avoid redundant computation of np.exp(x)
6+
exp_x = np.exp(-np.abs(x))
7+
return np.where(x >= 0, 1 / (1 + exp_x), exp_x / (1 + exp_x))
8+
9+
10+
def postprocess(logits: np.array, max_detections: int = 8):
11+
batch_size, num_queries, num_classes = logits.shape
12+
logits_sigmoid = sigmoid_stable(logits)
13+
processed_predictions = []
14+
for batch_idx in range(batch_size):
15+
logits_flat = logits_sigmoid[batch_idx].reshape(-1)
16+
17+
sorted_indices = np.argsort(-logits_flat)[:max_detections]
18+
processed_predictions.append(sorted_indices)
19+
return processed_predictions
20+
21+
22+
if __name__ == "__main__":
23+
predictions = np.random.normal(size=(8, 1000, 10))
24+
result = postprocess(predictions, max_detections=8)

0 commit comments

Comments
 (0)