Skip to content

Commit 3a0964d

Browse files
authored
Merge pull request #51 from mgiammar/update_device_handling
Update ttsim3d version and match device handling behavior
2 parents 60f1af9 + af6afa5 commit 3a0964d

File tree

4 files changed

+64
-149
lines changed

4 files changed

+64
-149
lines changed

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ classifiers = [
3333
"Programming Language :: Python :: 3.10",
3434
"Programming Language :: Python :: 3.11",
3535
"Programming Language :: Python :: 3.12",
36+
"Programming Language :: Python :: 3.13",
3637
"Typing :: Typed",
3738
]
3839
# add your package dependencies here
@@ -50,7 +51,7 @@ dependencies = [
5051
"torch-fourier-slice>=v0.2.0",
5152
"torch-fourier-filter>=v0.2.3",
5253
"torch-so3>=v0.2.0",
53-
"ttsim3d>=v0.3.0",
54+
"ttsim3d>=v0.4.0",
5455
"lmfit",
5556
"zenodo-get",
5657
]
Lines changed: 56 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,51 @@
11
"""Computational configuration for 2DTM."""
22

3-
from typing import Annotated
3+
from typing import Annotated, Optional, Union
44

55
import torch
6-
from pydantic import BaseModel, Field, field_validator
6+
from pydantic import BaseModel, Field
7+
8+
# Type alias for non-negative integer
9+
NonNegativeInt = Annotated[int, Field(ge=0)]
710

811

912
class ComputationalConfig(BaseModel):
1013
"""Serialization of computational resources allocated for 2DTM.
1114
15+
NOTE: The field `gpu_ids` is not validated at instantiation past being one of the
16+
valid types. For example, if "cuda:0" is specified but no CUDA device is available,
17+
the instantiation will succeed, and only upon translating `gpu_ids` to a list of
18+
`torch.device` objects will an error be raised. This is done to allow for
19+
configuration files to be loaded without requiring the actual hardware to be
20+
present at the time of loading.
21+
1222
Attributes
1323
----------
14-
gpu_ids : list[int]
15-
Which GPU(s) to use for computation, defaults to 0 which will use device at
16-
index 0. A value of -2 or less corresponds to CPU device. A value of -1 will
17-
use all available GPUs.
24+
gpu_ids : Optional[Union[int, list[int], str, list[str]]]
25+
Field which specifies which GPUs to use for computation. The following types
26+
of values are allowed:
27+
- A single integer, e.g. 0, which means to use GPU with ID 0.
28+
- A list of integers, e.g. [0, 2], which means to use GPUs with IDs 0 and 2.
29+
- A device specifier string, e.g. "cuda:0", which means to use GPU with ID 0.
30+
- A list of device specifier strings, e.g. ["cuda:0", "cuda:1"], which means to
31+
use GPUs with IDs 0 and 1.
32+
- The specific string "all" which means to use all available GPUs identified
33+
by torch.cuda.device_count().
34+
- The specific string "cpu" which means to use CPU.
1835
num_cpus : int
1936
Total number of CPUs to use, defaults to 1.
2037
"""
2138

22-
gpu_ids: int | list[int] = [0]
23-
num_cpus: Annotated[int, Field(ge=1)] = 1
24-
25-
@field_validator("gpu_ids") # type: ignore
26-
def validate_gpu_ids(cls, v): # pylint: disable=no-self-argument
27-
"""Validate input value for GPU ids."""
28-
if isinstance(v, int):
29-
v = [v]
30-
31-
# Check if -1 appears, it is only value in list
32-
if -1 in v and len(v) > 1:
33-
raise ValueError(
34-
"If -1 (all GPUs) is in the list, it must be the only value."
35-
)
36-
37-
# Check if -2 appears, it is only value in list
38-
if -2 in v and len(v) > 1:
39-
raise ValueError("If -2 (CPU) is in the list, it must be the only value.")
40-
41-
return v
39+
# Type-hinting here is ensuring non-negative integers, and list of at least one
40+
gpu_ids: Optional[
41+
Union[
42+
str,
43+
NonNegativeInt,
44+
Annotated[list[NonNegativeInt], Field(min_length=1)],
45+
Annotated[list[str], Field(min_length=1)],
46+
]
47+
] = [0]
48+
num_cpus: NonNegativeInt = 1
4249

4350
@property
4451
def gpu_devices(self) -> list[torch.device]:
@@ -48,13 +55,29 @@ def gpu_devices(self) -> list[torch.device]:
4855
-------
4956
list[torch.device]
5057
"""
51-
# Case where gpu_ids is integer
52-
if isinstance(self.gpu_ids, int):
53-
self.gpu_ids = [self.gpu_ids]
54-
55-
if -1 in self.gpu_ids:
58+
# Handle special string cases first
59+
if self.gpu_ids == "all":
60+
if not torch.cuda.is_available():
61+
raise ValueError("No CUDA devices available.")
5662
return [torch.device(f"cuda:{i}") for i in range(torch.cuda.device_count())]
57-
if -2 in self.gpu_ids:
63+
64+
if self.gpu_ids == "cpu":
5865
return [torch.device("cpu")]
5966

60-
return [torch.device(f"cuda:{gpu_id}") for gpu_id in self.gpu_ids]
67+
# Normalize to list for uniform processing
68+
gpu_list = self.gpu_ids if isinstance(self.gpu_ids, list) else [self.gpu_ids]
69+
70+
# Process each item in the normalized list
71+
devices = []
72+
for gpu_id in gpu_list:
73+
if isinstance(gpu_id, int):
74+
devices.append(torch.device(f"cuda:{gpu_id}"))
75+
elif isinstance(gpu_id, str):
76+
devices.append(torch.device(gpu_id))
77+
else:
78+
raise TypeError(
79+
f"Invalid type for gpu_ids element: {type(gpu_id)}. "
80+
"Expected int or str."
81+
)
82+
83+
return devices

src/leopard_em/pydantic_models/managers/optimize_template_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def make_backend_core_function_kwargs(
7171
Whether to use refined angles or not. Defaults to True.
7272
"""
7373
# simulate template volume
74-
template = self.simulator.run(gpu_ids=self.computational_config.gpu_ids)
74+
template = self.simulator.run(device=self.computational_config.gpu_ids)
7575

