Skip to content

Commit 896ab13

Browse files
authored
Merge pull request #79 from BloodAxe/feature/painless_sota
Feature/painless sota
2 parents 0d8c697 + 6abace9 commit 896ab13

File tree

21 files changed

+382
-108
lines changed

21 files changed

+382
-108
lines changed

.github/workflows/ci.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ jobs:
3333
- name: Install dependencies
3434
run: pip install .[${{ matrix.pytorch-toolbelt-version }}]
3535
- name: Install linters
36-
run: pip install flake8==3.8.4 flake8-docstrings==1.5.0
36+
run: pip install flake8==5
3737
- name: Run PyTest
3838
run: pytest
3939
- name: Run Flake8
@@ -48,7 +48,7 @@ jobs:
4848
runs-on: ubuntu-latest
4949
strategy:
5050
matrix:
51-
python-version: [3.8]
51+
python-version: [3.8, 3.9]
5252
steps:
5353
- name: Checkout
5454
uses: actions/checkout@v2
@@ -59,6 +59,6 @@ jobs:
5959
- name: Update pip
6060
run: python -m pip install --upgrade pip
6161
- name: Install Black
62-
run: pip install black==22.3.0
62+
run: pip install black==22.10.0
6363
- name: Run Black
6464
run: black --config=black.toml --check .

.github/workflows/upload_to_pypi.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@ jobs:
88
upload:
99
runs-on: ubuntu-latest
1010
steps:
11-
- uses: actions/checkout@v2
11+
- uses: actions/checkout
1212
- name: Set up Python
13-
uses: actions/setup-python@v2
13+
uses: actions/setup-python
1414
with:
1515
python-version: '3.8'
1616
- name: Install dependencies

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,4 @@ var/
2020
.pytest_cache/
2121
/tests/tta_eval.csv
2222
/tests/tmp.onnx
23-
/tests/test_plot_confusion_matrix.png
23+
/tests/test_plot_confusion_matrix.png

pytorch_toolbelt/datasets/common.py

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,30 +6,30 @@
66
"INPUT_INDEX_KEY",
77
"OUTPUT_EMBEDDINGS_KEY",
88
"OUTPUT_LOGITS_KEY",
9-
"OUTPUT_MASK_16_KEY",
10-
"OUTPUT_MASK_2_KEY",
11-
"OUTPUT_MASK_32_KEY",
12-
"OUTPUT_MASK_4_KEY",
13-
"OUTPUT_MASK_64_KEY",
14-
"OUTPUT_MASK_8_KEY",
159
"OUTPUT_MASK_KEY",
10+
"OUTPUT_MASK_KEY_STRIDE_16",
11+
"OUTPUT_MASK_KEY_STRIDE_2",
12+
"OUTPUT_MASK_KEY_STRIDE_32",
13+
"OUTPUT_MASK_KEY_STRIDE_4",
14+
"OUTPUT_MASK_KEY_STRIDE_64",
15+
"OUTPUT_MASK_KEY_STRIDE_8",
1616
"TARGET_CLASS_KEY",
1717
"TARGET_LABELS_KEY",
18-
"TARGET_MASK_16_KEY",
19-
"TARGET_MASK_2_KEY",
20-
"TARGET_MASK_32_KEY",
21-
"TARGET_MASK_4_KEY",
22-
"TARGET_MASK_64_KEY",
23-
"TARGET_MASK_8_KEY",
2418
"TARGET_MASK_KEY",
19+
"TARGET_MASK_KEY_STRIDE_16",
20+
"TARGET_MASK_KEY_STRIDE_2",
21+
"TARGET_MASK_KEY_STRIDE_32",
22+
"TARGET_MASK_KEY_STRIDE_4",
23+
"TARGET_MASK_KEY_STRIDE_64",
24+
"TARGET_MASK_KEY_STRIDE_8",
2525
"TARGET_MASK_WEIGHT_KEY",
2626
"name_for_stride",
2727
"read_image_rgb",
2828
]
2929

3030

3131
def name_for_stride(name, stride: int):
32-
return f"{name}_{stride}"
32+
return f"{name}_STRIDE_{stride}"
3333

3434

3535
INPUT_INDEX_KEY = "INPUT_INDEX_KEY"
@@ -41,20 +41,21 @@ def name_for_stride(name, stride: int):
4141
TARGET_LABELS_KEY = "TARGET_LABELS_KEY"
4242

