Skip to content

Commit 87c1c71

Browse files
committed
Updated training test with small model
1 parent fba44fd commit 87c1c71

File tree

2 files changed

+12
-11
lines changed

2 files changed

+12
-11
lines changed

napari_cellseg3d/code_models/model_workers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1336,6 +1336,7 @@ def train(self):
13361336

13371337
if model_name == "test":
13381338
self.quit()
1339+
yield TrainingReport(False)
13391340

13401341
for epoch in range(self.config.max_epochs):
13411342
# self.log("\n")

napari_cellseg3d/code_models/models/model_test.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def __init__(self):
1313
self.linear = nn.Linear(1, 1)
1414

1515
def forward(self, x):
16-
return self.linear(x)
16+
return self.linear(torch.tensor(x, requires_grad=True))
1717

1818
def get_net(self):
1919
return self
@@ -24,13 +24,13 @@ def get_output(self, _, input):
2424
def get_validation(self, val_inputs):
2525
return val_inputs
2626

27-
if __name__ == "__main__":
28-
29-
model = TestModel()
30-
model.train()
31-
model.zero_grad()
32-
from napari_cellseg3d.config import WEIGHTS_DIR
33-
torch.save(
34-
model.state_dict(),
35-
WEIGHTS_DIR + f"/{get_weights_file()}"
36-
)
27+
# if __name__ == "__main__":
28+
#
29+
# model = TestModel()
30+
# model.train()
31+
# model.zero_grad()
32+
# from napari_cellseg3d.config import WEIGHTS_DIR
33+
# torch.save(
34+
# model.state_dict(),
35+
# WEIGHTS_DIR + f"/{get_weights_file()}"
36+
# )

0 commit comments

Comments
 (0)