Skip to content

Commit 071cedc

Browse files
BennoLossinojeda
authored andcommitted
rust: add derive macro for Zeroable
Add a derive proc-macro for the `Zeroable` trait. The macro supports structs where every field implements the `Zeroable` trait. This way `unsafe` implementations can be avoided. The macro is split into two parts: - a proc-macro to parse generics into impl and ty generics, - a declarative macro that expands to the impl block. Suggested-by: Asahi Lina <[email protected]> Signed-off-by: Benno Lossin <[email protected]> Reviewed-by: Gary Guo <[email protected]> Reviewed-by: Martin Rodriguez Reboredo <[email protected]> Link: https://lore.kernel.org/r/[email protected] [ Added `ignore` to the `lib.rs` example and cleaned trivial nit. ] Signed-off-by: Miguel Ojeda <[email protected]>
1 parent f8badd1 commit 071cedc

File tree

5 files changed

+140
-1
lines changed

5 files changed

+140
-1
lines changed

rust/kernel/init/macros.rs

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1215,3 +1215,38 @@ macro_rules! __init_internal {
12151215
);
12161216
};
12171217
}
1218+
1219+
#[doc(hidden)]
1220+
#[macro_export]
1221+
macro_rules! __derive_zeroable {
1222+
(parse_input:
1223+
@sig(
1224+
$(#[$($struct_attr:tt)*])*
1225+
$vis:vis struct $name:ident
1226+
$(where $($whr:tt)*)?
1227+
),
1228+
@impl_generics($($impl_generics:tt)*),
1229+
@ty_generics($($ty_generics:tt)*),
1230+
@body({
1231+
$(
1232+
$(#[$($field_attr:tt)*])*
1233+
$field:ident : $field_ty:ty
1234+
),* $(,)?
1235+
}),
1236+
) => {
1237+
// SAFETY: Every field type implements `Zeroable` and padding bytes may be zero.
1238+
#[automatically_derived]
1239+
unsafe impl<$($impl_generics)*> $crate::init::Zeroable for $name<$($ty_generics)*>
1240+
where
1241+
$($($whr)*)?
1242+
{}
1243+
const _: () = {
1244+
fn assert_zeroable<T: ?::core::marker::Sized + $crate::init::Zeroable>() {}
1245+
fn ensure_zeroable<$($impl_generics)*>()
1246+
where $($($whr)*)?
1247+
{
1248+
$(assert_zeroable::<$field_ty>();)*
1249+
}
1250+
};
1251+
};
1252+
}

rust/kernel/prelude.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ pub use core::pin::Pin;
1818
pub use alloc::{boxed::Box, vec::Vec};
1919

2020
#[doc(no_inline)]
21-
pub use macros::{module, pin_data, pinned_drop, vtable};
21+
pub use macros::{module, pin_data, pinned_drop, vtable, Zeroable};
2222

2323
pub use super::build_assert;
2424

rust/macros/lib.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ mod paste;
1111
mod pin_data;
1212
mod pinned_drop;
1313
mod vtable;
14+
mod zeroable;
1415

1516
use proc_macro::TokenStream;
1617

@@ -343,3 +344,22 @@ pub fn paste(input: TokenStream) -> TokenStream {
343344
paste::expand(&mut tokens);
344345
tokens.into_iter().collect()
345346
}
347+
348+
/// Derives the [`Zeroable`] trait for the given struct.
349+
///
350+
/// This can only be used for structs where every field implements the [`Zeroable`] trait.
351+
///
352+
/// # Examples
353+
///
354+
/// ```rust,ignore
355+
/// #[derive(Zeroable)]
356+
/// pub struct DriverData {
357+
/// id: i64,
358+
/// buf_ptr: *mut u8,
359+
/// len: usize,
360+
/// }
361+
/// ```
362+
#[proc_macro_derive(Zeroable)]
363+
pub fn derive_zeroable(input: TokenStream) -> TokenStream {
364+
zeroable::derive(input)
365+
}

rust/macros/quote.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,18 @@ macro_rules! quote_spanned {
124124
));
125125
quote_spanned!(@proc $v $span $($tt)*);
126126
};
127+
(@proc $v:ident $span:ident ; $($tt:tt)*) => {
128+
$v.push(::proc_macro::TokenTree::Punct(
129+
::proc_macro::Punct::new(';', ::proc_macro::Spacing::Alone)
130+
));
131+
quote_spanned!(@proc $v $span $($tt)*);
132+
};
133+
(@proc $v:ident $span:ident + $($tt:tt)*) => {
134+
$v.push(::proc_macro::TokenTree::Punct(
135+
::proc_macro::Punct::new('+', ::proc_macro::Spacing::Alone)
136+
));
137+
quote_spanned!(@proc $v $span $($tt)*);
138+
};
127139
(@proc $v:ident $span:ident $id:ident $($tt:tt)*) => {
128140
$v.push(::proc_macro::TokenTree::Ident(::proc_macro::Ident::new(stringify!($id), $span)));
129141
quote_spanned!(@proc $v $span $($tt)*);

rust/macros/zeroable.rs

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
// SPDX-License-Identifier: GPL-2.0
2+
3+
use crate::helpers::{parse_generics, Generics};
4+
use proc_macro::{TokenStream, TokenTree};
5+
6+
pub(crate) fn derive(input: TokenStream) -> TokenStream {
7+
let (
8+
Generics {
9+
impl_generics,
10+
ty_generics,
11+
},
12+
mut rest,
13+
) = parse_generics(input);
14+
// This should be the body of the struct `{...}`.
15+
let last = rest.pop();
16+
// Now we insert `Zeroable` as a bound for every generic parameter in `impl_generics`.
17+
let mut new_impl_generics = Vec::with_capacity(impl_generics.len());
18+
// Are we inside of a generic where we want to add `Zeroable`?
19+
let mut in_generic = !impl_generics.is_empty();
20+
// Have we already inserted `Zeroable`?
21+
let mut inserted = false;
22+
// Level of `<>` nestings.
23+
let mut nested = 0;
24+
for tt in impl_generics {
25+
match &tt {
26+
// If we find a `,`, then we have finished a generic/constant/lifetime parameter.
27+
TokenTree::Punct(p) if nested == 0 && p.as_char() == ',' => {
28+
if in_generic && !inserted {
29+
new_impl_generics.extend(quote! { : ::kernel::init::Zeroable });
30+
}
31+
in_generic = true;
32+
inserted = false;
33+
new_impl_generics.push(tt);
34+
}
35+
// If we find `'`, then we are entering a lifetime.
36+
TokenTree::Punct(p) if nested == 0 && p.as_char() == '\'' => {
37+
in_generic = false;
38+
new_impl_generics.push(tt);
39+
}
40+
TokenTree::Punct(p) if nested == 0 && p.as_char() == ':' => {
41+
new_impl_generics.push(tt);
42+
if in_generic {
43+
new_impl_generics.extend(quote! { ::kernel::init::Zeroable + });
44+
inserted = true;
45+
}
46+
}
47+
TokenTree::Punct(p) if p.as_char() == '<' => {
48+
nested += 1;
49+
new_impl_generics.push(tt);
50+
}
51+
TokenTree::Punct(p) if p.as_char() == '>' => {
52+
assert!(nested > 0);
53+
nested -= 1;
54+
new_impl_generics.push(tt);
55+
}
56+
_ => new_impl_generics.push(tt),
57+
}
58+
}
59+
assert_eq!(nested, 0);
60+
if in_generic && !inserted {
61+
new_impl_generics.extend(quote! { : ::kernel::init::Zeroable });
62+
}
63+
quote! {
64+
::kernel::__derive_zeroable!(
65+
parse_input:
66+
@sig(#(#rest)*),
67+
@impl_generics(#(#new_impl_generics)*),
68+
@ty_generics(#(#ty_generics)*),
69+
@body(#last),
70+
);
71+
}
72+
}

0 commit comments

Comments
 (0)