|
11 | 11 | def nccl_import(message: Optional[str] = None) -> str: |
12 | 12 | nccl_test = ( |
13 | 13 | # detect if nccl is available and the user is expecting it to be used |
14 | | - # CuPy must be checked first otherwise util.find_spec assumes it presents and check nccl immediately and lead to crash |
15 | | - util.find_spec("cupy") is not None and util.find_spec("cupy.cuda.nccl") is not None and int(os.getenv("NCCL_PYLOPS_MPI", 1)) == 1 |
| 14 | + # cupy.cuda.nccl comes with cupy installation so check the cupy |
| 15 | + util.find_spec("cupy") is not None and int(os.getenv("NCCL_PYLOPS_MPI", 1)) == 1 |
16 | 16 | ) |
17 | 17 | if nccl_test: |
18 | | - # try importing it |
19 | | - try: |
20 | | - import_module("cupy.cuda.nccl") # noqa: F401 |
21 | | - |
22 | | - # if succesful, set the message to None |
| 18 | + # if cupy is present, this import will not throw error. The NCCL existence is checked with nccl.avaiable |
| 19 | + import cupy.cuda.nccl as nccl |
| 20 | + if nccl.available: |
| 21 | + # if succesfull, set the message to None |
23 | 22 | nccl_message = None |
24 | | - # if unable to import but the package is installed |
25 | | - except (ImportError, ModuleNotFoundError) as e: |
| 23 | + else: |
| 24 | + # if unable to import but the package is installed |
26 | 25 | nccl_message = ( |
27 | | - f"Fail to import cupy.cuda.nccl, Falling back to pure MPI (error: {e})." |
28 | | - "Please ensure your CUDA NCCL environment is set up correctly " |
29 | | - "for more detials visit 'https://docs.cupy.dev/en/stable/install.html'" |
30 | | - ) |
| 26 | + f"cupy is installed but cupy.cuda.nccl not available, Falling back to pure MPI." |
| 27 | + "Please ensure your CUDA NCCL environment is set up correctly " |
| 28 | + "for more details visit 'https://docs.cupy.dev/en/stable/install.html'" |
| 29 | + ) |
31 | 30 | print(UserWarning(nccl_message)) |
32 | 31 | else: |
33 | 32 | nccl_message = ( |
|
0 commit comments