99import torch
1010import torch .nn as nn
1111from datetime import datetime
12+ from random import random
1213from torch .optim import lr_scheduler
1314from torchvision import transforms
1415from tensorboardX import SummaryWriter
1718from pymic .net .net_dict_cls import TorchClsNetDict
1819from pymic .transform .trans_dict import TransformDict
1920from pymic .net_run .agent_abstract import NetRunAgent
21+ from pymic .util .general import mixup
2022import warnings
2123warnings .filterwarnings ('ignore' , '.*output shape of zoom.*' )
2224
@@ -111,16 +113,17 @@ def get_evaluation_score(self, outputs, labels):
111113 """
112114 Get evaluation score for a prediction.
113115
114- :param outputs: (tensor) Prediction obtained by a network.
115- :param labels: (tensor) The ground truth.
116+ :param outputs: (tensor) Prediction obtained by a network with size N X C .
117+ :param labels: (tensor) The ground truth with size N X C .
116118 """
117119 metrics = self .config ['training' ].get ("evaluation_metric" , "accuracy" )
118120 if (metrics != "accuracy" ): # default classification accuracy
119121 raise ValueError ("Not implemeted for metric {0:}" .format (metrics ))
120122 if (self .task_type == "cls" ):
121- _ , preds = torch .max (outputs , 1 )
122- consis = self .convert_tensor_type (preds == labels .data )
123- score = torch .mean (consis )
123+ out_argmax = torch .argmax (outputs , 1 )
124+ lab_argmax = torch .argmax (labels , 1 )
125+ consis = self .convert_tensor_type (out_argmax == lab_argmax )
126+ score = torch .mean (consis )
124127 elif (self .task_type == "cls_nexcl" ): #nonexclusive classification
125128 preds = self .convert_tensor_type (outputs > 0.5 )
126129 consis = self .convert_tensor_type (preds == labels .data )
@@ -129,6 +132,7 @@ def get_evaluation_score(self, outputs, labels):
129132
130133 def training (self ):
131134 iter_valid = self .config ['training' ]['iter_valid' ]
135+ mixup_prob = self .config ['training' ].get ('mixup_probability' , 0.5 )
132136 sample_num = 0
133137 running_loss = 0
134138 running_score = 0
@@ -140,8 +144,11 @@ def training(self):
140144 self .trainIter = iter (self .train_loader )
141145 data = next (self .trainIter )
142146 inputs = self .convert_tensor_type (data ['image' ])
143- labels = data ['label' ].long ()
147+ labels = self .convert_tensor_type (data ['label_prob' ])
148+ if (random () < mixup_prob ):
149+ inputs , labels = mixup (inputs , labels )
144150 inputs , labels = inputs .to (self .device ), labels .to (self .device )
151+
145152 # zero the parameter gradients
146153 self .optimizer .zero_grad ()
147154 # forward + backward + optimize
@@ -174,7 +181,7 @@ def validation(self):
174181 self .net .eval ()
175182 for data in validIter :
176183 inputs = self .convert_tensor_type (data ['image' ])
177- labels = data ['label' ]. long ()
184+ labels = self . convert_tensor_type ( data ['label_prob' ])
178185 inputs , labels = inputs .to (self .device ), labels .to (self .device )
179186 self .optimizer .zero_grad ()
180187 # forward + backward + optimize
0 commit comments