33from __future__ import annotations
44
55import sys
6- from typing import NoReturn
76
87import numpy as np
98import torch
1211from tiatoolbox import logger
1312
1413
15- def is_torch_compile_compatible () -> NoReturn :
14+ def is_torch_compile_compatible () -> bool :
1615 """Check if the current GPU is compatible with torch-compile.
1716
17+ Returns:
18+ True if current GPU is compatible with torch-compile, False otherwise.
19+
1820 Raises:
1921 Warning if GPU is not compatible with `torch.compile`.
2022
2123 """
24+ gpu_compatibility = True
2225 if torch .cuda .is_available (): # pragma: no cover
2326 device_cap = torch .cuda .get_device_capability ()
2427 if device_cap not in ((7 , 0 ), (8 , 0 ), (9 , 0 )):
@@ -28,13 +31,17 @@ def is_torch_compile_compatible() -> NoReturn:
2831 "Speedup numbers may be lower than expected." ,
2932 stacklevel = 2 ,
3033 )
34+ gpu_compatibility = False
3135 else :
3236 logger .warning (
3337 "No GPU detected or cuda not installed, "
3438 "torch.compile is only supported on selected NVIDIA GPUs. "
3539 "Speedup numbers may be lower than expected." ,
3640 stacklevel = 2 ,
3741 )
42+ gpu_compatibility = False
43+
44+ return gpu_compatibility
3845
3946
4047def compile_model (
@@ -68,12 +75,24 @@ def compile_model(
6875 return model
6976
7077 # Check if GPU is compatible with torch.compile
71- is_torch_compile_compatible ()
78+ gpu_compatibility = is_torch_compile_compatible ()
79+
80+ if not gpu_compatibility :
81+ return model
82+
83+ if sys .platform == "win32" : # pragma: no cover
84+ msg = (
85+ "`torch.compile` is not supported on Windows. Please see "
86+ "https://github.com/pytorch/pytorch/issues/122094."
87+ )
88+ logger .warning (msg = msg )
89+ return model
7290
7391 # This check will be removed when torch.compile is supported in Python 3.12+
7492 if sys .version_info > (3 , 12 ): # pragma: no cover
93+ msg = "torch-compile is currently not supported in Python 3.12+."
7594 logger .warning (
76- ( "torch-compile is currently not supported in Python 3.12+. " ,) ,
95+ msg = msg ,
7796 )
7897 return model
7998
0 commit comments