How to improve the speed of using for loop in custom transforms function? #3937
-
I have written my own data augmentation method, mainly to compress the data using the random number interval to the specified size, but when using Main: import torch
from monai.utils import misc
from monai.data import CacheDataset, ThreadDataLoader
from monai.transforms import Compose, EnsureTyped, ToDeviced, RandSpatialCropSamplesd
def _dataloader(_dataset):
return ThreadDataLoader(_dataset,
num_workers=0,
batch_size=1,
shuffle=True,
drop_last=True)
train_transforms = [
EnsureTyped(keys=['image','label']),
ToDeviced(keys=['image','label'], device='cuda:0')
RandSpatialCropSamplesd(keys=['image'], num_samples=8, roi_size=(2000,6), random_size=False),
RandResize(keys=['image'], prob=1.),
]
train_dataset = CacheDataset(
data=[{'image':torch.rand(1,4500,6),'label':torch.rand(1)} for i in range(128)],
transform=Compose(train_transforms),
cache_rate=1.0,
copy_cache=False)
_loader = misc.first(_dataloader(train_dataset))
x,y = _loader['image'], _loader['label'] Transforms Function: import torch
from monai.transforms.transform import RandomizableTransform, MapTransform
from monai.transforms.inverse import InvertibleTransform
from torch.nn import functional as F
class RandResize(RandomizableTransform, MapTransform, InvertibleTransform):
def __init__(self, keys, prob=0.1):
RandomizableTransform.__init__(self, prob)
MapTransform.__init__(self, keys)
def randomize(self):
super().randomize(None)
if not self._do_transform:
return None
def getRange(self, index, overlap, _size):
if index-overlap >= 0:
first_i = index-overlap
else:
first_i = index
if index+10+overlap < _size:
last_i = index+10+overlap
else:
last_i = _size-1
return first_i, last_i
def resize_method_switch(self, i, img, first_i, last_i):
switcher={
0:torch.mean(img[first_i:last_i], dim=0),
1:torch.std(img[first_i:last_i], dim=0),
2:torch.max(img[first_i:last_i], dim=0)[0],
3:torch.min(img[first_i:last_i], dim=0)[0],
4:torch.sum(img[first_i:last_i], dim=0),
5:torch.logsumexp(img[first_i:last_i], dim=0)
}
return switcher.get(i,"Number Error")
def _resize(self, img, overlap, choice):
###!!! The problem might be here!!!! ###
for index in range(0,img.size(0),10):
first_i, last_i = self.getRange(index, overlap, img.size(0))
if index == 0:
_results = torch.moveaxis(self.resize_method_switch(choice, img, first_i, last_i).reshape(-1,1), 0,1)
else:
_m = torch.moveaxis(self.resize_method_switch(choice, img, first_i, last_i).reshape(-1,1), 0,1)
_results = torch.cat((_results, _m), dim=0)
return _results
def __call__(self, data):
d = dict(data)
self.randomize()
if not self._do_transform:
for key in self.key_iterator(d):
img = F.interpolate(d[key].unsqueeze(0), scale_factor=(0.1, 1),
mode='bilinear', align_corners=False, recompute_scale_factor=True)
d[key] = torch.squeeze(img, 0)
self.push_transform(d, key)
return d
###Using this creates a bottleneck
for key in self.key_iterator(d):
choice = self.R.randint(0, 6)
overlap = self.R.randint(5, 21)
img = self._resize(torch.squeeze(d[key]), overlap, choice)
d[key] = torch.unsqueeze(img, 0)
self.push_transform(d, key)
return d |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments
-
I think you should double-check whether switcher={
0:lambda x: torch.mean(x, dim=0),
1:lambda x: torch.std(x, dim=0),
2:lambda x: torch.max(x, dim=0)[0],
3:lambda x: torch.min(x, dim=0)[0],
4:lambda x: torch.sum(x, dim=0),
5:lambda x: torch.logsumexp(x, dim=0)
}
return switcher[i](img[first_i:last_i]) |
Beta Was this translation helpful? Give feedback.
-
Hi, @wyli, |
Beta Was this translation helpful? Give feedback.
Hi, @wyli,
I finally gave up on using random intervals to resize.
Directly use
batch
to achieve fixed range compression.torch.tensor([2000,6]).reshape([200,10,6]).mean(dim=1)