Skip to content

Commit 3f79234

Browse files
authored
Merge pull request #26 from SFI-Visual-Intelligence/fix-main
finds number of channels based on dataset. Adds num_classes to dataset
2 parents e5aafb0 + fc787c2 commit 3f79234

File tree

2 files changed

+18
-5
lines changed

2 files changed

+18
-5
lines changed

main.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def main():
108108
parser.add_argument(
109109
"--device",
110110
type=str,
111-
default="cuda",
111+
default="cpu",
112112
choices=["cuda", "cpu", "mps"],
113113
help="Which device to run the training on.",
114114
)
@@ -124,10 +124,6 @@ def main():
124124

125125
device = args.device
126126

127-
# load model
128-
model = load_model(args.modelname)
129-
model.to(device)
130-
131127
metrics = MetricWrapper(*args.metric)
132128

133129
# Dataset
@@ -143,6 +139,20 @@ def main():
143139
data_path=args.datafolder,
144140
)
145141

142+
# Find number of channels in the dataset
143+
if len(traindata[0][0].shape) == 2:
144+
channels = 1
145+
else:
146+
channels = traindata[0][0].shape[0]
147+
148+
# load model
149+
model = load_model(
150+
args.modelname,
151+
in_channels=channels,
152+
num_classes=traindata.num_classes,
153+
)
154+
model.to(device)
155+
146156
trainloader = DataLoader(traindata,
147157
batch_size=args.batchsize,
148158
shuffle=True,

utils/dataloaders/usps_0_6.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ class USPSDataset0_6(Dataset):
3636
A function/transform that takes in a sample and returns a transformed version.
3737
idx : numpy.ndarray
3838
Indices of samples with labels 0-6.
39+
num_classes : int
40+
Number of classes in the dataset
3941
4042
Methods
4143
-------
@@ -71,6 +73,7 @@ def __init__(
7173
super().__init__()
7274
self.path = list(data_path.glob("*.h5"))[0]
7375
self.transform = transform
76+
self.num_classes = 7
7477

7578
if download:
7679
raise NotImplementedError("Download functionality not implemented.")

0 commit comments

Comments
 (0)