Skip to content

Commit 227131c

Browse files
code upload
1 parent 157bc49 commit 227131c

21 files changed

+1830
-0
lines changed

README.md

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,32 @@
11
# FourBi_7
2+
3+
## Setup
4+
To run this project, we used `python 3.11.7` and `pytorch 2.2.0`
5+
```bash
6+
conda create -n fourbi python=3.11.7
7+
conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia
8+
pip3 install opencv-python wandb pytorch-ignite
9+
```
10+
11+
## Inference
12+
To run the model on a folder with images, run with the following command
13+
```
14+
python binarize.py <path to checkpoint> --src <path to the test images folder>
15+
--dst <path to the output folder>
16+
```
17+
18+
## Training
19+
The model is trained on patches, then evaluated and tested on complete documents. We provide the code to create the patches and train the model.
20+
For example, to train on H-DIBCO12, first download the dataset from http://utopia.duth.gr/~ipratika/HDIBCO2012/benchmark/. Create a folder, then place the images in a sub-folder named "imgs" and the ground truth in a sub-folder named "gt_imgs". Then run the following command:
21+
```
22+
python create_patches.py --path_src <path to the dataset folder>
23+
--path_dst <path to the folder where the patches will be saved>
24+
--patch_size <size of the patches> --overlap_size <size of the overlap>
25+
```
26+
To launch the training, run the following command:
27+
```
28+
python train.py --datasets_paths <all datasets paths>
29+
--eval_dataset_name <name of the validation dataset>
30+
--test_dataset_name <name of the validation dataset>
31+
```
232

binarize.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import argparse
2+
import torch
3+
from pathlib import Path
4+
from trainer.fourbi_trainer import FourbiTrainingModule
5+
from data.test_dataset import FolderDataset
6+
from torchvision import transforms
7+
8+
if __name__ == '__main__':
9+
parser = argparse.ArgumentParser(description='Binarize a folder of images')
10+
parser.add_argument('model', type=str, metavar='PATH', help='path to the model file')
11+
parser.add_argument('--src', type=str, required=True, help='path to the folder of input images')
12+
parser.add_argument('--dst', type=str, required=True, help='path to the folder of output images')
13+
parser.add_argument('--patch_size', type=int, default=512, help='patch size')
14+
parser.add_argument('--batch_size', type=int, default=8, help='batch size when processing patches')
15+
args = parser.parse_args()
16+
17+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
18+
dst = Path(args.dst)
19+
dst.mkdir(parents=True, exist_ok=True)
20+
21+
fourbi = FourbiTrainingModule(config={'resume': args.model}, device=device, make_loaders=False)
22+
23+
dataset = FolderDataset(Path(args.src), patch_size=args.patch_size, overlap=True, transform=transforms.ToTensor())
24+
fourbi.test_data_loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0)
25+
26+
fourbi.config['test_patch_size'] = args.patch_size
27+
fourbi.config['test_stride'] = args.patch_size // 2
28+
fourbi.config['eval_batch_size'] = args.batch_size
29+
30+
for i, sample in enumerate(fourbi.folder_test()):
31+
key = list(sample.keys())[0]
32+
img, pred, gt = sample[key]
33+
src_img_path = Path(key)
34+
35+
dst_img_path = dst / (src_img_path.stem + '.png')
36+
pred.save(str(dst_img_path))
37+
print(f'({i + 1}/{len(dataset)}) Saving {dst_img_path}')
38+
39+
print('Done.')
40+

create_patches.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import argparse
2+
from data.process_image import PatchImage
3+
4+
5+
def main():
6+
parser = argparse.ArgumentParser(description='create patches')
7+
parser.add_argument('--path_dst', type=str, help=f"Destination folder path")
8+
parser.add_argument('--path_src', type=str, help="The path witch contains the images")
9+
parser.add_argument('--patch_size', type=int, help="Patch size", default=384)
10+
parser.add_argument('--overlap_size', type=int, help='Overlap size', default=192)
11+
args = parser.parse_args()
12+
13+
patcher = PatchImage(patch_size=args.patch_size, overlap_size=args.overlap_size, destination_root=args.path_dst)
14+
patcher.create_patches(root_original=args.path_src)
15+
16+
17+
if __name__ == '__main__':
18+
main()

data/custom_transforms.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from torchvision import transforms
2+
from torchvision.transforms import functional
3+
4+
5+
class ToTensor(transforms.ToTensor):
6+
7+
def __call__(self, sample):
8+
image, gt = sample['image'], sample['gt']
9+
image = super().__call__(image)
10+
gt = super().__call__(gt)
11+
return {'image': image, 'gt': gt}
12+
13+
14+
class ColorJitter(transforms.ColorJitter):
15+
16+
def __call__(self, sample):
17+
image, gt = sample['image'], sample['gt']
18+
image = super().__call__(image)
19+
return {'image': image, 'gt': gt}
20+
21+
22+
class RandomCrop(transforms.RandomCrop):
23+
24+
def __init__(self, size):
25+
super(RandomCrop, self).__init__(size=size)
26+
self.size = size
27+
28+
def __call__(self, sample):
29+
image, gt = sample['image'], sample['gt']
30+
i, j, h, w = self.get_params(image, output_size=(self.size, self.size))
31+
image = functional.crop(image, i, j, h, w)
32+
gt = functional.crop(gt, i, j, h, w)
33+
return {'image': image, 'gt': gt}
34+
35+
36+
class RandomRotation(transforms.RandomRotation):
37+
38+
def __call__(self, sample):
39+
image, gt = sample['image'], sample['gt']
40+
angle = self.get_params(self.degrees)
41+
42+
image = functional.rotate(image, angle, fill=[255, 255, 255])
43+
44+
gt = functional.invert(gt)
45+
gt = functional.rotate(gt, angle)
46+
gt = functional.invert(gt)
47+
return {'image': image, 'gt': gt}

