How to get multiple augmentations of the same graph in a single batch #3561
-
My question pertains to using Pytorch Geometric So my question is, How do I get different on-the-fly generated augmentations of my graph (which the library calls transforms) for a single element in a single batch. Is there a clean way to do this, or do I need to hack my way through by calling the tranform function myself instead of through the library, which might well result in a loss of performance. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
Mh, interesting problem. How do you solve this in regular PyTorch? One thing I could think of is to write a class WrapperDataset(torch.utils.data.Dataset):
def __init__(self, dataset, transform1=None, transform2=None):
self.dataset = dataset
self.transform1 = transform1
self.transform2 = transform2
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
data1 = self.dataset[idx]
data2 = copy.copy(data1)
return self.transform1(data1), self.transform2(data2)
dataset = WrapperDataset(dataset, transform1, transform2)
loader = DataLoader(dataset, ...)
for batch1, batch2 in loader:
print(batch1)
print(batch2) |
Beta Was this translation helpful? Give feedback.
Mh, interesting problem. How do you solve this in regular PyTorch? One thing I could think of is to write a
WrapperDataset
that allows to apply different transforms to adata
object: