Skip to content

Commit ad445c9

Browse files
authored
Feature/phrase grounding recall filter (#139)
* + add phrase_grounding_recall_filter op * * update new OP + make llava_to_dj and dj_to_llava support only_keep_caption mode * + Add unit test for phrase_grounding * + remove hf model caches automatically for unittest * * download required nltk data when initializing the phrase_grounding_recall_filter * * output the cleaning log when the cleaning actually happens * * update Operator docs * * fix some typos * * removing hf models automatically after unit test is finished for clip and blip
1 parent 0431f25 commit ad445c9

File tree

15 files changed

+957
-154
lines changed

15 files changed

+957
-154
lines changed

configs/config_all.yaml

Lines changed: 35 additions & 24 deletions
Large diffs are not rendered by default.

data_juicer/ops/filter/__init__.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55
image_shape_filter, image_size_filter,
66
image_text_matching_filter, image_text_similarity_filter,
77
language_id_score_filter, maximum_line_length_filter,
8-
perplexity_filter, special_characters_filter,
9-
specified_field_filter, specified_numeric_field_filter,
10-
stopwords_filter, suffix_filter, text_action_filter,
11-
text_entity_dependency_filter, text_length_filter,
12-
token_num_filter, word_num_filter, word_repetition_filter)
8+
perplexity_filter, phrase_grounding_recall_filter,
9+
special_characters_filter, specified_field_filter,
10+
specified_numeric_field_filter, stopwords_filter, suffix_filter,
11+
text_action_filter, text_entity_dependency_filter,
12+
text_length_filter, token_num_filter, word_num_filter,
13+
word_repetition_filter)
1314

