Skip to content

Commit 44e5525

Browse files
authored
Merge pull request #1350 from qiyulei-mt/musa_support
support musa backend in FlagEmbedding
2 parents fdc6786 + 6c9dba5 commit 44e5525

File tree

2 files changed

+30
-4
lines changed

2 files changed

+30
-4
lines changed

FlagEmbedding/abc/inference/AbsEmbedder.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@
1313
import numpy as np
1414
from transformers import is_torch_npu_available
1515

16+
try:
17+
import torch_musa
18+
except Exception:
19+
pass
20+
1621
logger = logging.getLogger(__name__)
1722

1823

@@ -106,6 +111,8 @@ def get_target_devices(devices: Union[str, int, List[str], List[int]]) -> List[s
106111
return [f"cuda:{i}" for i in range(torch.cuda.device_count())]
107112
elif is_torch_npu_available():
108113
return [f"npu:{i}" for i in range(torch.npu.device_count())]
114+
elif hasattr(torch, "musa") and torch.musa.is_available():
115+
return [f"musa:{i}" for i in range(torch.musa.device_count())]
109116
elif torch.backends.mps.is_available():
110117
try:
111118
return [f"mps:{i}" for i in range(torch.mps.device_count())]
@@ -116,12 +123,18 @@ def get_target_devices(devices: Union[str, int, List[str], List[int]]) -> List[s
116123
elif isinstance(devices, str):
117124
return [devices]
118125
elif isinstance(devices, int):
119-
return [f"cuda:{devices}"]
126+
if hasattr(torch, "musa") and torch.musa.is_available():
127+
return [f"musa:{devices}"]
128+
else:
129+
return [f"cuda:{devices}"]
120130
elif isinstance(devices, list):
121131
if isinstance(devices[0], str):
122132
return devices
123133
elif isinstance(devices[0], int):
124-
return [f"cuda:{device}" for device in devices]
134+
if hasattr(torch, "musa") and torch.musa.is_available():
135+
return [f"musa:{device}" for device in devices]
136+
else:
137+
return [f"cuda:{device}" for device in devices]
125138
else:
126139
raise ValueError("devices should be a string or an integer or a list of strings or a list of integers.")
127140
else:

FlagEmbedding/abc/inference/AbsReranker.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@
1212
from tqdm import tqdm, trange
1313
from transformers import is_torch_npu_available
1414

15+
try:
16+
import torch_musa
17+
except Exception:
18+
pass
19+
1520
logger = logging.getLogger(__name__)
1621

1722

@@ -107,19 +112,27 @@ def get_target_devices(devices: Union[str, int, List[str], List[int]]) -> List[s
107112
return [f"cuda:{i}" for i in range(torch.cuda.device_count())]
108113
elif is_torch_npu_available():
109114
return [f"npu:{i}" for i in range(torch.npu.device_count())]
115+
elif hasattr(torch, "musa") and torch.musa.is_available():
116+
return [f"musa:{i}" for i in range(torch.musa.device_count())]
110117
elif torch.backends.mps.is_available():
111118
return ["mps"]
112119
else:
113120
return ["cpu"]
114121
elif isinstance(devices, str):
115122
return [devices]
116123
elif isinstance(devices, int):
117-
return [f"cuda:{devices}"]
124+
if hasattr(torch, "musa") and torch.musa.is_available():
125+
return [f"musa:{devices}"]
126+
else:
127+
return [f"cuda:{devices}"]
118128
elif isinstance(devices, list):
119129
if isinstance(devices[0], str):
120130
return devices
121131
elif isinstance(devices[0], int):
122-
return [f"cuda:{device}" for device in devices]
132+
if hasattr(torch, "musa") and torch.musa.is_available():
133+
return [f"musa:{device}" for device in devices]
134+
else:
135+
return [f"cuda:{device}" for device in devices]
123136
else:
124137
raise ValueError("devices should be a string or an integer or a list of strings or a list of integers.")
125138
else:

0 commit comments

Comments
 (0)