File tree Expand file tree Collapse file tree 2 files changed +18
-5
lines changed
Expand file tree Collapse file tree 2 files changed +18
-5
lines changed Original file line number Diff line number Diff line change @@ -108,7 +108,7 @@ def main():
108108 parser .add_argument (
109109 "--device" ,
110110 type = str ,
111- default = "cuda " ,
111+ default = "cpu " ,
112112 choices = ["cuda" , "cpu" , "mps" ],
113113 help = "Which device to run the training on." ,
114114 )
@@ -124,10 +124,6 @@ def main():
124124
125125 device = args .device
126126
127- # load model
128- model = load_model (args .modelname )
129- model .to (device )
130-
131127 metrics = MetricWrapper (* args .metric )
132128
133129 # Dataset
@@ -143,6 +139,20 @@ def main():
143139 data_path = args .datafolder ,
144140 )
145141
142+ # Find number of channels in the dataset
143+ if len (traindata [0 ][0 ].shape ) == 2 :
144+ channels = 1
145+ else :
146+ channels = traindata [0 ][0 ].shape [0 ]
147+
148+ # load model
149+ model = load_model (
150+ args .modelname ,
151+ in_channels = channels ,
152+ num_classes = traindata .num_classes ,
153+ )
154+ model .to (device )
155+
146156 trainloader = DataLoader (traindata ,
147157 batch_size = args .batchsize ,
148158 shuffle = True ,
Original file line number Diff line number Diff line change @@ -36,6 +36,8 @@ class USPSDataset0_6(Dataset):
3636 A function/transform that takes in a sample and returns a transformed version.
3737 idx : numpy.ndarray
3838 Indices of samples with labels 0-6.
39+ num_classes : int
40+ Number of classes in the dataset
3941
4042 Methods
4143 -------
@@ -71,6 +73,7 @@ def __init__(
7173 super ().__init__ ()
7274 self .path = list (data_path .glob ("*.h5" ))[0 ]
7375 self .transform = transform
76+ self .num_classes = 7
7477
7578 if download :
7679 raise NotImplementedError ("Download functionality not implemented." )
You can’t perform that action at this time.
0 commit comments