Skip to content

Commit 94541d8

Browse files
committed
Added checks for commandline parameters
1 parent b93ee66 commit 94541d8

File tree

4 files changed

+25
-5
lines changed

4 files changed

+25
-5
lines changed

utils/arg_parser.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,14 @@ def get_args():
4444
"--modelname",
4545
type=str,
4646
default="MagnusModel",
47-
choices=["MagnusModel", "ChristianModel", "SolveigModel", "JanModel"],
47+
choices=["MagnusModel", "ChristianModel", "SolveigModel", "JanModel", "JohanModel"],
4848
help="Model which to be trained on",
4949
)
5050
parser.add_argument(
5151
"--dataset",
5252
type=str,
5353
default="svhn",
54-
choices=["svhn", "usps_0-6", "uspsh5_7_9", "mnist_0-3"],
54+
choices=["svhn", "usps_0-6", "usps_7-9", "mnist_0-3", "mnist_4-9"],
5555
help="Which dataset to train the model on.",
5656
)
5757

@@ -95,4 +95,17 @@ def get_args():
9595
action="store_true",
9696
help="If true, the code will not run the training loop.",
9797
)
98-
return parser.parse_args()
98+
args = parser.parse_args()
99+
100+
101+
assert args.device in ["cuda", "cpu", "mps"], "Device should be either 'cuda' or 'cpu' or 'mps'."
102+
assert args.epoch > 0, "Epoch should be a positive integer."
103+
assert args.learning_rate > 0, "Learning rate should be a positive float."
104+
assert args.batchsize > 0, "Batch size should be a positive integer."
105+
assert args.dataset in ["svhn", "usps_0-6", "usps_7-9", "mnist_0-3", "mnist_4-9"], "Dataset should be either 'svhn', 'usps_0-6', 'usps_7-9', 'mnist_0-3' or 'mnist_4-9'."
106+
assert args.modelname in ["MagnusModel", "ChristianModel", "SolveigModel", "JanModel", "JohanModel"], "Model name should be either 'MagnusModel', 'ChristianModel', 'SolveigModel', 'JanModel', or 'JohanModel."
107+
assert all([metric in ["entropy", "f1", "recall", "precision", "accuracy"] for metric in args.metric]), "Metric should be either 'entropy', 'f1', 'recall', 'precision', or 'accuracy'."
108+
109+
110+
111+
return args

utils/load_data.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,5 +40,9 @@ def load_data(dataset: str, *args, **kwargs) -> Dataset:
4040
return MNISTDataset0_3(*args, **kwargs)
4141
case "usps_7-9":
4242
return USPSH5_Digit_7_9_Dataset(*args, **kwargs)
43+
case "svhn":
44+
raise NotImplementedError("SVHN dataset not yet implemented.")
45+
case "mnist_4-9":
46+
raise NotImplementedError("MNIST 4-9 dataset not yet implemented.")
4347
case _:
4448
raise NotImplementedError(f"Dataset: {dataset} not implemented.")

utils/load_model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch.nn as nn
22

3-
from .models import ChristianModel, JanModel, MagnusModel, SolveigModel
3+
from .models import ChristianModel, JanModel, MagnusModel, SolveigModel, JohanModel
44

55

66
def load_model(modelname: str, *args, **kwargs) -> nn.Module:
@@ -44,6 +44,8 @@ def load_model(modelname: str, *args, **kwargs) -> nn.Module:
4444
return JanModel(*args, **kwargs)
4545
case "solveigmodel":
4646
return SolveigModel(*args, **kwargs)
47+
case "johanmodel":
48+
return JohanModel(*args, **kwargs)
4749
case _:
4850
errmsg = (
4951
f"Model: {modelname} not implemented. "

utils/models/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
__all__ = ["MagnusModel", "ChristianModel", "JanModel", "SolveigModel"]
1+
__all__ = ["MagnusModel", "ChristianModel", "JanModel", "SolveigModel", "JohanModel"]
22

33
from .christian_model import ChristianModel
44
from .jan_model import JanModel
55
from .magnus_model import MagnusModel
66
from .solveig_model import SolveigModel
7+
from .johan_model import JohanModel

0 commit comments

Comments
 (0)