-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdataset.py
More file actions
135 lines (105 loc) · 5.15 KB
/
dataset.py
File metadata and controls
135 lines (105 loc) · 5.15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
# Copyright 2021 Dakewe Biotech Corporation. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Realize the function of dataset preparation."""
import io
import os
import lmdb
import numpy as np
from PIL import Image, ImageOps
from torch import Tensor
from torch.utils.data import Dataset
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode as IMode
import imgproc
__all__ = ["ImageDataset", "LMDBDataset"]
class ImageDataset(Dataset):
"""Customize the data set loading function and prepare low/high resolution image data in advance.
Args:
dataroot (str): Training data set address
image_size (int): High resolution image size
upscale_factor (int): Image magnification
mode (str): Data set loading method, the training data set is for data enhancement,
and the verification data set is not for data enhancement
"""
def __init__(self, dataroot: str, image_size: int, upscale_factor: int, mode: str) -> None:
super(ImageDataset, self).__init__()
self.filenames = [os.path.join(dataroot, x) for x in os.listdir(dataroot)]
self.filenames.sort()
if mode == "train":
self.hr_transforms = transforms.Compose([
transforms.RandomRotation(90),
transforms.RandomCrop(image_size),
transforms.RandomHorizontalFlip(0.5),
])
else:
self.hr_transforms = transforms.CenterCrop(image_size)
self.lr_transforms = transforms.Resize(image_size // upscale_factor, interpolation=IMode.BICUBIC, antialias=True)
def __getitem__(self, batch_index: int) -> [Tensor, Tensor]:
# Read a batch of image data
image = Image.open(self.filenames[batch_index])
# Image to grayscale for thermal
if 1:
image = ImageOps.grayscale(image)
# Transform image
hr_image = self.hr_transforms(image)
lr_image = self.lr_transforms(hr_image)
# Convert image data into Tensor stream format (PyTorch).
# Note: The range of input and output is between [0, 1]
lr_tensor = imgproc.image2tensor(lr_image, range_norm=False, half=False)
hr_tensor = imgproc.image2tensor(hr_image, range_norm=False, half=False)
return lr_tensor, hr_tensor
def __len__(self) -> int:
return len(self.filenames)
class LMDBDataset(Dataset):
"""Load the data set as a data set in the form of LMDB.
Attributes:
lr_datasets (list): Low-resolution image data in the dataset
hr_datasets (list): High-resolution image data in the dataset
"""
def __init__(self, lr_lmdb_path, hr_lmdb_path) -> None:
super(LMDBDataset, self).__init__()
# Create low/high resolution image array
self.lr_datasets = []
self.hr_datasets = []
# Initialize the LMDB database file address
self.lr_lmdb_path = lr_lmdb_path
self.hr_lmdb_path = hr_lmdb_path
# Write image data in LMDB database to memory
self.read_lmdb_dataset()
def __getitem__(self, batch_index: int) -> [Tensor, Tensor]:
# Read a batch of image data
lr_image = self.lr_datasets[batch_index]
hr_image = self.hr_datasets[batch_index]
# Data augment
lr_image, hr_image = imgproc.random_rotate(lr_image, hr_image, angle=90)
lr_image, hr_image = imgproc.random_horizontally_flip(lr_image, hr_image, p=0.5)
# Convert image data into Tensor stream format (PyTorch).
# Note: The range of input and output is between [0, 1]
lr_tensor = imgproc.image2tensor(lr_image, range_norm=False, half=False)
hr_tensor = imgproc.image2tensor(hr_image, range_norm=False, half=False)
return lr_tensor, hr_tensor
def __len__(self) -> int:
return len(self.hr_datasets)
def read_lmdb_dataset(self) -> [list, list]:
# Open two LMDB database writing environments to read low/high image data
lr_lmdb_env = lmdb.open(self.lr_lmdb_path)
hr_lmdb_env = lmdb.open(self.hr_lmdb_path)
# Write the image data in the low-resolution LMDB data set to the memory
for _, image_bytes in lr_lmdb_env.begin().cursor():
image = Image.open(io.BytesIO(image_bytes))
self.lr_datasets.append(image)
# Write the image data in the high-resolution LMDB data set to the memory
for _, image_bytes in hr_lmdb_env.begin().cursor():
image = Image.open(io.BytesIO(image_bytes))
self.hr_datasets.append(image)