Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

# dtype_variant

A Rust derive macro for creating type-safe enum variants with shared type tokens across multiple enums. This enables synchronized variant types and powerful downcasting capabilities between related enums.
Expand Down Expand Up @@ -230,13 +229,13 @@ build_dtype_tokens!([Int, Float, Str]);
#[dtype(tokens_path = self)]
// Group variants by their logical category
#[dtype_grouped_matcher(name = match_by_category, grouping = [
Numeric([Int, Float]),
Text([Str])
Numeric(Int | Float),
Text(Str)
])]
// Group variants by their memory footprint
#[dtype_grouped_matcher(name = match_by_size, grouping = [
Small([Int]),
Large([Float, Str])
Small(Int),
Large(Float | Str)
])]
enum MyData {
Int(i32),
Expand Down Expand Up @@ -312,3 +311,7 @@ Contributions are welcome! Please feel free to submit a Pull Request.
## Acknowledgements

This project was inspired by [dtype_dispatch](https://github.com/pcodec/pcodec/tree/main/dtype_dispatch), which provides similar enum variant type dispatch functionality.

## Roadmap

- Add constraint support to the `grouped_matcher` argument for enforcing trait bounds on grouped variants.
6 changes: 3 additions & 3 deletions dtype_variant/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,10 @@ mod tests {
#[derive(DType, Debug, Clone, PartialEq)]
#[dtype(tokens_path = self)] // skip_from_impls is false by default
#[dtype_grouped_matcher(name = match_by_category, grouping = [
Numeric([Int, Float]),
Text([Str])
Numeric(Int | Float),
Text(Str)
])]
#[dtype_grouped_matcher(name = match_by_size, grouping = [Small([Int]), Large([Float, Str])])]
#[dtype_grouped_matcher(name = match_by_size, grouping = [Small(Int), Large(Float | Str)])]
#[allow(dead_code)]
enum MyData {
Int(i32),
Expand Down
47 changes: 26 additions & 21 deletions dtype_variant_derive/src/grouped_matcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,44 +55,49 @@ impl FromMeta for ParsedGroups {
_ => return Err(Error::custom("Expected group name identifier").with_span(&*call.func))
};

// Ensure there's exactly one argument (the variant array)
// Ensure there's exactly one argument (the variant expression)
if call.args.len() != 1 {
return Err(Error::custom(
format!("Group `{}` expects exactly one argument (a list of variants in brackets `[...]`)", group_name)
format!("Group `{}` expects exactly one argument (a list of variants separated by `|`)", group_name)
).with_span(&call.args));
}

// Extract the variants array
let variants_array = match &call.args[0] {
syn::Expr::Array(array) => array,
_ => return Err(Error::custom(
"Expected variant list in brackets `[...]`"
).with_span(&call.args))
};

// Extract each variant identifier
// Extract the variants separated by `|`
let variants_expr = &call.args[0];
let mut variants = Vec::new();
for variant_expr in &variants_array.elems {
let variant = match variant_expr {
syn::Expr::Path(path) => path.path.get_ident().cloned().ok_or_else(||
Error::custom("Expected variant identifier").with_span(variant_expr)
)?,
_ => return Err(Error::unexpected_expr_type(variant_expr).with_span(variant_expr))
};
variants.push(variant);

fn extract_variants(expr: &syn::Expr, variants: &mut Vec<Ident>) -> darling::Result<()> {
match expr {
syn::Expr::Binary(binary) if matches!(binary.op, syn::BinOp::BitOr(_)) => {
// Correctly handle `|` as a binary operator for separating variants
extract_variants(&binary.left, variants)?;
extract_variants(&binary.right, variants)?;
}
syn::Expr::Path(path) => {
variants.push(path.path.get_ident().cloned().ok_or_else(||
Error::custom("Expected variant identifier").with_span(path)
)?);
}
_ => return Err(Error::custom(
"Expected variants separated by `|` or a single variant identifier"
).with_span(expr)),
}
Ok(())
}

extract_variants(variants_expr, &mut variants)?;

// Ensure the variant list is not empty
if variants.is_empty() {
return Err(Error::custom(
"Group variant list cannot be empty"
).with_span(variants_array));
).with_span(variants_expr));
}

groups.push((group_name, variants));
},
_ => return Err(Error::custom(
"Expected group definition in the format `GroupName([Variant, ...])`"
"Expected group definition in the format `GroupName(Variant | ...)`"
).with_span(elem))
}
}
Expand Down
Loading