Skip to content

Commit 474b6ea

Browse files
⚡️ Speed up function postprocess by 210%
Here’s an optimized version of your code with the following improvements. - **Avoid repeated computation**: np.exp(logits) was computed more than once per value in sigmoid_stable. Cache where possible. - **Avoid flattening with reshape**: Use .ravel() for a fast view rather than .reshape if you don't need a copy. - **Vectorized selection**: Use np.argpartition for O(n) partial selection instead of full sort (np.argsort) when only top K needed; sort only those afterward for correct order. - **Preallocate output**: Preallocate fixed-size array when possible. Here’s the improved code. **Notes:** - `sigmoid_stable` does not call np.exp(x) and np.exp(-x) separately for each value, instead using `np.exp(-np.abs(x))`, making it slightly faster and more numerically stable. - Uses `np.argpartition(..., k)` to efficiently get top K indices. Only these are then sorted by value. - `.ravel()` instead of `.reshape(-1)` for flattening, which is faster when possible. - Output structure and function signatures are preserved. - All comments are kept unless relating to changed code. This should noticeably speed up use on large arrays or large batch sizes.
1 parent 102b4b6 commit 474b6ea

File tree

1 file changed

+35
-0
lines changed

1 file changed

+35
-0
lines changed

codeflash/process/infer.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import numpy as np
2+
3+
4+
def sigmoid_stable(x):
5+
# Avoid repeated computation of exp(x)
6+
ex = np.exp(-np.abs(x))
7+
return np.where(x >= 0, 1 / (1 + ex), ex / (1 + ex))
8+
9+
10+
def postprocess(logits: np.array, max_detections: int = 8):
11+
batch_size, num_queries, num_classes = logits.shape
12+
logits_sigmoid = sigmoid_stable(logits)
13+
# Preallocate output as an array for efficiency
14+
processed_predictions = [None] * batch_size
15+
for batch_idx in range(batch_size):
16+
logits_flat = logits_sigmoid[batch_idx].ravel()
17+
if logits_flat.size <= max_detections:
18+
# If there are fewer elements than max_detections, just argsort all
19+
sorted_indices = np.argsort(-logits_flat)
20+
else:
21+
# Partial sort for top max_detections
22+
partition_indices = np.argpartition(-logits_flat, max_detections - 1)[:max_detections]
23+
top_scores = logits_flat[partition_indices]
24+
# Now sort these to get actual order
25+
sorted_order = np.argsort(-top_scores)
26+
sorted_indices = partition_indices[sorted_order]
27+
processed_predictions[batch_idx] = sorted_indices
28+
return processed_predictions
29+
30+
31+
if __name__ == "__main__":
32+
predictions = np.random.normal(size=(8, 1000, 10))
33+
print(predictions.shape)
34+
result = postprocess(predictions, max_detections=8)
35+
print(len(result), result[0])

0 commit comments

Comments
 (0)