|
| 1 | +import codecs |
| 2 | +import json |
| 3 | +import os |
| 4 | +from subprocess import call |
| 5 | + |
| 6 | +import requests |
| 7 | +from PIL import Image |
| 8 | +from torchvision.datasets import VisionDataset |
| 9 | + |
| 10 | + |
| 11 | +GITHUB_DATA_PATH = "https://raw.githubusercontent.com/adobe-research/Cross-lingual-Test-Dataset-XTD10/main/XTD10" |
| 12 | +GITHUB_MIC_DATA_PATH = "https://raw.githubusercontent.com/adobe-research/Cross-lingual-Test-Dataset-XTD10/main/MIC" |
| 13 | +GITHUB_STAIR_DATA_PATH = "https://raw.githubusercontent.com/adobe-research/Cross-lingual-Test-Dataset-XTD10/main/STAIR" |
| 14 | +SUPPORTED_LANGUAGES = [ |
| 15 | + "de", |
| 16 | + "en", |
| 17 | + "es", |
| 18 | + "fr", |
| 19 | + "it", |
| 20 | + "jp", |
| 21 | + "ko", |
| 22 | + "pl", |
| 23 | + "ru", |
| 24 | + "tr", |
| 25 | + "zh", |
| 26 | +] |
| 27 | + |
| 28 | +IMAGE_INDEX_FILENAME = "test_image_names.txt" |
| 29 | + |
| 30 | +CAPTIONS_FILENAME_TEMPLATE = "test_1kcaptions_{}.txt" |
| 31 | +OUTPUT_FILENAME_TEMPLATE = "xtd10-{}.json" |
| 32 | + |
| 33 | +IMAGES_DOWNLOAD_URL = "https://nllb-data.com/test/xtd10/images.tar.gz" |
| 34 | + |
| 35 | + |
| 36 | +class XTD10(VisionDataset): |
| 37 | + def __init__(self, root, ann_file, transform=None, target_transform=None): |
| 38 | + super().__init__(root, transform=transform, target_transform=target_transform) |
| 39 | + self.ann_file = os.path.expanduser(ann_file) |
| 40 | + with codecs.open(ann_file, "r", encoding="utf-8") as fp: |
| 41 | + data = json.load(fp) |
| 42 | + self.data = [ |
| 43 | + (img_path, txt) |
| 44 | + for img_path, txt in zip(data["image_paths"], data["annotations"]) |
| 45 | + ] |
| 46 | + |
| 47 | + def __getitem__(self, index): |
| 48 | + img, captions = self.data[index] |
| 49 | + |
| 50 | + # Image |
| 51 | + img = Image.open(img).convert("RGB") |
| 52 | + if self.transform is not None: |
| 53 | + img = self.transform(img) |
| 54 | + |
| 55 | + # Captions |
| 56 | + target = [ |
| 57 | + captions, |
| 58 | + ] |
| 59 | + if self.target_transform is not None: |
| 60 | + target = self.target_transform(target) |
| 61 | + |
| 62 | + return img, target |
| 63 | + |
| 64 | + def __len__(self) -> int: |
| 65 | + return len(self.data) |
| 66 | + |
| 67 | + |
| 68 | +def _get_lines(url): |
| 69 | + response = requests.get(url, timeout=30) |
| 70 | + return response.text.splitlines() |
| 71 | + |
| 72 | + |
| 73 | +def _download_images(out_path): |
| 74 | + os.makedirs(out_path, exist_ok=True) |
| 75 | + print("Downloading images") |
| 76 | + call(f"wget {IMAGES_DOWNLOAD_URL} -O images.tar.gz", shell=True) |
| 77 | + call(f"tar -xzf images.tar.gz -C {out_path}", shell=True) |
| 78 | + call("rm images.tar.gz", shell=True) |
| 79 | + |
| 80 | + |
| 81 | +def create_annotation_file(root, lang_code): |
| 82 | + if lang_code not in SUPPORTED_LANGUAGES: |
| 83 | + raise ValueError( |
| 84 | + f"Language code {lang_code} not supported. Supported languages are {SUPPORTED_LANGUAGES}" |
| 85 | + ) |
| 86 | + data_dir = os.path.join(root, "xtd10") |
| 87 | + if not os.path.exists(data_dir): |
| 88 | + _download_images(data_dir) |
| 89 | + images_dir = os.path.join(data_dir, "images") |
| 90 | + print("Downloading xtd10 index file") |
| 91 | + download_path = os.path.join(GITHUB_DATA_PATH, IMAGE_INDEX_FILENAME) |
| 92 | + target_images = _get_lines(download_path) |
| 93 | + |
| 94 | + print("Downloading xtd10 captions:", lang_code) |
| 95 | + captions_path = GITHUB_DATA_PATH |
| 96 | + match lang_code: |
| 97 | + case "de" | "fr": |
| 98 | + captions_path = GITHUB_MIC_DATA_PATH |
| 99 | + case "jp": |
| 100 | + captions_path = GITHUB_STAIR_DATA_PATH |
| 101 | + case _: |
| 102 | + captions_path = GITHUB_DATA_PATH |
| 103 | + download_path = os.path.join( |
| 104 | + captions_path, CAPTIONS_FILENAME_TEMPLATE.format(lang_code) |
| 105 | + ) |
| 106 | + target_captions = _get_lines(download_path) |
| 107 | + |
| 108 | + number_of_missing_images = 0 |
| 109 | + valid_images, valid_annotations, valid_indicies = [], [], [] |
| 110 | + for i, (img, txt) in enumerate(zip(target_images, target_captions)): |
| 111 | + image_path = os.path.join(images_dir, img) |
| 112 | + if not os.path.exists(image_path): |
| 113 | + print("Missing image file", img) |
| 114 | + number_of_missing_images += 1 |
| 115 | + continue |
| 116 | + |
| 117 | + valid_images.append(image_path) |
| 118 | + valid_annotations.append(txt) |
| 119 | + valid_indicies.append(i) |
| 120 | + |
| 121 | + if number_of_missing_images > 0: |
| 122 | + print(f"*** WARNING *** missing {number_of_missing_images} files.") |
| 123 | + |
| 124 | + with codecs.open( |
| 125 | + os.path.join(root, OUTPUT_FILENAME_TEMPLATE.format(lang_code)), |
| 126 | + "w", |
| 127 | + encoding="utf-8", |
| 128 | + ) as fp: |
| 129 | + json.dump( |
| 130 | + { |
| 131 | + "image_paths": valid_images, |
| 132 | + "annotations": valid_annotations, |
| 133 | + "indicies": valid_indicies, |
| 134 | + }, |
| 135 | + fp, |
| 136 | + ensure_ascii=False, |
| 137 | + ) |
0 commit comments