Why the collate_fn did not execute? #9568
Unanswered
liangaomng
asked this question in
Q&A
Replies: 1 comment
-
I believe this behavior is expected since the By the way, the code: edge_index = torch.tensor([[i, j] for i in range(n) for j in range(n) if i != j], dtype=torch.long).t() can become very slow when %%time
n=10000
edge_index = torch.tensor([[i, j] for i in range(n) for j in range(n) if i != j], dtype=torch.long).t()
You might want to try this more efficient approach: from torch_geometric.utils import to_undirected, remove_self_loops %%time
n=10000
edge_index = torch.combinations(torch.arange(n)).T
edge_index = to_undirected(edge_index)
# edge_index, _ = remove_self_loops(edge_index)
Hope that helps! |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
import torch
from torch_geometric.data import Data, DataLoader, Batch
Define a function to create a graph data
def create_graph(n, feature_scale):
# Create node features scaled by a feature_scale factor
x = torch.tensor([[i] for i in range(1, n+1)], dtype=torch.float) * feature_scale
# Create edges (fully connected graph minus self-loops)
edge_index = torch.tensor([[i, j] for i in range(n) for j in range(n) if i != j], dtype=torch.long).t()
return Data(x=x, edge_index=edge_index)
Custom collate function for batching graphs
def cus_fn(batch):
print("Executing cus_fn")
return Batch.from_data_list(batch)
Create graph datasets
graph1 = create_graph(3, 1) # 3 nodes with features [1, 2, 3]
graph2 = create_graph(2, 2) # 2 nodes with features [2, 4]
Dataset list
dataset = [graph1, graph2]
DataLoader with a batch size of 2, using the custom collate function
loader = DataLoader(dataset, batch_size=2, collate_fn=cus_fn)
Iterate through the DataLoader and print batch information
for batch in loader:
print("Batch node features:", batch.x)
print("Batch edge indices:", batch.edge_index)
print("Batch mapping to original graphs:", batch.batch)
I wonder that if the collate_fn should print("Executing"),but it did not show it.
Beta Was this translation helpful? Give feedback.
All reactions