1415
# yapf: enable
Lines changed: 286 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,286 @@
1+
from typing import List
2+
3+
import numpy as np
4+
from jsonargparse.typing import ClosedUnitInterval
5+
from loguru import logger
6+
from PIL import ImageOps
7+
8+
from data_juicer.utils.availability_utils import AvailabilityChecking
9+
from data_juicer.utils.constant import Fields, StatsKeys
10+
from data_juicer.utils.mm_utils import (SpecialTokens, iou, load_image,
11+
remove_special_tokens)
12+
from data_juicer.utils.model_utils import get_model, prepare_model
13+
14+
from ..base_op import OPERATORS, Filter
15+
from ..op_fusion import LOADED_IMAGES
16+
17+
OP_NAME = 'phrase_grounding_recall_filter'
18+
19+
with AvailabilityChecking(['torch', 'transformers', 'nltk'], OP_NAME):
20+
21+
import torch
22+
import transformers # noqa: F401
23+
24+
# avoid hanging when calling clip in multiprocessing
25+
torch.set_num_threads(1)
26+
27+
import nltk
28+
29+
30+
# NER algorithm adapted from GLIP starts
31+
# https://github.com/microsoft/GLIP/blob/main/maskrcnn_benchmark/engine/predictor_glip.py#L107-L127
32+
def find_noun_phrases(caption: str) -> List[str]:
33+
caption = caption.lower()
34+
tokens = nltk.word_tokenize(caption)
35+
pos_tags = nltk.pos_tag(tokens)
36+
37+
grammar = 'NP: {<DT>?<JJ.*>*<NN.*>+}'
38+
cp = nltk.RegexpParser(grammar)
39+
result = cp.parse(pos_tags)
40+
41+
noun_phrases = list()
42+
for subtree in result.subtrees():
43+
if subtree.label() == 'NP':
44+
noun_phrases.append(' '.join(t[0] for t in subtree.leaves()))
45+
46+
return noun_phrases
47+
48+
49+
def remove_punctuation(text: str) -> str:
50+
punct = [
51+
'|', ':', ';', '@', '(', ')', '[', ']', '{', '}', '^', '\'', '\"', '’',
52+
'`', '?', '$', '%', '#', '!', '&', '*', '+', ',', '.'
53+
]
54+
for p in punct:
55+
text = text.replace(p, '')
56+
return text.strip()
57+
58+
59+
def run_ner(caption):
60+
noun_phrases = find_noun_phrases(caption)
61+
noun_phrases = [remove_punctuation(phrase) for phrase in noun_phrases]
62+
noun_phrases = [phrase for phrase in noun_phrases if phrase != '']
63+
noun_phrases = list(set(noun_phrases)) # remove duplicate ners
64+
return noun_phrases
65+
66+
67+
# NER algorithm adapted from GLIP ends
68+
69+
70+
@OPERATORS.register_module(OP_NAME)
71+
@LOADED_IMAGES.register_module(OP_NAME)
72+
class PhraseGroundingRecallFilter(Filter):
73+
"""Filter to keep samples whose locating recalls of phrases extracted
74+
from text in the images are within a specified range."""
75+
76+
def __init__(self,
77+
hf_owlvit='google/owlvit-base-patch32',
78+
min_recall: ClosedUnitInterval = 0.1,
79+
max_recall: ClosedUnitInterval = 1.0,
80+
horizontal_flip: bool = False,
81+
vertical_flip: bool = False,
82+
any_or_all: str = 'any',
83+
reduce_mode: str = 'avg',
84+
iou_thr: ClosedUnitInterval = 0.5,
85+
large_area_ratio_thr: ClosedUnitInterval = 0.95,
86+
conf_thr: ClosedUnitInterval = 0.0,
87+
*args,
88+
**kwargs):
89+
"""
90+
Initialization method.
91+
92+
:param hf_owlvit: Owl-ViT model name on huggingface to locate the
93+
phrases extracted from the text.
94+
:param min_recall: The min phrase grounding recall to keep samples.
95+
:param max_recall: The max phrase grounding recall to keep samples.
96+
:param horizontal_flip: Flip image horizontally (left to right).
97+
:param vertical_flip: Flip image vertically (top to bottom).
98+
:param any_or_all: keep this sample with 'any' or 'all' strategy of
99+
all images. 'any': keep this sample if any images meet the
100+
condition. 'all': keep this sample only if all images meet the
101+
condition.
102+
:param reduce_mode: reduce mode when one text corresponds to
103+
multiple images in a chunk.
104+
'avg': Take the average of multiple values
105+
'max': Take the max of multiple values
106+
'min': Take the min of multiple values
107+
:param iou_thr: the IoU threshold for NMS-like post-process. If two
108+
predicted bboxes are overlap with an IoU larger than this
109+
threshold, the bbox with less confidence will be removed. Default:
110+
0.5.
111+
:param large_area_ratio_thr: the area ratio threshold for filtering out
112+
those large predicted bboxes. If the area of a predicted bbox
113+
accounts for more than this ratio threshold of the whole image
114+
area, this bbox will be removed. Default: 0.95.
115+
:param conf_thr: the confidence score threshold for removing
116+
low-confidence bboxes. If the confidence score of a predicted bbox
117+
is lower than the threshold, this bbox will be removed. Default: 0.
118+
:param args: extra args
119+
:param kwargs: extra args
120+
"""
121+
super().__init__(*args, **kwargs)
122+
self.min_recall = min_recall
123+
self.max_recall = max_recall
124+
if reduce_mode not in ['avg', 'max', 'min']:
125+
raise ValueError(f'Reduce mode [{reduce_mode}] is not supported. '
126+
f'Can only be one of ["avg", "max", "min"].')
127+
if any_or_all not in ['any', 'all']:
128+
raise ValueError(f'Keep strategy [{any_or_all}] is not supported. '
129+
f'Can only be one of ["any", "all"].')
130+
self.any = (any_or_all == 'any')
131+
self.model_type = 'hf_owlvit'
132+
self.model_key = prepare_model(model_type=self.model_type,
133+
model_key=hf_owlvit)
134+
self.reduce_mode = reduce_mode
135+
self.horizontal_flip = horizontal_flip
136+
self.vertical_flip = vertical_flip
137+
138+
self.iou_thr = iou_thr
139+
self.large_area_ratio_thr = large_area_ratio_thr
140+
self.conf_thr = conf_thr
141+
142+
requires_nltk_data = ['punkt', 'averaged_perceptron_tagger']
143+
logger.info(f'Downloading nltk data of {requires_nltk_data}...')
144+
for nltk_data_pkg in requires_nltk_data:
145+
nltk.download(nltk_data_pkg)
146+
147+
def compute_stats(self, sample, context=False):
148+
# check if it's computed already
149+
if StatsKeys.phrase_grounding_recall in sample[Fields.stats]:
150+
return sample
151+
152+
# there is no image in this sample
153+
if self.image_key not in sample or not sample[self.image_key]:
154+
sample[Fields.stats][StatsKeys.phrase_grounding_recall] = np.array(
155+
[], dtype=np.float64)
156+
return sample
157+
158+
# load images
159+
loaded_image_keys = sample[self.image_key]
160+
images = {}
161+
for loaded_image_key in loaded_image_keys:
162+
if context and loaded_image_key in sample[Fields.context]:
163+
# load from context
164+
images[loaded_image_key] = sample[
165+
Fields.context][loaded_image_key]
166+
else:
167+
if loaded_image_key not in images:
168+
# avoid load the same images
169+
image = load_image(loaded_image_key)
170+
images[loaded_image_key] = image
171+
if context:
172+
# store the image data into context
173+
sample[Fields.context][loaded_image_key] = image
174+
175+
text = sample[self.text_key]
176+
offset = 0
177+
recalls = []
178+
model, processor = get_model(self.model_key,
179+
model_type=self.model_type)
180+
181+
for chunk in text.split(SpecialTokens.eoc):
182+
count = chunk.count(SpecialTokens.image)
183+
184+
# no image or no text
185+
if count == 0 or len(chunk) == 0:
186+
continue
187+
else:
188+
text_this_chunk = remove_special_tokens(chunk)
189+
ners_this_chunk = run_ner(text_this_chunk)
190+
num_ners = len(ners_this_chunk)
191+
if num_ners <= 0:
192+
# no ners found, just skip this chunk
193+
recalls.append(1.0)
194+
continue
195+
images_this_chunk = []
196+
for image_key in loaded_image_keys[offset:offset + count]:
197+
image = images[image_key]
198+
if self.horizontal_flip:
199+
image = ImageOps.mirror(image)
200+
if self.vertical_flip:
201+
image = ImageOps.flip(image)
202+
images_this_chunk.append(image)
203+
204+
ners_batch = [ners_this_chunk] * len(images_this_chunk)
205+
inputs = processor(text=ners_batch,
206+
images=images_this_chunk,
207+
return_tensors='pt',
208+
padding=True,
209+
truncation=True)
210+
211+
with torch.no_grad():
212+
outputs = model(**inputs)
213+
target_sizes = torch.tensor(
214+
[img.size[::-1] for img in images_this_chunk])
215+
results = processor.post_process_object_detection(
216+
outputs,
217+
threshold=self.conf_thr,
218+
target_sizes=target_sizes)
219+
220+
image_recalls = []
221+
for idx, result in enumerate(results):
222+
scores = result['scores']
223+
labels = result['labels']
224+
boxes = result['boxes']
225+
226+
# sort by the confidence scores
227+
# and only keep the first num_ners predictions
228+
order_idx = scores.argsort(descending=True)
229+
scores = scores[order_idx].tolist()[:num_ners]
230+
labels = labels[order_idx].tolist()[:num_ners]
231+
boxes = boxes[order_idx].tolist()[:num_ners]
232+
233+
image_area = target_sizes[idx].prod()
234+
hit = {}
235+
for box, label, score in zip(boxes, labels, scores):
236+
# this ner is already hit
237+
if ners_this_chunk[label] in hit:
238+
continue
239+
# skip boxes nearly cover the whole image
240+
xmin, ymin, xmax, ymax = box
241+
box_area = (xmax - xmin) * (ymax - ymin)
242+
if 1.0 * box_area / image_area > \
243+
self.large_area_ratio_thr:
244+
continue
245+
# skip overlapped boxes with nms-like method
246+
suppressed = False
247+
for ner in hit:
248+
if iou(box, hit[ner][0]) > self.iou_thr:
249+
suppressed = True
250+
break
251+
if suppressed:
252+
continue
253+
254+
# record the new hit box
255+
hit[ners_this_chunk[label]] = (box, score)
256+
257+
recall = 1.0 * len(hit) / num_ners
258+
image_recalls.append(recall)
259+
260+
if self.reduce_mode == 'avg':
261+
image_recall = sum(image_recalls) / len(image_recalls)
262+
elif self.reduce_mode == 'max':
263+
image_recall = max(image_recalls)
264+
else:
265+
image_recall = min(image_recalls)
266+
267+
recalls.append(image_recall)
268+
offset += count
269+
sample[Fields.stats][StatsKeys.phrase_grounding_recall] = recalls
270+
271+
return sample
272+
273+
def process(self, sample):
274+
recalls = sample[Fields.stats][StatsKeys.phrase_grounding_recall]
275+
if len(recalls) <= 0:
276+
return True
277+
278+
keep_bools = np.array([
279+
self.min_recall <= recall <= self.max_recall for recall in recalls
280+
])
281+
282+
# different strategies
283+
if self.any:
284+
return keep_bools.any()
285+
else:
286+
return keep_bools.all()

