You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm building a simple GraphSAGE model. The purpose of the model is to create an embedding for each node in the graph & store the embedding in a database that is keyed on the name of the node (strings -- Alice, Bob, Charlene, Donald), the region, and the dttm. So I want to yield a pandas.DataFrame alongside the batch data to hold these keys. I'm finding it challenging to verify that I'm actually retrieving the names correctly. The names are a list, each element giving the name of the node. The attributes region and dttm are scalar values and are are properties of the whole the graph, not the nodes.
(Backgroud: My project simply is not useful if I can't relate the region, dttm and name data to the graph embeddings -- storing node indices is not enough. I won't know these metadata elements in advance of loading the graph object, so I can't simply make a single lookup table for a graph's indices and refer back to it later -- I have to dispatch my metadata with the minibatch data.)
My work so far: my metadata (region,dttm,names) is stored as attributes in each torch_geometric.data.Data object. Then I have to go backwards through the indexing scheme used by the minibatch to the global indexing scheme used by the Data object. (I also only want the embeddings for the "positive" edges, which is why there's an if in the second list comprehension.)
This is my attempt at a minimal, reproducible example.
import numpy as np
import pandas as pd
import torch
from torch_geometric.data import Data
from torch_geometric.loader import LinkNeighborLoader
class MetaData(Data):
def __init__(self, x, edge_index, y, domain2ndx, ndx2domain, region_dttm, **kwargs):
super().__init__(
x=x, edge_index=edge_index, y=y, edge_attr=None, pos=None, time=None
)
self._domain2ndx = domain2ndx
self._ndx2domain = ndx2domain
self._region_dttm = region_dttm
@property
def region_dttm(self):
return self._region_dttm
def get_domain(self, ndx):
return self._domain2ndx[ndx]
def get_ndx(self, domain):
return self._ndx2domain[domain]
def metadata_iterator(queue):
graph_names =queue.data._ndx2domain
for batch in queue:
# emit graph_name to help debugging
# yield graph_fname, batch_
names_batch = [graph_names[i.item()] for i in batch.n_id]
names_pos = [
names_batch[batch.edge_label_index[0, i]]
for i, j in enumerate(batch.edge_label)
if int(j) == 1
]
df = pd.DataFrame({"name": names_pos})
df["region"] = region
df["dttm"] = dttm
yield df, batch
if __name__ == "__main__":
# A simple graph
e_index = torch.LongTensor(
[
[0, 1],
[1, 0],
[1, 2],
[2, 1],
[2, 3],
[3, 2],
[3, 4],
[4, 3],
]
)
my_names = ["alice", "bob", "charlie", "donald", "erik"]
my_names_dict = {k:v for v,k in enumerate(my_names)}
y = torch.LongTensor([0, 1, 0, 1, 0])
prng = np.random.default_rng(1234)
features = torch.FloatTensor(prng.normal(size=(len(my_names), 3)))
data = MetaData(
x=features,
edge_index=e_index.t().contiguous(),
y=y,
names=my_names,
region_dttm=("XYZ", "2025030100"),
ndx2domain=my_names,
domain2ndx=my_names_dict,
)
region, dttm = data.region_dttm
print(data)
loader = LinkNeighborLoader(
data, batch_size=1, num_neighbors=[2, 2], neg_sampling_ratio=1.0, shuffle=False
)
print(loader.__dict__)
print(loader.data)
for batch_df, batch in metadata_iterator(loader):
print(batch_df)
The purpose of MetaData is to allow me to store my metadata in a Data object without raising errors on iteration. If we just shove the list my_names into the Data object, loader iteration fails. But torch.Tensor classes can't store strings.
This code appears to work, in the sense that I don't get any IndexErrors. But I could have a nasty semantic error lurking in here, if I've misunderstood how these indices work. (Some names are repeated in the output; that's because their indices appear more than once in batch.edge_label_index -- I assume this is simply fixed by swapping NeighborLoader for LinkNeighborLoader, i.e. I'll need to use different loaders for training and testing, which is fine. The main thing is that I want to know how to make sure that I'm getting the metadata to match the model outputs.)
My question: Is this the right way to retrieve metadata information for a minibatch? And is there a simpler way or built-in method that I'm missing? I feel like I must be missing something obvious and am at risk of reinventing the wheel here.
I'm using torch-geometric==2.6.1 because it's the latest stable release/I don't have to build from the main branch.
PS -- I'm very pleased with PyG so far. It's working very well for me overall, and has been easy for me (a long-time PyTorch user) to get underway with the PyG extensions to PyTorch.
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
I'm building a simple GraphSAGE model. The purpose of the model is to create an embedding for each node in the graph & store the embedding in a database that is keyed on the name of the node (strings -- Alice, Bob, Charlene, Donald), the
region
, and thedttm
. So I want to yield apandas.DataFrame
alongside the batch data to hold these keys. I'm finding it challenging to verify that I'm actually retrieving thenames
correctly. Thenames
are a list, each element giving the name of the node. The attributesregion
anddttm
are scalar values and are are properties of the whole the graph, not the nodes.(Backgroud: My project simply is not useful if I can't relate the
region
,dttm
andname
data to the graph embeddings -- storing node indices is not enough. I won't know these metadata elements in advance of loading the graph object, so I can't simply make a single lookup table for a graph's indices and refer back to it later -- I have to dispatch my metadata with the minibatch data.)My work so far: my metadata (
region
,dttm
,names
) is stored as attributes in eachtorch_geometric.data.Data
object. Then I have to go backwards through the indexing scheme used by the minibatch to the global indexing scheme used by theData
object. (I also only want the embeddings for the "positive" edges, which is why there's anif
in the second list comprehension.)This is my attempt at a minimal, reproducible example.
The purpose of
MetaData
is to allow me to store my metadata in a Data object without raising errors on iteration. If we just shove the listmy_names
into the Data object, loader iteration fails. Buttorch.Tensor
classes can't store strings.This code appears to work, in the sense that I don't get any
IndexErrors
. But I could have a nasty semantic error lurking in here, if I've misunderstood how these indices work. (Some names are repeated in the output; that's because their indices appear more than once inbatch.edge_label_index
-- I assume this is simply fixed by swappingNeighborLoader
forLinkNeighborLoader
, i.e. I'll need to use different loaders for training and testing, which is fine. The main thing is that I want to know how to make sure that I'm getting the metadata to match the model outputs.)My question: Is this the right way to retrieve metadata information for a minibatch? And is there a simpler way or built-in method that I'm missing? I feel like I must be missing something obvious and am at risk of reinventing the wheel here.
I'm using
torch-geometric==2.6.1
because it's the latest stable release/I don't have to build from themain
branch.PS -- I'm very pleased with PyG so far. It's working very well for me overall, and has been easy for me (a long-time PyTorch user) to get underway with the PyG extensions to PyTorch.
Beta Was this translation helpful? Give feedback.
All reactions