Skip to content

Commit 89253e9

Browse files
BiologicalExplosionhuzhan
andauthored
Support Cambricon MLU (#6964)
Co-authored-by: huzhan <huzhan@cambricon.com>
1 parent 3ea3bc8 commit 89253e9

File tree

2 files changed

+50
-1
lines changed

2 files changed

+50
-1
lines changed

README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,13 @@ For models compatible with Ascend Extension for PyTorch (torch_npu). To get star
260260
3. Next, install the necessary packages for torch-npu by adhering to the platform-specific instructions on the [Installation](https://ascend.github.io/docs/sources/pytorch/install.html#pytorch) page.
261261
4. Finally, adhere to the [ComfyUI manual installation](#manual-install-windows-linux) guide for Linux. Once all components are installed, you can run ComfyUI as described earlier.
262262

263+
#### Cambricon MLUs
264+
265+
For models compatible with Cambricon Extension for PyTorch (torch_mlu). Here's a step-by-step guide tailored to your platform and installation method:
266+
267+
1. Install the Cambricon CNToolkit by adhering to the platform-specific instructions on the [Installation](https://www.cambricon.com/docs/sdk_1.15.0/cntoolkit_3.7.2/cntoolkit_install_3.7.2/index.html)
268+
2. Next, install the PyTorch(torch_mlu) following the instructions on the [Installation](https://www.cambricon.com/docs/sdk_1.15.0/cambricon_pytorch_1.17.0/user_guide_1.9/index.html)
269+
3. Launch ComfyUI by running `python main.py --listen`
263270

264271
# Running
265272

comfy/model_management.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,13 @@ class CPUState(Enum):
9595
except:
9696
npu_available = False
9797

98+
try:
99+
import torch_mlu # noqa: F401
100+
_ = torch.mlu.device_count()
101+
mlu_available = torch.mlu.is_available()
102+
except:
103+
mlu_available = False
104+
98105
if args.cpu:
99106
cpu_state = CPUState.CPU
100107

@@ -112,6 +119,12 @@ def is_ascend_npu():
112119
return True
113120
return False
114121

122+
def is_mlu():
123+
global mlu_available
124+
if mlu_available:
125+
return True
126+
return False
127+
115128
def get_torch_device():
116129
global directml_enabled
117130
global cpu_state
@@ -127,6 +140,8 @@ def get_torch_device():
127140
return torch.device("xpu", torch.xpu.current_device())
128141
elif is_ascend_npu():
129142
return torch.device("npu", torch.npu.current_device())
143+
elif is_mlu():
144+
return torch.device("mlu", torch.mlu.current_device())
130145
else:
131146
return torch.device(torch.cuda.current_device())
132147

@@ -153,6 +168,12 @@ def get_total_memory(dev=None, torch_total_too=False):
153168
_, mem_total_npu = torch.npu.mem_get_info(dev)
154169
mem_total_torch = mem_reserved
155170
mem_total = mem_total_npu
171+
elif is_mlu():
172+
stats = torch.mlu.memory_stats(dev)
173+
mem_reserved = stats['reserved_bytes.all.current']
174+
_, mem_total_mlu = torch.mlu.mem_get_info(dev)
175+
mem_total_torch = mem_reserved
176+
mem_total = mem_total_mlu
156177
else:
157178
stats = torch.cuda.memory_stats(dev)
158179
mem_reserved = stats['reserved_bytes.all.current']
@@ -232,7 +253,7 @@ def is_amd():
232253
if torch_version_numeric[0] >= 2:
233254
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
234255
ENABLE_PYTORCH_ATTENTION = True
235-
if is_intel_xpu() or is_ascend_npu():
256+
if is_intel_xpu() or is_ascend_npu() or is_mlu():
236257
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
237258
ENABLE_PYTORCH_ATTENTION = True
238259
except:
@@ -316,6 +337,8 @@ def get_torch_device_name(device):
316337
return "{} {}".format(device, torch.xpu.get_device_name(device))
317338
elif is_ascend_npu():
318339
return "{} {}".format(device, torch.npu.get_device_name(device))
340+
elif is_mlu():
341+
return "{} {}".format(device, torch.mlu.get_device_name(device))
319342
else:
320343
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
321344

@@ -905,6 +928,8 @@ def xformers_enabled():
905928
return False
906929
if is_ascend_npu():
907930
return False
931+
if is_mlu():
932+
return False
908933
if directml_enabled:
909934
return False
910935
return XFORMERS_IS_AVAILABLE
@@ -936,6 +961,8 @@ def pytorch_attention_flash_attention():
936961
return True
937962
if is_ascend_npu():
938963
return True
964+
if is_mlu():
965+
return True
939966
if is_amd():
940967
return True #if you have pytorch attention enabled on AMD it probably supports at least mem efficient attention
941968
return False
@@ -984,6 +1011,13 @@ def get_free_memory(dev=None, torch_free_too=False):
9841011
mem_free_npu, _ = torch.npu.mem_get_info(dev)
9851012
mem_free_torch = mem_reserved - mem_active
9861013
mem_free_total = mem_free_npu + mem_free_torch
1014+
elif is_mlu():
1015+
stats = torch.mlu.memory_stats(dev)
1016+
mem_active = stats['active_bytes.all.current']
1017+
mem_reserved = stats['reserved_bytes.all.current']
1018+
mem_free_mlu, _ = torch.mlu.mem_get_info(dev)
1019+
mem_free_torch = mem_reserved - mem_active
1020+
mem_free_total = mem_free_mlu + mem_free_torch
9871021
else:
9881022
stats = torch.cuda.memory_stats(dev)
9891023
mem_active = stats['active_bytes.all.current']
@@ -1053,6 +1087,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
10531087
if is_ascend_npu():
10541088
return True
10551089

1090+
if is_mlu():
1091+
return True
1092+
10561093
if torch.version.hip:
10571094
return True
10581095

@@ -1121,6 +1158,11 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
11211158
return False
11221159

11231160
props = torch.cuda.get_device_properties(device)
1161+
1162+
if is_mlu():
1163+
if props.major > 3:
1164+
return True
1165+
11241166
if props.major >= 8:
11251167
return True
11261168

0 commit comments

Comments
 (0)