Skip to content

Commit 2664a21

Browse files
authored
[Metal] Make fast math mode optional (#3205)
* Add ability to toggle fast math mode in metal. Chose how to apply based on os version. * Move available macro and friends to utils * Isolate #[allow(deprecated)] to the actually deprecated method * doc * Use objc2::available macro instead
1 parent 08d7b64 commit 2664a21

File tree

3 files changed

+43
-7
lines changed

3 files changed

+43
-7
lines changed

candle-metal-kernels/src/kernel.rs

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,13 @@ use crate::source::{
22
AFFINE, BINARY, CAST, CONV, FILL, INDEXING, MLX_GEMM, MLX_SORT, QUANTIZED, RANDOM, REDUCE,
33
SDPA, SORT, TERNARY, UNARY,
44
};
5+
use crate::utils::get_env_bool;
56
use crate::{
6-
ComputePipeline, ConstantValues, Device, Function, Library, MTLCompileOptions, MTLMathMode,
7-
MetalKernelError, Source,
7+
ComputePipeline, ConstantValues, Device, Function, Library, MTLCompileOptions,
8+
MTLMathFloatingPointFunctions, MTLMathMode, MetalKernelError, Source,
89
};
10+
use objc2::available;
11+
use objc2::rc::Retained;
912
use std::collections::HashMap;
1013
use std::sync::RwLock;
1114

@@ -113,9 +116,7 @@ impl Kernels {
113116
} else {
114117
let lib = {
115118
let source_content = self.get_library_source(source);
116-
let compile_options = MTLCompileOptions::new();
117-
//unsafe { compile_options.setEnableLogging(true) };
118-
compile_options.setMathMode(MTLMathMode::Fast);
119+
let compile_options = get_compile_options();
119120
device
120121
.new_library_with_source(source_content, Some(&compile_options))
121122
.map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?
@@ -176,3 +177,26 @@ impl Kernels {
176177
self.load_pipeline_with_constants(device, source, name, None)
177178
}
178179
}
180+
181+
fn get_compile_options() -> Retained<MTLCompileOptions> {
182+
let compile_options = MTLCompileOptions::new();
183+
//unsafe { compile_options.setEnableLogging(true) };
184+
185+
let fast_math_enabled = get_env_bool("CANDLE_METAL_ENABLE_FAST_MATH", true);
186+
// Ref availability:
187+
// https://developer.apple.com/documentation/metal/mtlcompileoptions/mathmode
188+
if available!(macos = 15, ios = 18) {
189+
if fast_math_enabled {
190+
compile_options.setMathMode(MTLMathMode::Fast);
191+
compile_options.setMathFloatingPointFunctions(MTLMathFloatingPointFunctions::Fast);
192+
} else {
193+
compile_options.setMathMode(MTLMathMode::Relaxed);
194+
compile_options.setMathFloatingPointFunctions(MTLMathFloatingPointFunctions::Precise);
195+
}
196+
} else {
197+
// For older OS versions we use the old api
198+
#[allow(deprecated)]
199+
compile_options.setFastMathEnabled(fast_math_enabled);
200+
}
201+
compile_options
202+
}

candle-metal-kernels/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ use metal::{
1616
BlitCommandEncoder, Buffer, CommandQueue, ComputeCommandEncoder, ComputePipeline,
1717
ConstantValues, Device, Function, Library, MTLResourceOptions, Value,
1818
};
19-
use objc2_metal::{MTLCompileOptions, MTLMathMode, MTLSize};
19+
use objc2_metal::{MTLCompileOptions, MTLMathFloatingPointFunctions, MTLMathMode, MTLSize};
2020
use source::Source;
2121
pub use utils::BufferOffset;
2222
use utils::{get_block_dims, linear_split, EncoderParam, EncoderProvider};

candle-metal-kernels/src/utils.rs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use crate::metal::{Buffer, CommandBuffer, ComputeCommandEncoder, ComputePipeline};
2-
use objc2_metal::MTLSize;
2+
use crate::MTLSize;
3+
use std::ffi::OsStr;
34
use std::ops::Deref;
45
use std::sync::{RwLockReadGuard, RwLockWriteGuard};
56

@@ -236,3 +237,14 @@ impl<'a, T> From<RwLockWriteGuard<'a, T>> for RwLockGuard<'a, T> {
236237
RwLockGuard::Write(g)
237238
}
238239
}
240+
241+
fn is_truthy(s: String) -> bool {
242+
match s.as_str() {
243+
"true" | "t" | "yes" | "y" | "1" => true,
244+
_ => false,
245+
}
246+
}
247+
248+
pub(crate) fn get_env_bool<K: AsRef<OsStr>>(key: K, default: bool) -> bool {
249+
std::env::var(key).map(is_truthy).unwrap_or(default)
250+
}

0 commit comments

Comments
 (0)