4343
TARGET_MASK_KEY = "TARGET_MASK_KEY"
44-
TARGET_MASK_2_KEY = name_for_stride(TARGET_MASK_KEY, 2)
45-
TARGET_MASK_4_KEY = name_for_stride(TARGET_MASK_KEY, 4)
46-
TARGET_MASK_8_KEY = name_for_stride(TARGET_MASK_KEY, 8)
47-
TARGET_MASK_16_KEY = name_for_stride(TARGET_MASK_KEY, 16)
48-
TARGET_MASK_32_KEY = name_for_stride(TARGET_MASK_KEY, 32)
49-
TARGET_MASK_64_KEY = name_for_stride(TARGET_MASK_KEY, 64)
44+
45+
TARGET_MASK_KEY_STRIDE_2 = name_for_stride(TARGET_MASK_KEY, 2)
46+
TARGET_MASK_KEY_STRIDE_4 = name_for_stride(TARGET_MASK_KEY, 4)
47+
TARGET_MASK_KEY_STRIDE_8 = name_for_stride(TARGET_MASK_KEY, 8)
48+
TARGET_MASK_KEY_STRIDE_16 = name_for_stride(TARGET_MASK_KEY, 16)
49+
TARGET_MASK_KEY_STRIDE_32 = name_for_stride(TARGET_MASK_KEY, 32)
50+
TARGET_MASK_KEY_STRIDE_64 = name_for_stride(TARGET_MASK_KEY, 64)
5051

5152
OUTPUT_MASK_KEY = "OUTPUT_MASK_KEY"
52-
OUTPUT_MASK_2_KEY = name_for_stride(OUTPUT_MASK_KEY, 2)
53-
OUTPUT_MASK_4_KEY = name_for_stride(OUTPUT_MASK_KEY, 4)
54-
OUTPUT_MASK_8_KEY = name_for_stride(OUTPUT_MASK_KEY, 8)
55-
OUTPUT_MASK_16_KEY = name_for_stride(OUTPUT_MASK_KEY, 16)
56-
OUTPUT_MASK_32_KEY = name_for_stride(OUTPUT_MASK_KEY, 32)
57-
OUTPUT_MASK_64_KEY = name_for_stride(OUTPUT_MASK_KEY, 64)
53+
OUTPUT_MASK_KEY_STRIDE_2 = name_for_stride(OUTPUT_MASK_KEY, 2)
54+
OUTPUT_MASK_KEY_STRIDE_4 = name_for_stride(OUTPUT_MASK_KEY, 4)
55+
OUTPUT_MASK_KEY_STRIDE_8 = name_for_stride(OUTPUT_MASK_KEY, 8)
56+
OUTPUT_MASK_KEY_STRIDE_16 = name_for_stride(OUTPUT_MASK_KEY, 16)
57+
OUTPUT_MASK_KEY_STRIDE_32 = name_for_stride(OUTPUT_MASK_KEY, 32)
58+
OUTPUT_MASK_KEY_STRIDE_64 = name_for_stride(OUTPUT_MASK_KEY, 64)
5859

5960
OUTPUT_LOGITS_KEY = "OUTPUT_LOGITS_KEY"
6061
OUTPUT_EMBEDDINGS_KEY = "OUTPUT_EMBEDDINGS_KEY"

pytorch_toolbelt/datasets/mean_std.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,31 +5,32 @@
55

66

77
class DatasetMeanStdCalculator:
8-
__slots__ = ["global_mean", "global_var", "n_items", "num_channels", "global_max", "global_min"]
8+
__slots__ = ["global_mean", "global_var", "n_items", "num_channels", "global_max", "global_min", "dtype"]
99

1010
"""
1111
Class to calculate running mean and std of the dataset. It helps when whole dataset does not fit entirely in RAM.
1212
"""
1313

14-
def __init__(self, num_channels: int = 3):
14+
def __init__(self, num_channels: int = 3, dtype=np.float64):
1515
"""
1616
Create a new instance of DatasetMeanStdCalculator
1717
1818
Args:
1919
num_channels: Number of channels in the image. Default value is 3
2020
"""
21-
super(DatasetMeanStdCalculator, self).__init__()
21+
super().__init__()
2222
self.num_channels = num_channels
2323
self.global_mean = None
2424
self.global_var = None
2525
self.global_max = None
2626
self.global_min = None
2727
self.n_items = 0
28+
self.dtype = dtype
2829
self.reset()
2930

