Skip to content

Commit 33861cf

Browse files
authored
Merge pull request #13 from alexhallam/11-make-c-for-categorical-data
update with categoricals
2 parents 10dda51 + 9365e83 commit 33861cf

File tree

12 files changed

+363
-25
lines changed

12 files changed

+363
-25
lines changed

CHANGELOG.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Changelog
22

3-
## [0.2.6] - 2024-12-19
3+
## [0.2.7] - 2024-12-19
44

55
### ✨ Added
66

@@ -9,6 +9,10 @@
99
- **Comprehensive Interaction Metadata**: Each variable now includes detailed interaction information showing all interactions it participates in
1010
- **Canonical Expansion**: Multi-way interactions now generate all possible combinations (2-way, 3-way, etc.) as per R/Wilkinson notation standards
1111
- **Test Example**: Added `test_multiway_interaction.rs` example demonstrating 2-way, 3-way, and 4-way interactions
12+
- **Categorical Function Support**: Added `c()` function for categorical variables with reference level specification
13+
- **Named Arguments**: Added support for named arguments in function calls (e.g., `ref=treatment`)
14+
- **Categorical Role**: Added new `Categorical` role to `VariableRole` enum for categorical variables
15+
- **Test Example**: Added `test_categorical_function.rs` example demonstrating categorical function usage
1216

1317
### 🔧 Improved
1418

Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[package]
22
description = "High-performance modern Wilkinson's formula parsing for statistical models. Parses R-style formulas into structured JSON metadata supporting linear models, mixed effects, and complex statistical specifications."
33
name = "fiasto"
4-
version = "0.2.6"
4+
version = "0.2.7"
55
edition = "2021"
66
authors = ["Alex Hallam <alexhallam6.28@gmail.com>"]
77
license = "MIT"
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
use fiasto::parse_formula;
2+
3+
fn main() -> Result<(), Box<dyn std::error::Error>> {
4+
println!("Testing Categorical Function c() in Fiasto");
5+
println!("==========================================");
6+
println!();
7+
8+
// Test 1: Basic categorical function without reference level
9+
println!("=== Test 1: Basic categorical function ===");
10+
let formula1 = "y ~ c(treatment)";
11+
println!("Formula: {}", formula1);
12+
match parse_formula(formula1) {
13+
Ok(result) => {
14+
println!("✓ Parsed successfully!");
15+
println!("{}", serde_json::to_string_pretty(&result)?);
16+
}
17+
Err(e) => {
18+
println!("✗ Error parsing formula: {}", e);
19+
}
20+
}
21+
println!();
22+
23+
// Test 2: Categorical function with reference level (unquoted)
24+
println!("=== Test 2: Categorical with reference level (unquoted) ===");
25+
let formula2 = "y ~ c(treatment, ref=control)";
26+
println!("Formula: {}", formula2);
27+
match parse_formula(formula2) {
28+
Ok(result) => {
29+
println!("✓ Parsed successfully!");
30+
println!("{}", serde_json::to_string_pretty(&result)?);
31+
}
32+
Err(e) => {
33+
println!("✗ Error parsing formula: {}", e);
34+
}
35+
}
36+
println!();
37+
38+
// Test 3: Categorical function with reference level (quoted)
39+
println!("=== Test 3: Categorical with reference level (quoted) ===");
40+
let formula3 = r#"y ~ c(group, ref="group1")"#;
41+
println!("Formula: {}", formula3);
42+
match parse_formula(formula3) {
43+
Ok(result) => {
44+
println!("✓ Parsed successfully!");
45+
println!("{}", serde_json::to_string_pretty(&result)?);
46+
}
47+
Err(e) => {
48+
println!("✗ Error parsing formula: {}", e);
49+
}
50+
}
51+
println!();
52+
53+
// Test 4: Categorical function with other variables
54+
println!("=== Test 4: Categorical with other variables ===");
55+
let formula4 = "y ~ x1 + c(treatment, ref=control) + x2";
56+
println!("Formula: {}", formula4);
57+
match parse_formula(formula4) {
58+
Ok(result) => {
59+
println!("✓ Parsed successfully!");
60+
println!("{}", serde_json::to_string_pretty(&result)?);
61+
}
62+
Err(e) => {
63+
println!("✗ Error parsing formula: {}", e);
64+
}
65+
}
66+
println!();
67+
68+
// Test 5: Multiple categorical variables
69+
println!("=== Test 5: Multiple categorical variables ===");
70+
let formula5 = "y ~ c(treatment, ref=control) + c(group, ref=\"group1\")";
71+
println!("Formula: {}", formula5);
72+
match parse_formula(formula5) {
73+
Ok(result) => {
74+
println!("✓ Parsed successfully!");
75+
println!("{}", serde_json::to_string_pretty(&result)?);
76+
}
77+
Err(e) => {
78+
println!("✗ Error parsing formula: {}", e);
79+
}
80+
}
81+
println!();
82+
83+
println!("All tests completed!");
84+
Ok(())
85+
}

