Skip to content

Commit 2b69608

Browse files
anna-grimanna-grim
andauthored
Refactor split inference (#648)
* refactor: updated split inference pipeline * messy and in progress * refactor: split inference almost working * refactor: infernce is working * refactor: split inference works * refactor: updated saving preds * bug: empty branch profiles * remove test block * updated empty edge index handling * doc * debugging gnn * refactor: split dataset creation * doc: split datasets * updated split train * bug: subgraph sampling * bug: subgraph sampling * debugging * debugging printouts * bug: debugging --------- Co-authored-by: anna-grim <anna.grim@alleninstitute.org>
1 parent 4dae5dc commit 2b69608

File tree

1 file changed

+9
-17
lines changed

1 file changed

+9
-17
lines changed

src/neuron_proofreader/machine_learning/gnn_models.py

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -80,39 +80,31 @@ def forward(self, input_dict):
8080
x_img = input_dict["img"]
8181
edge_index_dict = input_dict["edge_index_dict"]
8282

83+
before = list()
84+
for key, x in x_dict.items():
85+
before.append(f"{key}: {x.size()}")
86+
8387
# Node embeddings
8488
x_img = self.patch_embedding(x_img)
8589
for key, f in self.node_embedding.items():
8690
x_dict[key] = f(x_dict[key])
87-
88-
if x_img.dim() != 2:
89-
print("gnn_models - 89:", x_img.shape)
90-
x_img = x_img.unsqueeze(0)
91-
if x_dict["proposal"].dim() != 2:
92-
print("gnn_models - 92:", x_dict["proposal"].shape)
93-
x_dict["proposal"] = x_dict["proposal"].unsqueeze(0)
94-
9591
x_dict["proposal"] = torch.cat((x_dict["proposal"], x_img), dim=1)
9692

9793
# Message passing
9894
try:
9995
x_dict = self.gat1(x_dict, edge_index_dict)
10096
x_dict = self.gat2(x_dict, edge_index_dict)
10197
except:
102-
print("x_dict:", x_dict)
103-
print("edge_index_dict:", edge_index_dict)
98+
print("Before...")
99+
print("\n".join(before))
100+
print("After...")
101+
for key, x in x_dict.items():
102+
print(key, x.size())
104103
stop
105104
return self.output(x_dict["proposal"])
106105

107106

108107
# --- Helpers ---
109-
def _filter_empty(edge_index_dict):
110-
return {
111-
k: v for k, v in edge_index_dict.items()
112-
if v.numel() > 0
113-
}
114-
115-
116108
def init_gat_same(hidden_dim, edge_dim, heads):
117109
gat = nn_geometric.GATv2Conv(
118110
-1, hidden_dim, dropout=0.1, edge_dim=edge_dim, heads=heads

0 commit comments

Comments
 (0)