diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yaml similarity index 59% rename from .github/workflows/ci.yml rename to .github/workflows/ci.yaml index 2875e3e..c1c5c7a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yaml @@ -13,12 +13,9 @@ jobs: name: Build and Test runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - - uses: actions-rs/toolchain@v1 - with: - profile: minimal - toolchain: stable + - uses: dtolnay/rust-toolchain@stable - run: cargo build --release --all-features @@ -28,12 +25,10 @@ jobs: name: Rustfmt runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - - uses: actions-rs/toolchain@v1 + - uses: dtolnay/rust-toolchain@stable with: - profile: minimal - toolchain: stable components: rustfmt - run: cargo fmt -- --check @@ -42,12 +37,10 @@ jobs: name: Clippy runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - - uses: actions-rs/toolchain@v1 + - uses: dtolnay/rust-toolchain@stable with: - profile: minimal - toolchain: stable components: clippy - run: cargo clippy --all-targets -- -D warnings diff --git a/src/lib.rs b/src/lib.rs index 04cb6b6..c6ec5f7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -262,9 +262,8 @@ use strata::Strata; /// /// # Hygiene /// In addition to the relation structs, this macro generates implementations -/// of a private struct named `Crepe` for the runtime. Therefore, it is -/// recommended to place each Datalog program within its own module, to prevent -/// name collisions. +/// of a public struct named `Crepe` for the runtime. It is recommended to place +/// each Datalog program within its own module to prevent name collisions. #[proc_macro] #[proc_macro_error] pub fn crepe(input: TokenStream) -> TokenStream { @@ -299,7 +298,7 @@ struct Context { } impl Context { - fn new(program: Program) -> Self { + pub fn new(program: Program) -> Self { // Read in relations, ensure no duplicates let mut rels_input = HashMap::new(); let mut rels_output = HashMap::new(); @@ -316,12 +315,7 @@ impl Context { abort!(relation.name.span(), "Duplicate relation name: {}", name); } - if let Some(t) = relation.generics.type_params().next() { - abort!(t.span(), "Type parameters are not supported in relations"); - } - if let Some(c) = relation.generics.const_params().next() { - abort!(c.span(), "Const parameters are not supported in relations"); - } + validate_generic_params(&relation); let num_lifetimes = relation.generics.lifetimes().count(); match relation.relation_type() { @@ -399,7 +393,7 @@ impl Context { }; program.rules.iter().for_each(|rule| { check(&rule.goal); - if rels_input.get(&rule.goal.relation.to_string()).is_some() { + if rels_input.contains_key(&rule.goal.relation.to_string()) { abort!( rule.goal.relation.span(), "Relations marked as @input cannot be derived from a rule." @@ -570,8 +564,6 @@ fn make_runtime_decl(context: &Context) -> proc_macro2::TokenStream { .rels_input .values() .map(|relation| { - // because the generics have been validated to only contain lifetimes - // no further checking is done here. let rel_ty = relation_type(relation, LifetimeUsage::Item); let lowercase_name = to_lowercase(&relation.name); quote! { @@ -580,11 +572,12 @@ fn make_runtime_decl(context: &Context) -> proc_macro2::TokenStream { }) .collect(); - let lifetime = lifetime(context.has_input_lifetime); + let generics_decl = generic_params_decl(context); quote! { + /// The Crepe runtime generated from a Datalog program. #[derive(::core::default::Default)] - struct Crepe #lifetime { + pub struct Crepe #generics_decl { #fields } } @@ -594,11 +587,12 @@ fn make_runtime_impl(context: &Context) -> proc_macro2::TokenStream { let builders = make_extend(context); let run = make_run(context); - let lifetime = lifetime(context.has_input_lifetime); + let generics_decl = generic_params_decl(context); + let generics_args = generic_params_args(context); quote! { - impl #lifetime Crepe #lifetime { - fn new() -> Self { + impl #generics_decl Crepe #generics_args { + pub fn new() -> Self { ::core::default::Default::default() } #run @@ -613,21 +607,32 @@ fn make_extend(context: &Context) -> proc_macro2::TokenStream { .values() .map(|relation| { let rel_ty = relation_type(relation, LifetimeUsage::Item); - let lifetime = lifetime(context.has_input_lifetime); + let generics_decl = generic_params_decl(context); + let generics_args = generic_params_args(context); let lower = to_lowercase(&relation.name); + + // For the reference impl, we need to add the lifetime to the existing generics + let ref_impl_generics = { + let mut items = vec![quote! { 'a }]; + for tp in collect_generic_params(context) { + items.push(merge_bounds_with_required(tp)); + } + format_generics(items) + }; + quote! { - impl #lifetime ::core::iter::Extend<#rel_ty> for Crepe #lifetime { - fn extend(&mut self, iter: T) + impl #generics_decl ::core::iter::Extend<#rel_ty> for Crepe #generics_args { + fn extend<__I>(&mut self, iter: __I) where - T: ::core::iter::IntoIterator, + __I: ::core::iter::IntoIterator, { self.#lower.extend(iter); } } - impl<'a> ::core::iter::Extend<&'a #rel_ty> for Crepe #lifetime { - fn extend(&mut self, iter: T) + impl #ref_impl_generics ::core::iter::Extend<&'a #rel_ty> for Crepe #generics_args { + fn extend<__I>(&mut self, iter: __I) where - T: ::core::iter::IntoIterator, + __I: ::core::iter::IntoIterator, { self.extend(iter.into_iter().copied()); } @@ -787,7 +792,7 @@ fn make_run(context: &Context) -> proc_macro2::TokenStream { let output_ty_default = make_output_ty(context, quote! {}); quote! { #[allow(clippy::collapsible_if)] - fn run_with_hasher( + pub fn run_with_hasher( self ) -> #output_ty_hasher { #initialize @@ -795,7 +800,7 @@ fn make_run(context: &Context) -> proc_macro2::TokenStream { #output } - fn run(self) -> #output_ty_default { + pub fn run(self) -> #output_ty_default { self.run_with_hasher::<::std::collections::hash_map::RandomState>() } } @@ -957,6 +962,7 @@ fn make_rule( _ => None, }) .collect(); + if fact_positions.is_empty() { // Will not change, so we only need to evaluate it once let mut datalog_vars: HashSet = HashSet::new(); @@ -1240,13 +1246,148 @@ fn to_lowercase(name: &Ident) -> Ident { Ident::new(&s, name.span()) } -/// Create a tokenstream for a lifetime bound/application if it's needed -fn lifetime(needs_lifetime: bool) -> proc_macro2::TokenStream { - if needs_lifetime { - quote! { <'a> } - } else { +/// Validate generic paraeters on a relation. +fn validate_generic_params(relation: &Relation) { + if let Some(c) = relation.generics.const_params().next() { + abort!( + c.span(), + "Const parameters are not yet supported in relations" + ); + } + + // Where clauses are not yet supported + if let Some(where_clause) = &relation.generics.where_clause { + abort!( + where_clause.where_token.span(), + "Where clauses are not yet supported in relations. \ + Please specify trait bounds directly on the type parameter instead, e.g., `T: Trait`" + ); + } + + // Check for default type parameters (not supported) + for type_param in relation.generics.type_params() { + if type_param.default.is_some() { + abort!( + type_param.ident.span(), + "Default type parameters are not supported in relations. \ + Please remove the default value from type parameter `{}`", + type_param.ident + ); + } + } + + // Check for lifetime bounds (not supported) + for lifetime_param in relation.generics.lifetimes() { + if !lifetime_param.bounds.is_empty() { + abort!( + lifetime_param.lifetime.span(), + "Lifetime bounds are not supported in relations. \ + Please remove bounds from lifetime parameter `{}`", + lifetime_param.lifetime + ); + } + } +} + +/// Collect all unique type parameters from input relations. +fn collect_generic_params(context: &Context) -> Vec<&syn::TypeParam> { + let mut seen = HashSet::new(); + let mut params = Vec::new(); + + for relation in context.rels_input.values() { + for param in relation.generics.type_params() { + if seen.insert(param.ident.to_string()) { + params.push(param); + } + } + } + + params +} + +/// Check if a type parameter has a specific trait bound. +fn has_bound(tp: &syn::TypeParam, bound_name: &str) -> bool { + tp.bounds.iter().any(|b| match b { + syn::TypeParamBound::Trait(trait_bound) => trait_bound + .path + .segments + .last() + .is_some_and(|seg| seg.ident == bound_name), + _ => false, + }) +} + +/// Required trait bounds for all generic types in Datalog relations. +const REQUIRED_BOUNDS: &[&str] = &["Hash", "Eq", "Clone", "Copy", "Default"]; + +/// Get the TokenStream for a required bound. +fn required_bound_token(name: &str) -> proc_macro2::TokenStream { + match name { + "Hash" => quote! { ::core::hash::Hash }, + "Eq" => quote! { ::std::cmp::Eq }, + "Clone" => quote! { ::std::clone::Clone }, + "Copy" => quote! { ::std::marker::Copy }, + "Default" => quote! { ::std::default::Default }, + _ => panic!("Unknown required bound: {}", name), + } +} + +/// Merge user bounds with required bounds, avoiding duplicates. +fn merge_bounds_with_required(tp: &syn::TypeParam) -> proc_macro2::TokenStream { + let ident = &tp.ident; + let user_bounds = &tp.bounds; + + // Collect missing required bounds + let missing_bounds: Vec<_> = REQUIRED_BOUNDS + .iter() + .filter(|&req| !has_bound(tp, req)) + .map(|req| required_bound_token(req)) + .collect(); + + // Combine user bounds + missing required bounds + match (user_bounds.is_empty(), missing_bounds.is_empty()) { + (true, true) => quote! { #ident }, // No bounds at all (shouldn't happen) + (true, false) => quote! { #ident: #(#missing_bounds)+* }, + (false, true) => quote! { #ident: #user_bounds }, + (false, false) => quote! { #ident: #user_bounds + #(#missing_bounds)+* }, + } +} + +/// Helper to format generic parameters with angle brackets. +/// Returns `` or empty if the list is empty. +fn format_generics(items: Vec) -> proc_macro2::TokenStream { + if items.is_empty() { quote! {} + } else { + quote! { <#(#items),*> } + } +} + +/// Create a TokenStream for generic parameters (lifetimes + type params). +fn generic_params_decl(context: &Context) -> proc_macro2::TokenStream { + let mut items = Vec::new(); + if context.has_input_lifetime { + items.push(quote! { 'a }); } + items.extend( + collect_generic_params(context) + .into_iter() + .map(merge_bounds_with_required), + ); + format_generics(items) +} + +/// Create a TokenStream for generic arguments (just the names, no bounds). +fn generic_params_args(context: &Context) -> proc_macro2::TokenStream { + let mut items = Vec::new(); + if context.has_input_lifetime { + items.push(quote! { 'a }); + } + items.extend(collect_generic_params(context).into_iter().map(|tp| { + let ident = &tp.ident; + quote! { #ident } + })); + format_generics(items) } enum LifetimeUsage { @@ -1254,7 +1395,7 @@ enum LifetimeUsage { Local, } -/// Returns the type of a relation, with appropriate lifetimes +/// Returns the type of a relation, with appropriate lifetimes and type parameters. fn relation_type(rel: &Relation, usage: LifetimeUsage) -> proc_macro2::TokenStream { let symbol = match rel.relation_type().unwrap() { RelationType::Input | RelationType::Output => "'a", @@ -1265,10 +1406,22 @@ fn relation_type(rel: &Relation, usage: LifetimeUsage) -> proc_macro2::TokenStre }; let name = &rel.name; - let lifetimes = rel - .generics - .lifetimes() - .map(|l| Lifetime::new(symbol, l.span())) - .collect::>(); - quote! { #name<#(#lifetimes),*> } + + // Build list of generic arguments + let mut items = Vec::new(); + items.extend(rel.generics.lifetimes().map(|l| { + let lifetime = Lifetime::new(symbol, l.span()); + quote! { #lifetime } + })); + items.extend(rel.generics.type_params().map(|tp| { + let ident = &tp.ident; + quote! { #ident } + })); + + // Format with angle brackets if there are any generics + if items.is_empty() { + quote! { #name } + } else { + quote! { #name<#(#items),*> } + } } diff --git a/src/parse.rs b/src/parse.rs index fa58cf2..0c80dd1 100644 --- a/src/parse.rs +++ b/src/parse.rs @@ -49,7 +49,7 @@ pub struct Relation { pub struct_token: Token![struct], pub name: Ident, pub generics: Generics, - pub paren_token: token::Paren, + pub _paren_token: token::Paren, pub fields: Punctuated, pub semi_token: Token![;], } @@ -87,7 +87,7 @@ impl Parse for Relation { struct_token: input.parse()?, name: input.parse()?, generics: input.parse()?, - paren_token: parenthesized!(content in input), + _paren_token: parenthesized!(content in input), fields: content.parse_terminated(Field::parse_unnamed, Token![,])?, semi_token: input.parse()?, }) @@ -97,9 +97,9 @@ impl Parse for Relation { #[derive(Clone)] pub struct Rule { pub goal: Fact, - pub arrow_token: Token![<-], + pub _arrow_token: Token![<-], pub clauses: Punctuated, - pub semi_token: Token![;], + pub _semi_token: Token![;], } impl Parse for Rule { @@ -110,23 +110,23 @@ impl Parse for Rule { // A fact followed by a semicolon is the same as a rule with a single // clause of `(true)` if lookahead.peek(Token![;]) { - let semi_token = input.parse()?; + let _semi_token = input.parse()?; - let arrow_token = parse_quote!(<-); + let _arrow_token = parse_quote!(<-); let clauses = parse_quote!((true)); Ok(Self { goal, - arrow_token, + _arrow_token, clauses, - semi_token, + _semi_token, }) } else { Ok(Self { goal, - arrow_token: input.parse()?, + _arrow_token: input.parse()?, clauses: Punctuated::parse_separated_nonempty(input)?, - semi_token: input.parse()?, + _semi_token: input.parse()?, }) } } @@ -161,7 +161,7 @@ impl Parse for Clause { pub struct Fact { pub negate: Option, pub relation: Ident, - pub paren_token: token::Paren, + pub _paren_token: token::Paren, pub fields: Punctuated, } @@ -172,7 +172,7 @@ impl Parse for Fact { Ok(Self { negate: input.parse()?, relation: input.parse()?, - paren_token: parenthesized!(content in input), + _paren_token: parenthesized!(content in input), fields: content.parse_terminated( |input| { if input.peek(Token![_]) { @@ -193,9 +193,9 @@ impl Parse for Fact { #[derive(Clone)] pub struct For { - pub for_token: Token![for], + pub _for_token: Token![for], pub pat: Pat, - pub in_token: Token![in], + pub _in_token: Token![in], pub expr: Expr, } @@ -203,9 +203,9 @@ impl Parse for For { fn parse(input: ParseStream) -> Result { #[allow(clippy::mixed_read_write_in_expression)] Ok(Self { - for_token: input.parse()?, + _for_token: input.parse()?, pat: Pat::parse_single(input)?, - in_token: input.parse()?, + _in_token: input.parse()?, expr: input.parse()?, }) } diff --git a/src/strata.rs b/src/strata.rs index 2ac07e2..5ff318a 100644 --- a/src/strata.rs +++ b/src/strata.rs @@ -40,7 +40,7 @@ impl Strata { Self { list, index } } - pub fn iter(&self) -> Iter> { + pub fn iter(&self) -> Iter<'_, Vec> { self.list.iter() } diff --git a/tests/test_basic_generic.rs b/tests/test_basic_generic.rs new file mode 100644 index 0000000..8872a27 --- /dev/null +++ b/tests/test_basic_generic.rs @@ -0,0 +1,25 @@ +// Test basic generic type parameter support + +use crepe::crepe; + +crepe! { + @input + struct Input(T); + + @output + struct Output(T); + + Output(x) <- Input(x); +} + +#[test] +fn test_basic_generic() { + let mut runtime = Crepe::new(); + runtime.extend([Input(1), Input(2), Input(3)]); + + let (output,) = runtime.run(); + let mut results: Vec<_> = output.into_iter().map(|Output(x)| x).collect(); + results.sort_unstable(); + + assert_eq!(results, vec![1, 2, 3]); +} diff --git a/tests/test_complex_multiple_bounds.rs b/tests/test_complex_multiple_bounds.rs new file mode 100644 index 0000000..a434c0d --- /dev/null +++ b/tests/test_complex_multiple_bounds.rs @@ -0,0 +1,83 @@ +// Test multiple type parameters each with multiple trait bounds + +use crepe::crepe; +use std::fmt::{Debug, Display}; + +trait KeyTrait { + fn key_value(&self) -> u32; +} + +trait ValueTrait { + fn value_text(&self) -> &'static str; +} + +#[derive(Hash, Eq, PartialEq, Clone, Copy, Default)] +struct MyKey(u32); + +#[derive(Hash, Eq, PartialEq, Clone, Copy, Default)] +struct MyValue(&'static str); + +impl KeyTrait for MyKey { + fn key_value(&self) -> u32 { + self.0 + } +} + +impl Debug for MyKey { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Key({})", self.0) + } +} + +impl Display for MyKey { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "K{}", self.0) + } +} + +impl ValueTrait for MyValue { + fn value_text(&self) -> &'static str { + self.0 + } +} + +impl Debug for MyValue { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Value({})", self.0) + } +} + +impl Display for MyValue { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "V:{}", self.0) + } +} + +crepe! { + @input + struct Mapping(K, V); + + @output + struct Reversed(V, K); + + Reversed(v, k) <- Mapping(k, v); +} + +#[test] +fn test_complex_multiple_bounds() { + let mut runtime = Crepe::new(); + runtime.extend([ + Mapping(MyKey(1), MyValue("one")), + Mapping(MyKey(2), MyValue("two")), + Mapping(MyKey(3), MyValue("three")), + ]); + + let (reversed,) = runtime.run(); + let mut results: Vec<_> = reversed + .into_iter() + .map(|Reversed(v, k)| (v.value_text(), k.key_value())) + .collect(); + results.sort_unstable(); + + assert_eq!(results, vec![("one", 1), ("three", 3), ("two", 2),]); +} diff --git a/tests/test_complex_multiple_generics.rs b/tests/test_complex_multiple_generics.rs new file mode 100644 index 0000000..d3c78ea --- /dev/null +++ b/tests/test_complex_multiple_generics.rs @@ -0,0 +1,82 @@ +// Test combining multiple generics with different relations + +use crepe::crepe; + +trait Label { + fn label(&self) -> &'static str; +} + +#[derive(Hash, Eq, PartialEq, Clone, Copy, Debug, Default)] +struct Node(u32); + +#[derive(Hash, Eq, PartialEq, Clone, Copy, Debug, Default)] +struct Tag(&'static str); + +impl Label for Tag { + fn label(&self) -> &'static str { + self.0 + } +} + +crepe! { + @input + struct Edge(T, T); + + @input + struct Tagged(T, L); + + @output + struct Reachable(T, T); + + @output + struct LabeledPath(T, T, L); + + // Transitive closure + Reachable(x, y) <- Edge(x, y); + Reachable(x, z) <- Edge(x, y), Reachable(y, z); + + // Labeled paths + LabeledPath(x, y, l) <- Edge(x, y), Tagged(x, l); +} + +#[test] +fn test_complex_multiple_generics() { + let mut runtime = Crepe::new(); + + // Add edges + runtime.extend([ + Edge(Node(1), Node(2)), + Edge(Node(2), Node(3)), + Edge(Node(3), Node(4)), + ]); + + // Add tags + runtime.extend([ + Tagged(Node(1), Tag("start")), + Tagged(Node(2), Tag("middle")), + ]); + + let (reachable, labeled) = runtime.run(); + + // Check reachability + let reach_vec: Vec<_> = reachable + .into_iter() + .map(|Reachable(x, y)| (x.0, y.0)) + .collect(); + + assert!(reach_vec.contains(&(1, 2))); + assert!(reach_vec.contains(&(1, 3))); + assert!(reach_vec.contains(&(1, 4))); + assert!(reach_vec.contains(&(2, 3))); + assert!(reach_vec.contains(&(2, 4))); + assert!(reach_vec.contains(&(3, 4))); + + // Check labeled paths + let labeled_vec: Vec<_> = labeled + .into_iter() + .map(|LabeledPath(x, y, l)| (x.0, y.0, l.label())) + .collect(); + + assert!(labeled_vec.contains(&(1, 2, "start"))); + assert!(labeled_vec.contains(&(2, 3, "middle"))); +} diff --git a/tests/test_complex_trait_methods.rs b/tests/test_complex_trait_methods.rs new file mode 100644 index 0000000..8ad901a --- /dev/null +++ b/tests/test_complex_trait_methods.rs @@ -0,0 +1,102 @@ +// Test more complex trait method usage + +use crepe::crepe; + +trait Node { + fn id(&self) -> u32; + fn priority(&self) -> u32; +} + +#[derive(Hash, Eq, PartialEq, Clone, Copy, Debug, Default)] +struct Task { + task_id: u32, + prio: u32, +} + +impl Node for Task { + fn id(&self) -> u32 { + self.task_id + } + + fn priority(&self) -> u32 { + self.prio + } +} + +crepe! { + @input + struct Edge(T, T); + + @output + struct Path(T, T, u32); + + @output + struct HighPriorityPath(T, T); + + // Calculate path with combined priority + Path(x, y, p) <- + Edge(x, y), + let p = x.priority() + y.priority(); + + // Transitive paths with priority sum + Path(x, z, p) <- + Edge(x, y), + Path(y, z, p2), + let p = x.priority() + p2; + + // Filter high priority paths + HighPriorityPath(x, y) <- + Path(x, y, p), + (p > 10); +} + +#[test] +fn test_complex_trait_methods() { + let t1 = Task { + task_id: 1, + prio: 3, + }; + let t2 = Task { + task_id: 2, + prio: 5, + }; + let t3 = Task { + task_id: 3, + prio: 7, + }; + + let mut runtime = Crepe::new(); + runtime.extend([Edge(t1, t2), Edge(t2, t3)]); + + let (paths, high_prio) = runtime.run(); + + // Check path priorities + let path_vec: Vec<_> = paths + .into_iter() + .map(|Path(x, y, p)| (x.id(), y.id(), p)) + .collect(); + + // t1 -> t2: priority 3 + 5 = 8 + assert!(path_vec.contains(&(1, 2, 8))); + + // t2 -> t3: priority 5 + 7 = 12 + assert!(path_vec.contains(&(2, 3, 12))); + + // t1 -> t3: priority 3 + 12 = 15 + assert!(path_vec.contains(&(1, 3, 15))); + + // Check high priority paths (> 10) + let high_vec: Vec<_> = high_prio + .into_iter() + .map(|HighPriorityPath(x, y)| (x.id(), y.id())) + .collect(); + + // t2 -> t3 has priority 12 > 10 + assert!(high_vec.contains(&(2, 3))); + + // t1 -> t3 has priority 15 > 10 + assert!(high_vec.contains(&(1, 3))); + + // t1 -> t2 has priority 8 < 10 (should not be in high_prio) + assert!(!high_vec.contains(&(1, 2))); +} diff --git a/tests/test_custom_trait_bounds.rs b/tests/test_custom_trait_bounds.rs new file mode 100644 index 0000000..7ad6a1e --- /dev/null +++ b/tests/test_custom_trait_bounds.rs @@ -0,0 +1,64 @@ +// Test generic with custom trait bounds + +use crepe::crepe; + +// Custom trait +trait Valuable { + fn value(&self) -> i32; +} + +// Type that implements the custom trait and required traits +#[derive(Hash, Eq, PartialEq, Clone, Copy, Debug, Default)] +struct Item(i32); + +impl Valuable for Item { + fn value(&self) -> i32 { + self.0 + } +} + +crepe! { + @input + struct Input(T); + + @output + struct Output(T); + + Output(x) <- Input(x); +} + +#[test] +fn test_custom_trait_bound() { + let mut runtime = Crepe::new(); + runtime.extend([Input(Item(1)), Input(Item(2)), Input(Item(3))]); + + let (output,) = runtime.run(); + let mut results: Vec<_> = output.into_iter().map(|Output(x)| x.value()).collect(); + results.sort_unstable(); + + assert_eq!(results, vec![1, 2, 3]); +} + +#[test] +fn test_custom_trait_with_integers() { + // Create a simple wrapper that implements Valuable + #[derive(Hash, Eq, PartialEq, Clone, Copy, Debug, Default)] + struct Val(i32); + + impl Valuable for Val { + fn value(&self) -> i32 { + self.0 + } + } + + let mut runtime = Crepe::new(); + runtime.extend([Input(Val(10)), Input(Val(20)), Input(Val(30))]); + + let (output,) = runtime.run(); + let results: Vec<_> = output.into_iter().map(|Output(x)| x.value()).collect(); + + assert_eq!(results.len(), 3); + assert!(results.contains(&10)); + assert!(results.contains(&20)); + assert!(results.contains(&30)); +} diff --git a/tests/test_generic_transitive_closure.rs b/tests/test_generic_transitive_closure.rs new file mode 100644 index 0000000..76e496e --- /dev/null +++ b/tests/test_generic_transitive_closure.rs @@ -0,0 +1,60 @@ +// Test generic transitive closure + +use crepe::crepe; + +crepe! { + @input + struct Edge(T, T); + + @output + struct Reachable(T, T); + + Reachable(x, y) <- Edge(x, y); + Reachable(x, z) <- Edge(x, y), Reachable(y, z); +} + +#[test] +fn test_generic_transitive_closure() { + let mut runtime = Crepe::new(); + runtime.extend([Edge(1, 2), Edge(2, 3), Edge(3, 4), Edge(2, 5)]); + + let (reachable,) = runtime.run(); + let mut results: Vec<_> = reachable + .into_iter() + .map(|Reachable(x, y)| (x, y)) + .collect(); + results.sort_unstable(); + + let expected = vec![ + (1, 2), + (1, 3), + (1, 4), + (1, 5), + (2, 3), + (2, 4), + (2, 5), + (3, 4), + ]; + + assert_eq!(results, expected); +} + +#[test] +fn test_generic_with_strings() { + let mut runtime = Crepe::new(); + runtime.extend([Edge("a", "b"), Edge("b", "c"), Edge("c", "d")]); + + let (reachable,) = runtime.run(); + let mut results: Vec<_> = reachable + .into_iter() + .map(|Reachable(x, y)| (x, y)) + .collect(); + results.sort_unstable(); + + assert!(results.contains(&("a", "b"))); + assert!(results.contains(&("a", "c"))); + assert!(results.contains(&("a", "d"))); + assert!(results.contains(&("b", "c"))); + assert!(results.contains(&("b", "d"))); + assert!(results.contains(&("c", "d"))); +} diff --git a/tests/test_graph_with_node_trait.rs b/tests/test_graph_with_node_trait.rs new file mode 100644 index 0000000..bded911 --- /dev/null +++ b/tests/test_graph_with_node_trait.rs @@ -0,0 +1,76 @@ +// Test generic graph with custom node type and trait + +use crepe::crepe; +use std::fmt::Display; + +// Custom trait for graph nodes +trait Node: Display { + fn id(&self) -> u32; +} + +#[derive(Hash, Eq, PartialEq, Clone, Copy, Debug, Default)] +struct City { + id: u32, + name: &'static str, +} + +impl Node for City { + fn id(&self) -> u32 { + self.id + } +} + +impl Display for City { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}({})", self.name, self.id) + } +} + +crepe! { + @input + struct Connection(T, T); + + @output + struct Connected(T, T); + + Connected(x, y) <- Connection(x, y); + Connected(x, z) <- Connection(x, y), Connected(y, z); +} + +#[test] +fn test_graph_with_custom_node_trait() { + let seattle = City { + id: 1, + name: "Seattle", + }; + let portland = City { + id: 2, + name: "Portland", + }; + let sf = City { id: 3, name: "SF" }; + let la = City { id: 4, name: "LA" }; + + let mut runtime = Crepe::new(); + runtime.extend([ + Connection(seattle, portland), + Connection(portland, sf), + Connection(sf, la), + ]); + + let (connected,) = runtime.run(); + let mut results: Vec<_> = connected + .into_iter() + .map(|Connected(x, y)| (x.id(), y.id())) + .collect(); + results.sort_unstable(); + + // Check all connections + assert!(results.contains(&(1, 2))); // Seattle -> Portland + assert!(results.contains(&(1, 3))); // Seattle -> SF + assert!(results.contains(&(1, 4))); // Seattle -> LA + assert!(results.contains(&(2, 3))); // Portland -> SF + assert!(results.contains(&(2, 4))); // Portland -> LA + assert!(results.contains(&(3, 4))); // SF -> LA + + assert_eq!(results.len(), 6); +} diff --git a/tests/test_multiple_generics.rs b/tests/test_multiple_generics.rs new file mode 100644 index 0000000..9b8749b --- /dev/null +++ b/tests/test_multiple_generics.rs @@ -0,0 +1,37 @@ +// Test multiple generic type parameters + +use crepe::crepe; + +crepe! { + @input + struct Pair(K, V); + + @output + struct Swapped(V, K); + + Swapped(v, k) <- Pair(k, v); +} + +#[test] +fn test_multiple_generics() { + let mut runtime = Crepe::new(); + runtime.extend([Pair(1, "one"), Pair(2, "two"), Pair(3, "three")]); + + let (swapped,) = runtime.run(); + let mut results: Vec<_> = swapped.into_iter().map(|Swapped(v, k)| (v, k)).collect(); + results.sort_unstable(); + + assert_eq!(results, vec![("one", 1), ("three", 3), ("two", 2),]); +} + +#[test] +fn test_multiple_generics_same_type() { + let mut runtime = Crepe::new(); + runtime.extend([Pair(1, 10), Pair(2, 20), Pair(3, 30)]); + + let (swapped,) = runtime.run(); + let mut results: Vec<_> = swapped.into_iter().map(|Swapped(v, k)| (v, k)).collect(); + results.sort_unstable(); + + assert_eq!(results, vec![(10, 1), (20, 2), (30, 3),]); +} diff --git a/tests/test_multiple_generics_with_bounds.rs b/tests/test_multiple_generics_with_bounds.rs new file mode 100644 index 0000000..0230c06 --- /dev/null +++ b/tests/test_multiple_generics_with_bounds.rs @@ -0,0 +1,58 @@ +// Test multiple generic parameters with custom trait bounds + +use crepe::crepe; + +trait Key { + fn key_id(&self) -> u32; +} + +trait Value { + fn value_id(&self) -> u32; +} + +#[derive(Hash, Eq, PartialEq, Clone, Copy, Debug, Default)] +struct MyKey(u32); + +#[derive(Hash, Eq, PartialEq, Clone, Copy, Debug, Default)] +struct MyValue(u32); + +impl Key for MyKey { + fn key_id(&self) -> u32 { + self.0 + } +} + +impl Value for MyValue { + fn value_id(&self) -> u32 { + self.0 + } +} + +crepe! { + @input + struct Entry(K, V); + + @output + struct Flipped(V, K); + + Flipped(v, k) <- Entry(k, v); +} + +#[test] +fn test_multiple_generics_with_trait_bounds() { + let mut runtime = Crepe::new(); + runtime.extend([ + Entry(MyKey(1), MyValue(10)), + Entry(MyKey(2), MyValue(20)), + Entry(MyKey(3), MyValue(30)), + ]); + + let (flipped,) = runtime.run(); + let mut results: Vec<_> = flipped + .into_iter() + .map(|Flipped(v, k)| (v.value_id(), k.key_id())) + .collect(); + results.sort_unstable(); + + assert_eq!(results, vec![(10, 1), (20, 2), (30, 3),]); +} diff --git a/tests/test_multiple_trait_bounds.rs b/tests/test_multiple_trait_bounds.rs new file mode 100644 index 0000000..21b425a --- /dev/null +++ b/tests/test_multiple_trait_bounds.rs @@ -0,0 +1,51 @@ +// Test multiple trait bounds on a single type parameter + +use crepe::crepe; +use std::fmt::{Debug, Display}; + +trait Custom { + fn custom_id(&self) -> u32; +} + +#[derive(Hash, Eq, PartialEq, Clone, Copy, Default)] +struct Item(u32); + +impl Custom for Item { + fn custom_id(&self) -> u32 { + self.0 + } +} + +impl Debug for Item { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Item({})", self.0) + } +} + +impl Display for Item { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +crepe! { + @input + struct Input(T); + + @output + struct Output(T); + + Output(x) <- Input(x); +} + +#[test] +fn test_multiple_trait_bounds() { + let mut runtime = Crepe::new(); + runtime.extend([Input(Item(1)), Input(Item(2)), Input(Item(3))]); + + let (output,) = runtime.run(); + let mut results: Vec<_> = output.into_iter().map(|Output(x)| x.custom_id()).collect(); + results.sort_unstable(); + + assert_eq!(results, vec![1, 2, 3]); +} diff --git a/tests/test_multiple_trait_methods.rs b/tests/test_multiple_trait_methods.rs new file mode 100644 index 0000000..298aff3 --- /dev/null +++ b/tests/test_multiple_trait_methods.rs @@ -0,0 +1,96 @@ +// Test multiple trait methods in expressions + +use crepe::crepe; + +trait Measurable { + fn weight(&self) -> u32; + fn volume(&self) -> u32; + fn density(&self) -> u32 { + if self.volume() > 0 { + self.weight() / self.volume() + } else { + 0 + } + } +} + +#[derive(Hash, Eq, PartialEq, Clone, Copy, Debug, Default)] +struct Package { + id: u32, + w: u32, + v: u32, +} + +impl Measurable for Package { + fn weight(&self) -> u32 { + self.w + } + fn volume(&self) -> u32 { + self.v + } +} + +crepe! { + @input + struct Item(T); + + @output + struct ItemStats(T, u32, u32, u32); + + @output + struct Dense(T); + + // Extract stats using multiple trait methods + ItemStats(x, w, v, d) <- + Item(x), + let w = x.weight(), + let v = x.volume(), + let d = x.density(); + + // Filter dense items (density > 5) + Dense(x) <- + Item(x), + let d = x.density(), + (d > 5); +} + +#[test] +fn test_multiple_trait_methods() { + let p1 = Package { + id: 1, + w: 100, + v: 10, + }; // density = 10 + let p2 = Package { + id: 2, + w: 50, + v: 20, + }; // density = 2 + let p3 = Package { + id: 3, + w: 200, + v: 25, + }; // density = 8 + + let mut runtime = Crepe::new(); + runtime.extend([Item(p1), Item(p2), Item(p3)]); + + let (stats, dense) = runtime.run(); + + // Check stats + let stats_vec: Vec<_> = stats + .into_iter() + .map(|ItemStats(pkg, w, v, d)| (pkg.id, w, v, d)) + .collect(); + + assert!(stats_vec.contains(&(1, 100, 10, 10))); + assert!(stats_vec.contains(&(2, 50, 20, 2))); + assert!(stats_vec.contains(&(3, 200, 25, 8))); + + // Check dense items (density > 5) + let dense_vec: Vec<_> = dense.into_iter().map(|Dense(pkg)| pkg.id).collect(); + + assert!(dense_vec.contains(&1)); // density 10 > 5 + assert!(dense_vec.contains(&3)); // density 8 > 5 + assert!(!dense_vec.contains(&2)); // density 2 < 5 +} diff --git a/tests/test_no_duplicate_bounds.rs b/tests/test_no_duplicate_bounds.rs new file mode 100644 index 0000000..50db3a9 --- /dev/null +++ b/tests/test_no_duplicate_bounds.rs @@ -0,0 +1,28 @@ +// Test that duplicate bounds are not added + +use crepe::crepe; +use std::hash::Hash; + +#[derive(Hash, Eq, PartialEq, Clone, Copy, Debug, Default)] +struct Value(i32); + +crepe! { + @input + struct Input(T); + + @output + struct Output(T); + + Output(x) <- Input(x); +} + +#[test] +fn test_no_duplicate_bounds() { + let mut runtime = Crepe::new(); + runtime.extend([Input(Value(1)), Input(Value(2)), Input(Value(3))]); + + let (output,) = runtime.run(); + let results: Vec<_> = output.into_iter().map(|Output(x)| x.0).collect(); + + assert_eq!(results.len(), 3); +} diff --git a/tests/test_parse.rs b/tests/test_parse.rs index 9ad0bb7..656bd37 100644 --- a/tests/test_parse.rs +++ b/tests/test_parse.rs @@ -3,6 +3,7 @@ // Not much is done besides checking that crepe::crepe! is defined, // as well as not self-destructing with a compilation error. +#[allow(dead_code)] mod datalog { use crepe::crepe; diff --git a/tests/test_showcase_all_features.rs b/tests/test_showcase_all_features.rs new file mode 100644 index 0000000..a2ea868 --- /dev/null +++ b/tests/test_showcase_all_features.rs @@ -0,0 +1,161 @@ +// Showcase: All generic features working together + +use crepe::crepe; +use std::fmt::{Debug, Display}; + +// Multiple trait definitions +trait Node: Debug + Display { + fn name(&self) -> &'static str; +} + +trait Weighted { + fn weight(&self) -> u32; +} + +// Concrete type implementing multiple traits +#[derive(Hash, Eq, PartialEq, Clone, Copy, Default)] +struct City { + id: u32, + city_name: &'static str, + population: u32, +} + +impl Node for City { + fn name(&self) -> &'static str { + self.city_name + } +} + +impl Weighted for City { + fn weight(&self) -> u32 { + self.population / 1000 + } +} + +impl Debug for City { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "City{{id={}, name={}}}", self.id, self.city_name) + } +} + +impl Display for City { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.city_name) + } +} + +crepe! { + // Multiple type parameters with multiple bounds each + @input + struct Connection(N, N, W); + + @input + struct Location(N); + + @output + struct Route(N, N, u32); + + @output + struct WeightedRoute(N, N, u32); + + @output + struct HeavyNode(N); + + // Basic rule with trait method + Route(x, y, d) <- + Connection(x, y, distance), + let d = distance.weight(); + + // Transitive closure with trait methods and arithmetic + Route(x, z, total) <- + Connection(x, y, d1), + Route(y, z, d2), + let w1 = d1.weight(), + let total = w1 + d2; + + // Complex expression with multiple trait methods + WeightedRoute(x, y, score) <- + Route(x, y, d), + let wx = x.weight(), + let wy = y.weight(), + let score = d + wx + wy; + + // Filter using trait method in condition + HeavyNode(n) <- + Location(n), + let w = n.weight(), + (w > 500); +} + +#[test] +fn test_showcase_all_features() { + let seattle = City { + id: 1, + city_name: "Seattle", + population: 750000, + }; + let portland = City { + id: 2, + city_name: "Portland", + population: 650000, + }; + let sf = City { + id: 3, + city_name: "SF", + population: 900000, + }; + + // Distance is also weighted + #[derive(Hash, Eq, PartialEq, Clone, Copy, Default)] + struct Distance(u32); + + impl Weighted for Distance { + fn weight(&self) -> u32 { + self.0 + } + } + + let mut runtime = Crepe::new(); + + // Add connections + runtime.extend([ + Connection(seattle, portland, Distance(100)), + Connection(portland, sf, Distance(200)), + ]); + + // Add locations + runtime.extend([Location(seattle), Location(portland), Location(sf)]); + + let (routes, weighted_routes, heavy_nodes) = runtime.run(); + + // Test routes with trait methods + let route_vec: Vec<_> = routes + .into_iter() + .map(|Route(x, y, d)| (x.name(), y.name(), d)) + .collect(); + + assert!(route_vec.contains(&("Seattle", "Portland", 100))); + assert!(route_vec.contains(&("Portland", "SF", 200))); + assert!(route_vec.contains(&("Seattle", "SF", 300))); + + // Test weighted routes + let weighted_vec: Vec<_> = weighted_routes + .into_iter() + .map(|WeightedRoute(x, y, s)| (x.name(), y.name(), s)) + .collect(); + + // Seattle->Portland: distance(100) + seattle.weight(750) + portland.weight(650) = 1500 + assert!(weighted_vec + .iter() + .any(|(x, y, s)| *x == "Seattle" && *y == "Portland" && *s == 1500)); + + // Test heavy nodes (weight > 500) + let heavy_vec: Vec<_> = heavy_nodes + .into_iter() + .map(|HeavyNode(n)| n.name()) + .collect(); + + assert!(heavy_vec.contains(&"Seattle")); // 750 > 500 + assert!(heavy_vec.contains(&"Portland")); // 650 > 500 + assert!(heavy_vec.contains(&"SF")); // 900 > 500 +} diff --git a/tests/test_trait_method_call.rs b/tests/test_trait_method_call.rs new file mode 100644 index 0000000..2c52480 --- /dev/null +++ b/tests/test_trait_method_call.rs @@ -0,0 +1,49 @@ +// Test calling trait methods in Datalog rules + +use crepe::crepe; + +trait Cost { + fn cost(&self) -> u32; +} + +#[derive(Hash, Eq, PartialEq, Clone, Copy, Debug, Default)] +struct Item { + id: u32, + price: u32, +} + +impl Cost for Item { + fn cost(&self) -> u32 { + self.price + } +} + +crepe! { + @input + struct Product(T); + + @output + struct PriceInfo(T, u32); + + // Try calling trait method in rule + PriceInfo(x, c) <- Product(x), let c = x.cost(); +} + +#[test] +fn test_trait_method_call() { + let mut runtime = Crepe::new(); + runtime.extend([ + Product(Item { id: 1, price: 100 }), + Product(Item { id: 2, price: 200 }), + Product(Item { id: 3, price: 300 }), + ]); + + let (price_info,) = runtime.run(); + let mut results: Vec<_> = price_info + .into_iter() + .map(|PriceInfo(item, price)| (item.id, price)) + .collect(); + results.sort_unstable(); + + assert_eq!(results, vec![(1, 100), (2, 200), (3, 300),]); +} diff --git a/tests/ui/bad_visibility_of_relation.stderr b/tests/ui/bad_visibility_of_relation.stderr index d283f8b..0e3ee9e 100644 --- a/tests/ui/bad_visibility_of_relation.stderr +++ b/tests/ui/bad_visibility_of_relation.stderr @@ -1,7 +1,7 @@ error[E0603]: tuple struct constructor `Test` is private --> tests/ui/bad_visibility_of_relation.rs:21:22 | -9 | struct Test(u32); + 9 | struct Test(u32); | --- a constructor is private if any of the fields is private ... 21 | let _ = datalog::Test(1); @@ -10,11 +10,11 @@ error[E0603]: tuple struct constructor `Test` is private note: the tuple struct constructor `Test` is defined here --> tests/ui/bad_visibility_of_relation.rs:9:9 | -9 | struct Test(u32); + 9 | struct Test(u32); | ^^^^^^^^^^^^^^^^^ help: consider making the field publicly accessible | -9 | struct Test(pub u32); + 9 | struct Test(pub u32); | +++ error[E0603]: tuple struct constructor `MoreTest` is private