Skip to content
This repository was archived by the owner on Oct 15, 2025. It is now read-only.

Commit a01b75e

Browse files
author
Gregory Johnson
authored
Enhancement/api tools (#164)
* Expose train and predict, train_model returns model * minor changes for readability * bugfix for multichtiffdataset"
1 parent e4444c5 commit a01b75e

File tree

5 files changed

+44
-17
lines changed

5 files changed

+44
-17
lines changed

fnet/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
from fnet import models
22
from fnet.fnetlogger import FnetLogger
33

4+
# Clean these up later - GRJ 2020-02-04
5+
from fnet.cli.train_model import train_model as train
6+
from fnet.cli.predict import main as predict
7+
48
__author__ = "Gregory R. Johnson"
59
__email__ = "[email protected]"
610
__version__ = "0.2.0"

fnet/cli/main.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,12 @@ def main() -> None:
2828
init.add_parser_arguments(parser_init)
2929
train_model.add_parser_arguments(parser_train)
3030
predict.add_parser_arguments(parser_predict)
31+
3132
parser_init.set_defaults(func=init.main)
3233
parser_train.set_defaults(func=train_model.main)
3334
parser_predict.set_defaults(func=predict.main)
3435
args = parser.parse_args()
36+
3537
# Remove 'func' from args so it is not passed to target script
3638
func = args.func
3739
delattr(args, "func")

fnet/cli/train_model.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -86,26 +86,32 @@ def add_parser_arguments(parser) -> None:
8686
parser.add_argument("--gpu_ids", nargs="+", default=[0], type=int, help="gpu_id(s)")
8787

8888

89-
def main(args: Optional[argparse.Namespace] = None) -> None:
89+
def main(args: Optional[argparse.Namespace] = None):
9090
"""Trains a model."""
9191
time_start = time.time()
92+
9293
if args is None:
9394
parser = argparse.ArgumentParser()
9495
add_parser_arguments(parser)
9596
args = parser.parse_args()
96-
if args.json and not args.json.exists():
97-
save_default_train_options(args.json)
97+
98+
args.path_json = Path(args.json)
99+
100+
if args.path_json and not args.path_json.exists():
101+
save_default_train_options(args.path_json)
98102
return
99-
with open(args.json, "r") as fi:
103+
104+
with open(args.path_json, "r") as fi:
100105
train_options = json.load(fi)
106+
101107
args.__dict__.update(train_options)
102108
add_logging_file_handler(Path(args.path_save_dir, "train_model.log"))
103109
logger.info(f"Started training at: {datetime.datetime.now()}")
104110

105111
set_seeds(args.seed)
106112
log_training_options(vars(args))
107113
path_model = os.path.join(args.path_save_dir, "model.p")
108-
model = fnet.models.load_or_init_model(path_model, args.json)
114+
model = fnet.models.load_or_init_model(path_model, args.path_json)
109115
init_cuda(args.gpu_ids[0])
110116
model.to_gpu(args.gpu_ids)
111117
logger.info(model)
@@ -124,6 +130,8 @@ def main(args: Optional[argparse.Namespace] = None) -> None:
124130
# Get patch pair providers
125131
bpds_train = get_bpds_train(args)
126132
bpds_val = get_bpds_val(args)
133+
134+
# MAIN LOOP
127135
for idx_iter in range(model.count_iter, args.n_iter):
128136
do_save = ((idx_iter + 1) % args.interval_save == 0) or (
129137
(idx_iter + 1) == args.n_iter
@@ -164,6 +172,8 @@ def main(args: Optional[argparse.Namespace] = None) -> None:
164172
path_save=os.path.join(args.path_save_dir, "loss_curves.png"),
165173
)
166174

175+
return model
176+
167177

168178
def train_model(
169179
batch_size: int = 28,
@@ -182,8 +192,9 @@ def train_model(
182192
seed: Optional[int] = None,
183193
json: Optional[str] = None,
184194
gpu_ids: Optional[List[int]] = None,
185-
) -> None:
195+
):
186196
"""Python API for training."""
197+
187198
bpds_kwargs = bpds_kwargs or {
188199
"buffer_size": 16,
189200
"buffer_switch_interval": 2800, # every 100 updates
@@ -201,7 +212,8 @@ def train_model(
201212
}
202213
iter_checkpoint = iter_checkpoint or []
203214
gpu_ids = gpu_ids or [0]
204-
json = json or str(Path(path_save_dir, "train_options.json"))
215+
216+
json = json or f"{path_save_dir}train_options.json"
205217

206218
pnames, _, _, locs = inspect.getargvalues(inspect.currentframe())
207219
train_options = {k: locs[k] for k in pnames}
@@ -214,10 +226,11 @@ def train_model(
214226
path_json.parent.mkdir(parents=True)
215227

216228
json = globals()["json"] # retrieve global module
217-
with path_json.open("w") as fo:
218-
json.dump(train_options, fo, indent=4, sort_keys=True)
229+
with path_json.open("w") as f:
230+
json.dump(train_options, f, indent=4, sort_keys=True)
219231
logger.info(f"Saved: {path_json}")
220232

221233
args = argparse.Namespace()
222234
args.__dict__.update(train_options)
223-
main(args)
235+
236+
return main(args)

fnet/data/multichtiffdataset.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,8 @@
77

88

99
class MultiChTiffDataset(FnetDataset):
10-
"""Dataset for multi-channel tiff files.
11-
12-
Currently assumes that images are loaded in STCZYX format
13-
10+
"""
11+
Dataset for multi-channel tiff files.
1412
"""
1513

1614
def __init__(
@@ -43,6 +41,16 @@ def __init__(
4341
)
4442

4543
def __getitem__(self, index):
44+
"""
45+
Parameters
46+
----------
47+
index: integer
48+
49+
Returns
50+
-------
51+
C by <spatial dimensions> torch.Tensor
52+
"""
53+
4654
element = self.df.iloc[index, :]
4755
has_target = not np.any(np.isnan(element["channel_target"]))
4856

@@ -67,7 +75,7 @@ def __getitem__(self, index):
6775
im_out = [torch.from_numpy(im.astype(float)).float() for im in im_out]
6876

6977
# unsqueeze to make the first dimension be the channel dimension
70-
im_out = [torch.unsqueeze(im, 0) for im in im_out]
78+
# im_out = [torch.unsqueeze(im, 0) for im in im_out]
7179

7280
return tuple(im_out)
7381

fnet/tests/test_multichtiffdataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,5 +28,5 @@ def test_MultiTiffDataset(tmp_path, n_ch_in, n_ch_out, dims_zyx):
2828
len_data = 2
2929
assert len(data) == len_data
3030

31-
assert tuple(data[0].shape) == (1,) + (n_ch_in,) + dims_zyx
32-
assert tuple(data[1].shape) == (1,) + (n_ch_out,) + dims_zyx
31+
assert tuple(data[0].shape) == (n_ch_in,) + dims_zyx
32+
assert tuple(data[1].shape) == (n_ch_out,) + dims_zyx

0 commit comments

Comments
 (0)