|
5 | 5 | @email: anna.grim@alleninstitute.org |
6 | 6 |
|
7 | 7 | Code for vision models that perform image classification tasks within |
8 | | -NeuronProofreading pipelines. |
| 8 | +NeuronProofreader pipelines. |
9 | 9 |
|
10 | 10 | """ |
11 | 11 |
|
| 12 | +#from neurobase.finetune import finetune_model |
12 | 13 | from einops import rearrange |
13 | 14 |
|
14 | 15 | import torch |
@@ -130,6 +131,36 @@ def forward(self, x): |
130 | 131 |
|
131 | 132 |
|
132 | 133 | # --- Transformers --- |
| 134 | +class MAE3D(nn.Module): |
| 135 | + |
| 136 | + def __init__(self, checkpoint_path, model_config): |
| 137 | + # Call parent closs |
| 138 | + super().__init__() |
| 139 | + |
| 140 | + # Load model |
| 141 | + full_model = finetune_model( |
| 142 | + checkpoint_path=checkpoint_path, |
| 143 | + model_config=model_config, |
| 144 | + task_head_config="binary_classifier", |
| 145 | + freeze_encoder=True |
| 146 | + ) |
| 147 | + |
| 148 | + # Instance attributes |
| 149 | + self.encoder = full_model.encoder |
| 150 | + self.output = ml_util.init_feedforward(384, 1, 2) |
| 151 | + |
| 152 | + def forward(self, x): |
| 153 | + latent0 = self.encoder(x[:, 0:1, ...]) |
| 154 | + latent1 = self.encoder(x[:, 1:2, ...]) |
| 155 | + |
| 156 | + x0 = latent0["latents"][:, 0, :] |
| 157 | + x1 = latent1["latents"][:, 0, :] |
| 158 | + |
| 159 | + x = torch.cat((x0, x1), dim=1) |
| 160 | + x = self.output(x) |
| 161 | + return x |
| 162 | + |
| 163 | + |
133 | 164 | class ViT3D(nn.Module): |
134 | 165 | """ |
135 | 166 | A class that implements a 3D Vision transformer. |
|
0 commit comments