Skip to content
Discussion options

You must be logged in to vote

Instead of saving numerical node features in data.x as a PyTorch Tensor, you should be able to simply assign each node a str attribute:

from itertools import chain

from torch_geometric.data import Data
from torch_geometric.loader import DataLoader

data = Data(sentence=['node1', 'node2', 'node3'], num_nodes=3)

loader = DataLoader([data, data, data], batch_size=3)

batch = next(iter(loader))
print(batch)
print(batch.sentence)
print(list(chain(*batch.sentence)))
DataBatch(sentence=[3], num_nodes=9, batch=[9], ptr=[4])
[['node1', 'node2', 'node3'], ['node1', 'node2', 'node3'], ['node1', 'node2', 'node3']]
['node1', 'node2', 'node3', 'node1', 'node2', 'node3', 'node1', 'node2', 'node3']

Yo…

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@piyush-J
Comment options

Answer selected by piyush-J
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants