55Features:
66- PyTorch CUDAExtension for reliable NVCC compilation
77- Automatic fallback to CPU if CUDA/ROCm unavailable
8- - MAX_JOBS control to prevent OOM on smaller instances
8+ - Smart Architecture Detection: Compiles only for the active GPU to save RAM/Time
9+ - MAX_JOBS control to prevent OOM
910"""
1011
1112import os
1819# VERSION
1920# ============================================================================
2021
21- VERSION = "4.2.5 "
22+ VERSION = "4.2.6 "
2223
2324# ============================================================================
2425# PRE-FLIGHT CHECKS
2526# ============================================================================
2627
27- # Control parallelism to prevent OOM
28- os .environ ["MAX_JOBS" ] = os .environ .get ("MAX_JOBS" , "4 " )
28+ # Default to serial build to prevent OOM on Colab/Free tiers
29+ os .environ ["MAX_JOBS" ] = os .environ .get ("MAX_JOBS" , "1 " )
2930
3031def log (msg : str , level : str = "INFO" ) -> None :
3132 print (f"[CRAYON-BUILD] { msg } " , flush = True )
@@ -49,6 +50,45 @@ def log(msg: str, level: str = "INFO") -> None:
4950HAS_ROCM = os .path .exists (os .path .join (ROCM_HOME , "bin" , "hipcc" ))
5051
5152
53+ # ============================================================================
54+ # ARCHITECTURE SELECTION
55+ # ============================================================================
56+
57+ def get_cuda_arch_flags ():
58+ """
59+ Determine the best CUDA architecture flags.
60+ If CRAYON_GENERIC_BUILD=1, build for all common architectures (for PyPI wheels).
61+ Otherwise, build ONLY for the detected GPU (faster, less RAM).
62+ """
63+ base_flags = ["-O3" , "-std=c++17" , "--expt-relaxed-constexpr" ]
64+
65+ # Generic build for distribution (Wheel)
66+ if os .environ .get ("CRAYON_GENERIC_BUILD" , "0" ) == "1" :
67+ log ("Building for ALL common CUDA architectures (Generic Wheel)" )
68+ return base_flags + [
69+ "-gencode=arch=compute_70,code=sm_70" , # V100
70+ "-gencode=arch=compute_75,code=sm_75" , # T4
71+ "-gencode=arch=compute_80,code=sm_80" , # A100
72+ "-gencode=arch=compute_86,code=sm_86" , # RTX 3090
73+ "-gencode=arch=compute_90,code=sm_90" , # H100
74+ ]
75+
76+ # Local build (Colab/User Machine)
77+ if TORCH_CUDA_AVAILABLE :
78+ try :
79+ major , minor = torch .cuda .get_device_capability ()
80+ arch = f"{ major } { minor } "
81+ log (f"Detected GPU: SM { major } .{ minor } -> Compiling for sm_{ arch } ONLY" )
82+ return base_flags + [f"-gencode=arch=compute_{ arch } ,code=sm_{ arch } " ]
83+ except Exception as e :
84+ log (f"Error detecting GPU capability: { e } . Falling back to common archs." )
85+
86+ # Fallback if detection fails or no GPU present (but CUDA_HOME exists)
87+ return base_flags + [
88+ "-gencode=arch=compute_75,code=sm_75" , # T4 (Safe default for Colab)
89+ ]
90+
91+
5292# ============================================================================
5393# EXTENSION CONFIGURATION
5494# ============================================================================
@@ -73,24 +113,18 @@ def log(msg: str, level: str = "INFO") -> None:
73113
74114# --- 2. CUDA Extension (via PyTorch) ---
75115if TORCH_CUDA_AVAILABLE and not FORCE_CPU and CUDAExtension :
76- log (f"Configuring CUDA extension (PyTorch { torch .__version__ } , CUDA { torch .version .cuda } )" )
116+ nvcc_flags = get_cuda_arch_flags ()
117+ log (f"Configuring CUDA extension (max_jobs={ os .environ ['MAX_JOBS' ]} )" )
118+
77119 ext_modules .append (CUDAExtension (
78120 name = "crayon.c_ext.crayon_cuda" ,
79121 sources = ["src/crayon/c_ext/gpu_engine_cuda.cu" ],
80122 extra_compile_args = {
81123 "cxx" : ["-O3" , "-std=c++17" ],
82- "nvcc" : [
83- "-O3" , "-std=c++17" ,
84- "--expt-relaxed-constexpr" ,
85- # Broad architecture support
86- "-gencode=arch=compute_70,code=sm_70" ,
87- "-gencode=arch=compute_75,code=sm_75" ,
88- "-gencode=arch=compute_80,code=sm_80" ,
89- "-gencode=arch=compute_86,code=sm_86" ,
90- "-gencode=arch=compute_90,code=sm_90" ,
91- ],
124+ "nvcc" : nvcc_flags ,
92125 },
93126 ))
127+
94128elif not FORCE_CPU and CUDAExtension :
95129 log ("Skipping CUDA extension (PyTorch CUDA not found or CUDA_HOME missing)" )
96130
0 commit comments