data_juicer/utils/constant.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ class StatsKeysConstant(object):
129129
# multimodal
130130
image_text_similarity = 'image_text_similarity'
131131
image_text_matching_score = 'image_text_matching_score'
132+
phrase_grounding_recall = 'phrase_grounding_recall'
132133

133134

134135
class StatsKeys(object, metaclass=StatsKeysMeta):

data_juicer/utils/mm_utils.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,30 +26,21 @@ def get_special_tokens():
2626

2727
def remove_special_tokens(text):
2828
for value in get_special_tokens().values():
29-
text = text.replace(value, '')
29+
text = text.replace(value, '').strip()
3030
return text
3131

3232

33+
# Image
3334
def load_images(paths):
3435
return [load_image(path) for path in paths]
3536

3637

37-
def load_audios(paths):
38-
return [load_audio(path) for path in paths]
39-
40-
4138
def load_image(path):
4239
img_feature = Image()
4340
img = img_feature.decode_example(img_feature.encode_example(path))
4441
return img
4542

4643

47-
def load_audio(path, sampling_rate=None):
48-
aud_feature = Audio(sampling_rate)
49-
aud = aud_feature.decode_example(aud_feature.encode_example(path))
50-
return (aud['array'], aud['sampling_rate'])
51-
52-
5344
def pil_to_opencv(pil_image):
5445
if pil_image.mode != 'RGB':
5546
pil_image = pil_image.convert('RGB')
@@ -64,6 +55,32 @@ def get_image_size(path, ):
6455
return os.path.getsize(path)
6556

