|
| 1 | +import os |
| 2 | +import subprocess |
| 3 | +import warnings |
| 4 | +from pathlib import Path |
| 5 | +from typing import Union, Optional, Tuple |
| 6 | +import hashlib |
| 7 | + |
| 8 | +import numpy as np |
| 9 | +import pandas as pd |
| 10 | +import torch |
| 11 | +import zipfile |
| 12 | + |
| 13 | +from sklearn.model_selection import GroupKFold |
| 14 | + |
| 15 | +from pytorch_toolbelt.utils import fs |
| 16 | + |
| 17 | + |
| 18 | +__all__ = ["InriaAerialImageDataset"] |
| 19 | + |
| 20 | + |
| 21 | +class InriaAerialImageDataset: |
| 22 | + """ |
| 23 | + python -m pytorch_toolbelt.datasets.providers.inria_aerial inria_dataset |
| 24 | + """ |
| 25 | + |
| 26 | + TASK = "binary_segmentation" |
| 27 | + METRIC = "" |
| 28 | + ORIGIN = "https://project.inria.fr/aerialimagelabeling" |
| 29 | + TRAIN_LOCATIONS = ["austin", "chicago", "kitsap", "tyrol-w", "vienna"] |
| 30 | + TEST_LOCATIONS = ["bellingham", "bloomington", "innsbruck", "sfo", "tyrol-e"] |
| 31 | + |
| 32 | + urls = { |
| 33 | + "https://files.inria.fr/aerialimagelabeling/aerialimagelabeling.7z.001": "17a7d95c78e484328fd8fe5d5afa2b505e04b8db8fceb617819f3c935d1f39ec", |
| 34 | + "https://files.inria.fr/aerialimagelabeling/aerialimagelabeling.7z.002": "b505cb223964b157823e88fbd5b0bd041afcbf39427af3ca1ce981ff9f61aff4", |
| 35 | + "https://files.inria.fr/aerialimagelabeling/aerialimagelabeling.7z.003": "752916faa67be6fc6693f8559531598fa2798dc01b7d197263e911718038252e", |
| 36 | + "https://files.inria.fr/aerialimagelabeling/aerialimagelabeling.7z.004": "b3893e78f92572455fc2c811af560a558d2a57f9b92eff62fa41399b607a6f44", |
| 37 | + "https://files.inria.fr/aerialimagelabeling/aerialimagelabeling.7z.005": "a92eb20fdc9911c5ffe3afc514490b8f1e1e5b22301a6fc55d3b4e1624d8033f", |
| 38 | + } |
| 39 | + |
| 40 | + @classmethod |
| 41 | + def download_and_extract(cls, data_dir: Union[str, Path]) -> bool: |
| 42 | + try: |
| 43 | + from py7zr import py7zr |
| 44 | + except ImportError: |
| 45 | + print("You need to install py7zr to extract 7z-archive: `pip install py7zr`.") |
| 46 | + return False |
| 47 | + |
| 48 | + filenames = [] |
| 49 | + for file_url, file_hash in cls.urls.items(): |
| 50 | + file_path = os.path.join(data_dir, os.path.basename(file_url)) |
| 51 | + if not os.path.isfile(file_path) or cls.sha256digest(file_path) != file_hash: |
| 52 | + os.makedirs(data_dir, exist_ok=True) |
| 53 | + torch.hub.download_url_to_file(file_url, file_path) |
| 54 | + |
| 55 | + filenames.append(file_path) |
| 56 | + |
| 57 | + main_archive = os.path.join(data_dir, "aerialimagelabeling.7z") |
| 58 | + with open(main_archive, "ab") as outfile: # append in binary mode |
| 59 | + for fname in filenames: |
| 60 | + with open(fname, "rb") as infile: # open in binary mode also |
| 61 | + outfile.write(infile.read()) |
| 62 | + |
| 63 | + with py7zr.SevenZipFile(main_archive, "r") as archive: |
| 64 | + archive.extractall(data_dir) |
| 65 | + os.unlink(main_archive) |
| 66 | + |
| 67 | + zip_archive = os.path.join(data_dir, "NEW2-AerialImageDataset.zip") |
| 68 | + with zipfile.ZipFile(zip_archive, "r") as zip_ref: |
| 69 | + zip_ref.extractall(data_dir) |
| 70 | + os.unlink(zip_archive) |
| 71 | + return True |
| 72 | + |
| 73 | + @classmethod |
| 74 | + def init_from_folder(cls, data_dir: Union[str, Path], download: bool = False): |
| 75 | + data_dir = os.path.expanduser(data_dir) |
| 76 | + |
| 77 | + if download: |
| 78 | + if not cls.download_and_extract(data_dir): |
| 79 | + raise RuntimeError("Download and extract failed") |
| 80 | + |
| 81 | + return cls(os.path.join(data_dir, "AerialImageDataset")) |
| 82 | + |
| 83 | + @classmethod |
| 84 | + def sha256digest(cls, filename: str) -> str: |
| 85 | + blocksize = 4096 |
| 86 | + sha = hashlib.sha256() |
| 87 | + with open(filename, "rb") as f: |
| 88 | + file_buffer = f.read(blocksize) |
| 89 | + while len(file_buffer) > 0: |
| 90 | + sha.update(file_buffer) |
| 91 | + file_buffer = f.read(blocksize) |
| 92 | + readable_hash = sha.hexdigest() |
| 93 | + return readable_hash |
| 94 | + |
| 95 | + @classmethod |
| 96 | + def read_tiff( |
| 97 | + cls, image_fname: str, crop_coords: Optional[Tuple[Tuple[int, int], Tuple[int, int]]] = None |
| 98 | + ) -> np.ndarray: |
| 99 | + import rasterio |
| 100 | + from rasterio.windows import Window |
| 101 | + |
| 102 | + window = None |
| 103 | + if crop_coords is not None: |
| 104 | + (row_start, row_stop), (col_start, col_stop) = crop_coords |
| 105 | + window = Window.from_slices((row_start, row_stop), (col_start, col_stop)) |
| 106 | + |
| 107 | + if not os.path.isfile(image_fname): |
| 108 | + raise FileNotFoundError(image_fname) |
| 109 | + |
| 110 | + with warnings.catch_warnings(): |
| 111 | + warnings.filterwarnings("ignore", category=rasterio.errors.NotGeoreferencedWarning) |
| 112 | + |
| 113 | + with rasterio.open(image_fname) as f: |
| 114 | + image = f.read(window=window) |
| 115 | + image = np.moveaxis(image, 0, -1) # CHW->HWC |
| 116 | + if image.shape[2] == 1: |
| 117 | + image = np.squeeze(image, axis=2) |
| 118 | + return image |
| 119 | + |
| 120 | + @classmethod |
| 121 | + def compress_prediction_mask(cls, predicted_mask_fname, compressed_mask_fname): |
| 122 | + command = ( |
| 123 | + "gdal_translate --config GDAL_PAM_ENABLED NO -co COMPRESS=CCITTFAX4 -co NBITS=1 " |
| 124 | + + predicted_mask_fname |
| 125 | + + " " |
| 126 | + + compressed_mask_fname |
| 127 | + ) |
| 128 | + subprocess.call(command, shell=True) |
| 129 | + |
| 130 | + def __init__(self, root_dir: str): |
| 131 | + self.root_dir = root_dir |
| 132 | + self.train_dir = os.path.join(root_dir, "train") |
| 133 | + self.test_dir = os.path.join(root_dir, "test") |
| 134 | + |
| 135 | + if not os.path.isdir(self.train_dir): |
| 136 | + raise FileNotFoundError(f"Train directory {self.train_dir} does not exist") |
| 137 | + if not os.path.isdir(self.test_dir): |
| 138 | + raise FileNotFoundError(f"Test directory {self.train_dir} does not exist") |
| 139 | + |
| 140 | + self.train_images = fs.find_images_in_dir(os.path.join(self.train_dir, "images")) |
| 141 | + self.train_masks = fs.find_images_in_dir(os.path.join(self.train_dir, "gt")) |
| 142 | + |
| 143 | + if len(self.train_images) != 180 or len(self.train_masks) != 180: |
| 144 | + raise RuntimeError("Number of train images and ground-truth masks must be 180") |
| 145 | + |
| 146 | + def get_test_df(self) -> pd.DataFrame: |
| 147 | + test_images = fs.find_images_in_dir(os.path.join(self.test_dir, "images")) |
| 148 | + df = pd.DataFrame.from_dict({"images": test_images}) |
| 149 | + df["rows"] = 5000 |
| 150 | + df["cols"] = 5000 |
| 151 | + df["location"] = df["images"].apply(lambda x: fs.id_from_fname(x).rstrip("0123456789")) |
| 152 | + return df |
| 153 | + |
| 154 | + def get_train_val_split_train_df(self) -> pd.DataFrame: |
| 155 | + # For validation, we remove the first five images of every location |
| 156 | + # (e.g., austin{1-5}.tif, chicago{1-5}.tif) from the training set. |
| 157 | + # That is suggested validation strategy by competition host |
| 158 | + valid_locations = [] |
| 159 | + for loc in self.TRAIN_LOCATIONS: |
| 160 | + for i in range(1, 6): |
| 161 | + valid_locations.append(f"{loc}{i}") |
| 162 | + |
| 163 | + df = pd.DataFrame.from_dict({"images": self.train_images, "masks": self.train_masks}) |
| 164 | + df["location_with_index"] = df["images"].apply(lambda x: fs.id_from_fname(x)) |
| 165 | + df["location"] = df["location_with_index"].apply(lambda x: x.rstrip("0123456789")) |
| 166 | + df["split"] = df["location_with_index"].apply(lambda l: "valid" if l in valid_locations else "train") |
| 167 | + df["rows"] = 5000 |
| 168 | + df["cols"] = 5000 |
| 169 | + return df |
| 170 | + |
| 171 | + def get_kfold_split_train_df(self, num_folds: int = 5) -> pd.DataFrame: |
| 172 | + df = pd.DataFrame.from_dict({"images": self.train_images, "masks": self.train_masks}) |
| 173 | + df["location_with_index"] = df["images"].apply(lambda x: fs.id_from_fname(x)) |
| 174 | + df["location"] = df["location_with_index"].apply(lambda x: x.rstrip("0123456789")) |
| 175 | + df["rows"] = 5000 |
| 176 | + df["cols"] = 5000 |
| 177 | + df["fold"] = -1 |
| 178 | + kfold = GroupKFold(n_splits=num_folds) |
| 179 | + for fold, (train_index, test_index) in enumerate(kfold.split(df, df, groups=df["location"])): |
| 180 | + df.loc[test_index, "fold"] = fold |
| 181 | + return df |
| 182 | + |
| 183 | + |
| 184 | +def download_and_extract(data_dir): |
| 185 | + ds = InriaAerialImageDataset.init_from_folder(data_dir, download=True) |
| 186 | + print(ds.get_test_df()) |
| 187 | + print(ds.get_train_val_split_train_df()) |
| 188 | + print(ds.get_kfold_split_train_df()) |
| 189 | + |
| 190 | + |
| 191 | +if __name__ == "__main__": |
| 192 | + from fire import Fire |
| 193 | + |
| 194 | + Fire(download_and_extract) |
0 commit comments