3031
def reset(self):
31-
self.global_mean = np.zeros(self.num_channels, dtype=np.float64)
32-
self.global_var = np.zeros(self.num_channels, dtype=np.float64)
32+
self.global_mean = np.zeros(self.num_channels, dtype=self.dtype)
33+
self.global_var = np.zeros(self.num_channels, dtype=self.dtype)
3334
self.global_max = np.ones_like(self.global_mean) * float("-inf")
3435
self.global_min = np.ones_like(self.global_mean) * float("+inf")
3536
self.n_items = 0
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .inria_aerial import *
Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
import os
2+
import subprocess
3+
import warnings
4+
from pathlib import Path
5+
from typing import Union, Optional, Tuple
6+
import hashlib
7+
8+
import numpy as np
9+
import pandas as pd
10+
import torch
11+
import zipfile
12+
13+
from sklearn.model_selection import GroupKFold
14+
15+
from pytorch_toolbelt.utils import fs
16+
17+
18+
__all__ = ["InriaAerialImageDataset"]
19+
20+
21+
class InriaAerialImageDataset:
22+
"""
23+
python -m pytorch_toolbelt.datasets.providers.inria_aerial inria_dataset
24+
"""
25+
26+
TASK = "binary_segmentation"
27+
METRIC = ""
28+
ORIGIN = "https://project.inria.fr/aerialimagelabeling"
29+
TRAIN_LOCATIONS = ["austin", "chicago", "kitsap", "tyrol-w", "vienna"]
30+
TEST_LOCATIONS = ["bellingham", "bloomington", "innsbruck", "sfo", "tyrol-e"]
31+
32+
urls = {
33+
"https://files.inria.fr/aerialimagelabeling/aerialimagelabeling.7z.001": "17a7d95c78e484328fd8fe5d5afa2b505e04b8db8fceb617819f3c935d1f39ec",
34+
"https://files.inria.fr/aerialimagelabeling/aerialimagelabeling.7z.002": "b505cb223964b157823e88fbd5b0bd041afcbf39427af3ca1ce981ff9f61aff4",
35+
"https://files.inria.fr/aerialimagelabeling/aerialimagelabeling.7z.003": "752916faa67be6fc6693f8559531598fa2798dc01b7d197263e911718038252e",
36+
"https://files.inria.fr/aerialimagelabeling/aerialimagelabeling.7z.004": "b3893e78f92572455fc2c811af560a558d2a57f9b92eff62fa41399b607a6f44",
37+
"https://files.inria.fr/aerialimagelabeling/aerialimagelabeling.7z.005": "a92eb20fdc9911c5ffe3afc514490b8f1e1e5b22301a6fc55d3b4e1624d8033f",
38+
}
39+
40+
@classmethod
41+
def download_and_extract(cls, data_dir: Union[str, Path]) -> bool:
42+
try:
43+
from py7zr import py7zr
44+
except ImportError:
45+
print("You need to install py7zr to extract 7z-archive: `pip install py7zr`.")
46+
return False
47+
48+
filenames = []
49+
for file_url, file_hash in cls.urls.items():
50+
file_path = os.path.join(data_dir, os.path.basename(file_url))
51+
if not os.path.isfile(file_path) or cls.sha256digest(file_path) != file_hash:
52+
os.makedirs(data_dir, exist_ok=True)
53+
torch.hub.download_url_to_file(file_url, file_path)
54+
55+
filenames.append(file_path)
56+
57+
main_archive = os.path.join(data_dir, "aerialimagelabeling.7z")
58+
with open(main_archive, "ab") as outfile: # append in binary mode
59+
for fname in filenames:
60+
with open(fname, "rb") as infile: # open in binary mode also
61+
outfile.write(infile.read())
62+
63+
with py7zr.SevenZipFile(main_archive, "r") as archive:
64+
archive.extractall(data_dir)
65+
os.unlink(main_archive)
66+
67+
zip_archive = os.path.join(data_dir, "NEW2-AerialImageDataset.zip")
68+
with zipfile.ZipFile(zip_archive, "r") as zip_ref:
69+
zip_ref.extractall(data_dir)
70+
os.unlink(zip_archive)
71+
return True
72+
73+
@classmethod
74+
def init_from_folder(cls, data_dir: Union[str, Path], download: bool = False):
75+
data_dir = os.path.expanduser(data_dir)
76+
77+
if download:
78+
if not cls.download_and_extract(data_dir):
79+
raise RuntimeError("Download and extract failed")
80+
81+
return cls(os.path.join(data_dir, "AerialImageDataset"))
82+
83+
@classmethod
84+
def sha256digest(cls, filename: str) -> str:
85+
blocksize = 4096
86+
sha = hashlib.sha256()
87+
with open(filename, "rb") as f:
88+
file_buffer = f.read(blocksize)
89+
while len(file_buffer) > 0:
90+
sha.update(file_buffer)
91+
file_buffer = f.read(blocksize)
92+
readable_hash = sha.hexdigest()
93+
return readable_hash
94+
95+
@classmethod
96+
def read_tiff(
97+
cls, image_fname: str, crop_coords: Optional[Tuple[Tuple[int, int], Tuple[int, int]]] = None
98+
) -> np.ndarray:
99+
import rasterio
100+
from rasterio.windows import Window
101+
102+
window = None
103+
if crop_coords is not None:
104+
(row_start, row_stop), (col_start, col_stop) = crop_coords
105+
window = Window.from_slices((row_start, row_stop), (col_start, col_stop))
106+
107+
if not os.path.isfile(image_fname):
108+
raise FileNotFoundError(image_fname)
109+
110+
with warnings.catch_warnings():
111+
warnings.filterwarnings("ignore", category=rasterio.errors.NotGeoreferencedWarning)
112+
113+
with rasterio.open(image_fname) as f:
114+
image = f.read(window=window)
115+
image = np.moveaxis(image, 0, -1) # CHW->HWC
116+
if image.shape[2] == 1:
117+
image = np.squeeze(image, axis=2)
118+
return image
119+
120+
@classmethod
121+
def compress_prediction_mask(cls, predicted_mask_fname, compressed_mask_fname):
122+
command = (
123+
"gdal_translate --config GDAL_PAM_ENABLED NO -co COMPRESS=CCITTFAX4 -co NBITS=1 "
124+
+ predicted_mask_fname
125+
+ " "
126+
+ compressed_mask_fname
127+
)
128+
subprocess.call(command, shell=True)
129+
130+
def __init__(self, root_dir: str):
131+
self.root_dir = root_dir
132+
self.train_dir = os.path.join(root_dir, "train")
133+
self.test_dir = os.path.join(root_dir, "test")
134+
135+
if not os.path.isdir(self.train_dir):
136+
raise FileNotFoundError(f"Train directory {self.train_dir} does not exist")
137+
if not os.path.isdir(self.test_dir):
138+
raise FileNotFoundError(f"Test directory {self.train_dir} does not exist")
139+
140+
self.train_images = fs.find_images_in_dir(os.path.join(self.train_dir, "images"))
141+
self.train_masks = fs.find_images_in_dir(os.path.join(self.train_dir, "gt"))
142+
143+
if len(self.train_images) != 180 or len(self.train_masks) != 180:
144+
raise RuntimeError("Number of train images and ground-truth masks must be 180")
145+
146+
def get_test_df(self) -> pd.DataFrame:
147+
test_images = fs.find_images_in_dir(os.path.join(self.test_dir, "images"))
148+
df = pd.DataFrame.from_dict({"images": test_images})
149+
df["rows"] = 5000
150+
df["cols"] = 5000
151+
df["location"] = df["images"].apply(lambda x: fs.id_from_fname(x).rstrip("0123456789"))
152+
return df
153+
154+
def get_train_val_split_train_df(self) -> pd.DataFrame:
155+
# For validation, we remove the first five images of every location
156+
# (e.g., austin{1-5}.tif, chicago{1-5}.tif) from the training set.
157+
# That is suggested validation strategy by competition host
158+
valid_locations = []
159+
for loc in self.TRAIN_LOCATIONS:
160+
for i in range(1, 6):
161+
valid_locations.append(f"{loc}{i}")
162+
163+
df = pd.DataFrame.from_dict({"images": self.train_images, "masks": self.train_masks})
164+
df["location_with_index"] = df["images"].apply(lambda x: fs.id_from_fname(x))
165+
df["location"] = df["location_with_index"].apply(lambda x: x.rstrip("0123456789"))
166+
df["split"] = df["location_with_index"].apply(lambda l: "valid" if l in valid_locations else "train")
167+
df["rows"] = 5000
168+
df["cols"] = 5000
169+
return df
170+
171+
def get_kfold_split_train_df(self, num_folds: int = 5) -> pd.DataFrame:
172+
df = pd.DataFrame.from_dict({"images": self.train_images, "masks": self.train_masks})
173+
df["location_with_index"] = df["images"].apply(lambda x: fs.id_from_fname(x))
174+
df["location"] = df["location_with_index"].apply(lambda x: x.rstrip("0123456789"))
175+
df["rows"] = 5000
176+
df["cols"] = 5000
177+
df["fold"] = -1
178+
kfold = GroupKFold(n_splits=num_folds)
179+
for fold, (train_index, test_index) in enumerate(kfold.split(df, df, groups=df["location"])):
180+
df.loc[test_index, "fold"] = fold
181+
return df
182+
183+
184+
def download_and_extract(data_dir):
185+
ds = InriaAerialImageDataset.init_from_folder(data_dir, download=True)
186+
print(ds.get_test_df())
187+
print(ds.get_train_val_split_train_df())
188+
print(ds.get_kfold_split_train_df())
189+
190+
191+
if __name__ == "__main__":
192+
from fire import Fire
193+
194+
Fire(download_and_extract)

0 commit comments

Comments
 (0)