Skip to content

Commit 98897e4

Browse files
authored
feat: xtd-10 dataset (#134)
* add xtd10 dataset * fixes * fix dataset dir * add de, fr and jp * fix * add log * fix * update readme
1 parent 0cfe9c2 commit 98897e4

File tree

4 files changed

+153
-3
lines changed

4 files changed

+153
-3
lines changed

.gitignore

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,5 @@ features
5353
root
5454
cifar-100-*
5555
probe_benchmark/data
56-
datasets
5756
downloads
5857
.env

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,10 @@ For Flickr-8k (zero-shot retrieval)
214214

215215
- `clip_benchmark eval --model xlm-roberta-base-ViT-B-32 --pretrained laion5b_s13b_b90k --dataset=flickr8k --output=result.json --batch_size=64 --language=<LANG>`, where `<LANG>` can be among `en` (english), `zh` (chinese).
216216

217+
For XTD-10 (zero-shot retrieval)
218+
219+
- `clip_benchmark eval --model xlm-roberta-base-ViT-B-32 --pretrained laion5b_s13b_b90k --dataset=xtd10 --output=result.json --batch_size=64 --language=<LANG>`, where `<LANG>` can be among `es` (spanish), `it` (italian), `jp` (japanese), `ko` (korean), `pl` (polish), `ru` (russian), `tr` (Turkish), `zh` (chinese), `en` (english), `fr` (french), `de` (german).
220+
217221
For [Crossmodal-3600](https://google.github.io/crossmodal-3600/) (zero-shot retrieval)
218222

219223
- `clip_benchmark eval --model xlm-roberta-base-ViT-B-32 --pretrained laion5b_s13b_b90k --dataset=crossmodal3600 --output=result.json --batch_size=64 --language=<LANG>`, see supported languages [here](https://github.com/LAION-AI/CLIP_benchmark/blob/main/clip_benchmark/datasets/crossmodal3600.py#L9).

clip_benchmark/datasets/builder.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,16 @@ def download_imagenet(r):
283283
crossmodal3600.create_annotation_file(root, language)
284284

285285
ds = crossmodal3600.Crossmodal3600(root=root, ann_file=annotation_file, transform=transform, **kwargs)
286+
elif dataset_name == 'xtd10':
287+
from clip_benchmark.datasets import xtd10
288+
if language not in xtd10.SUPPORTED_LANGUAGES:
289+
raise ValueError("Unsupported language for xtd10:", language)
290+
291+
annotation_file = os.path.join(root, xtd10.OUTPUT_FILENAME_TEMPLATE.format(language))
292+
if not os.path.exists(annotation_file):
293+
xtd10.create_annotation_file(root, language)
294+
295+
ds = xtd10.XTD10(root=root, ann_file=annotation_file, transform=transform, **kwargs)
286296
elif dataset_name == 'xtd200':
287297
from clip_benchmark.datasets import xtd200
288298
if language not in xtd200.SUPPORTED_LANGUAGES:
@@ -523,15 +533,15 @@ def __len__(self):
523533
return 1
524534

525535
def get_dataset_default_task(dataset):
526-
if dataset in ("flickr30k", "flickr8k", "mscoco_captions", "multilingual_mscoco_captions", "flickr30k-200", "crossmodal3600", "xtd200"):
536+
if dataset in ("flickr30k", "flickr8k", "mscoco_captions", "multilingual_mscoco_captions", "flickr30k-200", "crossmodal3600", "xtd10", "xtd200"):
527537
return "zeroshot_retrieval"
528538
elif dataset.startswith("sugar_crepe") or dataset == "winoground":
529539
return "image_caption_selection"
530540
else:
531541
return "zeroshot_classification"
532542

533543
def get_dataset_collate_fn(dataset_name):
534-
if dataset_name in ("mscoco_captions", "multilingual_mscoco_captions", "flickr30k", "flickr8k", "flickr30k-200", "crossmodal3600", "xtd200", "winoground") or dataset_name.startswith("sugar_crepe"):
544+
if dataset_name in ("mscoco_captions", "multilingual_mscoco_captions", "flickr30k", "flickr8k", "flickr30k-200", "crossmodal3600", "xtd10", "xtd200", "winoground") or dataset_name.startswith("sugar_crepe"):
535545
return image_captions_collate_fn
536546
else:
537547
return default_collate

clip_benchmark/datasets/xtd10.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
import codecs
2+
import json
3+
import os
4+
from subprocess import call
5+
6+
import requests
7+
from PIL import Image
8+
from torchvision.datasets import VisionDataset
9+
10+
11+
GITHUB_DATA_PATH = "https://raw.githubusercontent.com/adobe-research/Cross-lingual-Test-Dataset-XTD10/main/XTD10"
12+
GITHUB_MIC_DATA_PATH = "https://raw.githubusercontent.com/adobe-research/Cross-lingual-Test-Dataset-XTD10/main/MIC"
13+
GITHUB_STAIR_DATA_PATH = "https://raw.githubusercontent.com/adobe-research/Cross-lingual-Test-Dataset-XTD10/main/STAIR"
14+
SUPPORTED_LANGUAGES = [
15+
"de",
16+
"en",
17+
"es",
18+
"fr",
19+
"it",
20+
"jp",
21+
"ko",
22+
"pl",
23+
"ru",
24+
"tr",
25+
"zh",
26+
]
27+
28+
IMAGE_INDEX_FILENAME = "test_image_names.txt"
29+
30+
CAPTIONS_FILENAME_TEMPLATE = "test_1kcaptions_{}.txt"
31+
OUTPUT_FILENAME_TEMPLATE = "xtd10-{}.json"
32+
33+
IMAGES_DOWNLOAD_URL = "https://nllb-data.com/test/xtd10/images.tar.gz"
34+
35+
36+
class XTD10(VisionDataset):
37+
def __init__(self, root, ann_file, transform=None, target_transform=None):
38+
super().__init__(root, transform=transform, target_transform=target_transform)
39+
self.ann_file = os.path.expanduser(ann_file)
40+
with codecs.open(ann_file, "r", encoding="utf-8") as fp:
41+
data = json.load(fp)
42+
self.data = [
43+
(img_path, txt)
44+
for img_path, txt in zip(data["image_paths"], data["annotations"])
45+
]
46+
47+
def __getitem__(self, index):
48+
img, captions = self.data[index]
49+
50+
# Image
51+
img = Image.open(img).convert("RGB")
52+
if self.transform is not None:
53+
img = self.transform(img)
54+
55+
# Captions
56+
target = [
57+
captions,
58+
]
59+
if self.target_transform is not None:
60+
target = self.target_transform(target)
61+
62+
return img, target
63+
64+
def __len__(self) -> int:
65+
return len(self.data)
66+
67+
68+
def _get_lines(url):
69+
response = requests.get(url, timeout=30)
70+
return response.text.splitlines()
71+
72+
73+
def _download_images(out_path):
74+
os.makedirs(out_path, exist_ok=True)
75+
print("Downloading images")
76+
call(f"wget {IMAGES_DOWNLOAD_URL} -O images.tar.gz", shell=True)
77+
call(f"tar -xzf images.tar.gz -C {out_path}", shell=True)
78+
call("rm images.tar.gz", shell=True)
79+
80+
81+
def create_annotation_file(root, lang_code):
82+
if lang_code not in SUPPORTED_LANGUAGES:
83+
raise ValueError(
84+
f"Language code {lang_code} not supported. Supported languages are {SUPPORTED_LANGUAGES}"
85+
)
86+
data_dir = os.path.join(root, "xtd10")
87+
if not os.path.exists(data_dir):
88+
_download_images(data_dir)
89+
images_dir = os.path.join(data_dir, "images")
90+
print("Downloading xtd10 index file")
91+
download_path = os.path.join(GITHUB_DATA_PATH, IMAGE_INDEX_FILENAME)
92+
target_images = _get_lines(download_path)
93+
94+
print("Downloading xtd10 captions:", lang_code)
95+
captions_path = GITHUB_DATA_PATH
96+
match lang_code:
97+
case "de" | "fr":
98+
captions_path = GITHUB_MIC_DATA_PATH
99+
case "jp":
100+
captions_path = GITHUB_STAIR_DATA_PATH
101+
case _:
102+
captions_path = GITHUB_DATA_PATH
103+
download_path = os.path.join(
104+
captions_path, CAPTIONS_FILENAME_TEMPLATE.format(lang_code)
105+
)
106+
target_captions = _get_lines(download_path)
107+
108+
number_of_missing_images = 0
109+
valid_images, valid_annotations, valid_indicies = [], [], []
110+
for i, (img, txt) in enumerate(zip(target_images, target_captions)):
111+
image_path = os.path.join(images_dir, img)
112+
if not os.path.exists(image_path):
113+
print("Missing image file", img)
114+
number_of_missing_images += 1
115+
continue
116+
117+
valid_images.append(image_path)
118+
valid_annotations.append(txt)
119+
valid_indicies.append(i)
120+
121+
if number_of_missing_images > 0:
122+
print(f"*** WARNING *** missing {number_of_missing_images} files.")
123+
124+
with codecs.open(
125+
os.path.join(root, OUTPUT_FILENAME_TEMPLATE.format(lang_code)),
126+
"w",
127+
encoding="utf-8",
128+
) as fp:
129+
json.dump(
130+
{
131+
"image_paths": valid_images,
132+
"annotations": valid_annotations,
133+
"indicies": valid_indicies,
134+
},
135+
fp,
136+
ensure_ascii=False,
137+
)

0 commit comments

Comments
 (0)