7676
# The set of "best" euler angles from match template search
7777
# Check if refined angles exist, otherwise use the original angles
Lines changed: 5 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""Tests for the ComputationalConfig model"""
22

33
import pytest
4-
import torch
54
from pydantic import ValidationError
65

76
from leopard_em.pydantic_models.config import ComputationalConfig
@@ -18,122 +17,14 @@ def test_default_values():
1817
assert config.num_cpus == 1
1918

2019

21-
def test_single_gpu_id():
22-
"""
23-
Test that a single integer gpu_id is converted to a list.
24-
25-
Verifies that when passing a single integer as gpu_ids, it's converted to a list
26-
containing that integer.
27-
"""
28-
config = ComputationalConfig(gpu_ids=1)
29-
assert config.gpu_ids == [1]
30-
31-
32-
def test_multiple_gpu_ids():
33-
"""
34-
Test that multiple gpu_ids are correctly stored as a list.
35-
36-
Verifies that when passing a list of gpu_ids, they are correctly stored in the
37-
config.
38-
"""
39-
config = ComputationalConfig(gpu_ids=[0, 1, 2])
40-
assert config.gpu_ids == [0, 1, 2]
41-
42-
43-
def test_all_gpus():
44-
"""
45-
Test the special value -1 for gpu_ids.
46-
47-
Verifies that when passing -1 as gpu_ids, it's stored correctly, which indicates
48-
using all available GPUs.
49-
"""
50-
config = ComputationalConfig(gpu_ids=-1)
51-
assert config.gpu_ids == [-1]
52-
53-
54-
def test_cpu_only():
55-
"""
56-
Test the special value -2 for gpu_ids.
57-
58-
Verifies that when passing -2 as gpu_ids, it's stored correctly, which indicates
59-
using CPU only.
60-
"""
61-
config = ComputationalConfig(gpu_ids=-2)
62-
assert config.gpu_ids == [-2]
63-
64-
6520
def test_invalid_gpu_ids():
6621
"""
67-
Test that invalid combinations of gpu_ids raise errors.
68-
69-
Verifies that special values -1 and -2 cannot be combined with other gpu IDs.
70-
"""
71-
with pytest.raises(ValueError, match="If -1"):
72-
ComputationalConfig(gpu_ids=[-1, 0])
73-
74-
with pytest.raises(ValueError, match="If -2"):
75-
ComputationalConfig(gpu_ids=[-2, 0])
76-
77-
78-
def test_num_cpus():
79-
"""
80-
Test that num_cpus is correctly stored.
22+
Test invalid gpu_ids values.
8123
82-
Verifies that the specified number of CPUs is correctly stored in the config.
83-
"""
84-
config = ComputationalConfig(num_cpus=4)
85-
assert config.num_cpus == 4
86-
87-
88-
def test_invalid_num_cpus():
89-
"""
90-
Test that invalid num_cpus values raise errors.
91-
92-
Verifies that specifying zero or negative values for num_cpus raises a
93-
ValidationError.
24+
Verifies that a ValidationError is raised when invalid gpu_ids are provided.
9425
"""
9526
with pytest.raises(ValidationError):
96-
ComputationalConfig(num_cpus=0)
97-
98-
99-
def test_gpu_devices_single_gpu():
100-
"""
101-
Test the gpu_devices property for a single GPU configuration.
102-
103-
Verifies that the correct torch.device object is created for a single GPU.
104-
"""
105-
config = ComputationalConfig(gpu_ids=0)
106-
assert config.gpu_devices == [torch.device("cuda:0")]
107-
108-
109-
def test_gpu_devices_multiple_gpus():
110-
"""
111-
Test the gpu_devices property for multiple GPUs configuration.
27+
ComputationalConfig(gpu_ids=[-1]) # Negative GPU ID is invalid
11228

113-
Verifies that the correct torch.device objects are created for multiple GPUs.
114-
"""
115-
config = ComputationalConfig(gpu_ids=[0, 1])
116-
assert config.gpu_devices == [torch.device("cuda:0"), torch.device("cuda:1")]
117-
118-
119-
def test_gpu_devices_all_gpus(monkeypatch):
120-
"""
121-
Test the gpu_devices property when using all available GPUs.
122-
123-
Uses monkeypatch to set a fixed number of GPUs for testing, and verifies
124-
that the correct device objects are created for all available GPUs.
125-
"""
126-
monkeypatch.setattr(torch.cuda, "device_count", lambda: 2)
127-
config = ComputationalConfig(gpu_ids=-1)
128-
assert config.gpu_devices == [torch.device("cuda:0"), torch.device("cuda:1")]
129-
130-
131-
def test_gpu_devices_cpu():
132-
"""
133-
Test the gpu_devices property in CPU-only mode.
134-
135-
Verifies that the correct torch.device object for CPU is created when using
136-
CPU-only mode.
137-
"""
138-
config = ComputationalConfig(gpu_ids=-2)
139-
assert config.gpu_devices == [torch.device("cpu")]
29+
with pytest.raises(ValidationError):
30+
ComputationalConfig(gpu_ids=[]) # Empty list

0 commit comments

Comments
 (0)