-
Hello i have a question regarding iterable/generator datasets. I receive data on the fly. It is a constant stream of small graphs. I am not sure how to do this "the correct" way. Planned pipeline:
Idea 1: pytorch IterableDatasetMy current approach is to prepare the keys for the torch_geometric.data.Data in a pytorch Iterabel Dataset: from torch.utils.data import IterableDataset, DataLoader # standard versions not pytorch geometric loader
from itertools import islice # for testing limited numer of streamed graphs
import torch_geometric.data as tg_data
class TraceStreamDataset(IterableDataset):
# ...
def __iter__(self):
# connects to gRPC stream
# on receive of new message:
# generate dict( x: ... , edge_index: ...)
yield dict_with_graph_data_keys Then i use the standard pytorch loader to and set iterable_dataset = TraceStreamDataset(...)
loader = DataLoader(iterable_dataset, batch_size=None)
graph_data_list = [] # for saving received graphs (test)
for graph_dict in islice(loader, 1): # islice 1 ends iteration after one test graph (stop receiving from stream for testing results)
graph = tg_data.Data(x=graph_dict['x'], edge_index=graph_dict['edge_index'])
graph_data_list.append(graph) # not batched yet
# if i want to batch graphs i would have to create batched lists here on the fly in the loop This works so far. But my goal is to create a simple iterable graph loader which i can use for my training loop. The standard pytorch iterable dataset/loader combination does not allow yields of torch_geometric.data.Data objects directly: typeerror: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'torch_geometric.data.data.Data'> Idea 2 torch_geometric.loader.DataloaderIn the FAQ part 2 ( Do I really need to use these dataset interfaces? ) i read about the torch_geometric.loader.Dataloader for synthetic data on the fly. So I also tried using the torch_geometric.loader.Dataloader for creating mini-batches. But it only allows dataset with a length (which my streaming dataset does not have of course.) loader = DataLoader(iterable_dataset, batch_size=None)
def generate_graph_data(loader_dict):
graph = tg_data.Data(x=graph_dict['x'], edge_index=graph_dict['edge_index'])
yield graph
graph_loader = tg_dataloader(dataset=generate_graph_data(loader_dict=loader),batch_size=2)
batched_graph_list = []
for graph_batch in graph_loader:
batched_graph_list.append(graph_batch) got error: TypeError: object of type 'generator' has no len() Idea 3: One more idea i have:
QuestionSo my question is how to do it correctly? Perhaps there is another easy way i did not even see. Thanks in advance for your help. I am really impressed by this whole ecosystem you build. :) |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
I think following the Idea 2 should work as well, but feels like you are re-implementing a lot of functionality that is already present in |
Beta Was this translation helpful? Give feedback.
I think following the
IterableDataset
interface is the way to go, as this fits perfectly into your desired needs. I'm not sure though why you say that theIterableDataset
/DataLoader
approach does not work. IMO, this should work just fine given that you are utilizing thetorch_geometric.loader.DataLoader
rather than thetorch.utils.data.DataLoader
. Is this not the case? This also eliminates to manually batchdata
objects after loading them.Idea 2 should work as well, but feels like you are re-implementing a lot of functionality that is already present in
IterableDataset
.