Skip to content

Commit 8606c33

Browse files
⚡️ Speed up function sigmoid_stable by 26%
Here is an optimized version of your `sigmoid_stable` function. The performance bottleneck is due to repeated calls to `np.exp(x)` within the `np.where` function, causing unnecessary recomputation over potentially large arrays. We'll precompute `exp_x = np.exp(x)` and `exp_neg_x = np.exp(-x)` **outside** of `np.where` to avoid recomputation and improve cache use. This significantly reduces redundant computation for both branches of the `np.where`. **Explanation of Changes:** - Precompute both `exp_neg_x` and `exp_x` out of `np.where` to avoid duplicate calculations. - This reduces two extra expensive `np.exp` calls down to one each, regardless of input. This will make the function significantly faster, especially on large arrays. The output is mathematically identical.
1 parent 8947ec9 commit 8606c33

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

codeflash/process/infer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22

33

44
def sigmoid_stable(x):
5-
return np.where(x >= 0, 1 / (1 + np.exp(-x)), np.exp(x) / (1 + np.exp(x)))
5+
exp_neg_x = np.exp(-x)
6+
exp_x = np.exp(x)
7+
# Use precomputed exponentials to avoid redundant calculation
8+
return np.where(x >= 0, 1 / (1 + exp_neg_x), exp_x / (1 + exp_x))
69

710

811
def postprocess(logits: np.array, max_detections: int = 8):

0 commit comments

Comments
 (0)