Skip to content

Commit da0a3e6

Browse files
anna-grimanna-grim
andauthored
bug: batch size (#656)
* bug: batch size * bug: batch size --------- Co-authored-by: anna-grim <anna.grim@alleninstitute.org>
1 parent 22d866e commit da0a3e6

File tree

4 files changed

+6
-20
lines changed

4 files changed

+6
-20
lines changed

src/neuron_proofreader/config.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,5 @@ def save(self, dir_path):
179179
dir_path : str
180180
Path to directory to save JSON file.
181181
"""
182-
183182
self.graph.save(os.path.join(dir_path, "metadata_graph.json"))
184183
self.ml.save(os.path.join(dir_path, "metadata_ml.json"))

src/neuron_proofreader/machine_learning/gnn_models.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -91,16 +91,8 @@ def forward(self, input_dict):
9191
x_dict["proposal"] = torch.cat((x_dict["proposal"], x_img), dim=1)
9292

9393
# Message passing
94-
try:
95-
x_dict = self.gat1(x_dict, edge_index_dict)
96-
x_dict = self.gat2(x_dict, edge_index_dict)
97-
except:
98-
print("Before...")
99-
print("\n".join(before))
100-
print("After...")
101-
for key, x in x_dict.items():
102-
print(key, x.size())
103-
stop
94+
x_dict = self.gat1(x_dict, edge_index_dict)
95+
x_dict = self.gat2(x_dict, edge_index_dict)
10496
return self.output(x_dict["proposal"])
10597

10698

src/neuron_proofreader/machine_learning/subgraph_sampler.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,6 @@ def __iter__(self):
107107
self.populate_via_bfs(subgraph, root)
108108

109109
# Yield batch
110-
self.populate_attributes(subgraph)
111110
yield subgraph
112111

113112
def populate_via_bfs(self, subgraph, root):
@@ -178,10 +177,6 @@ def visit_flagged_proposal(self, subgraph, queue, visited, proposal):
178177
if not (v in visited and v in nodes_added):
179178
queue.append((v, 0))
180179

181-
def populate_attributes(self, subgraph):
182-
# TO DO
183-
pass
184-
185180
# --- Helpers ---
186181
def init_subgraph(self):
187182
"""
@@ -204,10 +199,10 @@ def is_subgraph_full(self, subgraph):
204199

205200
class SeededSubgraphSampler(SubgraphSampler):
206201

207-
def __init__(self, graph, max_proposals=200, gnn_depth=2):
202+
def __init__(self, graph, gnn_depth=2, max_proposals=64):
208203
# Call parent class
209204
super(SeededSubgraphSampler, self).__init__(
210-
graph, max_proposals, gnn_depth
205+
graph, gnn_depth, max_proposals
211206
)
212207

213208
# --- Batch Generation ---

src/neuron_proofreader/machine_learning/vision_models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
"""
1111

12-
#from neurobase.finetune import finetune_model
12+
# from neurobase.finetune import finetune_model
1313
from einops import rearrange
1414

1515
import torch
@@ -147,7 +147,7 @@ def __init__(self, checkpoint_path, model_config):
147147

148148
# Instance attributes
149149
self.encoder = full_model.encoder
150-
self.output = ml_util.init_feedforward(384, 1, 2)
150+
self.output = FeedForwardNet(384, 1, 3)
151151

152152
def forward(self, x):
153153
latent0 = self.encoder(x[:, 0:1, ...])

0 commit comments

Comments
 (0)