Skip to content

Commit 8947ec9

Browse files
committed
check it in
Signed-off-by: Saurabh Misra <[email protected]>
1 parent 553a192 commit 8947ec9

File tree

1 file changed

+24
-0
lines changed

1 file changed

+24
-0
lines changed

codeflash/process/infer.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import numpy as np
2+
3+
4+
def sigmoid_stable(x):
5+
return np.where(x >= 0, 1 / (1 + np.exp(-x)), np.exp(x) / (1 + np.exp(x)))
6+
7+
8+
def postprocess(logits: np.array, max_detections: int = 8):
9+
batch_size, num_queries, num_classes = logits.shape
10+
logits_sigmoid = sigmoid_stable(logits)
11+
processed_predictions = []
12+
for batch_idx in range(batch_size):
13+
logits_flat = logits_sigmoid[batch_idx].reshape(-1)
14+
15+
sorted_indices = np.argsort(-logits_flat)[:max_detections]
16+
processed_predictions.append(sorted_indices)
17+
return processed_predictions
18+
19+
20+
if __name__ == "__main__":
21+
predictions = np.random.normal(size=(8, 1000, 10))
22+
print(predictions.shape)
23+
result = postprocess(predictions, max_detections=8)
24+
print(len(result), result[0])

0 commit comments

Comments
 (0)