Skip to content

Commit 595b523

Browse files
committed
Support ignoring parameters for #[derive(Trace)]
The 'ignored_params' are indicated with an attribute and zerogc will not add automatically generated 'Trace + GcSafe' bounds to these parameters. The compiler will error if this is incorrect.
1 parent 9864cbf commit 595b523

File tree

3 files changed

+321
-27
lines changed

3 files changed

+321
-27
lines changed

libs/derive/Cargo.toml

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[package]
22
name = "zerogc-derive"
33
description = "Procedural derive for zerogc's garbage collection"
4-
version = "0.1.1"
4+
version = "0.1.2"
55
authors = ["Techcable <[email protected]>"]
66
repository = "https://github.com/DuckLogic/zerogc"
77
readme = "../../README.md"
@@ -11,8 +11,13 @@ edition = "2018"
1111
[lib]
1212
proc-macro = true
1313

14+
[dev-dependencies]
15+
zerogc = { version = "0.1.2", path = "../.." }
16+
1417
[dependencies]
1518
# Proc macros
16-
syn = "1"
17-
quote = "1"
19+
syn = "1.0.55"
20+
quote = "1.0.8"
1821
proc-macro2 = "1"
22+
# Itertools
23+
itertools = "0.9"

libs/derive/src/lib.rs

Lines changed: 202 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,20 @@
1+
#![feature(backtrace)]
12
extern crate proc_macro;
23

34
use quote::{quote, quote_spanned};
45
use syn::{
5-
parse_macro_input, parenthesized, parse_quote, DeriveInput, Data,
6-
Error, Generics, GenericParam, TypeParamBound, Fields, Member,
7-
Index, Type, GenericArgument, Attribute, PathArguments,
6+
parse_macro_input, parenthesized, parse_quote, DeriveInput,
7+
Data, Error, Generics, GenericParam, TypeParamBound, Fields,
8+
Member, Index, Type, GenericArgument, Attribute, PathArguments,
9+
Meta, NestedMeta, TypeParam, WherePredicate, PredicateType,
10+
Token
811
};
912
use proc_macro2::{Ident, TokenStream, Span};
1013
use syn::spanned::Spanned;
1114
use syn::parse::{ParseStream, Parse};
15+
use std::collections::HashSet;
16+
use syn::export::fmt::Display;
17+
use std::io::Write;
1218

