Skip to content

Commit d045a2a

Browse files
committed
Add option for setting device to mps (for mac) and a dry_run parameter
- The mps option is necessary to accelerate gpu ops for mac - --dry_run now checks that models/datasets/metrics are loaded before starting training
1 parent f7c2058 commit d045a2a

File tree

1 file changed

+28
-11
lines changed

1 file changed

+28
-11
lines changed

main.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,22 +11,22 @@
1111

1212

1313
def main():
14-
'''
15-
14+
"""
15+
1616
Parameters
1717
----------
18-
18+
1919
Returns
2020
-------
21-
21+
2222
Raises
2323
------
24-
25-
'''
24+
25+
"""
2626
parser = argparse.ArgumentParser(
27-
prog='',
28-
description='',
29-
epilog='',
27+
prog="",
28+
description="",
29+
epilog="",
3030
)
3131
# Structuture related values
3232
parser.add_argument(
@@ -105,15 +105,27 @@ def main():
105105
default=64,
106106
help="Amount of training images loaded in one go",
107107
)
108+
parser.add_argument(
109+
"--device",
110+
type=str,
111+
default="cuda",
112+
choices=["cuda", "cpu", "mps"],
113+
help="Which device to run the training on.",
114+
)
115+
parser.add_argument(
116+
"--dry_run",
117+
action="store_true",
118+
help="If true, the code will not run the training loop.",
119+
)
108120

109121
args = parser.parse_args()
110122

111123
createfolders(args.datafolder, args.resultfolder, args.modelfolder)
112124

113-
device = 'cuda' if th.cuda.is_available() else 'cpu'
125+
device = args.device
114126

115127
# load model
116-
model = load_model()
128+
model = load_model(args.modelname)
117129
model.to(device)
118130

119131
metrics = MetricWrapper(*args.metric)
@@ -144,6 +156,11 @@ def main():
144156
criterion = nn.CrossEntropyLoss()
145157
optimizer = th.optim.Adam(model.parameters(), lr=args.learning_rate)
146158

159+
# This allows us to load all the components without running the training loop
160+
if args.dry_run:
161+
print("Dry run completed")
162+
exit(0)
163+
147164
wandb.init(project='',
148165
tags=[])
149166
wandb.watch(model)

0 commit comments

Comments
 (0)