Skip to content

Commit a7fbc63

Browse files
committed
Merge branch 'main' into metal-tensor-fix-send-sync
2 parents 0bbf9c7 + 0950959 commit a7fbc63

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

59 files changed

+3405
-3770
lines changed

candle-core/src/cpu/erf.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ mod evaluate {
77
//! Provides functions that don't have a numerical solution and must
88
//! be solved computationally (e.g. evaluation of a polynomial)
99
10-
/// evaluates a polynomial at `z` where `coeff` are the coeffecients
10+
/// evaluates a polynomial at `z` where `coeff` are the coefficients
1111
/// to a polynomial of order `k` where `k` is the length of `coeff` and the
1212
/// coeffecient
1313
/// to the `k`th power is the `k`th element in coeff. E.g. [3,-1,2] equates to

candle-core/src/custom_op.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@ pub struct UgIOp1 {
381381
#[cfg(feature = "cuda")]
382382
func: cudarc::driver::CudaFunction,
383383
#[cfg(feature = "metal")]
384-
func: candle_metal_kernels::metal_utils::ComputePipeline,
384+
func: candle_metal_kernels::metal::ComputePipeline,
385385
}
386386

387387
impl UgIOp1 {

candle-core/src/metal_backend/device.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use crate::{DType, Result};
22
use candle_metal_kernels::{
3-
metal_utils::{
3+
metal::{
44
Buffer, BufferMap, CommandBuffer, Commands, ComputePipeline, Device, MTLResourceOptions,
55
},
66
Kernels,

candle-core/src/metal_backend/mod.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose1D, ParamsConvT
55
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
66
use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape};
77
use candle_metal_kernels::{
8-
metal_utils::{Buffer, Commands, Device, MTLResourceOptions},
8+
metal::{Buffer, Commands, Device, MTLResourceOptions},
99
BufferOffset, CallConvTranspose2dCfg, Kernels,
1010
};
1111
use objc2_foundation::NSRange;
@@ -1832,7 +1832,7 @@ impl MetalStorage {
18321832
let lhs = buffer_o(&self.buffer, lhs_l, self.dtype);
18331833
let rhs = buffer_o(&rhs.buffer, rhs_l, rhs.dtype);
18341834
let (buffer, dtype) = if lhs_l.is_contiguous() && rhs_l.is_contiguous() && &op[..1] != "b" {
1835-
use candle_metal_kernels::binary::contiguous;
1835+
use candle_metal_kernels::kernels::binary::contiguous;
18361836

18371837
let (kernel_name, dtype) = match (op, self.dtype) {
18381838
("add", DType::F32) => (contiguous::add::FLOAT, self.dtype),
@@ -1919,7 +1919,7 @@ impl MetalStorage {
19191919
.map_err(MetalError::from)?;
19201920
(buffer, dtype)
19211921
} else {
1922-
use candle_metal_kernels::binary::strided;
1922+
use candle_metal_kernels::kernels::binary::strided;
19231923

19241924
let (kernel_name, dtype) = match (op, self.dtype) {
19251925
("badd", DType::F32) => (strided::add::FLOAT, self.dtype),

candle-core/src/quantized/metal.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use super::{GgmlDType, QStorage};
22
use crate::backend::BackendStorage;
33
use crate::{DType, MetalDevice, MetalStorage, Result, Shape, D};
4-
use candle_metal_kernels::metal_utils::Buffer;
4+
use candle_metal_kernels::metal::Buffer;
55
use std::sync::Arc;
66

77
pub struct QMetalStorage {

candle-core/tests/pth_tests.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ fn test_pth_with_key() {
1414
}
1515

1616
#[test]
17-
fn test_pth_fortran_congiguous() {
17+
fn test_pth_fortran_contiguous() {
1818
let tensors =
1919
candle_core::pickle::PthTensors::new("tests/fortran_tensor_3d.pth", None).unwrap();
2020
let tensor = tensors.get("tensor_fortran").unwrap().unwrap();

candle-metal-kernels/examples/metal_benchmarks.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use anyhow::Result;
22
use candle_metal_kernels::{
3-
metal_utils::{create_command_buffer, Device},
3+
metal::{create_command_buffer, Device},
44
GemmDType,
55
};
66
/// This example contains some simple benchmarks so that it's easy to run them in perf etc.

candle-metal-kernels/src/err.rs

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
use crate::kernels::sdpa::SdpaDType;
2+
3+
#[derive(thiserror::Error, Debug)]
4+
pub enum MetalKernelError {
5+
#[error("Could not lock kernel map: {0}")]
6+
LockError(String),
7+
#[error("Error while loading library: {0}")]
8+
LoadLibraryError(String),
9+
#[error("Error while loading function: {0}")]
10+
LoadFunctionError(String),
11+
#[error("Failed to create compute function")]
12+
FailedToCreateComputeFunction,
13+
#[error("Failed to create metal resource: {0}")]
14+
FailedToCreateResource(String),
15+
#[error("Failed to create pipeline")]
16+
FailedToCreatePipeline(String),
17+
#[error("Invalid matmul arguments {lhs_stride:?} {rhs_stride:?} {mnk:?}")]
18+
MatMulNonContiguous {
19+
lhs_stride: Vec<usize>,
20+
rhs_stride: Vec<usize>,
21+
mnk: (usize, usize, usize),
22+
},
23+
#[error("Sdpa {variation} head size was {got}, expectd {expected:?}")]
24+
SdpaHeadSizeMismatch {
25+
variation: &'static str,
26+
got: usize,
27+
expected: Vec<usize>,
28+
},
29+
#[error("Sdpa {variation} got dtype {got:?}")]
30+
SdpaHeadDTypeMismatch {
31+
variation: &'static str,
32+
got: SdpaDType,
33+
},
34+
}
35+
36+
impl<T> From<std::sync::PoisonError<T>> for MetalKernelError {
37+
fn from(e: std::sync::PoisonError<T>) -> Self {
38+
Self::LockError(e.to_string())
39+
}
40+
}

candle-metal-kernels/src/kernel.rs

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
use crate::source::{
2+
AFFINE, BINARY, CAST, CONV, FILL, INDEXING, MLX_GEMM, MLX_SORT, QUANTIZED, RANDOM, REDUCE,
3+
SDPA, SORT, TERNARY, UNARY,
4+
};
5+
use crate::{
6+
ComputePipeline, ConstantValues, Device, Function, Library, MTLCompileOptions, MTLMathMode,
7+
MetalKernelError, Source,
8+
};
9+
use std::collections::HashMap;
10+
use std::sync::RwLock;
11+
12+
#[derive(Debug, Clone)]
13+
pub enum KernelName {
14+
Ref(&'static str),
15+
Value(String),
16+
}
17+
18+
impl AsRef<str> for KernelName {
19+
fn as_ref(&self) -> &str {
20+
match self {
21+
Self::Ref(r) => r,
22+
Self::Value(v) => v.as_str(),
23+
}
24+
}
25+
}
26+
27+
impl std::hash::Hash for KernelName {
28+
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
29+
match self {
30+
Self::Ref(r) => r.hash(state),
31+
Self::Value(v) => v.hash(state),
32+
}
33+
}
34+
}
35+
36+
impl PartialEq for KernelName {
37+
fn eq(&self, other: &Self) -> bool {
38+
let v1: &str = self.as_ref();
39+
let v2: &str = other.as_ref();
40+
v1 == v2
41+
}
42+
}
43+
44+
impl Eq for KernelName {}
45+
46+
impl From<&'static str> for KernelName {
47+
fn from(value: &'static str) -> Self {
48+
Self::Ref(value)
49+
}
50+
}
51+
52+
impl From<String> for KernelName {
53+
fn from(value: String) -> Self {
54+
Self::Value(value)
55+
}
56+
}
57+
58+
type Libraries = HashMap<Source, Library>;
59+
type Pipelines = HashMap<(KernelName, Option<ConstantValues>), ComputePipeline>;
60+
61+
#[derive(Debug)]
62+
pub struct Kernels {
63+
libraries: RwLock<Libraries>,
64+
pipelines: RwLock<Pipelines>,
65+
}
66+
67+
impl Default for Kernels {
68+
fn default() -> Self {
69+
Self::new()
70+
}
71+
}
72+
73+
impl Kernels {
74+
pub fn new() -> Self {
75+
let libraries = RwLock::new(Libraries::new());
76+
let pipelines = RwLock::new(Pipelines::new());
77+
Self {
78+
libraries,
79+
pipelines,
80+
}
81+
}
82+
83+
fn get_library_source(&self, source: Source) -> &'static str {
84+
match source {
85+
Source::Affine => AFFINE,
86+
Source::Binary => BINARY,
87+
Source::Cast => CAST,
88+
Source::Conv => CONV,
89+
Source::Fill => FILL,
90+
Source::Gemm => MLX_GEMM,
91+
Source::Indexing => INDEXING,
92+
Source::MlxSort => MLX_SORT,
93+
Source::Quantized => QUANTIZED,
94+
Source::Random => RANDOM,
95+
Source::Reduce => REDUCE,
96+
Source::Sort => SORT,
97+
Source::Ternary => TERNARY,
98+
Source::Unary => UNARY,
99+
Source::Sdpa => SDPA,
100+
}
101+
}
102+
103+
/// Load the give library from its [`source`].
104+
/// If this has been previously loaded it will just fetch it from cache.
105+
pub fn load_library(
106+
&self,
107+
device: &Device,
108+
source: Source,
109+
) -> Result<Library, MetalKernelError> {
110+
let mut libraries = self.libraries.write()?;
111+
if let Some(lib) = libraries.get(&source) {
112+
Ok(lib.clone())
113+
} else {
114+
let lib = {
115+
let source_content = self.get_library_source(source);
116+
let compile_options = MTLCompileOptions::new();
117+
//unsafe { compile_options.setEnableLogging(true) };
118+
unsafe { compile_options.setMathMode(MTLMathMode::Fast) };
119+
device
120+
.new_library_with_source(source_content, Some(&compile_options))
121+
.map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?
122+
};
123+
libraries.insert(source, lib.clone());
124+
Ok(lib)
125+
}
126+
}
127+
128+
fn load_function(
129+
&self,
130+
device: &Device,
131+
source: Source,
132+
name: &str,
133+
constants: Option<&ConstantValues>,
134+
) -> Result<Function, MetalKernelError> {
135+
let func = self
136+
.load_library(device, source)?
137+
.get_function(name, constants)
138+
.map_err(|e| MetalKernelError::LoadFunctionError(e.to_string()))?;
139+
Ok(func)
140+
}
141+
142+
/// Load the give pipeline
143+
/// loads the library from source, then gets the function [`name`] from
144+
/// that source
145+
pub fn load_pipeline_with_constants(
146+
&self,
147+
device: &Device,
148+
source: Source,
149+
name: impl Into<KernelName>,
150+
constants: Option<ConstantValues>,
151+
) -> Result<ComputePipeline, MetalKernelError> {
152+
let mut pipelines = self.pipelines.write()?;
153+
let key = (name.into(), constants);
154+
if let Some(pipeline) = pipelines.get(&key) {
155+
Ok(pipeline.clone())
156+
} else {
157+
let (name, constants) = key;
158+
let func = self.load_function(device, source, name.as_ref(), constants.as_ref())?;
159+
let pipeline = device
160+
.new_compute_pipeline_state_with_function(&func)
161+
.map_err(|e| MetalKernelError::FailedToCreatePipeline(e.to_string()))?;
162+
pipelines.insert((name, constants), pipeline.clone());
163+
164+
Ok(pipeline)
165+
}
166+
}
167+
168+
/// Load the give pipeline
169+
/// loads the library from source, then gets the function [`name`] from
170+
/// that source (without constants)
171+
pub fn load_pipeline(
172+
&self,
173+
device: &Device,
174+
source: Source,
175+
name: impl Into<KernelName>,
176+
) -> Result<ComputePipeline, MetalKernelError> {
177+
self.load_pipeline_with_constants(device, source, name, None)
178+
}
179+
}

0 commit comments

Comments
 (0)