|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +""" |
| 3 | +Created on Sat Feb 13 19:24:17 2021 |
| 4 | +
|
| 5 | +@author: trabz |
| 6 | +""" |
| 7 | +import torch |
| 8 | + |
| 9 | + |
| 10 | +from xmlParser import Parser |
| 11 | +import glob |
| 12 | +import numpy as np |
| 13 | +from matplotlib import pyplot as plt |
| 14 | +import cv2 |
| 15 | +import torchvision.transforms as transforms |
| 16 | +class dataloader(): |
| 17 | + |
| 18 | + def __init__(self,path,transform=None): |
| 19 | + |
| 20 | + self.paths=glob.glob(path) |
| 21 | + self.transform=transform |
| 22 | + |
| 23 | + |
| 24 | + def __len__(self): |
| 25 | + return len(self.paths) |
| 26 | + |
| 27 | + def __getitem__(self,idx): |
| 28 | + |
| 29 | + annotation=(Parser.myType(self.paths[idx],idx,classes=['bird','zebra'])) |
| 30 | + image=plt.imread(self.paths[annotation['image_id']].replace('xml','jpg')) |
| 31 | + |
| 32 | + if self.transform is not None: |
| 33 | + |
| 34 | + image_aug=self.transform(image) |
| 35 | + |
| 36 | + return image,image_aug,annotation |
| 37 | + |
| 38 | + |
| 39 | +def collate_fn(batch): |
| 40 | + return tuple(zip(*batch)) |
| 41 | + |
| 42 | + |
| 43 | +data_transform=transforms.Compose([ |
| 44 | + transforms.ToTensor(), |
| 45 | + #transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), |
| 46 | + transforms.ToPILImage(), |
| 47 | + transforms.RandomApply([transforms.CenterCrop(222), transforms.RandomCrop(170)],p=0.01), |
| 48 | + |
| 49 | + |
| 50 | + transforms.RandomAffine((-15,15)), |
| 51 | + |
| 52 | + |
| 53 | + |
| 54 | + #transforms.Grayscale(), |
| 55 | + transforms.RandomGrayscale(p=0.1), |
| 56 | + transforms.RandomHorizontalFlip(p=0.1), |
| 57 | + transforms.RandomVerticalFlip(p=0.1), |
| 58 | + transforms.RandomApply([transforms.ColorJitter(brightness=15,contrast=12,hue=0.2)],p=0.1), |
| 59 | + # transforms.RandomResizedCrop((200,200)), |
| 60 | + transforms.RandomRotation((-90,90)), |
| 61 | + transforms.RandomPerspective(p=0.1), |
| 62 | + transforms.ToTensor(), |
| 63 | + # transforms.RandomCrop(10), |
| 64 | + transforms.RandomErasing(p=0.1), |
| 65 | + |
| 66 | + |
| 67 | + # |
| 68 | + ]) |
| 69 | + |
| 70 | + |
| 71 | + |
| 72 | + |
| 73 | + |
| 74 | + |
| 75 | +path=path= 'Image A*/train/*.xml' |
| 76 | + |
| 77 | +dataset = dataloader(path,data_transform) |
| 78 | +data_loader = torch.utils.data.DataLoader( |
| 79 | + dataset, batch_size=1, collate_fn=collate_fn) |
| 80 | + |
| 81 | + |
| 82 | + |
| 83 | + |
| 84 | +plt.figure(figsize=(12,12)) |
| 85 | +# for idx,(image,imgO,result) in enumerate(data_loader): |
| 86 | + |
| 87 | +# #plt.figure() |
| 88 | +# #plt.imshow(result[0].permute(1,2,0).numpy(),cmap='gray') |
| 89 | +# img=imgO[0].clone().permute(1,2,0).numpy() |
| 90 | +# image=image[0] |
| 91 | +# sz,wd,_=np.array(image.shape)-np.array((img).shape) |
| 92 | +# img2=img |
| 93 | +# img=np.pad((img),((sz//2,sz//2),(wd//2,wd//2),(0,0))) |
| 94 | + |
| 95 | +# try: |
| 96 | +# plt.imsave(f'PyTorch/images/{idx}.png',np.hstack(((img*255).astype('uint8'),image))) |
| 97 | +# except ValueError: |
| 98 | +# img=np.dstack((img,img,img)) |
| 99 | +# plt.imsave(f'PyTorch/images/{idx}.png',np.hstack(((img*255).astype('uint8'),image))) |
| 100 | +# plt.axis('off') |
| 101 | +# plt.tight_layout() |
| 102 | +# #plt.imshow(np.dstack((img[:,:,2],img[:,:,1],img[:,:,0]))) |
| 103 | +# plt.imshow(image) |
| 104 | + |
| 105 | +#♣plt.savefig('TFOrig.png',dpi=110,transparent=True) |
| 106 | + |
| 107 | +for idx,(image,imgO,result) in enumerate(data_loader): |
| 108 | + |
| 109 | + #plt.figure() |
| 110 | + #plt.imshow(result[0].permute(1,2,0).numpy(),cmap='gray') |
| 111 | + plt.subplot(6,5,idx+1) |
| 112 | + |
| 113 | + plt.axis('off') |
| 114 | + plt.tight_layout() |
| 115 | + #plt.imshow(np.dstack((img[:,:,2],img[:,:,1],img[:,:,0]))) |
| 116 | + plt.imshow(imgO[0].permute(1,2,0).numpy()) |
| 117 | + |
| 118 | +plt.savefig('PyTorchAugmented2.png',dpi=110,transparent=True) |
0 commit comments