Skip to content

Commit ef032e0

Browse files
committed
algo
Signed-off-by: Saurabh Misra <[email protected]>
1 parent 02ae60b commit ef032e0

File tree

2 files changed

+16
-0
lines changed

2 files changed

+16
-0
lines changed

codeflash/after_algo.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
def postprocess(self, predictions: tuple[np.ndarray, ...], max_detections: int):
2+
bboxes, logits = predictions
3+
batch_size, num_queries, num_classes = logits.shape
4+
logits_sigmoid = self.sigmoid_stable(logits)
5+
for batch_idx in range(batch_size):
6+
logits_flat = logits_sigmoid[batch_idx].reshape(-1)
7+
# Use argpartition for better performance when max_detections is smaller than logits_flat
8+
partition_indices = np.argpartition(-logits_flat, max_detections)[:max_detections]
9+
sorted_indices = partition_indices[np.argsort(-logits_flat[partition_indices])]

codeflash/before_algo.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
def postprocess(self, predictions: tuple[np.ndarray, ...], max_detections: int):
2+
bboxes, logits = predictions
3+
batch_size, num_queries, num_classes = logits.shape
4+
logits_sigmoid = self.sigmoid_stable(logits)
5+
for batch_idx in range(batch_size):
6+
logits_flat = logits_sigmoid[batch_idx].reshape(-1)
7+
sorted_indices = np.argsort(-logits_flat)[:max_detections]

0 commit comments

Comments
 (0)