|
| 1 | +import { Processor } from "../../base/processing_utils.js"; |
| 2 | +import { AutoImageProcessor } from "../auto/image_processing_auto.js"; |
| 3 | +import { AutoTokenizer } from "../../tokenizers.js"; |
| 4 | +import { center_to_corners_format } from "../../base/image_processors_utils.js"; |
| 5 | + |
| 6 | +/** |
| 7 | + * Get token ids of phrases from posmaps and input_ids. |
| 8 | + * @param {import('../../utils/tensor.js').Tensor} posmaps A boolean tensor of unbatched text-thresholded logits related to the detected bounding boxes of shape `(hidden_size, )`. |
| 9 | + * @param {import('../../utils/tensor.js').Tensor} input_ids A tensor of token ids of shape `(sequence_length, )`. |
| 10 | + */ |
| 11 | +function get_phrases_from_posmap(posmaps, input_ids) { |
| 12 | + |
| 13 | + const left_idx = 0; |
| 14 | + const right_idx = posmaps.dims.at(-1) - 1; |
| 15 | + |
| 16 | + const posmaps_list = posmaps.tolist(); |
| 17 | + posmaps_list.fill(false, 0, left_idx + 1); |
| 18 | + posmaps_list.fill(false, right_idx); |
| 19 | + |
| 20 | + const input_ids_list = input_ids.tolist(); |
| 21 | + return posmaps_list |
| 22 | + .map((val, idx) => val ? idx : null) |
| 23 | + .filter(idx => idx !== null) |
| 24 | + .map(i => input_ids_list[i]); |
| 25 | +} |
| 26 | + |
| 27 | +export class GroundingDinoProcessor extends Processor { |
| 28 | + static tokenizer_class = AutoTokenizer |
| 29 | + static image_processor_class = AutoImageProcessor |
| 30 | + |
| 31 | + /** |
| 32 | + * @typedef {import('../../utils/image.js').RawImage} RawImage |
| 33 | + */ |
| 34 | + /** |
| 35 | + * |
| 36 | + * @param {RawImage|RawImage[]|RawImage[][]} images |
| 37 | + * @param {string|string[]} text |
| 38 | + * @returns {Promise<any>} |
| 39 | + */ |
| 40 | + async _call(images, text, options = {}) { |
| 41 | + |
| 42 | + const image_inputs = images ? await this.image_processor(images, options) : {}; |
| 43 | + const text_inputs = text ? this.tokenizer(text, options) : {}; |
| 44 | + |
| 45 | + return { |
| 46 | + ...text_inputs, |
| 47 | + ...image_inputs, |
| 48 | + } |
| 49 | + } |
| 50 | + post_process_grounded_object_detection(outputs, input_ids, { |
| 51 | + box_threshold = 0.25, |
| 52 | + text_threshold = 0.25, |
| 53 | + target_sizes = null |
| 54 | + } = {}) { |
| 55 | + const { logits, pred_boxes } = outputs; |
| 56 | + const batch_size = logits.dims[0]; |
| 57 | + |
| 58 | + if (target_sizes !== null && target_sizes.length !== batch_size) { |
| 59 | + throw Error("Make sure that you pass in as many target sizes as the batch dimension of the logits") |
| 60 | + } |
| 61 | + const num_queries = logits.dims.at(1); |
| 62 | + |
| 63 | + const probs = logits.sigmoid(); // (batch_size, num_queries, 256) |
| 64 | + const scores = probs.max(-1).tolist(); // (batch_size, num_queries) |
| 65 | + |
| 66 | + // Convert to [x0, y0, x1, y1] format |
| 67 | + const boxes = pred_boxes.tolist() // (batch_size, num_queries, 4) |
| 68 | + .map(batch => batch.map(box => center_to_corners_format(box))); |
| 69 | + |
| 70 | + const results = []; |
| 71 | + for (let i = 0; i < batch_size; ++i) { |
| 72 | + const target_size = target_sizes !== null ? target_sizes[i] : null; |
| 73 | + |
| 74 | + // Convert from relative [0, 1] to absolute [0, height] coordinates |
| 75 | + if (target_size !== null) { |
| 76 | + boxes[i] = boxes[i].map(box => box.map((x, j) => x * target_size[(j + 1) % 2])); |
| 77 | + } |
| 78 | + |
| 79 | + const batch_scores = scores[i]; |
| 80 | + const final_scores = []; |
| 81 | + const final_phrases = []; |
| 82 | + const final_boxes = []; |
| 83 | + for (let j = 0; j < num_queries; ++j) { |
| 84 | + const score = batch_scores[j]; |
| 85 | + if (score <= box_threshold) { |
| 86 | + continue; |
| 87 | + } |
| 88 | + const box = boxes[i][j]; |
| 89 | + const prob = probs[i][j]; |
| 90 | + |
| 91 | + final_scores.push(score); |
| 92 | + final_boxes.push(box); |
| 93 | + |
| 94 | + const phrases = get_phrases_from_posmap(prob.gt(text_threshold), input_ids[i]); |
| 95 | + final_phrases.push(phrases); |
| 96 | + } |
| 97 | + results.push({ scores: final_scores, boxes: final_boxes, labels: this.batch_decode(final_phrases) }); |
| 98 | + } |
| 99 | + return results; |
| 100 | + } |
| 101 | +} |
0 commit comments