examples/test_factor_function.rs

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
use fiasto::parse_formula;
2+
3+
fn main() -> Result<(), Box<dyn std::error::Error>> {
4+
println!("Testing Factor Function factor() in Fiasto");
5+
println!("==========================================");
6+
println!();
7+
8+
// Test 1: Basic factor function without reference level
9+
println!("=== Test 1: Basic factor function ===");
10+
let formula1 = "y ~ factor(treatment)";
11+
println!("Formula: {}", formula1);
12+
match parse_formula(formula1) {
13+
Ok(result) => {
14+
println!("✓ Parsed successfully!");
15+
println!("{}", serde_json::to_string_pretty(&result)?);
16+
}
17+
Err(e) => {
18+
println!("✗ Error parsing formula: {}", e);
19+
}
20+
}
21+
println!();
22+
23+
// Test 2: Factor function with reference level (unquoted)
24+
println!("=== Test 2: Factor with reference level (unquoted) ===");
25+
let formula2 = "y ~ factor(treatment, ref=control)";
26+
println!("Formula: {}", formula2);
27+
match parse_formula(formula2) {
28+
Ok(result) => {
29+
println!("✓ Parsed successfully!");
30+
println!("{}", serde_json::to_string_pretty(&result)?);
31+
}
32+
Err(e) => {
33+
println!("✗ Error parsing formula: {}", e);
34+
}
35+
}
36+
println!();
37+
38+
// Test 3: Factor function with reference level (quoted)
39+
println!("=== Test 3: Factor with reference level (quoted) ===");
40+
let formula3 = r#"y ~ factor(group, ref="group1")"#;
41+
println!("Formula: {}", formula3);
42+
match parse_formula(formula3) {
43+
Ok(result) => {
44+
println!("✓ Parsed successfully!");
45+
println!("{}", serde_json::to_string_pretty(&result)?);
46+
}
47+
Err(e) => {
48+
println!("✗ Error parsing formula: {}", e);
49+
}
50+
}
51+
println!();
52+
53+
// Test 4: Factor function with other variables
54+
println!("=== Test 4: Factor with other variables ===");
55+
let formula4 = "y ~ x1 + factor(treatment, ref=control) + x2";
56+
println!("Formula: {}", formula4);
57+
match parse_formula(formula4) {
58+
Ok(result) => {
59+
println!("✓ Parsed successfully!");
60+
println!("{}", serde_json::to_string_pretty(&result)?);
61+
}
62+
Err(e) => {
63+
println!("✗ Error parsing formula: {}", e);
64+
}
65+
}
66+
println!();
67+
68+
// Test 5: Compare factor() and c() functions
69+
println!("=== Test 5: Compare factor() and c() functions ===");
70+
let formula5 = "y ~ factor(treatment, ref=control) + c(group, ref=\"group1\")";
71+
println!("Formula: {}", formula5);
72+
match parse_formula(formula5) {
73+
Ok(result) => {
74+
println!("✓ Parsed successfully!");
75+
println!("{}", serde_json::to_string_pretty(&result)?);
76+
}
77+
Err(e) => {
78+
println!("✗ Error parsing formula: {}", e);
79+
}
80+
}
81+
println!();
82+
83+
println!("All tests completed!");
84+
Ok(())
85+
}

