@@ -2,10 +2,13 @@ use crate::source::{
22 AFFINE , BINARY , CAST , CONV , FILL , INDEXING , MLX_GEMM , MLX_SORT , QUANTIZED , RANDOM , REDUCE ,
33 SDPA , SORT , TERNARY , UNARY ,
44} ;
5+ use crate :: utils:: get_env_bool;
56use crate :: {
6- ComputePipeline , ConstantValues , Device , Function , Library , MTLCompileOptions , MTLMathMode ,
7- MetalKernelError , Source ,
7+ ComputePipeline , ConstantValues , Device , Function , Library , MTLCompileOptions ,
8+ MTLMathFloatingPointFunctions , MTLMathMode , MetalKernelError , Source ,
89} ;
10+ use objc2:: available;
11+ use objc2:: rc:: Retained ;
912use std:: collections:: HashMap ;
1013use std:: sync:: RwLock ;
1114
@@ -113,9 +116,7 @@ impl Kernels {
113116 } else {
114117 let lib = {
115118 let source_content = self . get_library_source ( source) ;
116- let compile_options = MTLCompileOptions :: new ( ) ;
117- //unsafe { compile_options.setEnableLogging(true) };
118- compile_options. setMathMode ( MTLMathMode :: Fast ) ;
119+ let compile_options = get_compile_options ( ) ;
119120 device
120121 . new_library_with_source ( source_content, Some ( & compile_options) )
121122 . map_err ( |e| MetalKernelError :: LoadLibraryError ( e. to_string ( ) ) ) ?
@@ -176,3 +177,26 @@ impl Kernels {
176177 self . load_pipeline_with_constants ( device, source, name, None )
177178 }
178179}
180+
181+ fn get_compile_options ( ) -> Retained < MTLCompileOptions > {
182+ let compile_options = MTLCompileOptions :: new ( ) ;
183+ //unsafe { compile_options.setEnableLogging(true) };
184+
185+ let fast_math_enabled = get_env_bool ( "CANDLE_METAL_ENABLE_FAST_MATH" , true ) ;
186+ // Ref availability:
187+ // https://developer.apple.com/documentation/metal/mtlcompileoptions/mathmode
188+ if available ! ( macos = 15 , ios = 18 ) {
189+ if fast_math_enabled {
190+ compile_options. setMathMode ( MTLMathMode :: Fast ) ;
191+ compile_options. setMathFloatingPointFunctions ( MTLMathFloatingPointFunctions :: Fast ) ;
192+ } else {
193+ compile_options. setMathMode ( MTLMathMode :: Relaxed ) ;
194+ compile_options. setMathFloatingPointFunctions ( MTLMathFloatingPointFunctions :: Precise ) ;
195+ }
196+ } else {
197+ // For older OS versions we use the old api
198+ #[ allow( deprecated) ]
199+ compile_options. setFastMathEnabled ( fast_math_enabled) ;
200+ }
201+ compile_options
202+ }
0 commit comments