@@ -112,7 +112,7 @@ def get_target_devices(devices: Union[str, int, List[str], List[int]]) -> List[s
112112 return [f"cuda:{ i } " for i in range (torch .cuda .device_count ())]
113113 elif is_torch_npu_available ():
114114 return [f"npu:{ i } " for i in range (torch .npu .device_count ())]
115- elif torch .musa .is_available ():
115+ elif hasattr ( torch , "musa" ) and torch .musa .is_available ():
116116 return [f"musa:{ i } " for i in range (torch .musa .device_count ())]
117117 elif torch .backends .mps .is_available ():
118118 return ["mps" ]
@@ -121,15 +121,15 @@ def get_target_devices(devices: Union[str, int, List[str], List[int]]) -> List[s
121121 elif isinstance (devices , str ):
122122 return [devices ]
123123 elif isinstance (devices , int ):
124- if torch .musa .is_available ():
124+ if hasattr ( torch , "musa" ) and torch .musa .is_available ():
125125 return [f"musa:{ devices } " ]
126126 else :
127127 return [f"cuda:{ devices } " ]
128128 elif isinstance (devices , list ):
129129 if isinstance (devices [0 ], str ):
130130 return devices
131131 elif isinstance (devices [0 ], int ):
132- if torch .musa .is_available ():
132+ if hasattr ( torch , "musa" ) and torch .musa .is_available ():
133133 return [f"musa:{ device } " for device in devices ]
134134 else :
135135 return [f"cuda:{ device } " for device in devices ]
0 commit comments