-
I want to process a batch of graphs, where the graphs are stacked in a big matrix following the documentation: https://pytorch-geometric.readthedocs.io/en/latest/advanced/batching.html The differences are:
To illustrate the idea better, here is a figure: The final goal is to feed all logits at-once in the block-diagonal matrix form, and compare them with the labels using binary cross entropy loss. Could you help me with the predict() function? class GraphModel(pl.LightningModule):
def __init__(self, **kargs):
...
# Some GNN layers like GAT
self.encoder = ...
self.predictor= torch.nn.Sequential(
Linear(latent_size * 2, decoder_size),
ReLU(),
BatchNorm1d(decoder_size),
Linear(decoder_size, 1),
ReLU())
def encode(self, x, edge_attr, edge_index):
return self.encoder(x, edge_index, edge_attr)
def predict(self, z, batch_index):
# TODO: compute pair-wise node embeddings and stack the matrices
# z: latent vector |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 11 replies
-
I personally think the easiest way to do this would be via stacking of labels in a separate batch dimension and utilize padding (e.g., via
|
Beta Was this translation helpful? Give feedback.
I personally think the easiest way to do this would be via stacking of labels in a separate batch dimension and utilize padding (e.g., via
to_dense_batch
). I think you need to do this anyway inside the model, and so there doesn't exist a good reason to not do it already during batch creation. Then, you can directly comparez
andy
and compute BCE loss on top, e.g., via: