Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions comfy/ldm/flux/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,15 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
else:
device = pos.device

scale = torch.linspace(0, (dim - 2) / dim, steps=dim//2, dtype=torch.float64, device=device)
omega = 1.0 / (theta**scale)
if device.type == "musa":
# XXX (MUSA): Unsupported tensor dtype in Neg: Double
scale = torch.linspace(0, (dim - 2) / dim, steps=dim//2, dtype=torch.float32, device=device)
if not isinstance(theta, torch.Tensor):
theta = torch.tensor(theta, dtype=torch.float32, device=device)
omega = torch.exp(-scale * torch.log(theta + 1e-6))
else:
scale = torch.linspace(0, (dim - 2) / dim, steps=dim//2, dtype=torch.float64, device=device)
omega = 1.0 / (theta**scale)
out = torch.einsum("...n,d->...nd", pos.to(dtype=torch.float32, device=device), omega)
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
Expand Down
77 changes: 45 additions & 32 deletions comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,34 +139,37 @@ def get_supported_float8_types():
except:
ixuca_available = False

try:
import torchada # noqa: F401
musa_available = hasattr(torch, "musa") and torch.musa.is_available()
except:
musa_available = False

if args.cpu:
cpu_state = CPUState.CPU

def is_intel_xpu():
global cpu_state
global xpu_available
if cpu_state == CPUState.GPU:
if xpu_available:
return True
return xpu_available
return False

def is_ascend_npu():
global npu_available
if npu_available:
return True
return False
return npu_available

def is_mlu():
global mlu_available
if mlu_available:
return True
return False
return mlu_available

def is_ixuca():
global ixuca_available
if ixuca_available:
return True
return False
return ixuca_available

def is_musa():
global musa_available
return musa_available

def get_torch_device():
global directml_enabled
Expand Down Expand Up @@ -311,7 +314,7 @@ def amd_min_version(device=None, min_rdna_version=0):
return False

MIN_WEIGHT_MEMORY_RATIO = 0.4
if is_nvidia():
if is_nvidia() or is_musa():
MIN_WEIGHT_MEMORY_RATIO = 0.0

ENABLE_PYTORCH_ATTENTION = False
Expand All @@ -320,7 +323,7 @@ def amd_min_version(device=None, min_rdna_version=0):
XFORMERS_IS_AVAILABLE = False

try:
if is_nvidia():
if is_nvidia() or is_musa():
if torch_version_numeric[0] >= 2:
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
ENABLE_PYTORCH_ATTENTION = True
Expand Down Expand Up @@ -375,7 +378,7 @@ def amd_min_version(device=None, min_rdna_version=0):

PRIORITIZE_FP16 = False # TODO: remove and replace with something that shows exactly which dtype is faster than the other
try:
if (is_nvidia() or is_amd()) and PerformanceFeature.Fp16Accumulation in args.fast:
if (is_nvidia() or is_amd() or is_musa()) and PerformanceFeature.Fp16Accumulation in args.fast:
torch.backends.cuda.matmul.allow_fp16_accumulation = True
PRIORITIZE_FP16 = True # TODO: limit to cards where it actually boosts performance
logging.info("Enabled fp16 accumulation.")
Expand Down Expand Up @@ -1020,7 +1023,7 @@ def force_channels_last():
NUM_STREAMS = args.async_offload
else:
# Enable by default on Nvidia and AMD
if is_nvidia() or is_amd():
if is_nvidia() or is_amd() or is_musa():
NUM_STREAMS = 2

if args.disable_async_offload:
Expand Down Expand Up @@ -1117,7 +1120,7 @@ def cast_to_device(tensor, device, dtype, copy=False):
TOTAL_PINNED_MEMORY = 0
MAX_PINNED_MEMORY = -1
if not args.disable_pinned_memory:
if is_nvidia() or is_amd():
if is_nvidia() or is_amd() or is_musa():
if WINDOWS:
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.45 # Windows limit is apparently 50%
else:
Expand Down Expand Up @@ -1261,6 +1264,8 @@ def pytorch_attention_flash_attention():
return True #if you have pytorch attention enabled on AMD it probably supports at least mem efficient attention
if is_ixuca():
return True
if is_musa():
return True
return False

def force_upcast_attention_dtype():
Expand Down Expand Up @@ -1392,6 +1397,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
if torch.version.hip:
return True

if is_musa():
return True

props = torch.cuda.get_device_properties(device)
if props.major >= 8:
return True
Expand Down Expand Up @@ -1462,6 +1470,9 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
return True
return False

if is_musa():
return True

props = torch.cuda.get_device_properties(device)

if is_mlu():
Expand All @@ -1484,25 +1495,27 @@ def supports_fp8_compute(device=None):
if SUPPORT_FP8_OPS:
return True

if not is_nvidia():
return False

props = torch.cuda.get_device_properties(device)
if props.major >= 9:
return True
if props.major < 8:
return False
if props.minor < 9:
return False

if torch_version_numeric < (2, 3):
return False
if is_nvidia():
if props.major >= 9:
return True
if props.major < 8:
return False
if props.minor < 9:
return False

if WINDOWS:
if torch_version_numeric < (2, 4):
if torch_version_numeric < (2, 3):
return False

return True
if WINDOWS:
if torch_version_numeric < (2, 4):
return False

elif is_musa():
if props.major >= 3:
return True

return False

def extended_fp16_support():
# TODO: check why some models work with fp16 on newer torch versions but not on older
Expand Down Expand Up @@ -1543,7 +1556,7 @@ def unload_all_models():
free_memory(1e30, get_torch_device())

def debug_memory_summary():
if is_amd() or is_nvidia():
if is_amd() or is_nvidia() or is_musa():
return torch.cuda.memory.memory_summary()
return ""

Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@ kornia>=0.7.1
spandrel
pydantic~=2.0
pydantic-settings~=2.0
torchada>=0.1.11