How to get the main graph index of each batch_data #10128
Unanswered
streetcorner
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hello, after I use clusterloader to separate my graph data, I want to add an attribute to each batch_Data for model training. How can I know the node index of each node in the entire graph in each batch_data?
my torch version is 1.12.1,There is no n_id or node_idx in the batch_data attributes
My code example is as follows:
class CustomClusterLoader(ClusterLoader):
def init(self, data, total_attr, batch_size, shuffle=True):
cluster_data = ClusterData(data, num_parts=100)
super().init(cluster_data, batch_size=batch_size, shuffle=shuffle)
self.total_attr = total_attr
for batch_data in super().iter():
print(f"batch_data keys: {batch_data.keys()}")
print(f"batch_data attributes: {dir(batch_data)}")
curent_attr is curent_attr is the information that needs to be extracted from the main graph based on the node index. So I need to know the index of the node of batch_data in the main graph。
Thanks!
Beta Was this translation helpful? Give feedback.
All reactions