1319
struct MutableFieldOpts {
1420
public: bool
@@ -91,6 +97,7 @@ impl Parse for GcFieldAttrs {
9197

9298
struct GcTypeAttrs {
9399
is_copy: bool,
100+
ignore_params: HashSet<Ident>
94101
}
95102
impl GcTypeAttrs {
96103
pub fn find(attrs: &[Attribute]) -> Result<Self, Error> {
@@ -107,6 +114,7 @@ impl Default for GcTypeAttrs {
107114
fn default() -> Self {
108115
GcTypeAttrs {
109116
is_copy: false,
117+
ignore_params: HashSet::new()
110118
}
111119
}
112120
}
@@ -116,14 +124,67 @@ impl Parse for GcTypeAttrs {
116124
parenthesized!(input in raw_input);
117125
let mut result = GcTypeAttrs::default();
118126
while !input.is_empty() {
119-
let flag_name = input.parse::<Ident>()?;
120-
if flag_name == "copy" {
127+
let meta = input.parse::<Meta>()?;
128+
if meta.path().is_ident("copy") {
129+
if !matches!(meta, Meta::Path(_)) {
130+
return Err(Error::new(
131+
meta.span(),
132+
"Malformed attribute for #[zerogc(copy)]"
133+
))
134+
}
135+
if result.is_copy {
136+
return Err(Error::new(
137+
meta.span(),
138+
"Duplicate flags: #[zerogc(copy)]"
139+
))
140+
}
121141
result.is_copy = true;
142+
} else if meta.path().is_ident("ignore_params") {
143+
if !result.ignore_params.is_empty() {
144+
return Err(Error::new(
145+
meta.span(),
146+
"Duplicate flags: #[zerogc(ignore_params)]"
147+
))
148+
}
149+
let list = match meta {
150+
Meta::List(ref list) if list.nested.is_empty() => {
151+
return Err(Error::new(
152+
list.span(),
153+
"Empty list for #[zerogc(ignore_parameters)]"
154+
))
155+
}
156+
Meta::List(list) => list,
157+
_ => return Err(Error::new(
158+
meta.span(),
159+
"Expected a list attribute for #[zerogc(ignore_params)]"
160+
))
161+
};
162+
for nested in list.nested {
163+
match nested {
164+
NestedMeta::Meta(Meta::Path(ref p)) if
165+
p.get_ident().is_some() => {
166+
let ident = p.get_ident().unwrap();
167+
if !result.ignore_params.insert(ident.clone()) {
168+
return Err(Error::new(
169+
ident.span(),
170+
"Duplicate parameter to ignore"
171+
));
172+
}
173+
}
174+
_ => return Err(Error::new(
175+
nested.span(),
176+
"Invalid list value for #[zerogc(ignore_param)]"
177+
))
178+
}
179+
}
122180
} else {
123181
return Err(Error::new(
124-
input.span(), "Unknown type flag"
182+
meta.span(), "Unknown type flag"
125183
))
126184
}
185+
if input.peek(Token![,]) {
186+
input.parse::<Token![,]>()?;
187+
}
127188
}
128189
Ok(result)
129190
}
@@ -132,20 +193,30 @@ impl Parse for GcTypeAttrs {
132193
#[proc_macro_derive(Trace, attributes(zerogc))]
133194
pub fn derive_trace(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
134195
let input = parse_macro_input!(input as DeriveInput);
135-
let trace_impl = impl_trace(&input)
196+
let attrs = match GcTypeAttrs::find(&*input.attrs) {
197+
Ok(attrs) => attrs,
198+
Err(e) => return e.to_compile_error().into()
199+
};
200+
let trace_impl = impl_trace(&input, &attrs)
136201
.unwrap_or_else(|e| e.to_compile_error());
137202
let brand_impl = impl_brand(&input)
138203
.unwrap_or_else(|e| e.to_compile_error());
139-
let gc_safe_impl = impl_gc_safe(&input)
204+
let gc_safe_impl = impl_gc_safe(&input, &attrs)
140205
.unwrap_or_else(|e| e.to_compile_error());
141206
let extra_impls = impl_extras(&input)
142207
.unwrap_or_else(|e| e.to_compile_error());
143-
From::from(quote! {
208+
let t = From::from(quote! {
144209
#trace_impl
145210
#brand_impl
146211
#gc_safe_impl
147212
#extra_impls
148-
})
213+
});
214+
debug_derive(
215+
"derive(Trace)",
216+
&format_args!("#[derive(Trace) for {}", input.ident),
217+
&t
218+
);
219+
t
149220
}
150221

151222
fn trace_fields(fields: &Fields, access_ref: &mut dyn FnMut(Member) -> TokenStream) -> TokenStream {
@@ -275,13 +346,22 @@ fn impl_brand(target: &DeriveInput) -> Result<TokenStream, Error> {
275346
let name = &target.ident;
276347
let mut generics: Generics = target.generics.clone();
277348
let mut rewritten_params = Vec::new();
349+
let mut rewritten_restrictions = Vec::new();
278350
for param in &mut generics.params {
279351
let rewritten_param: GenericArgument;
280352
match param {
281353
GenericParam::Type(ref mut type_param) => {
354+
let original_bounds = type_param.bounds.iter().cloned().collect::<Vec<_>>();
282355
type_param.bounds.push(parse_quote!(::zerogc::GcBrand<'new_gc, S>));
283356
let param_name = &type_param.ident;
284-
rewritten_param = parse_quote!(<#param_name as ::zerogc::GcBrand<'new_gc, S>::Branded);
357+
let rewritten_type: Type = parse_quote!(<#param_name as ::zerogc::GcBrand<'new_gc, S>>::Branded);
358+
rewritten_restrictions.push(WherePredicate::Type(PredicateType {
359+
lifetimes: None,
360+
bounded_ty: rewritten_type.clone(),
361+
colon_token: Default::default(),
362+
bounds: original_bounds.into_iter().collect()
363+
}));
364+
rewritten_param = GenericArgument::Type(rewritten_type);
285365
},
286366
GenericParam::Lifetime(ref l) => {
287367
/*
@@ -307,20 +387,22 @@ fn impl_brand(target: &DeriveInput) -> Result<TokenStream, Error> {
307387
let mut impl_generics = generics.clone();
308388
impl_generics.params.push(GenericParam::Lifetime(parse_quote!('new_gc)));
309389
impl_generics.params.push(GenericParam::Type(parse_quote!(S: ::zerogc::CollectorId)));
310-
let (_, ty_generics, where_clause) = generics.split_for_impl();
311-
let (impl_generics, _, _) = impl_generics.split_for_impl();
390+
impl_generics.make_where_clause().predicates.extend(rewritten_restrictions);
391+
let (_, ty_generics, _) = generics.split_for_impl();
392+
let (impl_generics, _, where_clause) = impl_generics.split_for_impl();
312393
Ok(quote! {
313394
unsafe impl #impl_generics ::zerogc::GcBrand<'new_gc, S>
314395
for #name #ty_generics #where_clause {
315396
type Branded = #name::<#(#rewritten_params),*>;
316397
}
317398
})
318399
}
319-
fn impl_trace(target: &DeriveInput) -> Result<TokenStream, Error> {
400+
fn impl_trace(target: &DeriveInput, attrs: &GcTypeAttrs) -> Result<TokenStream, Error> {
320401
let name = &target.ident;
321-
let generics = add_trait_bounds(
322-
&target.generics, parse_quote!(zerogc::Trace)
323-
);
402+
let generics = add_trait_bounds_except(
403+
&target.generics, parse_quote!(zerogc::Trace),
404+
&attrs.ignore_params
405+
)?;
324406
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
325407
let field_types: Vec<&Type>;
326408
let trace_impl: TokenStream;
@@ -399,12 +481,12 @@ fn impl_trace(target: &DeriveInput) -> Result<TokenStream, Error> {
399481
}
400482
})
401483
}
402-
fn impl_gc_safe(target: &DeriveInput) -> Result<TokenStream, Error> {
484+
fn impl_gc_safe(target: &DeriveInput, attrs: &GcTypeAttrs) -> Result<TokenStream, Error> {
403485
let name = &target.ident;
404-
let generics = add_trait_bounds(
405-
&target.generics, parse_quote!(zerogc::GcSafe)
406-
);
407-
let attrs = GcTypeAttrs::find(&*target.attrs)?;
486+
let generics = add_trait_bounds_except(
487+
&target.generics, parse_quote!(zerogc::GcSafe),
488+
&attrs.ignore_params
489+
)?;
408490
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
409491
let field_types: Vec<&Type> = match target.data {
410492
Data::Struct(ref data) => {
@@ -468,12 +550,108 @@ fn impl_gc_safe(target: &DeriveInput) -> Result<TokenStream, Error> {
468550
})
469551
}
470552

471-
fn add_trait_bounds(generics: &Generics, bound: TypeParamBound) -> Generics {
553+
fn add_trait_bounds_except(
554+
generics: &Generics, bound: TypeParamBound,
555+
ignored_params: &HashSet<Ident>
556+
) -> Result<Generics, Error> {
557+
let mut actually_ignored_args = HashSet::<Ident>::new();
558+
let generics = add_trait_bounds(
559+
&generics, bound,
560+
&mut |param: &TypeParam| {
561+
if ignored_params.contains(&param.ident) {
562+
actually_ignored_args.insert(param.ident.clone());
563+
true
564+
} else {
565+
false
566+
}
567+
}
568+
);
569+
if actually_ignored_args != *ignored_params {
570+
let missing = ignored_params - &actually_ignored_args;
571+
assert!(!missing.is_empty());
572+
let mut combined_error: Option<Error> = None;
573+
for missing in missing {
574+
let error = Error::new(
575+
missing.span(),
576+
"Unknown parameter",
577+
);
578+
match combined_error {
579+
Some(ref mut combined_error) => {
580+
combined_error.combine(error);
581+
},
582+
None => {
583+
combined_error = Some(error);
584+
}
585+
}
586+
}
587+
return Err(combined_error.unwrap());
588+
}
589+
Ok(generics)
590+
}
591+
592+
fn add_trait_bounds(
593+
generics: &Generics, bound: TypeParamBound,
594+
should_ignore: &mut dyn FnMut(&TypeParam) -> bool
595+
) -> Generics {
472596
let mut result: Generics = (*generics).clone();
473-
for param in &mut result.params {
597+
'paramLoop: for param in &mut result.params {
474598
if let GenericParam::Type(ref mut type_param) = *param {
599+
if should_ignore(type_param) {
600+
continue 'paramLoop;
601+
}
475602
type_param.bounds.push(bound.clone());
476603
}
477604
}
478605
result
479606
}
607+
608+
fn debug_derive(key: &str, message: &dyn Display, value: &dyn Display) {
609+
match ::std::env::var_os("DEBUG_DERIVE") {
610+
Some(var) if var == "*" ||
611+
var.to_string_lossy().contains(key) => {
612+
// Enable this debug
613+
},
614+
_ => return,
615+
}
616+
eprintln!("{}:", message);
617+
use std::process::{Command, Stdio};
618+
let original_input = format!("{}", value);
619+
let cmd_res = Command::new("rustfmt")
620+
.stdin(Stdio::piped())
621+
.stdout(Stdio::piped())
622+
.stderr(Stdio::piped())
623+
.spawn()
624+
.and_then(|mut child| {
625+
let mut stdin = child.stdin.take().unwrap();
626+
stdin.write_all(original_input.as_bytes())?;
627+
drop(stdin);
628+
child.wait_with_output()
629+
});
630+
match cmd_res {
631+
Ok(output) if output.status.success() => {
632+
let formatted = String::from_utf8(output.stdout).unwrap();
633+
for line in formatted.lines() {
634+
eprintln!(" {}", line);
635+
}
636+
},
637+
// Fallthrough on failure
638+
Ok(output) => {
639+
eprintln!("Rustfmt error [code={}]:", output.status.code().map_or_else(
640+
|| String::from("?"),
641+
|i| format!("{}", i)
642+
));
643+
let err_msg = String::from_utf8(output.stderr).unwrap();
644+
for line in err_msg.lines() {
645+
eprintln!(" {}", line);
646+
}
647+
eprintln!("Original input: [[[[");
648+
for line in original_input.lines() {
649+
eprintln!("{}", line);
650+
}
651+
eprintln!("]]]]");
652+
}
653+
Err(e) => {
654+
eprintln!("Failed to run rustfmt: {}", e)
655+
}
656+
}
657+
}

0 commit comments

Comments
 (0)