|
| 1 | +from pathlib import Path |
| 2 | + |
1 | 3 | from napari_cellseg3d import config |
2 | | -from napari_cellseg3d.code_plugins import plugin_model_training as train |
| 4 | +from napari_cellseg3d.code_plugins.plugin_model_training import Trainer |
| 5 | +from napari_cellseg3d._tests.fixtures import LogFixture |
3 | 6 |
|
4 | 7 |
|
5 | | -def test_check_ready(make_napari_viewer): |
6 | | - view = make_napari_viewer() |
7 | | - widget = train.Trainer(view) |
| 8 | +def test_training(make_napari_viewer, qtbot): |
| 9 | + im_path = str(Path(__file__).resolve().parent / "res/test.tif") |
| 10 | + |
| 11 | + viewer = make_napari_viewer() |
| 12 | + widget = Trainer(viewer) |
| 13 | + widget.log = LogFixture() |
| 14 | + viewer.window.add_dock_widget(widget) |
8 | 15 |
|
9 | 16 | widget.images_filepath = None |
10 | 17 | widget.labels_filepaths = None |
11 | 18 |
|
12 | | - res = widget.check_ready() |
13 | | - assert not res |
| 19 | + assert not widget.check_ready() |
| 20 | + |
| 21 | + assert widget.filetype_choice.currentText() == ".tif" |
| 22 | + |
| 23 | + widget.images_filepaths = [im_path] |
| 24 | + widget.labels_filepaths = [im_path] |
| 25 | + |
| 26 | + assert widget.check_ready() |
| 27 | + |
| 28 | + ################# |
| 29 | + # Training is too long to test properly this way. Do not use on Github |
| 30 | + ################# |
14 | 31 |
|
15 | | - # widget.images_filepath = ["C:/test/something.tif"] |
16 | | - # widget.labels_filepaths = ["C:/test/lab_something.tif"] |
17 | | - # res = widget.check_ready() |
| 32 | + # widget.start() |
| 33 | + # assert widget.worker is not None |
18 | 34 | # |
19 | | - # assert res |
| 35 | + # with qtbot.waitSignal(signal=widget.worker.finished, timeout=60000) as blocker: # wait only for 60 seconds. |
| 36 | + # blocker.connect(widget.worker.errored) |
20 | 37 |
|
21 | 38 |
|
22 | 39 | def test_update_loss_plot(make_napari_viewer): |
23 | 40 | view = make_napari_viewer() |
24 | | - widget = train.Trainer(view) |
| 41 | + widget = Trainer(view) |
25 | 42 |
|
26 | 43 | widget.worker_config = config.TrainingWorkerConfig() |
27 | 44 | widget.worker_config.validation_interval = 1 |
|
0 commit comments