diff --git a/README.md b/README.md index 6790b91..56f9d11 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,3 @@ - # dtype_variant A Rust derive macro for creating type-safe enum variants with shared type tokens across multiple enums. This enables synchronized variant types and powerful downcasting capabilities between related enums. @@ -230,13 +229,13 @@ build_dtype_tokens!([Int, Float, Str]); #[dtype(tokens_path = self)] // Group variants by their logical category #[dtype_grouped_matcher(name = match_by_category, grouping = [ - Numeric([Int, Float]), - Text([Str]) + Numeric(Int | Float), + Text(Str) ])] // Group variants by their memory footprint #[dtype_grouped_matcher(name = match_by_size, grouping = [ - Small([Int]), - Large([Float, Str]) + Small(Int), + Large(Float | Str) ])] enum MyData { Int(i32), @@ -312,3 +311,7 @@ Contributions are welcome! Please feel free to submit a Pull Request. ## Acknowledgements This project was inspired by [dtype_dispatch](https://github.com/pcodec/pcodec/tree/main/dtype_dispatch), which provides similar enum variant type dispatch functionality. + +## Roadmap + +- Add constraint support to the `grouped_matcher` argument for enforcing trait bounds on grouped variants. \ No newline at end of file diff --git a/dtype_variant/src/lib.rs b/dtype_variant/src/lib.rs index b820e2f..78e66e8 100644 --- a/dtype_variant/src/lib.rs +++ b/dtype_variant/src/lib.rs @@ -148,10 +148,10 @@ mod tests { #[derive(DType, Debug, Clone, PartialEq)] #[dtype(tokens_path = self)] // skip_from_impls is false by default #[dtype_grouped_matcher(name = match_by_category, grouping = [ - Numeric([Int, Float]), - Text([Str]) + Numeric(Int | Float), + Text(Str) ])] - #[dtype_grouped_matcher(name = match_by_size, grouping = [Small([Int]), Large([Float, Str])])] + #[dtype_grouped_matcher(name = match_by_size, grouping = [Small(Int), Large(Float | Str)])] #[allow(dead_code)] enum MyData { Int(i32), diff --git a/dtype_variant_derive/src/grouped_matcher.rs b/dtype_variant_derive/src/grouped_matcher.rs index 3731c70..8298687 100644 --- a/dtype_variant_derive/src/grouped_matcher.rs +++ b/dtype_variant_derive/src/grouped_matcher.rs @@ -55,44 +55,49 @@ impl FromMeta for ParsedGroups { _ => return Err(Error::custom("Expected group name identifier").with_span(&*call.func)) }; - // Ensure there's exactly one argument (the variant array) + // Ensure there's exactly one argument (the variant expression) if call.args.len() != 1 { return Err(Error::custom( - format!("Group `{}` expects exactly one argument (a list of variants in brackets `[...]`)", group_name) + format!("Group `{}` expects exactly one argument (a list of variants separated by `|`)", group_name) ).with_span(&call.args)); } - // Extract the variants array - let variants_array = match &call.args[0] { - syn::Expr::Array(array) => array, - _ => return Err(Error::custom( - "Expected variant list in brackets `[...]`" - ).with_span(&call.args)) - }; - - // Extract each variant identifier + // Extract the variants separated by `|` + let variants_expr = &call.args[0]; let mut variants = Vec::new(); - for variant_expr in &variants_array.elems { - let variant = match variant_expr { - syn::Expr::Path(path) => path.path.get_ident().cloned().ok_or_else(|| - Error::custom("Expected variant identifier").with_span(variant_expr) - )?, - _ => return Err(Error::unexpected_expr_type(variant_expr).with_span(variant_expr)) - }; - variants.push(variant); + + fn extract_variants(expr: &syn::Expr, variants: &mut Vec) -> darling::Result<()> { + match expr { + syn::Expr::Binary(binary) if matches!(binary.op, syn::BinOp::BitOr(_)) => { + // Correctly handle `|` as a binary operator for separating variants + extract_variants(&binary.left, variants)?; + extract_variants(&binary.right, variants)?; + } + syn::Expr::Path(path) => { + variants.push(path.path.get_ident().cloned().ok_or_else(|| + Error::custom("Expected variant identifier").with_span(path) + )?); + } + _ => return Err(Error::custom( + "Expected variants separated by `|` or a single variant identifier" + ).with_span(expr)), + } + Ok(()) } + extract_variants(variants_expr, &mut variants)?; + // Ensure the variant list is not empty if variants.is_empty() { return Err(Error::custom( "Group variant list cannot be empty" - ).with_span(variants_array)); + ).with_span(variants_expr)); } groups.push((group_name, variants)); }, _ => return Err(Error::custom( - "Expected group definition in the format `GroupName([Variant, ...])`" + "Expected group definition in the format `GroupName(Variant | ...)`" ).with_span(elem)) } }