Skip to content

Commit fe48ffa

Browse files
author
Noé Pion
committed
crash instead of trying to set the available option
1 parent ae396b6 commit fe48ffa

File tree

3 files changed

+21
-18
lines changed

3 files changed

+21
-18
lines changed

nerfstudio/configs/base_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ class MachineConfig(PrintableConfig):
6868
"""current machine's rank (for DDP)"""
6969
dist_url: str = "auto"
7070
"""distributed connection point (for DDP)"""
71-
device_type: Literal["cpu", "cuda", "mps"] | None = None
72-
"""device type to use for training. If none set, script will do its best to set the value."""
71+
device_type: Literal["cpu", "cuda", "mps"]
72+
"""device type to use for training"""
7373

7474

7575
@dataclass

nerfstudio/scripts/train.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
from nerfstudio.configs.method_configs import AnnotatedBaseConfigUnion
6363
from nerfstudio.engine.trainer import TrainerConfig
6464
from nerfstudio.utils import comms, profiler
65-
from nerfstudio.utils.best_device import get_best_device
65+
from nerfstudio.utils.available_devices import get_available_devices
6666
from nerfstudio.utils.rich_utils import CONSOLE
6767

6868
DEFAULT_TIMEOUT = timedelta(minutes=30)
@@ -227,6 +227,15 @@ def launch(
227227
def main(config: TrainerConfig) -> None:
228228
"""Main function."""
229229

230+
# Check if the specified device type is available
231+
available_device_types = get_available_devices()
232+
if config.machine.device_type not in available_device_types:
233+
raise RuntimeError(
234+
f"Specified device type '{config.machine.device_type}' is not available. "
235+
f"Available device types: {available_device_types}. "
236+
"Please specify a valid device type using the CLI option: --machine.device_type [cuda|mps|cpu]"
237+
)
238+
230239
if config.data:
231240
CONSOLE.log("Using --data alias for --data.pipeline.datamanager.data")
232241
config.pipeline.datamanager.data = config.data
@@ -239,12 +248,6 @@ def main(config: TrainerConfig) -> None:
239248
CONSOLE.log(f"Loading pre-set config from: {config.load_config}")
240249
config = yaml.load(config.load_config.read_text(), Loader=yaml.Loader)
241250

242-
if not hasattr(config.machine, "device_type") or config.machine.device_type is None:
243-
device_type, reason = get_best_device()
244-
config.machine.device_type = device_type
245-
CONSOLE.log(f"No device specified: {reason}")
246-
CONSOLE.log("You can force a certain device type with --machine.device_type [cuda|mps|cpu]")
247-
248251
config.set_timestamp()
249252

250253
# print and save config
Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,21 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Literal, Tuple
15+
from typing import List, Literal
1616

1717
import torch
1818

1919

20-
def get_best_device() -> Tuple[Literal["cpu", "cuda", "mps"], str]:
21-
"""Determine the best available device to run nerfstudio inference.
20+
def get_available_devices() -> List[Literal["cpu", "cuda", "mps"]]:
21+
"""Determine the available devices on the machine
2222
2323
Returns:
24-
tuple: (device_type, reason) where device_type is the selected device and
25-
reason is an explanation of why it was chosen
24+
list: List of available device types
2625
"""
26+
available_devices: List[Literal["cpu", "cuda", "mps"]] = []
2727
if torch.cuda.is_available():
28-
return "cuda", "CUDA GPU available - using for optimal performance"
28+
available_devices.append("cuda")
2929
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
30-
return "mps", "Apple Metal (MPS) available - using for accelerated performance"
31-
else:
32-
return "cpu", "No GPU/MPS detected - falling back to CPU"
30+
available_devices.append("mps")
31+
available_devices.append("cpu")
32+
return available_devices

0 commit comments

Comments
 (0)