data/dataloaders.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import torch
2+
from torch.utils.data import Dataset
3+
4+
from utils.htr_logging import get_logger
5+
6+
logger = get_logger(__file__)
7+
8+
9+
def make_train_dataloader(train_dataset: Dataset, config: dict):
10+
train_dataloader_config = config['train_kwargs']
11+
train_data_loader = torch.utils.data.DataLoader(train_dataset, **train_dataloader_config)
12+
return train_data_loader
13+
14+
15+
def make_valid_dataloader(valid_dataset: Dataset, config: dict):
16+
valid_dataloader_config = config['eval_kwargs']
17+
valid_data_loader = torch.utils.data.DataLoader(valid_dataset, **valid_dataloader_config)
18+
return valid_data_loader
19+
20+
21+
def make_test_dataloader(test_dataset: Dataset, config: dict):
22+
test_dataloader_config = config['test_kwargs']
23+
test_data_loader = torch.utils.data.DataLoader(test_dataset, **test_dataloader_config)
24+
return test_data_loader

data/datasets.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
from torchvision.transforms import transforms
2+
import time
3+
from data.training_dataset import TrainingDataset
4+
from data.test_dataset import TestDataset
5+
from data.utils import get_transform
6+
from utils.htr_logging import get_logger
7+
from torch.utils.data import ConcatDataset
8+
from pathlib import Path
9+
10+
logger = get_logger(__file__)
11+
12+
13+
def make_train_dataset(config: dict):
14+
train_data_path = config['train_data_path']
15+
patch_size = config['train_patch_size']
16+
load_data = config['load_data']
17+
18+
logger.info(f"Train path: \"{train_data_path}\" with patch size {patch_size} and load_data={load_data}")
19+
20+
transform = get_transform(output_size=patch_size)
21+
22+
logger.info(f"Loading train datasets...")
23+
time_start = time.time()
24+
datasets = []
25+
for i, path in enumerate(train_data_path):
26+
logger.info(f"[{i+1}/{len(train_data_path)}] Loading train dataset from \"{path}\"")
27+
data_path = Path(path) / 'train' if (Path(path) / 'train').exists() else Path(path)
28+
datasets.append(
29+
TrainingDataset(
30+
data_path=data_path,
31+
split_size=patch_size,
32+
patch_size=config['train_patch_size_raw'],
33+
transform=transform,
34+
load_data=load_data))
35+
36+
logger.info(f"Loading train datasets took {time.time() - time_start:.2f} seconds")
37+
train_dataset = ConcatDataset(datasets)
38+
logger.info(f"Training set has {len(train_dataset)} instances")
39+
40+
return train_dataset
41+
42+
43+
def make_val_dataset(config: dict):
44+
val_data_path = config['eval_data_path']
45+
stride = config['test_stride']
46+
patch_size = config['eval_patch_size']
47+
load_data = config['load_data']
48+
49+
transform = transforms.Compose([transforms.ToTensor()])
50+
51+
logger.info(f"Loading validation datasets...")
52+
time_start = time.time()
53+
datasets = []
54+
for i, path in enumerate(val_data_path):
55+
logger.info(f"[{i}/{len(val_data_path)}] Loading validation dataset from \"{path}\"")
56+
datasets.append(
57+
TestDataset(
58+
data_path=Path(path),
59+
patch_size=patch_size,
60+
stride=stride,
61+
transform=transform,
62+
load_data=load_data
63+
)
64+
)
65+
66+
logger.info(f"Loading validation datasets took {time.time() - time_start:.2f} seconds")
67+
validation_dataset = ConcatDataset(datasets)
68+
logger.info(f"Validation set has {len(validation_dataset)} instances")
69+
70+
return validation_dataset
71+
72+
73+
def make_test_dataset(config: dict):
74+
test_data_path = config['test_data_path']
75+
patch_size = config['test_patch_size']
76+
stride = config['test_stride']
77+
load_data = config['load_data']
78+
79+
transform = transforms.Compose([transforms.ToTensor()])
80+
81+
logger.info(f"Loading test datasets...")
82+
time_start = time.time()
83+
datasets = []
84+
85+
for path in test_data_path:
86+
datasets.append(
87+
TestDataset(
88+
data_path=path,
89+
patch_size=patch_size,
90+
stride=stride,
91+
transform=transform,
92+
load_data=load_data))
93+
logger.info(f'Loaded test dataset from {path} with {len(datasets[-1])} instances.')
94+
95+
logger.info(f"Loading test datasets took {time.time() - time_start:.2f} seconds")
96+
97+
test_dataset = ConcatDataset(datasets)
98+
99+
logger.info(f"Test set has {len(test_dataset)} instances")
100+
return test_dataset

