From 8d9fc66a1e209b6ac47aa799a4262a8b3d9ddc8e Mon Sep 17 00:00:00 2001 From: Roman Lebedev Date: Mon, 23 Jun 2025 03:13:09 +0300 Subject: [PATCH] Loop unroll --- Cargo.toml | 4 + src/utils/loop_xform/Cargo.toml | 22 ++ src/utils/loop_xform/mod.rs | 43 +++ src/utils/loop_xform/parse/mod.rs | 169 +++++++++++ src/utils/loop_xform/parse/tests/mod.rs | 7 + .../loop_xform/parse/tests/unroll_runtime.rs | 269 ++++++++++++++++++ .../parse/tests/unroll_with_remainder.rs | 45 +++ src/utils/loop_xform/transform/mod.rs | 16 ++ .../transform/unroll_runtime/mod.rs | 38 +++ .../transform/unroll_runtime/tests.rs | 64 +++++ .../transform/unroll_with_remainder/mod.rs | 129 +++++++++ .../transform/unroll_with_remainder/tests.rs | 245 ++++++++++++++++ .../utils/loop_break_labeller/mod.rs | 34 +++ .../utils/loop_break_labeller/tests.rs | 69 +++++ tests/utils/loop_xform/Cargo.toml | 20 ++ tests/utils/loop_xform/mod.rs | 128 +++++++++ tests/utils/loop_xform/naive.rs | 96 +++++++ tests/utils/loop_xform/unroll_runtime.rs | 120 ++++++++ .../utils/loop_xform/unroll_with_remainder.rs | 258 +++++++++++++++++ 19 files changed, 1776 insertions(+) create mode 100644 src/utils/loop_xform/Cargo.toml create mode 100644 src/utils/loop_xform/mod.rs create mode 100644 src/utils/loop_xform/parse/mod.rs create mode 100644 src/utils/loop_xform/parse/tests/mod.rs create mode 100644 src/utils/loop_xform/parse/tests/unroll_runtime.rs create mode 100644 src/utils/loop_xform/parse/tests/unroll_with_remainder.rs create mode 100644 src/utils/loop_xform/transform/mod.rs create mode 100644 src/utils/loop_xform/transform/unroll_runtime/mod.rs create mode 100644 src/utils/loop_xform/transform/unroll_runtime/tests.rs create mode 100644 src/utils/loop_xform/transform/unroll_with_remainder/mod.rs create mode 100644 src/utils/loop_xform/transform/unroll_with_remainder/tests.rs create mode 100644 src/utils/loop_xform/transform/utils/loop_break_labeller/mod.rs create mode 100644 src/utils/loop_xform/transform/utils/loop_break_labeller/tests.rs create mode 100644 tests/utils/loop_xform/Cargo.toml create mode 100644 tests/utils/loop_xform/mod.rs create mode 100644 tests/utils/loop_xform/naive.rs create mode 100644 tests/utils/loop_xform/unroll_runtime.rs create mode 100644 tests/utils/loop_xform/unroll_with_remainder.rs diff --git a/Cargo.toml b/Cargo.toml index ccd922d..1f1b991 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,6 +13,8 @@ members = [ "src/memory/bytevacuumer", "src/memory/bytestreamer", "src/misc/md5", + "src/utils/loop_xform", + "tests/utils/loop_xform", ] [workspace.package] @@ -131,6 +133,8 @@ panic_in_result_fn = { level = "allow", priority = 0 } implicit_return = { level = "allow", priority = 0 } absolute_paths = { level = "allow", priority = 0 } question_mark_used = { level = "allow", priority = 0 } +std_instead_of_alloc = { level = "allow", priority = 0 } +single_call_fn = { level = "allow", priority = 0 } [profile.release] panic = 'abort' diff --git a/src/utils/loop_xform/Cargo.toml b/src/utils/loop_xform/Cargo.toml new file mode 100644 index 0000000..16093f3 --- /dev/null +++ b/src/utils/loop_xform/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "rawspeed-utils-loop_xform" +version.workspace = true +authors.workspace = true +edition.workspace = true +rust-version.workspace = true +documentation.workspace = true +homepage.workspace = true +repository.workspace = true +license.workspace = true + +[lints] +workspace = true + +[dependencies] +proc-macro2 = { version = "1.0", default-features = false, features = [] } +syn = { version = "2.0", default-features = false, features = ["proc-macro", "parsing", "full", "visit-mut", "printing"] } +quote = { version = "1.0", default-features = false, features = [] } + +[lib] +proc-macro = true +path = "mod.rs" diff --git a/src/utils/loop_xform/mod.rs b/src/utils/loop_xform/mod.rs new file mode 100644 index 0000000..a0e1d51 --- /dev/null +++ b/src/utils/loop_xform/mod.rs @@ -0,0 +1,43 @@ +mod kw { + syn::custom_keyword!(runtime); + syn::custom_keyword!(with_remainder); +} + +#[derive(PartialEq, Eq, Debug)] +enum UnrollMethod { + Runtime, + WithRemainder, +} + +enum Item { + LoopUnrollAttr(LoopUnrollConf), +} + +struct LoopUnrollConf { + pub unroll_method: UnrollMethod, + pub unroll_factor: usize, + pub for_loop: syn::ExprForLoop, + pub rest_of_tokenstream: proc_macro2::TokenStream, +} + +#[proc_macro] +pub fn enable_loop_xforms( + tokens: proc_macro::TokenStream, +) -> proc_macro::TokenStream { + let input = syn::parse_macro_input!(tokens as Item); + + match input { + Item::LoopUnrollAttr(c) => { + #[cfg(clippy)] + { + use quote::ToTokens as _; + return c.for_loop.to_token_stream().into(); + } + #[cfg_attr(clippy, allow(unreachable_code))] + transform::perform_loop_unroll(c).into() + } + } +} + +mod parse; +mod transform; diff --git a/src/utils/loop_xform/parse/mod.rs b/src/utils/loop_xform/parse/mod.rs new file mode 100644 index 0000000..70103cc --- /dev/null +++ b/src/utils/loop_xform/parse/mod.rs @@ -0,0 +1,169 @@ +use super::Item; +use super::LoopUnrollConf; +use super::UnrollMethod; +use super::kw; +use syn::Attribute; +use syn::LitInt; +use syn::Result; +use syn::parenthesized; +use syn::parse::Parse; +use syn::parse::ParseStream; + +impl Parse for UnrollMethod { + fn parse(input: ParseStream<'_>) -> Result { + let lookahead = input.lookahead1(); + if lookahead.peek(kw::runtime) { + input.parse::()?; + Ok(UnrollMethod::Runtime) + } else if lookahead.peek(kw::with_remainder) { + input.parse::()?; + Ok(UnrollMethod::WithRemainder) + } else { + Err(lookahead.error()) + } + } +} + +#[allow(clippy::single_call_fn)] +fn parse_method( + meta: &syn::meta::ParseNestedMeta<'_>, + unroll_method: &mut Option, +) -> Result<()> { + assert!(meta.path.is_ident("method")); + + if unroll_method.is_some() { + return Err( + meta.error("only a single unroll method shall be specified") + ); + } + + let content; + parenthesized!(content in meta.input); + let head = content.fork(); + match content.parse::() { + Ok(m) => *unroll_method = Some(m), + Err(_) => { + return Err(head.error("expected valid unroll method")); + } + } + if !content.is_empty() { + return Err(syn::Error::new_spanned( + content.parse::()?, + "unexpected garbage in unroll method argument", + )); + } + Ok(()) +} + +#[allow(clippy::single_call_fn)] +fn parse_factor( + meta: &syn::meta::ParseNestedMeta<'_>, + unroll_factor: &mut Option, +) -> Result<()> { + assert!(meta.path.is_ident("factor")); + + if unroll_factor.is_some() { + return Err( + meta.error("only a single unroll factor shall be specified") + ); + } + let content; + parenthesized!(content in meta.input); + let lit: LitInt = content.parse()?; + if !lit.suffix().is_empty() { + return Err(syn::Error::new_spanned( + lit, + "unroll factor should not have any suffix", + )); + } + if !content.is_empty() { + return Err(syn::Error::new_spanned( + content.parse::()?, + "unexpected garbage in unroll factor argument", + )); + } + let n: usize = lit.base10_parse()?; + if n < 1 { + return Err(meta.error("Unroll factor can not be zero")); + } + *unroll_factor = Some(n); + Ok(()) +} + +#[allow(clippy::single_call_fn)] +fn parse_attr(attr: &Attribute) -> Result<(UnrollMethod, usize)> { + if !attr.path().is_ident("loop_unroll") { + return Err(syn::Error::new_spanned( + attr, + "`loop_unroll` attribute expected", + )); + } + + let mut unroll_method: Option = None; + let mut unroll_factor: Option = None; + attr.parse_nested_meta(|meta| { + if meta.path.is_ident("method") { + return parse_method(&meta, &mut unroll_method); + } + if meta.path.is_ident("factor") { + return parse_factor(&meta, &mut unroll_factor); + } + Err(meta.error( + "unrecognized parameter, expected `method(...)` and `factor(..)`", + )) + })?; + + let Some(unroll_method) = unroll_method else { + return Err(syn::Error::new_spanned( + attr, + "The attribute must specify unroll `method`", + )); + }; + + let Some(unroll_factor) = unroll_factor else { + return Err(syn::Error::new_spanned( + attr, + "The attribute must specify `factor`", + )); + }; + + Ok((unroll_method, unroll_factor)) +} + +impl Parse for Item { + fn parse(input: ParseStream<'_>) -> Result { + let attrs = input.call(Attribute::parse_outer)?; + + let (unroll_method, unroll_factor); + if let Some(attr) = attrs.first() { + if let Some(ea) = attrs.get(1) { + return Err(syn::Error::new_spanned( + ea, + "There should only be a single attribute", + )); + } + + (unroll_method, unroll_factor) = parse_attr(attr)?; + } else { + return Err(syn::Error::new_spanned( + input.parse::()?, + "There must be an attribute", + )); + } + + let for_loop = input.parse::()?; + let remainder: proc_macro2::TokenStream = input.parse()?; + assert!(input.is_empty()); + + Ok(Item::LoopUnrollAttr(LoopUnrollConf { + unroll_method, + unroll_factor, + for_loop, + rest_of_tokenstream: remainder, + })) + } +} + +#[cfg(test)] +#[allow(clippy::large_stack_frames)] +mod tests; diff --git a/src/utils/loop_xform/parse/tests/mod.rs b/src/utils/loop_xform/parse/tests/mod.rs new file mode 100644 index 0000000..d01ff24 --- /dev/null +++ b/src/utils/loop_xform/parse/tests/mod.rs @@ -0,0 +1,7 @@ +#[cfg(test)] +#[allow(clippy::large_stack_frames)] +mod unroll_runtime; + +#[cfg(test)] +#[allow(clippy::large_stack_frames)] +mod unroll_with_remainder; diff --git a/src/utils/loop_xform/parse/tests/unroll_runtime.rs b/src/utils/loop_xform/parse/tests/unroll_runtime.rs new file mode 100644 index 0000000..b72b8d9 --- /dev/null +++ b/src/utils/loop_xform/parse/tests/unroll_runtime.rs @@ -0,0 +1,269 @@ +use crate::Item; +use crate::UnrollMethod; +use quote::ToTokens as _; +use quote::quote; + +#[test] +#[should_panic(expected = "There must be an attribute")] +fn t0_test() { + let tokens = quote! {}; + match syn::parse2::(tokens) { + Ok(_) => (), + Err(e) => { + panic!("{}", e) + } + } +} + +#[test] +#[should_panic(expected = "`loop_unroll` attribute expected")] +fn t1_test() { + let tokens = quote! { + #[attr] + }; + match syn::parse2::(tokens) { + Ok(_) => (), + Err(e) => { + panic!("{}", e) + } + } +} + +#[test] +#[should_panic( + expected = "expected attribute arguments in parentheses: #[loop_unroll(...)" +)] +fn t2_test() { + let tokens = quote! { + #[loop_unroll] + }; + match syn::parse2::(tokens) { + Ok(_) => (), + Err(e) => { + panic!("{}", e) + } + } +} + +#[test] +#[should_panic(expected = "The attribute must specify unroll `method`")] +fn t3_test() { + let tokens = quote! { + #[loop_unroll()] + }; + match syn::parse2::(tokens) { + Ok(_) => (), + Err(e) => { + panic!("{}", e) + } + } +} + +#[test] +#[should_panic(expected = "unexpected end of input, expected parentheses")] +fn t4_test() { + let tokens = quote! { + #[loop_unroll(method)] + }; + match syn::parse2::(tokens) { + Ok(_) => (), + Err(e) => { + panic!("{}", e) + } + } +} + +#[test] +#[should_panic( + expected = "unexpected end of input, expected valid unroll method" +)] +fn t5_test() { + let tokens = quote! { + #[loop_unroll(method())] + }; + match syn::parse2::(tokens) { + Ok(_) => (), + Err(e) => { + panic!("{}", e) + } + } +} + +#[test] +#[should_panic(expected = "expected valid unroll method")] +fn t6_test() { + let tokens = quote! { + #[loop_unroll(method(run1time))] + }; + match syn::parse2::(tokens) { + Ok(_) => (), + Err(e) => { + panic!("{}", e) + } + } +} + +#[test] +#[should_panic(expected = "The attribute must specify `factor`")] +fn t7_test() { + let tokens = quote! { + #[loop_unroll(method(runtime))] + }; + match syn::parse2::(tokens) { + Ok(_) => (), + Err(e) => { + panic!("{}", e) + } + } +} + +#[test] +#[should_panic(expected = "expected `,")] +fn t8_test() { + let tokens = quote! { + #[loop_unroll(method(runtime) factor)] + }; + match syn::parse2::(tokens) { + Ok(_) => (), + Err(e) => { + panic!("{}", e) + } + } +} + +#[test] +#[should_panic(expected = "unexpected end of input, expected parentheses")] +fn t9_test() { + let tokens = quote! { + #[loop_unroll(method(runtime), factor)] + }; + match syn::parse2::(tokens) { + Ok(_) => (), + Err(e) => { + panic!("{}", e) + } + } +} + +#[test] +#[should_panic(expected = "unexpected end of input, expected integer literal")] +fn t10_test() { + let tokens = quote! { + #[loop_unroll(method(runtime), factor())] + }; + match syn::parse2::(tokens) { + Ok(_) => (), + Err(e) => { + panic!("{}", e) + } + } +} + +#[test] +#[should_panic(expected = "unexpected end of input, expected `for`")] +fn t11_test() { + let tokens = quote! { + #[loop_unroll(method(runtime), factor(42))] + }; + match syn::parse2::(tokens) { + Ok(_) => (), + Err(e) => { + panic!("{}", e) + } + } +} + +#[test] +#[should_panic(expected = "unroll factor should not have any suffix")] +fn t12_test() { + let tokens = quote! { + #[loop_unroll(method(runtime), factor(42u16))] + }; + match syn::parse2::(tokens) { + Ok(_) => (), + Err(e) => { + panic!("{}", e) + } + } +} + +#[test] +#[should_panic(expected = "Unroll factor can not be zero")] +fn t13_test() { + let tokens = quote! { + #[loop_unroll(method(runtime), factor(0))] + }; + match syn::parse2::(tokens) { + Ok(_) => (), + Err(e) => { + panic!("{}", e) + } + } +} + +#[test] +#[should_panic(expected = "unexpected end of input, expected `for`")] +fn t14_test() { + let tokens = quote! { + #[loop_unroll(method(runtime), factor(1))] + }; + match syn::parse2::(tokens) { + Ok(_) => (), + Err(e) => { + panic!("{}", e) + } + } +} + +#[test] +#[should_panic(expected = "There should only be a single attribute")] +fn t15_test() { + let tokens = quote! { + #[loop_unroll(method(runtime), factor(1))] + #[loop_unroll(method(runtime), factor(1))] + }; + match syn::parse2::(tokens) { + Ok(_) => (), + Err(e) => { + panic!("{}", e) + } + } +} + +#[test] +fn good_test() { + let tokens = quote! { + #[loop_unroll(method(runtime), factor(42))] + 'loop_label: for elt in iter { body } rest + }; + match syn::parse2::(tokens).unwrap() { + Item::LoopUnrollAttr(loop_unroll_conf) => { + assert_eq!(loop_unroll_conf.unroll_method, UnrollMethod::Runtime); + assert_eq!(loop_unroll_conf.unroll_factor, 42); + assert_eq!( + loop_unroll_conf + .for_loop + .label + .to_token_stream() + .to_string(), + ("'loop_label :") + ); + assert_eq!( + loop_unroll_conf.for_loop.pat.to_token_stream().to_string(), + ("elt") + ); + assert_eq!( + loop_unroll_conf.for_loop.expr.to_token_stream().to_string(), + ("iter") + ); + assert_eq!( + loop_unroll_conf.for_loop.body.to_token_stream().to_string(), + ("{ body }") + ); + assert_eq!( + loop_unroll_conf.rest_of_tokenstream.to_string(), + "rest" + ); + } + } +} diff --git a/src/utils/loop_xform/parse/tests/unroll_with_remainder.rs b/src/utils/loop_xform/parse/tests/unroll_with_remainder.rs new file mode 100644 index 0000000..a4a75ae --- /dev/null +++ b/src/utils/loop_xform/parse/tests/unroll_with_remainder.rs @@ -0,0 +1,45 @@ +use crate::Item; +use crate::UnrollMethod; +use quote::ToTokens as _; +use quote::quote; + +#[test] +fn good_test() { + let tokens = quote! { + #[loop_unroll(method(with_remainder), factor(42))] + 'loop_label: for elt in iter { body } rest + }; + match syn::parse2::(tokens).unwrap() { + Item::LoopUnrollAttr(loop_unroll_conf) => { + assert_eq!( + loop_unroll_conf.unroll_method, + UnrollMethod::WithRemainder + ); + assert_eq!(loop_unroll_conf.unroll_factor, 42); + assert_eq!( + loop_unroll_conf + .for_loop + .label + .to_token_stream() + .to_string(), + ("'loop_label :") + ); + assert_eq!( + loop_unroll_conf.for_loop.pat.to_token_stream().to_string(), + ("elt") + ); + assert_eq!( + loop_unroll_conf.for_loop.expr.to_token_stream().to_string(), + ("iter") + ); + assert_eq!( + loop_unroll_conf.for_loop.body.to_token_stream().to_string(), + ("{ body }") + ); + assert_eq!( + loop_unroll_conf.rest_of_tokenstream.to_string(), + "rest" + ); + } + } +} diff --git a/src/utils/loop_xform/transform/mod.rs b/src/utils/loop_xform/transform/mod.rs new file mode 100644 index 0000000..15b58c1 --- /dev/null +++ b/src/utils/loop_xform/transform/mod.rs @@ -0,0 +1,16 @@ +use super::LoopUnrollConf; +use super::UnrollMethod; + +pub fn perform_loop_unroll(c: LoopUnrollConf) -> proc_macro2::TokenStream { + match c.unroll_method { + UnrollMethod::Runtime => unroll_runtime::transform(&c), + UnrollMethod::WithRemainder => unroll_with_remainder::transform(c), + } +} + +mod utils { + pub mod loop_break_labeller; +} + +mod unroll_runtime; +mod unroll_with_remainder; diff --git a/src/utils/loop_xform/transform/unroll_runtime/mod.rs b/src/utils/loop_xform/transform/unroll_runtime/mod.rs new file mode 100644 index 0000000..db3b1fc --- /dev/null +++ b/src/utils/loop_xform/transform/unroll_runtime/mod.rs @@ -0,0 +1,38 @@ +use super::super::LoopUnrollConf; +use quote::quote; + +#[allow(clippy::single_call_fn)] +pub fn transform(ast: &LoopUnrollConf) -> proc_macro2::TokenStream { + let label = &ast.for_loop.label; + let pat = &ast.for_loop.pat; + let expr = &ast.for_loop.expr; + let body_stmts = &ast.for_loop.body.stmts; + let remainder = &ast.rest_of_tokenstream; + + let iter = syn::Ident::new_raw("iter", proc_macro2::Span::mixed_site()); + + let new_body = core::iter::repeat_with(|| { + quote! { + if let Some(#pat) = #iter.next() { + #(#body_stmts)* + } else { + break; + } + } + }) + .take(ast.unroll_factor); + + quote! { + { + let mut #iter = #expr; + #label while true { + #(#new_body)* + } + } + #remainder + } +} + +#[cfg(test)] +#[allow(clippy::large_stack_frames)] +mod tests; diff --git a/src/utils/loop_xform/transform/unroll_runtime/tests.rs b/src/utils/loop_xform/transform/unroll_runtime/tests.rs new file mode 100644 index 0000000..cf1ac27 --- /dev/null +++ b/src/utils/loop_xform/transform/unroll_runtime/tests.rs @@ -0,0 +1,64 @@ +use super::super::UnrollMethod; +use super::LoopUnrollConf; +use quote::ToTokens as _; +use quote::quote; +use syn::ExprForLoop; + +#[test] +fn unroll1_test() { + let conf = LoopUnrollConf { + unroll_method: UnrollMethod::Runtime, + for_loop: syn::parse2::( + quote! { 'loop_label: for elt in iter { body } }, + ) + .unwrap(), + unroll_factor: 1, + rest_of_tokenstream: quote! { rest }, + }; + + let res = super::transform(&conf); + assert_eq!( + res.to_string(), + quote! { + { + let mut r#iter = iter; + 'loop_label : while true { + if let Some(elt) = r#iter.next() { body } else { break; } + } + } + rest + } + .to_token_stream() + .to_string() + ); +} + +#[test] +fn unroll2_test() { + let conf = LoopUnrollConf { + unroll_method: UnrollMethod::Runtime, + for_loop: syn::parse2::( + quote! { 'loop_label: for elt in iter { body } }, + ) + .unwrap(), + unroll_factor: 2, + rest_of_tokenstream: quote! { rest }, + }; + + let res = super::transform(&conf); + assert_eq!( + res.to_string(), + quote! { + { + let mut r#iter = iter; + 'loop_label : while true { + if let Some(elt) = r#iter.next() { body } else { break; } + if let Some(elt) = r#iter.next() { body } else { break; } + } + } + rest + } + .to_token_stream() + .to_string() + ); +} diff --git a/src/utils/loop_xform/transform/unroll_with_remainder/mod.rs b/src/utils/loop_xform/transform/unroll_with_remainder/mod.rs new file mode 100644 index 0000000..eb4739a --- /dev/null +++ b/src/utils/loop_xform/transform/unroll_with_remainder/mod.rs @@ -0,0 +1,129 @@ +use super::super::LoopUnrollConf; +use quote::{ToTokens as _, quote}; +use syn::{Label, Lifetime}; + +#[allow(clippy::single_call_fn)] +pub fn transform(ast: LoopUnrollConf) -> proc_macro2::TokenStream { + let mut for_loop = ast.for_loop; + + let label_outer = match for_loop.label { + Some(l) => l.name.clone(), + None => Lifetime::new("'loop_label", proc_macro2::Span::mixed_site()), + }; + + for_loop.label = Some(Label { + name: label_outer.clone(), + colon_token: syn::token::Colon { + spans: [proc_macro2::Span::mixed_site()], + }, + }); + + super::utils::loop_break_labeller::LabelUnlabelledBreaks::visit( + &mut for_loop, + ); + + let label_inner = + Lifetime::new("'label_inner", proc_macro2::Span::mixed_site()); + + let iter = syn::Ident::new_raw("iter", proc_macro2::Span::mixed_site()); + + let mut iter_evals = vec![]; + for i in 0..ast.unroll_factor { + iter_evals.push(syn::Ident::new_raw( + &format!("iter_{}_of_{}", i + 1, ast.unroll_factor).to_owned(), + proc_macro2::Span::mixed_site(), + )); + } + + let p = Pieces { + label_outer, + pat: for_loop.pat, + expr: for_loop.expr, + body_stmts: for_loop.body.stmts, + remainder: ast.rest_of_tokenstream, + label_inner, + iter, + iter_evals, + }; + builder(&p) +} + +struct Pieces { + label_outer: Lifetime, + pat: Box, + expr: Box, + body_stmts: Vec, + remainder: proc_macro2::TokenStream, + label_inner: Lifetime, + iter: syn::Ident, + iter_evals: Vec, +} + +fn builder(s: &Pieces) -> proc_macro2::TokenStream { + let label_outer = &s.label_outer; + let pat = &s.pat; + let expr = &s.expr; + let body_stmts = &s.body_stmts; + let remainder = &s.remainder; + let label_inner = &s.label_inner; + let iter = &s.iter; + + let mut prologue = quote! {}; + for curr_pat in s.iter_evals.iter().rev() { + quote! { + let mut #curr_pat = None; + } + .to_tokens(&mut prologue); + } + + let mut new_body = quote! {}; + for curr_pat in &s.iter_evals { + quote! { + { + let #pat = #curr_pat.take().unwrap(); + #(#body_stmts)* + } + } + .to_tokens(&mut new_body); + } + + for curr_pat in s.iter_evals.iter().rev() { + new_body = quote! { + #curr_pat = #iter.next(); + if #curr_pat.is_some() { + #new_body + } else { + break #label_inner; + } + }; + } + + let mut epilogue = quote! {}; + for curr_pat in s.iter_evals.iter().rev().skip(1).rev() { + quote! { + if let Some(#pat) = #curr_pat.take() { + #(#body_stmts)* + } else { + break #label_outer; + } + } + .to_tokens(&mut epilogue); + } + + quote! { + #label_outer: while true { + let mut #iter = #expr; + #prologue + #label_inner: while true { + #new_body + } + #epilogue + break #label_outer; + } + #remainder + } +} + +#[cfg(test)] +#[allow(clippy::large_stack_frames)] +mod tests; diff --git a/src/utils/loop_xform/transform/unroll_with_remainder/tests.rs b/src/utils/loop_xform/transform/unroll_with_remainder/tests.rs new file mode 100644 index 0000000..d82906e --- /dev/null +++ b/src/utils/loop_xform/transform/unroll_with_remainder/tests.rs @@ -0,0 +1,245 @@ +use super::LoopUnrollConf; +use crate::UnrollMethod; +use quote::ToTokens as _; +use quote::quote; +use syn::ExprForLoop; + +#[test] +fn unroll1_test() { + for src in [ + quote! { for elt in iter { body; break; } }, + quote! { 'loop_label: for elt in iter { body; break; } }, + quote! { 'loop_label: for elt in iter { body; break 'loop_label; } }, + ] { + let conf = LoopUnrollConf { + unroll_method: UnrollMethod::WithRemainder, + for_loop: syn::parse2::(src).unwrap(), + unroll_factor: 1, + rest_of_tokenstream: quote! { rest }, + }; + + let res = super::transform(conf); + assert_eq!( + res.to_string(), + quote! { + 'loop_label: while true { + let mut r#iter = iter; + let mut r#iter_1_of_1 = None; + 'label_inner: while true { + r#iter_1_of_1 = r#iter.next(); + if r#iter_1_of_1.is_some() { + { + let elt = r#iter_1_of_1.take().unwrap(); + body; + break 'loop_label; + } + } else { + break 'label_inner; + } + } + break 'loop_label; + } + rest + } + .to_token_stream() + .to_string() + ); + } +} + +#[test] +fn unroll2_test() { + for src in [ + quote! { for elt in iter { body; break; } }, + quote! { 'loop_label: for elt in iter { body; break; } }, + quote! { 'loop_label: for elt in iter { body; break 'loop_label; } }, + ] { + let conf = LoopUnrollConf { + unroll_method: UnrollMethod::WithRemainder, + for_loop: syn::parse2::(src).unwrap(), + unroll_factor: 2, + rest_of_tokenstream: quote! { rest }, + }; + + let res = super::transform(conf); + assert_eq!( + res.to_string(), + quote! { + 'loop_label: while true { + let mut r#iter = iter; + let mut r#iter_2_of_2 = None; + let mut r#iter_1_of_2 = None; + 'label_inner: while true { + r#iter_1_of_2 = r#iter.next(); + if r#iter_1_of_2.is_some() { + r#iter_2_of_2 = r#iter.next(); + if r#iter_2_of_2.is_some() { + { + let elt = r#iter_1_of_2.take().unwrap(); + body; + break 'loop_label; + } + { + let elt = r#iter_2_of_2.take().unwrap(); + body; + break 'loop_label; + } + } else { + break 'label_inner; + } + } else { + break 'label_inner; + } + } + if let Some(elt) = r#iter_1_of_2.take() { + body; + break 'loop_label; + } else { + break 'loop_label; + } + break 'loop_label; + } + rest + } + .to_token_stream() + .to_string() + ); + } +} + +#[test] +fn unroll3_test() { + for src in [ + quote! { for elt in iter { body; break; } }, + quote! { 'loop_label: for elt in iter { body; break; } }, + quote! { 'loop_label: for elt in iter { body; break 'loop_label; } }, + ] { + let conf = LoopUnrollConf { + unroll_method: UnrollMethod::WithRemainder, + for_loop: syn::parse2::(src).unwrap(), + unroll_factor: 3, + rest_of_tokenstream: quote! { rest }, + }; + + let res = super::transform(conf); + assert_eq!( + res.to_string(), + quote! { + 'loop_label: while true { + let mut r#iter = iter; + let mut r#iter_3_of_3 = None; + let mut r#iter_2_of_3 = None; + let mut r#iter_1_of_3 = None; + 'label_inner: while true { + r#iter_1_of_3 = r#iter.next(); + if r#iter_1_of_3.is_some() { + r#iter_2_of_3 = r#iter.next(); + if r#iter_2_of_3.is_some() { + r#iter_3_of_3 = r#iter.next(); + if r#iter_3_of_3.is_some() { + { + let elt = r#iter_1_of_3.take().unwrap(); + body; + break 'loop_label; + } + { + let elt = r#iter_2_of_3.take().unwrap(); + body; + break 'loop_label; + } + { + let elt = r#iter_3_of_3.take().unwrap(); + body; + break 'loop_label; + } + } else { + break 'label_inner; + } + } else { + break 'label_inner; + } + } else { + break 'label_inner; + } + } + if let Some(elt) = r#iter_1_of_3.take() { + body; + break 'loop_label; + } else { + break 'loop_label; + } + if let Some(elt) = r#iter_2_of_3.take() { + body; + break 'loop_label; + } else { + break 'loop_label; + } + break 'loop_label; + } + rest + } + .to_token_stream() + .to_string() + ); + } +} + +#[test] +fn unroll1_with_nested_loop_test() { + let body = quote! { + for elt in other_iter { body; break; }; + while other_iter { body; break; }; + loop { body; break; }; + }; + for src in [ + quote! { + for elt in iter { + #body + break; + } }, + quote! { + 'loop_label: for elt in iter { + #body + break; + } }, + quote! { + 'loop_label: for elt in iter { + #body + break 'loop_label; + } }, + ] { + let conf = LoopUnrollConf { + unroll_method: UnrollMethod::WithRemainder, + for_loop: syn::parse2::(src).unwrap(), + unroll_factor: 1, + rest_of_tokenstream: quote! { rest }, + }; + + let res = super::transform(conf); + assert_eq!( + res.to_string(), + quote! { + 'loop_label: while true { + let mut r#iter = iter; + let mut r#iter_1_of_1 = None; + 'label_inner: while true { + r#iter_1_of_1 = r#iter.next(); + if r#iter_1_of_1.is_some() { + { + let elt = r#iter_1_of_1.take().unwrap(); + #body + break 'loop_label; + } + } else { + break 'label_inner; + } + } + break 'loop_label; + } + rest + } + .to_token_stream() + .to_string() + ); + } +} diff --git a/src/utils/loop_xform/transform/utils/loop_break_labeller/mod.rs b/src/utils/loop_xform/transform/utils/loop_break_labeller/mod.rs new file mode 100644 index 0000000..b3cbfd9 --- /dev/null +++ b/src/utils/loop_xform/transform/utils/loop_break_labeller/mod.rs @@ -0,0 +1,34 @@ +use syn::visit_mut::VisitMut; + +pub struct LabelUnlabelledBreaks { + loop_label: syn::Lifetime, +} + +impl LabelUnlabelledBreaks { + #[allow(clippy::single_call_fn)] + pub fn visit(i: &mut syn::ExprForLoop) { + let mut this = Self { + loop_label: i.label.as_ref().unwrap().name.clone(), + }; + for stmt in &mut i.body.stmts { + this.visit_stmt_mut(stmt); + } + } +} + +#[allow(clippy::missing_trait_methods)] +impl VisitMut for LabelUnlabelledBreaks { + fn visit_expr_loop_mut(&mut self, _i: &mut syn::ExprLoop) {} + fn visit_expr_while_mut(&mut self, _i: &mut syn::ExprWhile) {} + fn visit_expr_for_loop_mut(&mut self, _i: &mut syn::ExprForLoop) {} + + fn visit_expr_break_mut(&mut self, i: &mut syn::ExprBreak) { + if i.label.is_none() { + i.label = Some(self.loop_label.clone()); + } + } +} + +#[cfg(test)] +#[allow(clippy::large_stack_frames)] +mod tests; diff --git a/src/utils/loop_xform/transform/utils/loop_break_labeller/tests.rs b/src/utils/loop_xform/transform/utils/loop_break_labeller/tests.rs new file mode 100644 index 0000000..177224a --- /dev/null +++ b/src/utils/loop_xform/transform/utils/loop_break_labeller/tests.rs @@ -0,0 +1,69 @@ +use quote::ToTokens as _; +use quote::quote; + +#[test] +#[should_panic(expected = "called `Option::unwrap()` on a `None` value")] +fn t0_test() { + let tokens = quote! { + for i in e {} + }; + let mut for_loop = syn::parse2::(tokens).unwrap(); + super::LabelUnlabelledBreaks::visit(&mut for_loop); +} + +#[test] +fn test() { + let body_verbatim = quote! { + break 'loop_label; + for a in b { + break; + break 'loop_label; + break 'other_loop_label; + } + 'other_loop_label: for a in b { + break; + break 'loop_label; + break 'other_loop_label; + } + while c { + break; + break 'loop_label; + break 'other_loop_label; + } + 'other_loop_label: while c { + break; + break 'loop_label; + break 'other_loop_label; + } + loop { + break; + break 'loop_label; + break 'other_loop_label; + } + 'other_loop_label: loop { + break; + break 'loop_label; + break 'other_loop_label; + } + }; + let tokens = quote! { + 'loop_label: for i in e { + break; + #body_verbatim + } + }; + let mut for_loop = syn::parse2::(tokens).unwrap(); + super::LabelUnlabelledBreaks::visit(&mut for_loop); + + assert_eq!( + quote! { #for_loop }.to_string(), + quote! { + 'loop_label: for i in e { + break 'loop_label; + #body_verbatim + } + } + .to_token_stream() + .to_string() + ); +} diff --git a/tests/utils/loop_xform/Cargo.toml b/tests/utils/loop_xform/Cargo.toml new file mode 100644 index 0000000..0c43379 --- /dev/null +++ b/tests/utils/loop_xform/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "loop_xform-test" +version.workspace = true +authors.workspace = true +edition.workspace = true +rust-version.workspace = true +documentation.workspace = true +homepage.workspace = true +repository.workspace = true +license.workspace = true + +[lints] +workspace = true + +[dependencies] +rawspeed-utils-loop_xform = { path = "../../../src/utils/loop_xform" } + +[[test]] +name = "loop_xform-test" +path = "mod.rs" diff --git a/tests/utils/loop_xform/mod.rs b/tests/utils/loop_xform/mod.rs new file mode 100644 index 0000000..6c8ea81 --- /dev/null +++ b/tests/utils/loop_xform/mod.rs @@ -0,0 +1,128 @@ +use core::cell::RefCell; +use std::rc::Rc; + +struct LoggingIter { + log: Rc>>, + pos: usize, + end: usize, +} + +impl LoggingIter { + fn new( + log: Rc>>, + core::ops::Range { start, end }: core::ops::Range, + ) -> Self { + log.borrow_mut().push(format!("Iter created, at {start}")); + + Self { + log, + pos: start, + end, + } + } +} + +#[allow(clippy::missing_trait_methods)] +impl Iterator for LoggingIter { + type Item = IterVal; + + fn next(&mut self) -> Option { + self.log + .borrow_mut() + .push(format!("Iter next() called at pos = {}", self.pos)); + + if self.pos >= self.end { + self.log.borrow_mut().push(format!( + "Iter next() called at pos = {}, returning None", + self.pos + )); + return None; + } + + let current = self.pos; + let next = self.pos + 1; + self.log.borrow_mut().push(format!( + "Iter next() called at pos = {}, returning {}, next is {}", + self.pos, current, next + )); + self.pos = next; + Some(IterVal::new(Rc::clone(&self.log), current)) + } +} + +impl Drop for LoggingIter { + fn drop(&mut self) { + self.log + .borrow_mut() + .push(format!("Iter dropped, was at {}", self.pos)); + } +} + +struct IterVal { + log: Rc>>, + val: usize, +} + +impl IterVal { + fn new(log: Rc>>, val: usize) -> Self { + log.borrow_mut().push(format!("IterVal({val}) created")); + Self { log, val } + } +} + +impl core::ops::Deref for IterVal { + type Target = usize; + + fn deref(&self) -> &Self::Target { + self.log + .borrow_mut() + .push(format!("IterVal({}) deref", self.val)); + &self.val + } +} + +impl Drop for IterVal { + fn drop(&mut self) { + self.log + .borrow_mut() + .push(format!("IterVal({}) dropped", self.val)); + } +} + +fn gen_native_output(r: core::ops::Range) -> Vec { + let mut vec = vec![]; + vec.push("Before macro".to_owned()); + vec.push(format!("Iter created, at {}", r.start).to_owned()); + for i in r.clone() { + vec.push(format!("Iter next() called at pos = {i}")); + vec.push(format!( + "Iter next() called at pos = {i}, returning {i}, next is {}", + i + 1 + )); + vec.push(format!("IterVal({i}) created")); + vec.push(format!("IterVal({i}) deref")); + vec.push(format!("Loop body at i = {i}")); + vec.push(format!("IterVal({i}) dropped")); + } + vec.push(format!("Iter next() called at pos = {}", r.end)); + vec.push(format!( + "Iter next() called at pos = {}, returning None", + r.end, + )); + vec.push(format!("Iter dropped, was at {}", r.end).to_owned()); + vec.push("After loop".to_owned()); + vec.push("After macro".to_owned()); + vec +} + +#[cfg(test)] +#[allow(clippy::large_stack_frames)] +mod naive; + +#[cfg(test)] +#[allow(clippy::large_stack_frames)] +mod unroll_runtime; + +#[cfg(test)] +#[allow(clippy::large_stack_frames)] +mod unroll_with_remainder; diff --git a/tests/utils/loop_xform/naive.rs b/tests/utils/loop_xform/naive.rs new file mode 100644 index 0000000..371e933 --- /dev/null +++ b/tests/utils/loop_xform/naive.rs @@ -0,0 +1,96 @@ +use crate::LoggingIter; +use crate::gen_native_output; +use core::cell::RefCell; +use std::rc::Rc; + +macro_rules! gen_native_test { + ($name:ident, $len:expr) => { + #[test] + fn $name() { + let log = Rc::new(RefCell::new(Vec::::new())); + + log.borrow_mut().push("Before macro".to_owned()); + for i in LoggingIter::new(Rc::clone(&log), 0..$len) { + let i = *i; + log.borrow_mut().push(format!("Loop body at i = {i}")); + } + log.borrow_mut().push("After loop".to_owned()); + log.borrow_mut().push("After macro".to_owned()); + + assert_eq!(log.borrow()[..], gen_native_output(0..$len)); + } + }; +} + +gen_native_test!(baseline_len0, 0); +gen_native_test!(baseline_len1, 1); +gen_native_test!(baseline_len2, 2); +gen_native_test!(baseline_len3, 3); +gen_native_test!(baseline_len4, 4); +gen_native_test!(baseline_len5, 5); + +#[test] +fn break0_test() { + let log = Rc::new(RefCell::new(Vec::::new())); + + log.borrow_mut().push("Before macro".to_owned()); + for i in LoggingIter::new(Rc::clone(&log), 0..16) { + let i = *i; + log.borrow_mut().push(format!("Loop body at i = {i}")); + if i == 0 { + break; + } + } + log.borrow_mut().push("After loop".to_owned()); + log.borrow_mut().push("After macro".to_owned()); + + assert_eq!( + log.borrow()[..], + [ + "Before macro", + "Iter created, at 0", + "Iter next() called at pos = 0", + "Iter next() called at pos = 0, returning 0, next is 1", + "IterVal(0) created", + "IterVal(0) deref", + "Loop body at i = 0", + "IterVal(0) dropped", + "Iter dropped, was at 1", + "After loop", + "After macro" + ] + ); +} + +#[test] +fn break_label0_test() { + let log = Rc::new(RefCell::new(Vec::::new())); + + log.borrow_mut().push("Before macro".to_owned()); + 'my_loop: for i in LoggingIter::new(Rc::clone(&log), 0..16) { + let i = *i; + log.borrow_mut().push(format!("Loop body at i = {i}")); + if i == 0 { + break 'my_loop; + } + } + log.borrow_mut().push("After loop".to_owned()); + log.borrow_mut().push("After macro".to_owned()); + + assert_eq!( + log.borrow()[..], + [ + "Before macro", + "Iter created, at 0", + "Iter next() called at pos = 0", + "Iter next() called at pos = 0, returning 0, next is 1", + "IterVal(0) created", + "IterVal(0) deref", + "Loop body at i = 0", + "IterVal(0) dropped", + "Iter dropped, was at 1", + "After loop", + "After macro" + ] + ); +} diff --git a/tests/utils/loop_xform/unroll_runtime.rs b/tests/utils/loop_xform/unroll_runtime.rs new file mode 100644 index 0000000..ea991dc --- /dev/null +++ b/tests/utils/loop_xform/unroll_runtime.rs @@ -0,0 +1,120 @@ +use crate::LoggingIter; +use crate::gen_native_output; +use core::cell::RefCell; +use rawspeed_utils_loop_xform::enable_loop_xforms; +use std::rc::Rc; + +macro_rules! gen_test { + ($name:ident, $uf:expr, $len:expr) => { + #[test] + fn $name() { + let log = Rc::new(RefCell::new(Vec::::new())); + + log.borrow_mut().push("Before macro".to_owned()); + enable_loop_xforms!( + #[loop_unroll(method(runtime), factor($uf))] + for i in LoggingIter::new(Rc::clone(&log), 0..$len) { + let i = *i; + log.borrow_mut().push(format!("Loop body at i = {i}")); + } + ); + log.borrow_mut().push("After loop".to_owned()); + log.borrow_mut().push("After macro".to_owned()); + + assert_eq!(log.borrow()[..], gen_native_output(0..$len)); + } + }; +} + +gen_test!(unroll1_len0, 1, 0); +gen_test!(unroll1_len1, 1, 1); +gen_test!(unroll1_len2, 1, 2); +gen_test!(unroll1_len3, 1, 3); +gen_test!(unroll1_len4, 1, 4); +gen_test!(unroll1_len5, 1, 5); + +gen_test!(unroll2_len0, 2, 0); +gen_test!(unroll2_len1, 2, 1); +gen_test!(unroll2_len2, 2, 2); +gen_test!(unroll2_len3, 2, 3); +gen_test!(unroll2_len4, 2, 4); +gen_test!(unroll2_len5, 2, 5); + +gen_test!(unroll3_len0, 3, 0); +gen_test!(unroll3_len1, 3, 1); +gen_test!(unroll3_len2, 3, 2); +gen_test!(unroll3_len3, 3, 3); +gen_test!(unroll3_len4, 3, 4); +gen_test!(unroll3_len5, 3, 5); + +#[test] +fn break0_test() { + let log = Rc::new(RefCell::new(Vec::::new())); + + log.borrow_mut().push("Before macro".to_owned()); + enable_loop_xforms!( + #[loop_unroll(method(runtime), factor(16))] + for i in LoggingIter::new(Rc::clone(&log), 0..16) { + let i = *i; + log.borrow_mut().push(format!("Loop body at i = {i}")); + if i == 0 { + break; + } + } + ); + log.borrow_mut().push("After loop".to_owned()); + log.borrow_mut().push("After macro".to_owned()); + + assert_eq!( + log.borrow()[..], + [ + "Before macro", + "Iter created, at 0", + "Iter next() called at pos = 0", + "Iter next() called at pos = 0, returning 0, next is 1", + "IterVal(0) created", + "IterVal(0) deref", + "Loop body at i = 0", + "IterVal(0) dropped", + "Iter dropped, was at 1", + "After loop", + "After macro" + ] + ); +} + +#[test] +fn break_label0_test() { + let log = Rc::new(RefCell::new(Vec::::new())); + + log.borrow_mut().push("Before macro".to_owned()); + enable_loop_xforms!( + #[loop_unroll(method(runtime), factor(16))] + 'my_loop: for i in LoggingIter::new(Rc::clone(&log), 0..16) { + let i = *i; + log.borrow_mut().push(format!("Loop body at i = {i}")); + if i == 0 { + break 'my_loop; + } + } + ); + log.borrow_mut().push("After loop".to_owned()); + log.borrow_mut().push("After macro".to_owned()); + + assert_eq!( + log.borrow()[..], + [ + "Before macro", + "Iter created, at 0", + "Iter next() called at pos = 0", + "Iter next() called at pos = 0, returning 0, next is 1", + "IterVal(0) created", + "IterVal(0) deref", + "Loop body at i = 0", + "IterVal(0) dropped", + "Iter dropped, was at 1", + "After loop", + "After macro" + ] + ); +} diff --git a/tests/utils/loop_xform/unroll_with_remainder.rs b/tests/utils/loop_xform/unroll_with_remainder.rs new file mode 100644 index 0000000..b1f776d --- /dev/null +++ b/tests/utils/loop_xform/unroll_with_remainder.rs @@ -0,0 +1,258 @@ +use crate::LoggingIter; +use core::cell::RefCell; +use rawspeed_utils_loop_xform::enable_loop_xforms; +use std::rc::Rc; + +fn gen_unroll_output(uf: usize, r: core::ops::Range) -> Vec { + let mut vec = vec![]; + vec.push("Before macro".to_owned()); + vec.push(format!("Iter created, at {}", r.start).to_owned()); + let iterspace = r.clone().collect::>(); + let mut chunks = iterspace[..].chunks_exact(uf); + for chunk in chunks.by_ref() { + for i in chunk { + vec.push(format!("Iter next() called at pos = {i}")); + vec.push(format!( + "Iter next() called at pos = {i}, returning {i}, next is {}", + i + 1 + )); + vec.push(format!("IterVal({i}) created")); + } + for i in chunk { + vec.push(format!("IterVal({i}) deref")); + vec.push(format!("Loop body at i = {i}")); + vec.push(format!("IterVal({i}) dropped")); + } + } + for i in chunks.remainder() { + vec.push(format!("Iter next() called at pos = {i}")); + vec.push(format!( + "Iter next() called at pos = {i}, returning {i}, next is {}", + i + 1 + )); + vec.push(format!("IterVal({i}) created")); + } + vec.push(format!("Iter next() called at pos = {}", r.end)); + vec.push(format!( + "Iter next() called at pos = {}, returning None", + r.end, + )); + for i in chunks.remainder() { + vec.push(format!("IterVal({i}) deref")); + vec.push(format!("Loop body at i = {i}")); + vec.push(format!("IterVal({i}) dropped")); + } + vec.push(format!("Iter dropped, was at {}", r.end,)); + vec.push("After loop".to_owned()); + vec.push("After macro".to_owned()); + vec +} + +macro_rules! gen_test { + ($name:ident, $uf:expr, $len:expr) => { + #[test] + fn $name() { + let log = Rc::new(RefCell::new(Vec::::new())); + + log.borrow_mut().push("Before macro".to_owned()); + enable_loop_xforms!( + #[loop_unroll(method(with_remainder), factor($uf))] + for i in LoggingIter::new(Rc::clone(&log), 0..$len) { + let i = *i; + log.borrow_mut().push(format!("Loop body at i = {i}")); + } + ); + log.borrow_mut().push("After loop".to_owned()); + log.borrow_mut().push("After macro".to_owned()); + + assert_eq!(log.borrow()[..], gen_unroll_output($uf, 0..$len)); + } + }; +} + +gen_test!(unroll1_len0, 1, 0); +gen_test!(unroll1_len1, 1, 1); +gen_test!(unroll1_len2, 1, 2); +gen_test!(unroll1_len3, 1, 3); +gen_test!(unroll1_len4, 1, 4); +gen_test!(unroll1_len5, 1, 5); + +gen_test!(unroll2_len0, 2, 0); +gen_test!(unroll2_len1, 2, 1); +gen_test!(unroll2_len2, 2, 2); +gen_test!(unroll2_len3, 2, 3); +gen_test!(unroll2_len4, 2, 4); +gen_test!(unroll2_len5, 2, 5); + +gen_test!(unroll3_len0, 3, 0); +gen_test!(unroll3_len1, 3, 1); +gen_test!(unroll3_len2, 3, 2); +gen_test!(unroll3_len3, 3, 3); +gen_test!(unroll3_len4, 3, 4); +gen_test!(unroll3_len5, 3, 5); + +#[test] +fn break0_unroll2_len2_test() { + let log = Rc::new(RefCell::new(Vec::::new())); + + log.borrow_mut().push("Before macro".to_owned()); + enable_loop_xforms!( + #[loop_unroll(method(with_remainder), factor(2))] + for i in LoggingIter::new(Rc::clone(&log), 0..2) { + let i = *i; + log.borrow_mut().push(format!("Loop body at i = {i}")); + if i == 0 { + break; + } + } + ); + log.borrow_mut().push("After loop".to_owned()); + log.borrow_mut().push("After macro".to_owned()); + + assert_eq!( + log.borrow()[..], + [ + "Before macro", + "Iter created, at 0", + "Iter next() called at pos = 0", + "Iter next() called at pos = 0, returning 0, next is 1", + "IterVal(0) created", + "Iter next() called at pos = 1", + "Iter next() called at pos = 1, returning 1, next is 2", + "IterVal(1) created", + "IterVal(0) deref", + "Loop body at i = 0", + "IterVal(0) dropped", + "IterVal(1) dropped", + "Iter dropped, was at 2", + "After loop", + "After macro" + ] + ); +} + +#[test] +fn break0_unroll3_len2_test() { + let log = Rc::new(RefCell::new(Vec::::new())); + + log.borrow_mut().push("Before macro".to_owned()); + enable_loop_xforms!( + #[loop_unroll(method(with_remainder), factor(3))] + for i in LoggingIter::new(Rc::clone(&log), 0..2) { + let i = *i; + log.borrow_mut().push(format!("Loop body at i = {i}")); + if i == 0 { + break; + } + } + ); + log.borrow_mut().push("After loop".to_owned()); + log.borrow_mut().push("After macro".to_owned()); + + assert_eq!( + log.borrow()[..], + [ + "Before macro", + "Iter created, at 0", + "Iter next() called at pos = 0", + "Iter next() called at pos = 0, returning 0, next is 1", + "IterVal(0) created", + "Iter next() called at pos = 1", + "Iter next() called at pos = 1, returning 1, next is 2", + "IterVal(1) created", + "Iter next() called at pos = 2", + "Iter next() called at pos = 2, returning None", + "IterVal(0) deref", + "Loop body at i = 0", + "IterVal(0) dropped", + "IterVal(1) dropped", + "Iter dropped, was at 2", + "After loop", + "After macro" + ] + ); +} + +#[test] +fn break0_unroll2_len3_test() { + let log = Rc::new(RefCell::new(Vec::::new())); + + log.borrow_mut().push("Before macro".to_owned()); + enable_loop_xforms!( + #[loop_unroll(method(with_remainder), factor(2))] + for i in LoggingIter::new(Rc::clone(&log), 0..3) { + let i = *i; + log.borrow_mut().push(format!("Loop body at i = {i}")); + if i == 0 { + break; + } + } + ); + log.borrow_mut().push("After loop".to_owned()); + log.borrow_mut().push("After macro".to_owned()); + + assert_eq!( + log.borrow()[..], + [ + "Before macro", + "Iter created, at 0", + "Iter next() called at pos = 0", + "Iter next() called at pos = 0, returning 0, next is 1", + "IterVal(0) created", + "Iter next() called at pos = 1", + "Iter next() called at pos = 1, returning 1, next is 2", + "IterVal(1) created", + "IterVal(0) deref", + "Loop body at i = 0", + "IterVal(0) dropped", + "IterVal(1) dropped", + "Iter dropped, was at 2", + "After loop", + "After macro" + ] + ); +} + +#[test] +fn break0_unroll3_len3_test() { + let log = Rc::new(RefCell::new(Vec::::new())); + + log.borrow_mut().push("Before macro".to_owned()); + enable_loop_xforms!( + #[loop_unroll(method(with_remainder), factor(3))] + for i in LoggingIter::new(Rc::clone(&log), 0..3) { + let i = *i; + log.borrow_mut().push(format!("Loop body at i = {i}")); + if i == 0 { + break; + } + } + ); + log.borrow_mut().push("After loop".to_owned()); + log.borrow_mut().push("After macro".to_owned()); + + assert_eq!( + log.borrow()[..], + [ + "Before macro", + "Iter created, at 0", + "Iter next() called at pos = 0", + "Iter next() called at pos = 0, returning 0, next is 1", + "IterVal(0) created", + "Iter next() called at pos = 1", + "Iter next() called at pos = 1, returning 1, next is 2", + "IterVal(1) created", + "Iter next() called at pos = 2", + "Iter next() called at pos = 2, returning 2, next is 3", + "IterVal(2) created", + "IterVal(0) deref", + "Loop body at i = 0", + "IterVal(0) dropped", + "IterVal(1) dropped", + "IterVal(2) dropped", + "Iter dropped, was at 3", + "After loop", + "After macro" + ] + ); +}