src/internal/ast.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,13 @@ pub enum Argument {
229229
/// - `"group_id"` → `Argument::String("group_id")`
230230
String(String),
231231

232+
/// A named argument (key=value)
233+
///
234+
/// # Examples
235+
/// - `ref=treatment` → `Argument::Named("ref", "treatment")`
236+
/// - `level=high` → `Argument::Named("level", "high")`
237+
Named(String, String),
238+
232239
/// A boolean value
233240
///
234241
/// # Examples

src/internal/data_structures.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,13 @@ pub enum VariableRole {
112112
/// - `x1_x2` for interaction `x1:x2`
113113
/// - `x1_x2_x3` for interaction `x1:x2:x3`
114114
InteractionTerm,
115+
116+
/// A categorical variable with reference level specification
117+
///
118+
/// # Examples
119+
/// - `c(treatment, ref=control)` for categorical treatment with control as reference
120+
/// - `c(group, ref="group1")` for categorical group with "group1" as reference
121+
Categorical,
115122
}
116123

117124
/// A transformation applied to a variable

src/internal/lexer.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,10 @@ pub enum Token {
211211
#[token("factor")]
212212
Factor,
213213

214+
/// Categorical variable with reference level: `c(x, ref=level)`
215+
#[token("c", priority = 3)]
216+
C,
217+
214218
/// Scaling transformation: `scale(x)`
215219
#[token("scale")]
216220
Scale,

src/internal/meta_builder.rs

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,12 @@ impl MetaBuilder {
486486

487487
/// Adds a function/transformation term
488488
pub fn push_function_term(&mut self, fname: &str, args: &[Argument]) {
489+
// Special handling for categorical functions
490+
if fname == "c" || fname == "factor" {
491+
self.push_categorical_term_with_name(fname, args);
492+
return;
493+
}
494+
489495
let base_ident = args.iter().find_map(|a| match a {
490496
Argument::Ident(s) => Some(s.as_str()),
491497
_ => None,
@@ -510,6 +516,47 @@ impl MetaBuilder {
510516
}
511517
}
512518

519+
/// Handles categorical variables with reference level specification
520+
fn push_categorical_term_with_name(&mut self, fname: &str, args: &[Argument]) {
521+
// Extract the variable name (first argument)
522+
let var_name = args.iter().find_map(|a| match a {
523+
Argument::Ident(s) => Some(s.as_str()),
524+
_ => None,
525+
});
526+
527+
if let Some(var_name) = var_name {
528+
self.ensure_variable(var_name);
529+
530+
// Add both Categorical and FixedEffect roles
531+
self.add_role(var_name, VariableRole::Categorical);
532+
self.add_role(var_name, VariableRole::FixedEffect);
533+
534+
// Extract reference level from named arguments
535+
let ref_level = args.iter().find_map(|a| match a {
536+
Argument::Named(key, value) if key == "ref" => Some(value.clone()),
537+
_ => None,
538+
});
539+
540+
// Create transformation info with reference level
541+
let mut parameters = self.extract_function_parameters(fname, args);
542+
if let Some(ref_level) = ref_level {
543+
if let serde_json::Value::Object(ref mut params_map) = parameters {
544+
params_map.insert("ref".to_string(), serde_json::Value::String(ref_level));
545+
}
546+
}
547+
548+
let generates_columns = self.generate_transformation_columns(fname, args);
549+
550+
let transformation = Transformation {
551+
function: fname.to_string(),
552+
parameters,
553+
generates_columns,
554+
};
555+
556+
self.add_transformation(var_name, transformation);
557+
}
558+
}
559+
513560
/// Handles random effects with variable-centric approach
514561
pub fn push_random_effect(&mut self, random_effect: &RandomEffect) {
515562
self.is_random_effects_model = true;
@@ -674,6 +721,10 @@ impl MetaBuilder {
674721
"log" => {
675722
// No additional parameters for log
676723
}
724+
"factor" => {
725+
// Handle factor function parameters (same as c function)
726+
// Parameters are handled by the generic case below
727+
}
677728
_ => {
678729
// Generic parameter handling
679730
for (i, arg) in args.iter().enumerate() {
@@ -683,6 +734,11 @@ impl MetaBuilder {
683734
Argument::String(s) => serde_json::Value::String(s.clone()),
684735
Argument::Boolean(b) => serde_json::Value::Bool(*b),
685736
Argument::Ident(s) => serde_json::Value::String(s.clone()),
737+
Argument::Named(key, value) => {
738+
// For named arguments, use the key directly
739+
params.insert(key.clone(), serde_json::Value::String(value.clone()));
740+
continue; // Skip the generic arg_N handling
741+
}
686742
};
687743
params.insert(key, value);
688744
}
@@ -713,6 +769,11 @@ impl MetaBuilder {
713769
}
714770
}
715771
"log" => vec![format!("{}_log", base_name)],
772+
"c" | "factor" => {
773+
// For categorical variables, we generate dummy variables for each level
774+
// The reference level is excluded (handled by the ref parameter)
775+
vec![format!("{}_categorical", base_name)]
776+
}
716777
_ => vec![format!("{}_{}", base_name, fname)],
717778
}
718779
}

src/internal/parse_arg.rs

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,41 @@ pub fn parse_arg<'a>(
7777
match tok {
7878
Token::ColumnName => {
7979
crate::internal::next::next(tokens, pos);
80-
Ok(Argument::Ident(slice.to_string()))
80+
81+
// Check if this is a named argument (key=value)
82+
if crate::internal::peek::peek(tokens, *pos)
83+
.map(|(t, _)| matches!(t, Token::Equal))
84+
.unwrap_or(false)
85+
{
86+
let key = slice.to_string();
87+
crate::internal::next::next(tokens, pos); // Skip the equals sign
88+
89+
// Parse the value
90+
if let Some((value_tok, value_slice)) =
91+
crate::internal::peek::peek(tokens, *pos).cloned()
92+
{
93+
match value_tok {
94+
Token::ColumnName => {
95+
crate::internal::next::next(tokens, pos);
96+
Ok(Argument::Named(key, value_slice.to_string()))
97+
}
98+
Token::StringLiteral => {
99+
crate::internal::next::next(tokens, pos);
100+
// Remove quotes from string literal
101+
let value = value_slice.trim_matches('"').to_string();
102+
Ok(Argument::Named(key, value))
103+
}
104+
_ => Err(ParseError::Unexpected {
105+
expected: "column name or string literal",
106+
found: Some(value_tok),
107+
}),
108+
}
109+
} else {
110+
Err(ParseError::Eoi)
111+
}
112+
} else {
113+
Ok(Argument::Ident(slice.to_string()))
114+
}
81115
}
82116
Token::Integer => {
83117
crate::internal::next::next(tokens, pos);
@@ -87,6 +121,12 @@ pub fn parse_arg<'a>(
87121
crate::internal::next::next(tokens, pos);
88122
Ok(Argument::Integer(1))
89123
}
124+
Token::StringLiteral => {
125+
crate::internal::next::next(tokens, pos);
126+
// Remove quotes from string literal
127+
let value = slice.trim_matches('"').to_string();
128+
Ok(Argument::String(value))
129+
}
90130
_ => Err(ParseError::Unexpected {
91131
expected: "argument",
92132
found: Some(tok),

0 commit comments

Comments
 (0)