Skip to content

Commit 555d554

Browse files
NoezorNoé Pionjb-ye
authored
Improve default launch device for train (#3523)
* feat: improve launch device * add licence header * crash instead of trying to set the available option * set cuda as default --------- Co-authored-by: Noé Pion <[email protected]> Co-authored-by: J.Y. <[email protected]>
1 parent a8888e7 commit 555d554

File tree

2 files changed

+42
-0
lines changed

2 files changed

+42
-0
lines changed

nerfstudio/scripts/train.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
from nerfstudio.configs.method_configs import AnnotatedBaseConfigUnion
6363
from nerfstudio.engine.trainer import TrainerConfig
6464
from nerfstudio.utils import comms, profiler
65+
from nerfstudio.utils.available_devices import get_available_devices
6566
from nerfstudio.utils.rich_utils import CONSOLE
6667

6768
DEFAULT_TIMEOUT = timedelta(minutes=30)
@@ -226,6 +227,15 @@ def launch(
226227
def main(config: TrainerConfig) -> None:
227228
"""Main function."""
228229

230+
# Check if the specified device type is available
231+
available_device_types = get_available_devices()
232+
if config.machine.device_type not in available_device_types:
233+
raise RuntimeError(
234+
f"Specified device type '{config.machine.device_type}' is not available. "
235+
f"Available device types: {available_device_types}. "
236+
"Please specify a valid device type using the CLI option: --machine.device_type [cuda|mps|cpu]"
237+
)
238+
229239
if config.data:
230240
CONSOLE.log("Using --data alias for --data.pipeline.datamanager.data")
231241
config.pipeline.datamanager.data = config.data
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import List, Literal
16+
17+
import torch
18+
19+
20+
def get_available_devices() -> List[Literal["cpu", "cuda", "mps"]]:
21+
"""Determine the available devices on the machine
22+
23+
Returns:
24+
list: List of available device types
25+
"""
26+
available_devices: List[Literal["cpu", "cuda", "mps"]] = []
27+
if torch.cuda.is_available():
28+
available_devices.append("cuda")
29+
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
30+
available_devices.append("mps")
31+
available_devices.append("cpu")
32+
return available_devices

0 commit comments

Comments
 (0)