-
Notifications
You must be signed in to change notification settings - Fork 29
Open
Description
I have used this class for creting dataset class for my flower data
defining dataset
from PIL import Image
import cv2
import albumentations
import torch
import numpy as np
import io
from torch.utils.data import Dataset
class FlowerDataset(Dataset):
def __init__(self, id , classes , image , img_height , img_width, mean , std , is_valid):
self.id = id
self.classes = classes
self.image = image
if is_valid == 1:
self.aug = albumentations.Compose([
albumentations.Resize(img_height , img_width, always_apply = True) ,
albumentations.Normalize(mean , std , always_apply = True)
])
else:
self.aug = albumentations.Compose([
albumentations.Resize(img_height , img_width, always_apply = True) ,
albumentations.Normalize(mean , std , always_apply = True),
albumentations.ShiftScaleRotate(shift_limit = 0.0625,
scale_limit = 0.1 ,
rotate_limit = 5,
p = 0.9)
])
def __len__(self):
return len(self.id)
def __getitem__(self, index):
id = self.id[index]
img = np.array(Image.open(io.BytesIO(self.image[index])))
img = cv2.resize(img, dsize=(128, 128), interpolation=cv2.INTER_CUBIC)
img = self.aug(image = img)['image']
img = np.transpose(img , (2,0,1)).astype(np.float32)
return {
'image' : torch.tensor(img, dtype = torch.float),
'class' : torch.tensor(self.classes[index], dtype = torch.long)
}
then did the sanity check to ensure its good to go
# sanity check for FlowerDataset class created
train_dataset = FlowerDataset(id = train_ids, classes = train_class, image = train_images,
img_height = 128 , img_width = 128,
mean = (0.485, 0.456, 0.406),
std = (0.229, 0.224, 0.225) , is_valid = 0)
import matplotlib.pyplot as plt
%matplotlib inline
idx = 0
img = train_dataset[idx]['image']
print(train_dataset[idx]['class'])
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1,2,0)))
& then I did
# setting up the dataloader with cutmix data agumentation
!pip install git+https://github.com/ildoonet/cutmix
# setting up the train data loader
from cutmix.cutmix import CutMix
train_dataloader = CutMix(train_dataset,
num_class=104,
beta=1.0,
prob=0.5,
num_mix=2)
It worked successfully.
but when I did the sanity check as:-->
batch = next(iter(train_dataloader))
len(batch)
and thereby I am unable to train the model
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels

