File tree Expand file tree Collapse file tree 1 file changed +24
-0
lines changed
Expand file tree Collapse file tree 1 file changed +24
-0
lines changed Original file line number Diff line number Diff line change 1+ import numpy as np
2+
3+
4+ def sigmoid_stable (x ):
5+ return np .where (x >= 0 , 1 / (1 + np .exp (- x )), np .exp (x ) / (1 + np .exp (x )))
6+
7+
8+ def postprocess (logits : np .array , max_detections : int = 8 ):
9+ batch_size , num_queries , num_classes = logits .shape
10+ logits_sigmoid = sigmoid_stable (logits )
11+ processed_predictions = []
12+ for batch_idx in range (batch_size ):
13+ logits_flat = logits_sigmoid [batch_idx ].reshape (- 1 )
14+
15+ sorted_indices = np .argsort (- logits_flat )[:max_detections ]
16+ processed_predictions .append (sorted_indices )
17+ return processed_predictions
18+
19+
20+ if __name__ == "__main__" :
21+ predictions = np .random .normal (size = (8 , 1000 , 10 ))
22+ print (predictions .shape )
23+ result = postprocess (predictions , max_detections = 8 )
24+ print (len (result ), result [0 ])
You can’t perform that action at this time.
0 commit comments