Passing the new argument in LinkNeighborLoader #8840
-
Hi, always thank you for your effort. I made a new custom dataset that contains a # Please ignore the generate_network function. It just returns the Data object.
GraphList = []
for k in range(4):
GraphList.append(generate_network(pos_set, neg_set, torch.tensor(NodeFeature[k].values)))
GraphList[k].edge_class = torch.full([len(GraphList[k].edge_label)], k)
GraphList
#[Data(edge_index=[2, 1108], edge_label=[1108], x=[19392, 4], edge_class=[1108]),
# Data(edge_index=[2, 1237], edge_label=[1237], x=[19392, 4], edge_class=[1237]),
# Data(edge_index=[2, 3826], edge_label=[3826], x=[19392, 4], edge_class=[3826]),
# Data(edge_index=[2, 7001], edge_label=[7001], x=[19392, 4], edge_class=[7001])] Then, I tried to utilize the # DataList: custom Dataset from the GraphList
for fold, (train_idx, val_idx, test_idx) in enumerate(zip(*k_fold(DataList, folds))):
print(f'FOLD {fold}')
print('-------------------------------------------')
kf_train_data = Data(edge_index = BaseNetwork.edge_index,
edge_label = DataList.edge_label[train_idx],
edge_label_index = DataList.edge_index[:, train_idx],
edge_class = DataList.edge_class[train_idx],
...)
# same process on the validation and test datasets
print(kf_train_data)
#FOLD 0
#-------------------------------------------
#Data(edge_index=[2, 7854705], edge_label=[7903], edge_label_index=[2, 7903], edge_class=[7903], num_nodes=107940)
#FOLD 1
#-------------------------------------------
#Data(edge_index=[2, 7854705], edge_label=[7902], edge_label_index=[2, 7902], edge_class=[7902], num_nodes=107940)
#FOLD 2
#-------------------------------------------
#Data(edge_index=[2, 7854705], edge_label=[7903], edge_label_index=[2, 7903], edge_class=[7903], num_nodes=107940)
#FOLD 3
#-------------------------------------------
#Data(edge_index=[2, 7854705], edge_label=[7904], edge_label_index=[2, 7904], edge_class=[7904], num_nodes=107940)
#FOLD 4
#-------------------------------------------
#Data(edge_index=[2, 7854705], edge_label=[7904], edge_label_index=[2, 7904], edge_class=[7904], num_nodes=107940) ...
train_loader = LinkNeighborLoader(kf_train_data, edge_label_index=kf_train_data.edge_label_index,
batch_size=32, shuffle=True, neg_sampling_ratio=1.0, num_neighbors=[2,2])
next(iter(train_loader))
#Data(edge_index=[2, 446], edge_label=[64], edge_label_index=[2, 64], edge_class=[7904], num_nodes=506, n_id=[506], e_id=[446], num_sampled_nodes=[3], num_sampled_edges=[2], input_id=[32]) Of course, it is the expected result that But when I tried to put Is there any solution for this case or making the custom |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 3 replies
-
For this job, I tried to mimic the original from typing import Callable, Dict, List, Optional, Tuple, Union
from torch_geometric.data import Data, FeatureStore, GraphStore, HeteroData
from torch_geometric.loader.link_loader import LinkLoader
from torch_geometric.sampler import NegativeSampling, NeighborSampler
from torch_geometric.sampler.base import SubgraphType
from torch_geometric.typing import EdgeType, InputEdges, OptTensor
class CustomLinkNeighborLoader(LinkNeighborLoader):
def __init__(
self,
data: Union[Data, HeteroData, Tuple[FeatureStore, GraphStore]],
num_neighbors: Union[List[int], Dict[EdgeType, List[int]]],
edge_label_index: InputEdges = None,
edge_label: OptTensor = None,
edge_label_time: OptTensor = None,
edge_class: OptTensor = None,
replace: bool = False,
subgraph_type: Union[SubgraphType, str] = 'directional',
disjoint: bool = False,
temporal_strategy: str = 'uniform',
neg_sampling: Optional[NegativeSampling] = None,
neg_sampling_ratio: Optional[Union[int, float]] = None,
time_attr: Optional[str] = None,
weight_attr: Optional[str] = None,
transform: Optional[Callable] = None,
transform_sampler_output: Optional[Callable] = None,
is_sorted: bool = False,
filter_per_worker: Optional[bool] = None,
neighbor_sampler: Optional[NeighborSampler] = None,
directed: bool = True, # Deprecated.
**kwargs,
):
self.edge_class = edge_class
super().__init__(
data=data,
num_neighbors=num_neighbors,
edge_label_index=edge_label_index,
edge_label=edge_label,
edge_label_time=edge_label_time,
replace=replace,
subgraph_type=subgraph_type,
disjoint=disjoint,
temporal_strategy=temporal_strategy,
neg_sampling=neg_sampling,
neg_sampling_ratio=neg_sampling_ratio,
time_attr=time_attr,
weight_attr=weight_attr,
transform=transform,
transform_sampler_output=transform_sampler_output,
is_sorted=is_sorted,
filter_per_worker=filter_per_worker,
neighbor_sampler=neighbor_sampler,
directed=directed,
**kwargs,
)
def custom_collate(self, batch):
collated_data = super(CustomLinkNeighborLoader, self).custom_collate(batch)
collated_data.edge_class = torch.cat([data.edge_class for data in batch]) if self.edge_class is not None else None
return collated_data
for fold, (train_idx, val_idx, test_idx) in enumerate(zip(*k_fold(DataList, folds))):
...
train_loader = CustomLinkNeighborLoader(kf_train_data, edge_label_index=kf_train_data.edge_label_index, edge_class=kf_train_data.edge_class,
batch_size=32, shuffle=True, neg_sampling_ratio=1.0, num_neighbors=[2,2]) But it still returned the same result as the initial question.
I guess that I missed some parts in the |
Beta Was this translation helpful? Give feedback.
-
You can use
|
Beta Was this translation helpful? Give feedback.
You can use
input_id
to filteredge_class
later: