diff --git a/codeflash/process/infer.py b/codeflash/process/infer.py new file mode 100644 index 000000000..ebe3c7e88 --- /dev/null +++ b/codeflash/process/infer.py @@ -0,0 +1,24 @@ +import numpy as np + + +def sigmoid_stable(x): + # Avoid redundant computation of np.exp(x) + exp_x = np.exp(-np.abs(x)) + return np.where(x >= 0, 1 / (1 + exp_x), exp_x / (1 + exp_x)) + + +def postprocess(logits: np.array, max_detections: int = 8): + batch_size, num_queries, num_classes = logits.shape + logits_sigmoid = sigmoid_stable(logits) + processed_predictions = [] + for batch_idx in range(batch_size): + logits_flat = logits_sigmoid[batch_idx].reshape(-1) + + sorted_indices = np.argsort(-logits_flat)[:max_detections] + processed_predictions.append(sorted_indices) + return processed_predictions + + +if __name__ == "__main__": + predictions = np.random.normal(size=(8, 1000, 10)) + result = postprocess(predictions, max_detections=8)