-
Hi, I would like to use text features as node features for a GNN. I need to compute all nodes embeddings per batch in order to get all candidate embeddings for a downstream task. Similar to #2765 but I need to make the language model trainable. (which makes it challenging since a What would be the reasonable approach to this? This is my pseudocode # in a forward call
def forward():
avg_text_emb = []
# can't load all text to bert
for text in all_text:
# assume we have batch of 8
# bert is unfreezed
text_embedding = bert(text).pooler_output
# average embeddings
avg_text_emb.append(text_embedding.mean(dim=1))
# all candidate text embeddings
# I have 20K nodes and 200K edges
all_avg_text_emb = torch.stack(avg_text_emb)
# all candidate embeddings from the gnn for a downstream task
all_node_features = gnn(avg_text_emb, edge_idx)
# downstream task
score = some_task(all_node_features, node_features)
return score Any suggestion would be helpful. Thanks!. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 11 replies
-
I think your code looks reasonable but I don't think it will be possible to train a large language model end-to-end in combination with a GNN for around 20k nodes. Alternatives include
|
Beta Was this translation helpful? Give feedback.
I think your code looks reasonable but I don't think it will be possible to train a large language model end-to-end in combination with a GNN for around 20k nodes. Alternatives include