data/process_image.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import logging
2+
import cv2
3+
import numpy as np
4+
from pathlib import Path
5+
6+
7+
class PatchImage:
8+
9+
def __init__(self, patch_size: int, overlap_size: int, destination_root: str):
10+
logging.basicConfig(format='%(levelname)s: %(message)s', level=logging.INFO)
11+
destination_root = Path(destination_root)
12+
self.train_folder = destination_root / f'imgs_{patch_size}/'
13+
self.train_gt_folder = destination_root / f'gt_imgs_{patch_size}/'
14+
self.train_folder.mkdir(parents=True, exist_ok=True)
15+
self.train_gt_folder.mkdir(parents=True, exist_ok=True)
16+
17+
self.patch_size = patch_size
18+
self.overlap_size = overlap_size
19+
self.number_image = 1
20+
self.image_name = ""
21+
22+
logging.info(f"Using Patch size: {self.patch_size} - Overlapping: {self.overlap_size}")
23+
24+
def create_patches(self, root_original: str):
25+
logging.info("Start process ...")
26+
root_original = Path(root_original)
27+
gt = root_original / 'gt_imgs'
28+
imgs = root_original / 'imgs'
29+
30+
path_imgs = list(path_img for path_img in imgs.glob('*') if path_img.suffix in {".png", ".jpg", ".bmp", ".tif"})
31+
for i, img in enumerate(path_imgs):
32+
or_img = cv2.imread(str(img))
33+
gt_img = gt / img.name
34+
gt_img = gt_img if gt_img.exists() else gt / (img.stem + '.png')
35+
gt_img = cv2.imread(str(gt_img))
36+
try:
37+
self._split_train_images(or_img, gt_img)
38+
except Exception as e:
39+
print(f'Error: {e} - {img}')
40+
41+
def _split_train_images(self, or_img: np.ndarray, gt_img: np.ndarray):
42+
runtime_size = self.overlap_size
43+
patch_size = self.patch_size
44+
for i in range(0, or_img.shape[0], runtime_size):
45+
for j in range(0, or_img.shape[1], runtime_size):
46+
47+
if i + patch_size <= or_img.shape[0] and j + patch_size <= or_img.shape[1]:
48+
dg_patch = or_img[i:i + patch_size, j:j + patch_size, :]
49+
gt_patch = gt_img[i:i + patch_size, j:j + patch_size, :]
50+
51+
elif i + patch_size > or_img.shape[0] and j + patch_size <= or_img.shape[1]:
52+
dg_patch = np.ones((patch_size, patch_size, 3)) * 255
53+
gt_patch = np.ones((patch_size, patch_size, 3)) * 255
54+
55+
dg_patch[0:or_img.shape[0] - i, :, :] = or_img[i:or_img.shape[0], j:j + patch_size, :]
56+
gt_patch[0:or_img.shape[0] - i, :, :] = gt_img[i:or_img.shape[0], j:j + patch_size, :]
57+
58+
elif i + patch_size <= or_img.shape[0] and j + patch_size > or_img.shape[1]:
59+
dg_patch = np.ones((patch_size, patch_size, 3)) * 255
60+
gt_patch = np.ones((patch_size, patch_size, 3)) * 255
61+
62+
dg_patch[:, 0:or_img.shape[1] - j, :] = or_img[i:i + patch_size, j:or_img.shape[1], :]
63+
gt_patch[:, 0:or_img.shape[1] - j, :] = gt_img[i:i + patch_size, j:or_img.shape[1], :]
64+
65+
else:
66+
dg_patch = np.ones((patch_size, patch_size, 3)) * 255
67+
gt_patch = np.ones((patch_size, patch_size, 3)) * 255
68+
69+
dg_patch[0:or_img.shape[0] - i, 0:or_img.shape[1] - j, :] = or_img[i:or_img.shape[0],
70+
j:or_img.shape[1],
71+
:]
72+
gt_patch[0:or_img.shape[0] - i, 0:or_img.shape[1] - j, :] = gt_img[i:or_img.shape[0],
73+
j:or_img.shape[1],
74+
:]
75+
gt_patch[0:or_img.shape[0] - i, 0:or_img.shape[1] - j, :] = gt_img[i:or_img.shape[0],
76+
j:or_img.shape[1],
77+
:]
78+
79+
cv2.imwrite(str(self.train_folder / (str(self.number_image) + '.png')), dg_patch)
80+
cv2.imwrite(str(self.train_gt_folder / (str(self.number_image) + '.png')), gt_patch)
81+
self.number_image += 1
82+
print(self.number_image, end='\r')
83+
84+
def _create_name(self, folder: str, i: int, j: int):
85+
return folder + self.image_name.split('.')[0] + '_' + str(i) + '_' + str(j) + '.png'

0 commit comments

Comments
 (0)