Skip to content

Commit 59adaa0

Browse files
authored
Merge pull request #223 from oxideai/update/version-0.23.0
Update/version 0.23.0
2 parents 45e4985 + 159086a commit 59adaa0

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+2694
-680
lines changed

Cargo.toml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[workspace.package]
22
# All but mlx-sys should follow the same version. mlx-sys should follow
33
# the version of mlx-c.
4-
version = "0.21.2"
4+
version = "0.23.0"
55
edition = "2021"
66
authors = [
77
"Minghua Wu <michael.wu1107@gmail.com>",
@@ -28,10 +28,10 @@ resolver = "2"
2828

2929
[workspace.dependencies]
3030
# workspace local dependencies
31-
mlx-sys = { version = "=0.1.0", path = "mlx-sys" }
32-
mlx-macros = { version = "0.21", path = "mlx-macros" }
33-
mlx-internal-macros = { version = "0.21", path = "mlx-internal-macros" }
34-
mlx-rs = { version = "0.21.2", path = "mlx-rs" }
31+
mlx-sys = { version = "=0.1.2-release", path = "mlx-sys" }
32+
mlx-macros = { version = "0.23", path = "mlx-macros" }
33+
mlx-internal-macros = { version = "0.23", path = "mlx-internal-macros" }
34+
mlx-rs = { version = "0.23", path = "mlx-rs" }
3535

3636
# external dependencies
3737
thiserror = "1"

examples/mistral/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ mlx-rs.workspace = true
1212
tokenizers = "=0.21.0" # 0.21.1 uses features that went stable in 1.82 while our MSRV is 1.81
1313
thiserror = "1.0"
1414
anyhow = "1.0"
15-
hf-hub = "=0.4.1" # 0.4.1 use features that went stable in 1.82 while our MSRV is 1.81
15+
hf-hub = "=0.4.1" # 0.4.2 uses features that went stable in 1.82 while our MSRV is 1.81
1616
dotenv = "0.15"
1717
serde = { version = "1", features = ["derive"] }
1818
serde_json = "1"

mlx-internal-macros/Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,5 @@ proc-macro = true
1717
syn.workspace = true
1818
quote.workspace = true
1919
darling.workspace = true
20-
proc-macro2.workspace = true
20+
proc-macro2.workspace = true
21+
itertools.workspace = true
Lines changed: 347 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,347 @@
1+
use darling::FromMeta;
2+
use itertools::Itertools;
3+
use proc_macro::TokenStream;
4+
use quote::quote;
5+
use syn::{FnArg, Ident, ItemFn, Meta};
6+
7+
const CUSTOM_ATTRIBUTE_OPTIONAL: &str = "optional";
8+
const CUSTOM_ATTRIBUTE_NAMED: &str = "named";
9+
10+
const CUSTOM_ATTRIBUTES: &[&str] = &[CUSTOM_ATTRIBUTE_OPTIONAL, CUSTOM_ATTRIBUTE_NAMED];
11+
12+
#[derive(Default, Debug, FromMeta)]
13+
#[darling(default)]
14+
struct Customize {
15+
root: Option<syn::LitStr>,
16+
default_dtype: Option<syn::Path>,
17+
}
18+
19+
fn arg_type(attrs: &[syn::Attribute]) -> ArgType {
20+
for attr in attrs {
21+
if attr.path().is_ident(CUSTOM_ATTRIBUTE_OPTIONAL) {
22+
return ArgType::NamedOptional;
23+
} else if attr.path().is_ident(CUSTOM_ATTRIBUTE_NAMED) {
24+
return ArgType::Named;
25+
}
26+
}
27+
ArgType::Positional
28+
}
29+
30+
fn remove_attribute(attrs: &mut Vec<syn::Attribute>, targets: &[&str]) {
31+
attrs.retain(|attr| !targets.iter().any(|target| !attr.path().is_ident(target)));
32+
}
33+
34+
/// Remove "$" prefix from the string
35+
fn remove_prefix_from_str(s: &str) -> String {
36+
s.trim_start_matches("$").to_string()
37+
}
38+
39+
pub fn expand_generate_macro(
40+
attr: Option<Meta>,
41+
mut item: ItemFn, // The original function should be kept as is
42+
) -> Result<TokenStream, syn::Error> {
43+
let customize = match attr {
44+
Some(attr) => Customize::from_meta(&attr).map_err(|e| syn::Error::new_spanned(attr, e))?,
45+
None => Customize::default(),
46+
};
47+
48+
// The mod path where the function can be accessed publicly
49+
let (fn_mod_path, doc_mod_path) = match customize.root {
50+
Some(lit_str) => {
51+
let tokens: proc_macro2::TokenStream = lit_str.parse()?;
52+
let s = remove_prefix_from_str(&lit_str.value());
53+
(quote! { #tokens }, s)
54+
}
55+
None => (quote! { $crate::ops }, "crate::ops".into()),
56+
};
57+
58+
let (default_generics, dtype_generics) =
59+
handle_generic_args(&item.sig.generics, &customize.default_dtype);
60+
61+
let args = item
62+
.sig
63+
.inputs
64+
.iter_mut()
65+
.map(|arg| match arg {
66+
FnArg::Receiver(_) => Err(syn::Error::new_spanned(arg, "self is not allowed")),
67+
FnArg::Typed(pat_type) => Ok(pat_type),
68+
})
69+
.collect::<Result<Vec<_>, _>>()?;
70+
71+
let mut parsed_args = parse_args(args);
72+
73+
// Check if the last optional argument is `stream`
74+
if let Some(arg) = parsed_args.last() {
75+
if arg.ident != "stream" {
76+
return Err(syn::Error::new_spanned(
77+
&item,
78+
"the last optional argument must be `stream`",
79+
));
80+
}
81+
}
82+
// Remove the last optional argument `stream`
83+
parsed_args.pop();
84+
85+
// Remove "_device" suffix from the macro name if it exists
86+
let fn_ident = &item.sig.ident;
87+
88+
let generated = generate_macro(
89+
&fn_mod_path,
90+
&doc_mod_path,
91+
fn_ident,
92+
&parsed_args,
93+
&default_generics,
94+
&dtype_generics,
95+
)?;
96+
97+
let output = quote! {
98+
#item
99+
#generated
100+
};
101+
102+
Ok(output.into())
103+
}
104+
105+
/// If there are generic arguments, the last argument is assumed to be `dtype`.
106+
///
107+
/// Returns two `syn::Generics`:
108+
/// 1. With the last argument set to `f32`
109+
/// 2. With the last argument set to `$dtype`
110+
fn handle_generic_args(
111+
generic_args: &syn::Generics,
112+
default_dtype: &Option<syn::Path>,
113+
) -> (proc_macro2::TokenStream, Option<proc_macro2::TokenStream>) {
114+
// Count number of generic type arguments
115+
let count = generic_args
116+
.params
117+
.iter()
118+
.filter(|param| matches!(param, syn::GenericParam::Type(_)))
119+
.count();
120+
121+
if count == 0 {
122+
return (quote! {}, None);
123+
}
124+
125+
// All generics arguments except for the last one will be inferred
126+
let infer_tokens = vec![quote! { _ }; count - 1];
127+
128+
let default_generics = match default_dtype {
129+
Some(path) => quote! { ::<#(#infer_tokens,)* #path> },
130+
None => quote! { ::<#(#infer_tokens,)* f32> },
131+
};
132+
let dtype_generics = quote! { ::<#(#infer_tokens,)* $dtype> };
133+
134+
(default_generics, Some(dtype_generics))
135+
}
136+
137+
#[derive(Debug, Clone, Copy)]
138+
enum ArgType {
139+
Positional,
140+
Named,
141+
NamedOptional,
142+
}
143+
144+
struct Arg {
145+
ident: Ident,
146+
arg_type: ArgType,
147+
}
148+
149+
fn parse_args(args: Vec<&mut syn::PatType>) -> Vec<Arg> {
150+
let mut is_prev_optional = false;
151+
let mut parsed = Vec::new();
152+
for arg in args {
153+
match &*arg.pat {
154+
syn::Pat::Ident(ident) => {
155+
let arg_type = arg_type(&arg.attrs);
156+
157+
let is_positional = matches!(arg_type, ArgType::Positional);
158+
if is_prev_optional && is_positional {
159+
panic!("positional argument cannot follow an optional argument");
160+
}
161+
is_prev_optional = matches!(arg_type, ArgType::NamedOptional);
162+
163+
parsed.push(Arg {
164+
ident: ident.ident.clone(),
165+
arg_type,
166+
});
167+
}
168+
_ => panic!("unsupported pattern"),
169+
}
170+
171+
remove_attribute(&mut arg.attrs, CUSTOM_ATTRIBUTES);
172+
}
173+
parsed
174+
}
175+
176+
fn generate_macro(
177+
fn_mod_path: &proc_macro2::TokenStream,
178+
doc_mod_path: &str,
179+
fn_ident: &Ident,
180+
args: &[Arg],
181+
default_generics: &proc_macro2::TokenStream,
182+
dtype_generics: &Option<proc_macro2::TokenStream>,
183+
) -> Result<proc_macro2::TokenStream, syn::Error> {
184+
let mut trimmed_fn_ident_str = fn_ident.to_string();
185+
if trimmed_fn_ident_str.ends_with("_device") {
186+
trimmed_fn_ident_str = trimmed_fn_ident_str.trim_end_matches("_device").to_string();
187+
}
188+
let trimmed_fn_ident = Ident::new(&trimmed_fn_ident_str, fn_ident.span());
189+
190+
let mut macro_variants = Vec::new();
191+
192+
generate_macro_variants(
193+
fn_mod_path,
194+
fn_ident,
195+
&trimmed_fn_ident,
196+
args,
197+
default_generics,
198+
dtype_generics,
199+
&mut macro_variants,
200+
);
201+
202+
let macro_docs = format!(
203+
"Macro generated for the function [`{}::{}`]. See the function documentation for more details.",
204+
doc_mod_path, trimmed_fn_ident
205+
);
206+
207+
let generated = quote! {
208+
#[doc = #macro_docs]
209+
#[macro_export]
210+
macro_rules! #trimmed_fn_ident {
211+
#(
212+
#macro_variants
213+
)*
214+
}
215+
};
216+
217+
Ok(generated)
218+
}
219+
220+
fn generate_macro_variants(
221+
fn_mod_path: &proc_macro2::TokenStream,
222+
fn_ident: &Ident,
223+
trimmed_fn_ident: &Ident,
224+
args: &[Arg],
225+
default_generics: &proc_macro2::TokenStream,
226+
dtype_generics: &Option<proc_macro2::TokenStream>,
227+
macro_variants: &mut Vec<proc_macro2::TokenStream>,
228+
) {
229+
let args_ident = args.iter().map(|arg| &arg.ident).collect::<Vec<_>>();
230+
let args_type = args.iter().map(|arg| arg.arg_type).collect::<Vec<_>>();
231+
let mut optional_indices = Vec::new();
232+
let mut selected = Vec::with_capacity(args.len());
233+
for (idx, arg) in args.iter().enumerate() {
234+
match arg.arg_type {
235+
ArgType::Positional => {
236+
selected.push(true);
237+
}
238+
ArgType::Named => {
239+
selected.push(true);
240+
}
241+
ArgType::NamedOptional => {
242+
selected.push(false);
243+
optional_indices.push(idx);
244+
}
245+
}
246+
}
247+
248+
for perms in 0..optional_indices.len() + 1 {
249+
// Select `perms` number of optional arguments
250+
for selected_indice in optional_indices.iter().permutations(perms) {
251+
selected_indice.iter().for_each(|&&i| selected[i] = true);
252+
253+
generate_macro_variants_for_selected_args(
254+
fn_mod_path,
255+
fn_ident,
256+
trimmed_fn_ident,
257+
&args_ident,
258+
&args_type,
259+
&selected,
260+
default_generics,
261+
dtype_generics,
262+
macro_variants,
263+
);
264+
265+
// Clear the selected flag for the next iteration
266+
selected_indice.iter().for_each(|&&i| selected[i] = false);
267+
}
268+
}
269+
}
270+
271+
#[allow(clippy::too_many_arguments)]
272+
fn generate_macro_variants_for_selected_args(
273+
fn_mod_path: &proc_macro2::TokenStream,
274+
fn_ident: &Ident,
275+
trimmed_fn_ident: &Ident,
276+
args_ident: &[&Ident],
277+
args_type: &[ArgType],
278+
selected: &[bool],
279+
default_generics: &proc_macro2::TokenStream,
280+
dtype_generics: &Option<proc_macro2::TokenStream>,
281+
macro_variants: &mut Vec<proc_macro2::TokenStream>,
282+
) {
283+
let macro_args: Vec<proc_macro2::TokenStream> = args_ident
284+
.iter()
285+
.zip(args_type.iter())
286+
.zip(selected.iter())
287+
.filter_map(|((ident, arg_type), &selected)| match selected {
288+
true => {
289+
let token = match arg_type {
290+
ArgType::Positional => quote! { $#ident:expr },
291+
ArgType::Named => quote! { #ident=$#ident:expr },
292+
ArgType::NamedOptional => quote! { #ident=$#ident:expr },
293+
};
294+
Some(token)
295+
}
296+
false => None,
297+
})
298+
.collect();
299+
300+
let input: Vec<proc_macro2::TokenStream> = args_ident
301+
.iter()
302+
.zip(selected.iter())
303+
.map(|(ident, &selected)| {
304+
if selected {
305+
quote! { $#ident }
306+
} else {
307+
quote! { None }
308+
}
309+
})
310+
.collect();
311+
312+
let variant_body = quote! {
313+
(
314+
#(#macro_args),*
315+
) => {
316+
#fn_mod_path::#trimmed_fn_ident #default_generics(#(#input,)*)
317+
};
318+
(
319+
#(#macro_args,)*
320+
stream=$stream:expr
321+
) => {
322+
#fn_mod_path::#fn_ident #default_generics(#(#input,)* $stream)
323+
};
324+
};
325+
326+
macro_variants.push(variant_body);
327+
328+
if let Some(dtype_generics) = &dtype_generics {
329+
let variant_body = quote! {
330+
(
331+
#(#macro_args,)*
332+
dtype=$dtype:ty
333+
) => {
334+
#fn_mod_path::#trimmed_fn_ident #dtype_generics(#(#input,)*)
335+
};
336+
(
337+
#(#macro_args,)*
338+
dtype=$dtype:ty,
339+
stream=$stream:expr
340+
) => {
341+
#fn_mod_path::#fn_ident #dtype_generics(#(#input,)* $stream)
342+
};
343+
};
344+
345+
macro_variants.push(variant_body);
346+
}
347+
}

0 commit comments

Comments
 (0)