6657

58+
def iou(box1, box2):
59+
x1_min, y1_min, x1_max, y1_max = box1
60+
x2_min, y2_min, x2_max, y2_max = box2
61+
area1 = (x1_max - x1_min) * (y1_max - y1_min)
62+
area2 = (x2_max - x2_min) * (y2_max - y2_min)
63+
ix_min = max(x1_min, x2_min)
64+
ix_max = min(x1_max, x2_max)
65+
iy_min = max(y1_min, y2_min)
66+
iy_max = min(y1_max, y2_max)
67+
intersection = max(0, (ix_max - ix_min) * (iy_max - iy_min))
68+
union = area1 + area2 - intersection
69+
return 1.0 * intersection / union
70+
71+
72+
# Audio
73+
def load_audios(paths):
74+
return [load_audio(path) for path in paths]
75+
76+
77+
def load_audio(path, sampling_rate=None):
78+
aud_feature = Audio(sampling_rate)
79+
aud = aud_feature.decode_example(aud_feature.encode_example(path))
80+
return aud['array'], aud['sampling_rate']
81+
82+
83+
# Others
6784
def size_to_bytes(size):
6885
alphabets_list = [char for char in size if char.isalpha()]
6986
numbers_list = [char for char in size if char.isdigit()]

0 commit comments

Comments
 (0)