Skip to content

Commit 55d2edc

Browse files
committed
update classification agent
1 parent 04211f4 commit 55d2edc

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

pymic/loss/seg/ce.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ class CrossEntropyLoss(AbstractSegLoss):
1818
"""
1919
def __init__(self, params = None):
2020
super(CrossEntropyLoss, self).__init__(params)
21-
2221

2322
def forward(self, loss_input_dict):
2423
predict = loss_input_dict['prediction']

pymic/net_run/agent_cls.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,8 @@ def training(self):
145145
self.optimizer.zero_grad()
146146
# forward + backward + optimize
147147
outputs = self.net(inputs)
148-
loss = self.get_loss_value(data, inputs, outputs, labels)
148+
149+
loss = self.get_loss_value(data, outputs, labels)
149150
loss.backward()
150151
self.optimizer.step()
151152
self.scheduler.step()
@@ -175,7 +176,7 @@ def validation(self):
175176
self.optimizer.zero_grad()
176177
# forward + backward + optimize
177178
outputs = self.net(inputs)
178-
loss = self.get_loss_value(data, inputs, outputs, labels)
179+
loss = self.get_loss_value(data, outputs, labels)
179180

180181
# statistics
181182
sample_num += labels.size(0)
@@ -243,10 +244,11 @@ def train_valid(self):
243244
logging.info("{0:} training start".format(str(datetime.now())[:-7]))
244245
self.summ_writer = SummaryWriter(self.config['training']['ckpt_save_dir'])
245246
for it in range(iter_start, iter_max, iter_valid):
247+
lr_value = self.optimizer.param_groups[0]['lr']
246248
train_scalars = self.training()
247249
valid_scalars = self.validation()
248250
glob_it = it + iter_valid
249-
self.write_scalars(train_scalars, valid_scalars, glob_it)
251+
self.write_scalars(train_scalars, valid_scalars, lr_value, glob_it)
250252

251253
if(valid_scalars[metrics] > self.max_val_score):
252254
self.max_val_score = valid_scalars[metrics]

0 commit comments

Comments
 (0)