|
| 1 | +from typing import List |
| 2 | + |
| 3 | +import numpy as np |
| 4 | +from jsonargparse.typing import ClosedUnitInterval |
| 5 | +from loguru import logger |
| 6 | +from PIL import ImageOps |
| 7 | + |
| 8 | +from data_juicer.utils.availability_utils import AvailabilityChecking |
| 9 | +from data_juicer.utils.constant import Fields, StatsKeys |
| 10 | +from data_juicer.utils.mm_utils import (SpecialTokens, iou, load_image, |
| 11 | + remove_special_tokens) |
| 12 | +from data_juicer.utils.model_utils import get_model, prepare_model |
| 13 | + |
| 14 | +from ..base_op import OPERATORS, Filter |
| 15 | +from ..op_fusion import LOADED_IMAGES |
| 16 | + |
| 17 | +OP_NAME = 'phrase_grounding_recall_filter' |
| 18 | + |
| 19 | +with AvailabilityChecking(['torch', 'transformers', 'nltk'], OP_NAME): |
| 20 | + |
| 21 | + import torch |
| 22 | + import transformers # noqa: F401 |
| 23 | + |
| 24 | + # avoid hanging when calling clip in multiprocessing |
| 25 | + torch.set_num_threads(1) |
| 26 | + |
| 27 | + import nltk |
| 28 | + |
| 29 | + |
| 30 | +# NER algorithm adapted from GLIP starts |
| 31 | +# https://github.com/microsoft/GLIP/blob/main/maskrcnn_benchmark/engine/predictor_glip.py#L107-L127 |
| 32 | +def find_noun_phrases(caption: str) -> List[str]: |
| 33 | + caption = caption.lower() |
| 34 | + tokens = nltk.word_tokenize(caption) |
| 35 | + pos_tags = nltk.pos_tag(tokens) |
| 36 | + |
| 37 | + grammar = 'NP: {<DT>?<JJ.*>*<NN.*>+}' |
| 38 | + cp = nltk.RegexpParser(grammar) |
| 39 | + result = cp.parse(pos_tags) |
| 40 | + |
| 41 | + noun_phrases = list() |
| 42 | + for subtree in result.subtrees(): |
| 43 | + if subtree.label() == 'NP': |
| 44 | + noun_phrases.append(' '.join(t[0] for t in subtree.leaves())) |
| 45 | + |
| 46 | + return noun_phrases |
| 47 | + |
| 48 | + |
| 49 | +def remove_punctuation(text: str) -> str: |
| 50 | + punct = [ |
| 51 | + '|', ':', ';', '@', '(', ')', '[', ']', '{', '}', '^', '\'', '\"', '’', |
| 52 | + '`', '?', '$', '%', '#', '!', '&', '*', '+', ',', '.' |
| 53 | + ] |
| 54 | + for p in punct: |
| 55 | + text = text.replace(p, '') |
| 56 | + return text.strip() |
| 57 | + |
| 58 | + |
| 59 | +def run_ner(caption): |
| 60 | + noun_phrases = find_noun_phrases(caption) |
| 61 | + noun_phrases = [remove_punctuation(phrase) for phrase in noun_phrases] |
| 62 | + noun_phrases = [phrase for phrase in noun_phrases if phrase != ''] |
| 63 | + noun_phrases = list(set(noun_phrases)) # remove duplicate ners |
| 64 | + return noun_phrases |
| 65 | + |
| 66 | + |
| 67 | +# NER algorithm adapted from GLIP ends |
| 68 | + |
| 69 | + |
| 70 | +@OPERATORS.register_module(OP_NAME) |
| 71 | +@LOADED_IMAGES.register_module(OP_NAME) |
| 72 | +class PhraseGroundingRecallFilter(Filter): |
| 73 | + """Filter to keep samples whose locating recalls of phrases extracted |
| 74 | + from text in the images are within a specified range.""" |
| 75 | + |
| 76 | + def __init__(self, |
| 77 | + hf_owlvit='google/owlvit-base-patch32', |
| 78 | + min_recall: ClosedUnitInterval = 0.1, |
| 79 | + max_recall: ClosedUnitInterval = 1.0, |
| 80 | + horizontal_flip: bool = False, |
| 81 | + vertical_flip: bool = False, |
| 82 | + any_or_all: str = 'any', |
| 83 | + reduce_mode: str = 'avg', |
| 84 | + iou_thr: ClosedUnitInterval = 0.5, |
| 85 | + large_area_ratio_thr: ClosedUnitInterval = 0.95, |
| 86 | + conf_thr: ClosedUnitInterval = 0.0, |
| 87 | + *args, |
| 88 | + **kwargs): |
| 89 | + """ |
| 90 | + Initialization method. |
| 91 | +
|
| 92 | + :param hf_owlvit: Owl-ViT model name on huggingface to locate the |
| 93 | + phrases extracted from the text. |
| 94 | + :param min_recall: The min phrase grounding recall to keep samples. |
| 95 | + :param max_recall: The max phrase grounding recall to keep samples. |
| 96 | + :param horizontal_flip: Flip image horizontally (left to right). |
| 97 | + :param vertical_flip: Flip image vertically (top to bottom). |
| 98 | + :param any_or_all: keep this sample with 'any' or 'all' strategy of |
| 99 | + all images. 'any': keep this sample if any images meet the |
| 100 | + condition. 'all': keep this sample only if all images meet the |
| 101 | + condition. |
| 102 | + :param reduce_mode: reduce mode when one text corresponds to |
| 103 | + multiple images in a chunk. |
| 104 | + 'avg': Take the average of multiple values |
| 105 | + 'max': Take the max of multiple values |
| 106 | + 'min': Take the min of multiple values |
| 107 | + :param iou_thr: the IoU threshold for NMS-like post-process. If two |
| 108 | + predicted bboxes are overlap with an IoU larger than this |
| 109 | + threshold, the bbox with less confidence will be removed. Default: |
| 110 | + 0.5. |
| 111 | + :param large_area_ratio_thr: the area ratio threshold for filtering out |
| 112 | + those large predicted bboxes. If the area of a predicted bbox |
| 113 | + accounts for more than this ratio threshold of the whole image |
| 114 | + area, this bbox will be removed. Default: 0.95. |
| 115 | + :param conf_thr: the confidence score threshold for removing |
| 116 | + low-confidence bboxes. If the confidence score of a predicted bbox |
| 117 | + is lower than the threshold, this bbox will be removed. Default: 0. |
| 118 | + :param args: extra args |
| 119 | + :param kwargs: extra args |
| 120 | + """ |
| 121 | + super().__init__(*args, **kwargs) |
| 122 | + self.min_recall = min_recall |
| 123 | + self.max_recall = max_recall |
| 124 | + if reduce_mode not in ['avg', 'max', 'min']: |
| 125 | + raise ValueError(f'Reduce mode [{reduce_mode}] is not supported. ' |
| 126 | + f'Can only be one of ["avg", "max", "min"].') |
| 127 | + if any_or_all not in ['any', 'all']: |
| 128 | + raise ValueError(f'Keep strategy [{any_or_all}] is not supported. ' |
| 129 | + f'Can only be one of ["any", "all"].') |
| 130 | + self.any = (any_or_all == 'any') |
| 131 | + self.model_type = 'hf_owlvit' |
| 132 | + self.model_key = prepare_model(model_type=self.model_type, |
| 133 | + model_key=hf_owlvit) |
| 134 | + self.reduce_mode = reduce_mode |
| 135 | + self.horizontal_flip = horizontal_flip |
| 136 | + self.vertical_flip = vertical_flip |
| 137 | + |
| 138 | + self.iou_thr = iou_thr |
| 139 | + self.large_area_ratio_thr = large_area_ratio_thr |
| 140 | + self.conf_thr = conf_thr |
| 141 | + |
| 142 | + requires_nltk_data = ['punkt', 'averaged_perceptron_tagger'] |
| 143 | + logger.info(f'Downloading nltk data of {requires_nltk_data}...') |
| 144 | + for nltk_data_pkg in requires_nltk_data: |
| 145 | + nltk.download(nltk_data_pkg) |
| 146 | + |
| 147 | + def compute_stats(self, sample, context=False): |
| 148 | + # check if it's computed already |
| 149 | + if StatsKeys.phrase_grounding_recall in sample[Fields.stats]: |
| 150 | + return sample |
| 151 | + |
| 152 | + # there is no image in this sample |
| 153 | + if self.image_key not in sample or not sample[self.image_key]: |
| 154 | + sample[Fields.stats][StatsKeys.phrase_grounding_recall] = np.array( |
| 155 | + [], dtype=np.float64) |
| 156 | + return sample |
| 157 | + |
| 158 | + # load images |
| 159 | + loaded_image_keys = sample[self.image_key] |
| 160 | + images = {} |
| 161 | + for loaded_image_key in loaded_image_keys: |
| 162 | + if context and loaded_image_key in sample[Fields.context]: |
| 163 | + # load from context |
| 164 | + images[loaded_image_key] = sample[ |
| 165 | + Fields.context][loaded_image_key] |
| 166 | + else: |
| 167 | + if loaded_image_key not in images: |
| 168 | + # avoid load the same images |
| 169 | + image = load_image(loaded_image_key) |
| 170 | + images[loaded_image_key] = image |
| 171 | + if context: |
| 172 | + # store the image data into context |
| 173 | + sample[Fields.context][loaded_image_key] = image |
| 174 | + |
| 175 | + text = sample[self.text_key] |
| 176 | + offset = 0 |
| 177 | + recalls = [] |
| 178 | + model, processor = get_model(self.model_key, |
| 179 | + model_type=self.model_type) |
| 180 | + |
| 181 | + for chunk in text.split(SpecialTokens.eoc): |
| 182 | + count = chunk.count(SpecialTokens.image) |
| 183 | + |
| 184 | + # no image or no text |
| 185 | + if count == 0 or len(chunk) == 0: |
| 186 | + continue |
| 187 | + else: |
| 188 | + text_this_chunk = remove_special_tokens(chunk) |
| 189 | + ners_this_chunk = run_ner(text_this_chunk) |
| 190 | + num_ners = len(ners_this_chunk) |
| 191 | + if num_ners <= 0: |
| 192 | + # no ners found, just skip this chunk |
| 193 | + recalls.append(1.0) |
| 194 | + continue |
| 195 | + images_this_chunk = [] |
| 196 | + for image_key in loaded_image_keys[offset:offset + count]: |
| 197 | + image = images[image_key] |
| 198 | + if self.horizontal_flip: |
| 199 | + image = ImageOps.mirror(image) |
| 200 | + if self.vertical_flip: |
| 201 | + image = ImageOps.flip(image) |
| 202 | + images_this_chunk.append(image) |
| 203 | + |
| 204 | + ners_batch = [ners_this_chunk] * len(images_this_chunk) |
| 205 | + inputs = processor(text=ners_batch, |
| 206 | + images=images_this_chunk, |
| 207 | + return_tensors='pt', |
| 208 | + padding=True, |
| 209 | + truncation=True) |
| 210 | + |
| 211 | + with torch.no_grad(): |
| 212 | + outputs = model(**inputs) |
| 213 | + target_sizes = torch.tensor( |
| 214 | + [img.size[::-1] for img in images_this_chunk]) |
| 215 | + results = processor.post_process_object_detection( |
| 216 | + outputs, |
| 217 | + threshold=self.conf_thr, |
| 218 | + target_sizes=target_sizes) |
| 219 | + |
| 220 | + image_recalls = [] |
| 221 | + for idx, result in enumerate(results): |
| 222 | + scores = result['scores'] |
| 223 | + labels = result['labels'] |
| 224 | + boxes = result['boxes'] |
| 225 | + |
| 226 | + # sort by the confidence scores |
| 227 | + # and only keep the first num_ners predictions |
| 228 | + order_idx = scores.argsort(descending=True) |
| 229 | + scores = scores[order_idx].tolist()[:num_ners] |
| 230 | + labels = labels[order_idx].tolist()[:num_ners] |
| 231 | + boxes = boxes[order_idx].tolist()[:num_ners] |
| 232 | + |
| 233 | + image_area = target_sizes[idx].prod() |
| 234 | + hit = {} |
| 235 | + for box, label, score in zip(boxes, labels, scores): |
| 236 | + # this ner is already hit |
| 237 | + if ners_this_chunk[label] in hit: |
| 238 | + continue |
| 239 | + # skip boxes nearly cover the whole image |
| 240 | + xmin, ymin, xmax, ymax = box |
| 241 | + box_area = (xmax - xmin) * (ymax - ymin) |
| 242 | + if 1.0 * box_area / image_area > \ |
| 243 | + self.large_area_ratio_thr: |
| 244 | + continue |
| 245 | + # skip overlapped boxes with nms-like method |
| 246 | + suppressed = False |
| 247 | + for ner in hit: |
| 248 | + if iou(box, hit[ner][0]) > self.iou_thr: |
| 249 | + suppressed = True |
| 250 | + break |
| 251 | + if suppressed: |
| 252 | + continue |
| 253 | + |
| 254 | + # record the new hit box |
| 255 | + hit[ners_this_chunk[label]] = (box, score) |
| 256 | + |
| 257 | + recall = 1.0 * len(hit) / num_ners |
| 258 | + image_recalls.append(recall) |
| 259 | + |
| 260 | + if self.reduce_mode == 'avg': |
| 261 | + image_recall = sum(image_recalls) / len(image_recalls) |
| 262 | + elif self.reduce_mode == 'max': |
| 263 | + image_recall = max(image_recalls) |
| 264 | + else: |
| 265 | + image_recall = min(image_recalls) |
| 266 | + |
| 267 | + recalls.append(image_recall) |
| 268 | + offset += count |
| 269 | + sample[Fields.stats][StatsKeys.phrase_grounding_recall] = recalls |
| 270 | + |
| 271 | + return sample |
| 272 | + |
| 273 | + def process(self, sample): |
| 274 | + recalls = sample[Fields.stats][StatsKeys.phrase_grounding_recall] |
| 275 | + if len(recalls) <= 0: |
| 276 | + return True |
| 277 | + |
| 278 | + keep_bools = np.array([ |
| 279 | + self.min_recall <= recall <= self.max_recall for recall in recalls |
| 280 | + ]) |
| 281 | + |
| 282 | + # different strategies |
| 283 | + if self.any: |
| 284 | + return keep_bools.any() |
| 285 | + else: |
| 286 | + return keep_bools.all() |
0 commit comments