1313import numpy as np
1414from transformers import is_torch_npu_available
1515
16+ try :
17+ import torch_musa
18+ except Exception :
19+ pass
20+
1621logger = logging .getLogger (__name__ )
1722
1823
@@ -106,6 +111,8 @@ def get_target_devices(devices: Union[str, int, List[str], List[int]]) -> List[s
106111 return [f"cuda:{ i } " for i in range (torch .cuda .device_count ())]
107112 elif is_torch_npu_available ():
108113 return [f"npu:{ i } " for i in range (torch .npu .device_count ())]
114+ elif hasattr (torch , "musa" ) and torch .musa .is_available ():
115+ return [f"musa:{ i } " for i in range (torch .musa .device_count ())]
109116 elif torch .backends .mps .is_available ():
110117 try :
111118 return [f"mps:{ i } " for i in range (torch .mps .device_count ())]
@@ -116,12 +123,18 @@ def get_target_devices(devices: Union[str, int, List[str], List[int]]) -> List[s
116123 elif isinstance (devices , str ):
117124 return [devices ]
118125 elif isinstance (devices , int ):
119- return [f"cuda:{ devices } " ]
126+ if hasattr (torch , "musa" ) and torch .musa .is_available ():
127+ return [f"musa:{ devices } " ]
128+ else :
129+ return [f"cuda:{ devices } " ]
120130 elif isinstance (devices , list ):
121131 if isinstance (devices [0 ], str ):
122132 return devices
123133 elif isinstance (devices [0 ], int ):
124- return [f"cuda:{ device } " for device in devices ]
134+ if hasattr (torch , "musa" ) and torch .musa .is_available ():
135+ return [f"musa:{ device } " for device in devices ]
136+ else :
137+ return [f"cuda:{ device } " for device in devices ]
125138 else :
126139 raise ValueError ("devices should be a string or an integer or a list of strings or a list of integers." )
127140 else :
0 commit comments