Skip to content
This repository was archived by the owner on Oct 15, 2025. It is now read-only.

Commit 8fa0f93

Browse files
author
Gregory Johnson
authored
Enhancement/api tools (#168)
* Expose train and predict, train_model returns model * minor changes for readability * bugfix for multichtiffdataset" * try emptying cache after every iter to save memory * Enhancement for multi-channel support
1 parent 89c8c19 commit 8fa0f93

File tree

6 files changed

+31
-12
lines changed

6 files changed

+31
-12
lines changed

examples/download_and_train.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import argparse
22
import os
33
import json
4+
from pathlib import Path
45

56
import quilt3
67
import pandas as pd
7-
from pathlib import Path
8+
import numpy as np
89

910
from fnet.cli.init import save_default_train_options
1011

@@ -50,11 +51,17 @@
5051

5152
data_manifest = aics_pipeline["metadata.csv"]()
5253

54+
# THE ROWS OF THE MANIFEST CORRESPOND TO CELLS, WE TRIM DOWN TO UNIQUIE FOVS
55+
unique_fov_indices = np.unique(data_manifest['FOVId'], return_index=True)[1]
56+
data_manifest = data_manifest.iloc[unique_fov_indices]
57+
5358
# SELECT THE FIRST N_IMAGES_TO_DOWNLOAD
5459
data_manifest = data_manifest.iloc[0:n_images_to_download]
5560

5661
image_source_paths = data_manifest["SourceReadPath"]
5762

63+
pdb.set_trace()
64+
5865
image_target_paths = [
5966
"{}/{}".format(image_save_dir, image_source_path)
6067
for image_source_path in image_source_paths

fnet/cli/train_model.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def main(args: Optional[argparse.Namespace] = None):
9999

100100
if args.path_json and not args.path_json.exists():
101101
save_default_train_options(args.path_json)
102-
return
102+
return None
103103

104104
with open(args.path_json, "r") as fi:
105105
train_options = json.load(fi)
@@ -125,7 +125,7 @@ def main(args: Optional[argparse.Namespace] = None):
125125

126126
if (args.n_iter - model.count_iter) <= 0:
127127
# Stop if no more iterations needed
128-
return
128+
return model
129129

130130
# Get patch pair providers
131131
bpds_train = get_bpds_train(args)
@@ -136,12 +136,14 @@ def main(args: Optional[argparse.Namespace] = None):
136136
do_save = ((idx_iter + 1) % args.interval_save == 0) or (
137137
(idx_iter + 1) == args.n_iter
138138
)
139+
139140
loss_train = model.train_on_batch(*bpds_train.get_batch(args.batch_size))
140141
loss_val = None
141142
if do_save and bpds_val is not None:
142143
loss_val = model.test_on_iterator(
143144
[bpds_val.get_batch(args.batch_size) for _ in range(4)]
144145
)
146+
145147
fnetlogger.add(
146148
{"num_iter": idx_iter + 1, "loss_train": loss_train, "loss_val": loss_val}
147149
)

fnet/data/multichtiffdataset.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,18 @@ def __init__(
2424
# if this column is a string assume it is in "[ind_1, ind_2, ..., ind_n]" format
2525
if isinstance(self.df["channel_signal"][0], str):
2626
self.df["channel_signal"] = [
27-
np.fromstring(ch[1:-1], sep=" ").astype(int)
27+
np.fromstring(ch[1:-1], sep=", ").astype(int)
2828
for ch in self.df["channel_signal"]
2929
]
30+
else:
31+
self.df["channel_signal"] = [[int(ch)] for ch in self.df["channel_signal"]]
32+
33+
if isinstance(self.df["channel_target"][0], str):
3034
self.df["channel_target"] = [
31-
np.fromstring(ch[1:-1], sep=" ").astype(int)
35+
np.fromstring(ch[1:-1], sep=", ").astype(int)
3236
for ch in self.df["channel_target"]
3337
]
3438
else:
35-
self.df["channel_signal"] = [[int(ch)] for ch in self.df["channel_signal"]]
3639
self.df["channel_target"] = [[int(ch)] for ch in self.df["channel_target"]]
3740

3841
assert all(

fnet/fnet_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,7 @@ def train_on_batch(
232232
module = torch.nn.DataParallel(self.net, device_ids=self.gpu_ids)
233233
else:
234234
module = self.net
235+
235236
self.optimizer.zero_grad()
236237
y_hat_batch = module(x_batch)
237238
args = [y_hat_batch, y_batch]

fnet/nn_modules/fnet_nn_3d_params.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def __init__(self, depth=4, mult_chan=32, in_channels=1, out_channels=1):
1010
self.out_channels = out_channels
1111

1212
self.net_recurse = _Net_recurse(
13-
n_in_channels=self.in_channels, mult_chan=self.mult_chan, depth=self.depth
13+
n_in_channels=self.in_channels, mult_chan=self.mult_chan, depth_parent=self.depth, depth=self.depth
1414
)
1515
self.conv_out = torch.nn.Conv3d(
1616
self.mult_chan, self.out_channels, kernel_size=3, padding=1
@@ -22,7 +22,7 @@ def forward(self, x):
2222

2323

2424
class _Net_recurse(torch.nn.Module):
25-
def __init__(self, n_in_channels, mult_chan=2, depth=0):
25+
def __init__(self, n_in_channels, mult_chan=2, depth_parent=0, depth=0):
2626
"""Class for recursive definition of U-network.p
2727
2828
Parameters
@@ -37,8 +37,14 @@ def __init__(self, n_in_channels, mult_chan=2, depth=0):
3737
3838
"""
3939
super().__init__()
40+
4041
self.depth = depth
41-
n_out_channels = n_in_channels * mult_chan
42+
43+
if self.depth == depth_parent:
44+
n_out_channels = mult_chan
45+
else:
46+
n_out_channels = n_in_channels * mult_chan
47+
4248
self.sub_2conv_more = SubNet2Conv(n_in_channels, n_out_channels)
4349
if depth > 0:
4450
self.sub_2conv_less = SubNet2Conv(2 * n_out_channels, n_out_channels)
@@ -52,7 +58,7 @@ def __init__(self, n_in_channels, mult_chan=2, depth=0):
5258
)
5359
self.bn1 = torch.nn.BatchNorm3d(n_out_channels)
5460
self.relu1 = torch.nn.ReLU()
55-
self.sub_u = _Net_recurse(n_out_channels, mult_chan=2, depth=(depth - 1))
61+
self.sub_u = _Net_recurse(n_out_channels, mult_chan=2, depth_parent=depth_parent, depth=(depth - 1))
5662

5763
def forward(self, x):
5864
if self.depth == 0:

fnet/tests/data/testlib.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ def create_multichtiff_data(
7070
{
7171
"dummy_id": idx,
7272
"path_tiff": path_x,
73-
"channel_signal": np.arange(0, n_ch_in),
74-
"channel_target": np.arange(0, n_ch_out) + n_ch_in,
73+
"channel_signal": list(np.arange(0, n_ch_in)),
74+
"channel_target": list(np.arange(0, n_ch_out) + n_ch_in),
7575
}
7676
)
7777

0 commit comments

Comments
 (0)