Skip to content

Commit de66b05

Browse files
anna-grimanna-grim
andauthored
feat: added mae3d (#652)
Co-authored-by: anna-grim <anna.grim@alleninstitute.org>
1 parent 251e813 commit de66b05

File tree

1 file changed

+32
-1
lines changed

1 file changed

+32
-1
lines changed

src/neuron_proofreader/machine_learning/vision_models.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55
@email: anna.grim@alleninstitute.org
66
77
Code for vision models that perform image classification tasks within
8-
NeuronProofreading pipelines.
8+
NeuronProofreader pipelines.
99
1010
"""
1111

12+
#from neurobase.finetune import finetune_model
1213
from einops import rearrange
1314

1415
import torch
@@ -130,6 +131,36 @@ def forward(self, x):
130131

131132

132133
# --- Transformers ---
134+
class MAE3D(nn.Module):
135+
136+
def __init__(self, checkpoint_path, model_config):
137+
# Call parent closs
138+
super().__init__()
139+
140+
# Load model
141+
full_model = finetune_model(
142+
checkpoint_path=checkpoint_path,
143+
model_config=model_config,
144+
task_head_config="binary_classifier",
145+
freeze_encoder=True
146+
)
147+
148+
# Instance attributes
149+
self.encoder = full_model.encoder
150+
self.output = ml_util.init_feedforward(384, 1, 2)
151+
152+
def forward(self, x):
153+
latent0 = self.encoder(x[:, 0:1, ...])
154+
latent1 = self.encoder(x[:, 1:2, ...])
155+
156+
x0 = latent0["latents"][:, 0, :]
157+
x1 = latent1["latents"][:, 0, :]
158+
159+
x = torch.cat((x0, x1), dim=1)
160+
x = self.output(x)
161+
return x
162+
163+
133164
class ViT3D(nn.Module):
134165
"""
135166
A class that implements a 3D Vision transformer.

0 commit comments

Comments
 (0)