Skip to content

Commit 7cb170d

Browse files
committed
Add conditional install for cuda
1 parent 0fdf3f7 commit 7cb170d

File tree

2 files changed

+19
-2
lines changed

2 files changed

+19
-2
lines changed

install.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,25 @@ def pip_install(package: str) -> None:
1212
)
1313

1414

15+
def try_get_cuda_version() -> str | None:
16+
try:
17+
import torch
18+
return torch.version.cuda
19+
except ImportError or AttributeError:
20+
return None
21+
22+
1523
def main() -> None:
16-
pip_install("-e .")
24+
25+
cuda_version = try_get_cuda_version()
26+
if cuda_version is not None:
27+
if cuda_version.startswith("12."):
28+
pip_install("-e .[cuda-12]")
29+
else:
30+
pip_install("-e .[cuda]")
31+
else:
32+
# Default install
33+
pip_install("-e .[cpu]")
1734

1835

1936
if __name__ == "__main__":

src/inference_core_nodes/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
__all__ = ("__version__", "NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS")
22

3-
__version__ = "0.2.0"
3+
__version__ = "0.2.1"
44

55

66
def _get_node_mappings():

0 commit comments

Comments
 (0)