|
| 1 | +from typing import Tuple, List |
| 2 | + |
| 3 | +import ldm_patched.modules.model_management as model_management |
| 4 | +from ldm_patched.modules.model_patcher import ModelPatcher |
| 5 | +from modules.config import path_inpaint |
| 6 | +from modules.model_loader import load_file_from_url |
| 7 | + |
| 8 | +import numpy as np |
| 9 | +import supervision as sv |
| 10 | +import torch |
| 11 | +from groundingdino.util.inference import Model |
| 12 | +from groundingdino.util.inference import load_model, preprocess_caption, get_phrases_from_posmap |
| 13 | + |
| 14 | + |
| 15 | +class GroundingDinoModel(Model): |
| 16 | + def __init__(self): |
| 17 | + self.config_file = 'extras/GroundingDINO/config/GroundingDINO_SwinT_OGC.py' |
| 18 | + self.model = None |
| 19 | + self.load_device = torch.device('cpu') |
| 20 | + self.offload_device = torch.device('cpu') |
| 21 | + |
| 22 | + @torch.no_grad() |
| 23 | + @torch.inference_mode() |
| 24 | + def predict_with_caption( |
| 25 | + self, |
| 26 | + image: np.ndarray, |
| 27 | + caption: str, |
| 28 | + box_threshold: float = 0.35, |
| 29 | + text_threshold: float = 0.25 |
| 30 | + ) -> Tuple[sv.Detections, torch.Tensor, torch.Tensor, List[str]]: |
| 31 | + if self.model is None: |
| 32 | + filename = load_file_from_url( |
| 33 | + url="https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth", |
| 34 | + file_name='groundingdino_swint_ogc.pth', |
| 35 | + model_dir=path_inpaint) |
| 36 | + model = load_model(model_config_path=self.config_file, model_checkpoint_path=filename) |
| 37 | + |
| 38 | + self.load_device = model_management.text_encoder_device() |
| 39 | + self.offload_device = model_management.text_encoder_offload_device() |
| 40 | + |
| 41 | + model.to(self.offload_device) |
| 42 | + |
| 43 | + self.model = ModelPatcher(model, load_device=self.load_device, offload_device=self.offload_device) |
| 44 | + |
| 45 | + model_management.load_model_gpu(self.model) |
| 46 | + |
| 47 | + processed_image = GroundingDinoModel.preprocess_image(image_bgr=image).to(self.load_device) |
| 48 | + boxes, logits, phrases = predict( |
| 49 | + model=self.model, |
| 50 | + image=processed_image, |
| 51 | + caption=caption, |
| 52 | + box_threshold=box_threshold, |
| 53 | + text_threshold=text_threshold, |
| 54 | + device=self.load_device) |
| 55 | + source_h, source_w, _ = image.shape |
| 56 | + detections = GroundingDinoModel.post_process_result( |
| 57 | + source_h=source_h, |
| 58 | + source_w=source_w, |
| 59 | + boxes=boxes, |
| 60 | + logits=logits) |
| 61 | + return detections, boxes, logits, phrases |
| 62 | + |
| 63 | + |
| 64 | +def predict( |
| 65 | + model, |
| 66 | + image: torch.Tensor, |
| 67 | + caption: str, |
| 68 | + box_threshold: float, |
| 69 | + text_threshold: float, |
| 70 | + device: str = "cuda" |
| 71 | +) -> Tuple[torch.Tensor, torch.Tensor, List[str]]: |
| 72 | + caption = preprocess_caption(caption=caption) |
| 73 | + |
| 74 | + # override to use model wrapped by patcher |
| 75 | + model = model.model.to(device) |
| 76 | + image = image.to(device) |
| 77 | + |
| 78 | + with torch.no_grad(): |
| 79 | + outputs = model(image[None], captions=[caption]) |
| 80 | + |
| 81 | + prediction_logits = outputs["pred_logits"].cpu().sigmoid()[0] # prediction_logits.shape = (nq, 256) |
| 82 | + prediction_boxes = outputs["pred_boxes"].cpu()[0] # prediction_boxes.shape = (nq, 4) |
| 83 | + |
| 84 | + mask = prediction_logits.max(dim=1)[0] > box_threshold |
| 85 | + logits = prediction_logits[mask] # logits.shape = (n, 256) |
| 86 | + boxes = prediction_boxes[mask] # boxes.shape = (n, 4) |
| 87 | + |
| 88 | + tokenizer = model.tokenizer |
| 89 | + tokenized = tokenizer(caption) |
| 90 | + |
| 91 | + phrases = [ |
| 92 | + get_phrases_from_posmap(logit > text_threshold, tokenized, tokenizer).replace('.', '') |
| 93 | + for logit |
| 94 | + in logits |
| 95 | + ] |
| 96 | + |
| 97 | + return boxes, logits.max(dim=1)[0], phrases |
| 98 | + |
| 99 | + |
| 100 | +default_groundingdino = GroundingDinoModel().predict_with_caption |
0 commit comments