Skip to content
Discussion options

You must be logged in to vote

I see. The problem with that is that the Batch class does not have access to the PairData.__inc__ method. I currently do not see a way to allow that. I think converting your Batch back to a PairData object yields the most elegant formulation to fix this:

from torch_geometric.data import Data, DataLoader, Batch
import torch
from copy import deepcopy


class PairData(Data):
    def __init__(self, **kwargs):  # We allow arbitrary arguments
        super(PairData, self).__init__(**kwargs)

    def __inc__(self, key, value):
        if key == 'edge_index_s':
            return self.x_s.size(0)
        if key == 'edge_index_t':
            return self.x_t.size(0)
        else:
            return 

Replies: 1 comment 3 replies

Comment options

You must be logged in to vote
3 replies
@yangysc
Comment options

@rusty1s
Comment options

@yangysc
Comment options

Answer selected by yangysc
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants