Skip to content

Commit 37311e4

Browse files
aivanoufacebook-github-bot
authored andcommitted
Make output_path optional
Summary: Since we removed distributed sum, we need to use this example to run fb internal tests. For internal tests, we don't need the `output_path`, which introduces around ~200 mb of data on each run Reviewed By: kiukchung Differential Revision: D31661378 fbshipit-source-id: 098bf9f5be9302e7d8cced672ba9cf7eaf8b32e6
1 parent 7195872 commit 37311e4

File tree

1 file changed

+16
-8
lines changed
  • torchx/examples/apps/lightning_classy_vision

1 file changed

+16
-8
lines changed

torchx/examples/apps/lightning_classy_vision/train.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import os
2222
import sys
2323
import tempfile
24-
from typing import List
24+
from typing import List, Optional
2525

2626
import pytorch_lightning as pl
2727
import torch
@@ -72,8 +72,7 @@ def parse_args(argv: List[str]) -> argparse.Namespace:
7272
parser.add_argument(
7373
"--output_path",
7474
type=str,
75-
help="path to place checkpoints and model outputs",
76-
required=True,
75+
help="path to place checkpoints and model outputs, if not specified, checkpoints are not saved",
7776
)
7877
parser.add_argument(
7978
"--log_path",
@@ -94,6 +93,16 @@ def get_gpu_devices() -> int:
9493
return torch.cuda.device_count()
9594

9695

96+
def get_model_checkpoint(args: argparse.Namespace) -> Optional[ModelCheckpoint]:
97+
if not args.output_path:
98+
return None
99+
return ModelCheckpoint(
100+
monitor="train_loss",
101+
dirpath=args.output_path,
102+
save_last=True,
103+
)
104+
105+
97106
def main(argv: List[str]) -> None:
98107
with tempfile.TemporaryDirectory() as tmpdir:
99108
args = parse_args(argv)
@@ -117,11 +126,10 @@ def main(argv: List[str]) -> None:
117126
)
118127

119128
# Setup model checkpointing
120-
checkpoint_callback = ModelCheckpoint(
121-
monitor="train_loss",
122-
dirpath=args.output_path,
123-
save_last=True,
124-
)
129+
checkpoint_callback = get_model_checkpoint(args)
130+
callbacks = []
131+
if checkpoint_callback:
132+
callbacks.append(checkpoint_callback)
125133
if args.load_path:
126134
print(f"loading checkpoint: {args.load_path}...")
127135
model.load_from_checkpoint(checkpoint_path=args.load_path)

0 commit comments

Comments
 (0)