File tree Expand file tree Collapse file tree 1 file changed +13
-13
lines changed
torchx/examples/apps/lightning_classy_vision Expand file tree Collapse file tree 1 file changed +13
-13
lines changed Original file line number Diff line number Diff line change 2727import torch
2828from pytorch_lightning .callbacks import ModelCheckpoint
2929from pytorch_lightning .loggers import TensorBoardLogger
30-
31- # ensure data and module are on the path
32- sys .path .append ("." )
33-
3430from torchx .examples .apps .lightning_classy_vision .data import (
3531 TinyImageNetDataModule ,
36- download_data ,
3732 create_random_data ,
33+ download_data ,
3834)
3935from torchx .examples .apps .lightning_classy_vision .model import (
4036 TinyImageNetModel ,
4137 export_inference_model ,
4238)
43- from torchx .examples .apps .lightning_classy_vision .profiler import (
44- SimpleLoggingProfiler ,
45- )
39+ from torchx .examples .apps .lightning_classy_vision .profiler import SimpleLoggingProfiler
40+
41+
42+ # ensure data and module are on the path
43+ sys .path .append ("." )
4644
4745
4846def parse_args (argv : List [str ]) -> argparse .Namespace :
@@ -84,10 +82,6 @@ def parse_args(argv: List[str]) -> argparse.Namespace:
8482 return parser .parse_args (argv )
8583
8684
87- def get_gpu_devices () -> int :
88- return torch .cuda .device_count ()
89-
90-
9185def get_model_checkpoint (args : argparse .Namespace ) -> Optional [ModelCheckpoint ]:
9286 if not args .output_path :
9387 return None
@@ -138,10 +132,16 @@ def main(argv: List[str]) -> None:
138132 # Initialize a trainer
139133 num_nodes = int (os .environ .get ("GROUP_WORLD_SIZE" , 1 ))
140134 num_processes = int (os .environ .get ("LOCAL_WORLD_SIZE" , 1 ))
135+
136+ if torch .cuda .is_available ():
137+ gpus = num_processes
138+ else :
139+ gpus = None
140+
141141 trainer = pl .Trainer (
142142 num_nodes = num_nodes ,
143143 num_processes = num_processes ,
144- gpus = get_gpu_devices () ,
144+ gpus = gpus ,
145145 accelerator = "ddp" ,
146146 logger = logger ,
147147 max_epochs = args .epochs ,
You can’t perform that action at this time.
0 commit comments