Skip to content

Commit f78fc10

Browse files
committed
Use torchruntime's device utilities
1 parent 9d5a371 commit f78fc10

File tree

4 files changed

+7
-10
lines changed

4 files changed

+7
-10
lines changed

sdkit/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,9 @@ def __init__(self) -> None:
5858
https://github.com/sczhou/CodeFormer/blob/master/LICENSE
5959
"""
6060

61-
from sdkit.utils import get_torch_platform
61+
from torchruntime.utils import get_installed_torch_platform
6262

63-
self.device = get_torch_platform()[0]
63+
self.device = get_installed_torch_platform()[0]
6464

6565
# hacky approach, but we need to enforce full precision for some devices
6666
# we also need to force full precision for these devices (haven't implemented this yet):
@@ -73,7 +73,7 @@ def device(self):
7373
def device(self, d):
7474
self._device = d
7575

76-
from sdkit.utils import get_device
76+
from torchruntime.utils import get_device
7777

7878
if d.split(":")[0] in ("cpu", "mps"):
7979
from sdkit.utils import log

sdkit/utils/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,6 @@
4444
)
4545
from .device_utils import (
4646
has_amd_gpu,
47-
get_torch_platform,
48-
get_device_count,
49-
get_device_name,
50-
get_device,
5147
mem_get_info,
5248
memory_allocated,
5349
memory_stats,

sdkit/utils/memory_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ def gc(context: Context):
2020

2121

2222
def get_device_usage(device, log_info=False, process_usage_only=True, log_prefix=""):
23-
import torch
24-
from sdkit.utils import log, is_cpu_device, get_device, mem_get_info, memory_allocated, memory_stats
23+
from torchruntime.utils import get_device
24+
from sdkit.utils import log, is_cpu_device, mem_get_info, memory_allocated, memory_stats
2525

2626
if isinstance(device, str):
2727
device = get_device(device)
@@ -120,7 +120,7 @@ def get_tensors_in_memory(device):
120120
prevent garbage-collection of all the tensors in memory.**
121121
"""
122122
import torch
123-
from .device_utils import get_device
123+
from torchruntime.utils import get_device
124124

125125
if isinstance(device, str):
126126
device = get_device(device)

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,6 @@
1717
"controlnet-aux==0.0.6",
1818
"invisible-watermark==0.2.0", # required for SD XL
1919
"huggingface_hub==0.21.4",
20+
"torchruntime>=1.7.0",
2021
],
2122
)

0 commit comments

Comments
 (0)