Skip to content

Commit 7baeb7d

Browse files
⚡️ Speed up function sigmoid_stable by 236%
Here is a faster and more memory-efficient version of the function. The bottleneck in the original implementation is calling `np.exp(x)` twice for `x < 0` branches. We can compute `exp_x = np.exp(x)` once, then reuse it for both cases. This version calls `np.exp` only once and chooses the correct formulation based on `x`, which results in a significant speedup and less redundant computation.
1 parent 8947ec9 commit 7baeb7d

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

codeflash/process/infer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33

44
def sigmoid_stable(x):
5-
return np.where(x >= 0, 1 / (1 + np.exp(-x)), np.exp(x) / (1 + np.exp(x)))
5+
exp_x = np.exp(-np.abs(x))
6+
return np.where(x >= 0, 1 / (1 + exp_x), exp_x / (1 + exp_x))
67

78

89
def postprocess(logits: np.array, max_detections: int = 8):

0 commit comments

Comments
 (0)