diff --git a/sds-go/go/dd_sds.h b/sds-go/go/dd_sds.h index ae37e93a..16695554 100644 --- a/sds-go/go/dd_sds.h +++ b/sds-go/go/dd_sds.h @@ -21,3 +21,4 @@ void append_rule_to_list(long rule_ptr, long list_ptr); void free_rule_list(long list_ptr); const char* validate_regex(const char* regex, const char** error_out); +const char* explain_regex(const char* regex, const char** error_out); diff --git a/sds-go/go/regex.go b/sds-go/go/regex.go new file mode 100644 index 00000000..60db3660 --- /dev/null +++ b/sds-go/go/regex.go @@ -0,0 +1,95 @@ +package dd_sds + +/* +#include +#include +*/ +import "C" +import ( + "encoding/json" + "fmt" + "unsafe" +) + +// ValidateRegex validates a regex pattern and returns any error message. +// Returns (true, nil) if the regex is valid, or (false, error) if invalid. +func ValidateRegex(regex string) (bool, error) { + cRegex := C.CString(regex) + defer C.free(unsafe.Pointer(cRegex)) + + result := C.validate_regex(cRegex, nil) + if result == nil { + return true, nil + } + + errorMsg := C.GoString(result) + C.free_string(result) + return false, fmt.Errorf("invalid regex: %s", errorMsg) +} + +// AstNode represents a node in the regex abstract syntax tree. +// Each node provides detailed information about a specific part of the regex pattern. +type AstNode struct { + // NodeType is the type of syntax element (e.g., "Literal", "Alternation", "Capturing Group") + NodeType string `json:"node_type"` + + // Description is a human-readable explanation of what this node does + Description string `json:"description"` + + // Start is the character position where this node begins in the original pattern (for highlighting) + Start int `json:"start"` + + // End is the character position where this node ends in the original pattern (for highlighting) + End int `json:"end"` + + // Children contains nested AST nodes for complex patterns + Children []AstNode `json:"children,omitempty"` + + // Properties contains additional metadata about the node + Properties map[string]interface{} `json:"properties,omitempty"` +} + +// RegexExplanation contains the result of explaining a regex pattern. +// If the regex is invalid, IsValid will be false and Error will contain the error message. +type RegexExplanation struct { + // IsValid indicates whether the regex pattern was successfully parsed + IsValid bool `json:"is_valid"` + + // Error contains the error message if the regex is invalid + Error *string `json:"error,omitempty"` + + // Tree is the root node of the Abstract Syntax Tree if the regex is valid + Tree *AstNode `json:"tree,omitempty"` +} + +// ExplainRegex parses a regex pattern and returns its Abstract Syntax Tree (AST) +// along with human-readable descriptions of each node. +func ExplainRegex(regex string) (RegexExplanation, error) { + cRegex := C.CString(regex) + defer C.free(unsafe.Pointer(cRegex)) + + result := C.explain_regex(cRegex, nil) + if result == nil { + return RegexExplanation{ + IsValid: false, + Error: stringPtr("Failed to explain regex"), + }, nil + } + + jsonStr := C.GoString(result) + C.free_string(result) + + var explanation RegexExplanation + if err := json.Unmarshal([]byte(jsonStr), &explanation); err != nil { + return RegexExplanation{ + IsValid: false, + Error: stringPtr("Failed to parse explanation JSON"), + }, err + } + + return explanation, nil +} + +func stringPtr(s string) *string { + return &s +} diff --git a/sds-go/go/regex_test.go b/sds-go/go/regex_test.go new file mode 100644 index 00000000..30d92a37 --- /dev/null +++ b/sds-go/go/regex_test.go @@ -0,0 +1,134 @@ +package dd_sds + +import ( + "testing" +) + +func TestValidateRegex(t *testing.T) { + _, err := ValidateRegex("hello") + if err != nil { + t.Fatal(err) + } + + _, err = ValidateRegex("[") + if err == nil { + t.Fatal("Expected error for invalid regex") + } +} + +func TestExplainRegex(t *testing.T) { + tests := []struct { + name string + pattern string + valid bool + }{ + { + name: "simple literal", + pattern: "hello", + valid: true, + }, + { + name: "digit class", + pattern: "\\d+", + valid: true, + }, + { + name: "alternation", + pattern: "a|b|c", + valid: true, + }, + { + name: "capturing group", + pattern: "(abc)", + valid: true, + }, + { + name: "complex pattern", + pattern: "(\\d{3})-\\d{3}-\\d{4}", + valid: true, + }, + { + name: "invalid regex", + pattern: "[", + valid: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + explanation, err := ExplainRegex(tt.pattern) + if err != nil { + t.Fatalf("ExplainRegex() error = %v", err) + } + + if explanation.IsValid != tt.valid { + t.Errorf("ExplainRegex() IsValid = %v, want %v", explanation.IsValid, tt.valid) + } + + if tt.valid { + if explanation.Tree == nil { + t.Error("ExplainRegex() Tree should not be nil for valid regex") + } + if explanation.Error != nil { + t.Errorf("ExplainRegex() Error should be nil for valid regex, got %v", *explanation.Error) + } + + if explanation.Tree.NodeType == "" { + t.Error("ExplainRegex() Tree.NodeType should not be empty") + } + if explanation.Tree.Description == "" { + t.Error("ExplainRegex() Tree.Description should not be empty") + } + } else { + if explanation.Error == nil { + t.Error("ExplainRegex() Error should not be nil for invalid regex") + } + } + }) + } +} + +func TestExplainRegexWithPositions(t *testing.T) { + explanation, err := ExplainRegex("abc") + if err != nil { + t.Fatalf("ExplainRegex() error = %v", err) + } + + if !explanation.IsValid { + t.Fatal("ExplainRegex() should be valid") + } + + if explanation.Tree == nil { + t.Fatal("ExplainRegex() Tree should not be nil") + } + + if explanation.Tree.Start < 0 { + t.Errorf("ExplainRegex() Tree.Start should be >= 0, got %d", explanation.Tree.Start) + } + if explanation.Tree.End <= explanation.Tree.Start { + t.Errorf("ExplainRegex() Tree.End (%d) should be > Start (%d)", explanation.Tree.End, explanation.Tree.Start) + } +} + +func TestExplainRegexWithChildren(t *testing.T) { + explanation, err := ExplainRegex("a|b|c") + if err != nil { + t.Fatalf("ExplainRegex() error = %v", err) + } + + if !explanation.IsValid { + t.Fatal("ExplainRegex() should be valid") + } + + if explanation.Tree == nil { + t.Fatal("ExplainRegex() Tree should not be nil") + } + + if explanation.Tree.NodeType != "Alternation" { + t.Errorf("ExplainRegex() Tree.NodeType should be 'Alternation', got %s", explanation.Tree.NodeType) + } + + if len(explanation.Tree.Children) != 3 { + t.Errorf("ExplainRegex() Tree.Children should have 3 elements, got %d", len(explanation.Tree.Children)) + } +} diff --git a/sds-go/go/validation.go b/sds-go/go/validation.go deleted file mode 100644 index bc6b79f4..00000000 --- a/sds-go/go/validation.go +++ /dev/null @@ -1,26 +0,0 @@ -package dd_sds - -/* -#include -#include -*/ -import "C" -import ( - "fmt" - "unsafe" -) - -func ValidateRegex(regex string) (bool, error) { - cRegex := C.CString(regex) - defer C.free(unsafe.Pointer(cRegex)) - - result := C.validate_regex(cRegex, nil) - // If result is null, regex is valid - if result == nil { - return true, nil - } - // Otherwise, result contains error message - errorMsg := C.GoString(result) - C.free_string(result) // Free the string allocated by Rust - return false, fmt.Errorf("invalid regex: %s", errorMsg) -} diff --git a/sds-go/go/validation_test.go b/sds-go/go/validation_test.go deleted file mode 100644 index 5f332c09..00000000 --- a/sds-go/go/validation_test.go +++ /dev/null @@ -1,19 +0,0 @@ -package dd_sds - -import ( - "testing" -) - -func TestValidateRegex(t *testing.T) { - // Test valid regex - _, err := ValidateRegex("hello") - if err != nil { - t.Fatal(err) - } - - // Test invalid regex to ensure memory is properly freed - _, err = ValidateRegex("[") - if err == nil { - t.Fatal("Expected error for invalid regex") - } -} diff --git a/sds-go/rust/src/native/mod.rs b/sds-go/rust/src/native/mod.rs index 7182018a..cde8b46b 100644 --- a/sds-go/rust/src/native/mod.rs +++ b/sds-go/rust/src/native/mod.rs @@ -9,9 +9,9 @@ use std::sync::{Arc, Mutex}; pub mod create_scanner; pub mod delete_scanner; +pub mod regex; pub mod rule; pub mod scan; -pub mod validation; pub const ERR_PANIC: i64 = -5; diff --git a/sds-go/rust/src/native/regex.rs b/sds-go/rust/src/native/regex.rs new file mode 100644 index 00000000..8e20712d --- /dev/null +++ b/sds-go/rust/src/native/regex.rs @@ -0,0 +1,143 @@ +use crate::handle_panic_ptr_return; +use dd_sds::explain_regex as explain_regex_impl; +use dd_sds::validate_regex as validate_regex_impl; +use serde::{Deserialize, Serialize}; +use std::ffi::{CStr, CString, c_char}; + +/// # Safety +/// +/// This function dereferences `regex` and `error_out` which are pointers to c_char. +/// The caller must ensure that the pointers are valid. +/// +/// Thread Safety: This is safe to call simultaneously from multiple threads. +/// Return value: `null` if the regex is valid, otherwise a string describing the error. +#[unsafe(no_mangle)] +pub unsafe extern "C" fn validate_regex( + regex: *const c_char, + error_out: *mut *const c_char, +) -> *const c_char { + handle_panic_ptr_return(Some(error_out), || { + let pattern = unsafe { CStr::from_ptr(regex).to_string_lossy().into_owned() }; + + match validate_regex_impl(&pattern) { + Ok(_) => 0i64, // Return null pointer as i64 + Err(err) => { + // Convert error to CString and return as pointer + let error_msg = format!("{err}"); + let c_string = CString::new(error_msg).unwrap_or_else(|_| { + CString::new("Invalid regex (error details unavailable)").unwrap() + }); + let ptr = c_string.into_raw(); + ptr as i64 + } + } + }) as *const c_char +} + +#[derive(Debug, Serialize, Deserialize)] +struct RegexExplanation { + is_valid: bool, + #[serde(skip_serializing_if = "Option::is_none")] + error: Option, + #[serde(skip_serializing_if = "Option::is_none")] + tree: Option, +} + +impl From> for RegexExplanation { + fn from(result: Result) -> Self { + match result { + Ok(tree) => RegexExplanation { + is_valid: true, + error: None, + tree: Some(tree), + }, + Err(err) => RegexExplanation { + is_valid: false, + error: Some(err), + tree: None, + }, + } + } +} + +/// # Safety +/// +/// This function dereferences `regex` which is a pointer to c_char. +/// The caller must ensure that the pointer is valid. +/// +/// Thread Safety: This is safe to call simultaneously from multiple threads. +/// Return value: A JSON string containing the regex explanation if valid or the error message if invalid. +#[unsafe(no_mangle)] +pub unsafe extern "C" fn explain_regex( + regex: *const c_char, + error_out: *mut *const c_char, +) -> *const c_char { + handle_panic_ptr_return(Some(error_out), || { + let pattern = unsafe { CStr::from_ptr(regex).to_string_lossy().into_owned() }; + + let result = explain_regex_impl(&pattern); + let explanation: RegexExplanation = result.into(); + + match serde_json::to_string(&explanation) { + Ok(json_str) => { + let c_string = CString::new(json_str).unwrap_or_else(|_| { + CString::new( + "{\"is_valid\":false,\"error\":\"Failed to serialize explanation\"}", + ) + .unwrap() + }); + let ptr = c_string.into_raw(); + ptr as i64 + } + Err(_) => { + let error_msg = + "{\"is_valid\":false,\"error\":\"Failed to serialize explanation\"}"; + let c_string = CString::new(error_msg).unwrap(); + let ptr = c_string.into_raw(); + ptr as i64 + } + } + }) as *const c_char +} + +#[cfg(test)] +mod tests { + use super::*; + use std::ffi::CString; + + #[test] + fn test_explain_regex_valid() { + let pattern = CString::new("a+").unwrap(); + let mut error_out: *const c_char = std::ptr::null(); + + unsafe { + let result = explain_regex(pattern.as_ptr(), &mut error_out); + assert!(!result.is_null()); + + let json_str = CStr::from_ptr(result).to_string_lossy(); + assert!(json_str.contains("is_valid")); + assert!(json_str.contains("true")); + + // Free the allocated string + let _ = CString::from_raw(result as *mut c_char); + } + } + + #[test] + fn test_explain_regex_invalid() { + let pattern = CString::new("[").unwrap(); + let mut error_out: *const c_char = std::ptr::null(); + + unsafe { + let result = explain_regex(pattern.as_ptr(), &mut error_out); + assert!(!result.is_null()); + + let json_str = CStr::from_ptr(result).to_string_lossy(); + assert!(json_str.contains("is_valid")); + assert!(json_str.contains("false")); + + // Free the allocated string + let _ = CString::from_raw(result as *mut c_char); + } + } +} diff --git a/sds-go/rust/src/native/validation.rs b/sds-go/rust/src/native/validation.rs deleted file mode 100644 index 1f828900..00000000 --- a/sds-go/rust/src/native/validation.rs +++ /dev/null @@ -1,33 +0,0 @@ -use crate::handle_panic_ptr_return; -use dd_sds::validate_regex as validate_regex_impl; -use std::ffi::{CStr, CString, c_char}; - -/// # Safety -/// -/// This function dereferences `regex` and `error_out` which are pointers to c_char. -/// The caller must ensure that the pointers are valid. -/// -/// Thread Safety: This is safe to call simultaneously from multiple threads. -/// Return value: `null` if the regex is valid, otherwise a string describing the error. -#[unsafe(no_mangle)] -pub unsafe extern "C" fn validate_regex( - regex: *const c_char, - error_out: *mut *const c_char, -) -> *const c_char { - handle_panic_ptr_return(Some(error_out), || { - let pattern = unsafe { CStr::from_ptr(regex).to_string_lossy().into_owned() }; - - match validate_regex_impl(&pattern) { - Ok(_) => 0i64, // Return null pointer as i64 - Err(err) => { - // Convert error to CString and return as pointer - let error_msg = format!("{err}"); - let c_string = CString::new(error_msg).unwrap_or_else(|_| { - CString::new("Invalid regex (error details unavailable)").unwrap() - }); - let ptr = c_string.into_raw(); - ptr as i64 - } - } - }) as *const c_char -} diff --git a/sds/src/lib.rs b/sds/src/lib.rs index 700aec0c..f2bdeb55 100644 --- a/sds/src/lib.rs +++ b/sds/src/lib.rs @@ -3,11 +3,11 @@ #![deny(clippy::print_stdout)] #![allow(clippy::new_without_default)] +mod ast_utils; mod encoding; mod event; mod match_action; -mod ast_utils; #[cfg(any(test, feature = "testing", feature = "bench"))] mod event_json; mod match_validation; @@ -42,6 +42,7 @@ pub use path::{Path, PathSegment}; pub use rule_match::{ReplacementType, RuleMatch}; pub use scanner::shared_pool::{SharedPool, SharedPoolGuard}; +pub use parser::explainer::{AstNode, explain_regex}; pub use scanner::suppression::Suppressions; pub use scanner::{ CompiledRule, MatchEmitter, RootCompiledRule, RootRuleConfig, RuleResult, RuleStatus, diff --git a/sds/src/parser/explainer.rs b/sds/src/parser/explainer.rs new file mode 100644 index 00000000..3cea46ba --- /dev/null +++ b/sds/src/parser/explainer.rs @@ -0,0 +1,1349 @@ +use crate::parser::ast::{ + AsciiClass, AsciiClassKind, AssertionType, Ast, BracketCharacterClass, + BracketCharacterClassItem, CharacterClass, Flag, Flags, Group, PerlCharacterClass, + QuantifierKind, UnicodePropertyClass, +}; +use crate::parser::regex_parser::parse_regex_pattern; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AstNode { + pub node_type: String, + pub description: String, + pub start: usize, + pub end: usize, + #[serde(skip_serializing_if = "Option::is_none")] + pub children: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub properties: Option, +} + +/// Explains a regex pattern by parsing it and converting it to an AST with position tracking. +/// +/// # Arguments +/// * `pattern` - The regex pattern string to explain +/// +/// # Returns +/// * `Ok(AstNode)` - The root AST node with descriptions and positions if the pattern is valid +/// * `Err(String)` - An error message if the pattern is invalid +pub fn explain_regex(pattern: &str) -> Result { + parse_regex_pattern(pattern) + .map(|ast| { + let mut node = ast_to_node_with_tracking(&ast, pattern, 0); + // Fix the end position to match the actual pattern length + // This is necessary because ast_to_string doesn't preserve + // all escape sequences (e.g., \/ in bracket classes) + node.end = pattern.len(); + node + }) + .map_err(|err| format!("{:?}", err)) +} + +fn ast_to_string(ast: &Ast) -> String { + match ast { + Ast::Empty => String::new(), + Ast::Literal(lit) => { + if lit.escaped { + format!("\\{}", lit.c) + } else { + lit.c.to_string() + } + } + Ast::Concat(items) => items.iter().map(ast_to_string).collect(), + Ast::Alternation(alts) => alts.iter().map(ast_to_string).collect::>().join("|"), + Ast::Group(group) => match group.as_ref() { + Group::Capturing(g) => format!("({})", ast_to_string(&g.inner)), + Group::NonCapturing(g) => { + format!("(?{}:{})", format_flags(&g.flags), ast_to_string(&g.inner)) + } + Group::NamedCapturing(g) => format!("(?<{}>{})", g.name, ast_to_string(&g.inner)), + }, + Ast::Repetition(rep) => { + let quantifier = match &rep.quantifier.kind { + QuantifierKind::ZeroOrMore => "*".to_string(), + QuantifierKind::OneOrMore => "+".to_string(), + QuantifierKind::ZeroOrOne => "?".to_string(), + QuantifierKind::RangeExact(n) => format!("{{{}}}", n), + QuantifierKind::RangeMinMax(min, max) => format!("{{{},{}}}", min, max), + QuantifierKind::RangeMin(min) => format!("{{{},}}", min), + }; + let lazy = if rep.quantifier.lazy { "?" } else { "" }; + format!("{}{}{}", ast_to_string(&rep.inner), quantifier, lazy) + } + Ast::CharacterClass(class) => match class { + CharacterClass::Dot => ".".to_string(), + CharacterClass::Perl(perl) => match perl { + PerlCharacterClass::Digit => "\\d", + PerlCharacterClass::Space => "\\s", + PerlCharacterClass::Word => "\\w", + PerlCharacterClass::NonDigit => "\\D", + PerlCharacterClass::NonSpace => "\\S", + PerlCharacterClass::NonWord => "\\W", + } + .to_string(), + CharacterClass::Bracket(bracket) => format_bracket_character_class(bracket), + CharacterClass::HorizontalWhitespace => "\\h".to_string(), + CharacterClass::NotHorizontalWhitespace => "\\H".to_string(), + CharacterClass::VerticalWhitespace => "\\v".to_string(), + CharacterClass::NotVerticalWhitespace => "\\V".to_string(), + CharacterClass::UnicodeProperty(prop) => format_unicode_property(prop), + }, + Ast::Assertion(assertion) => match assertion { + AssertionType::WordBoundary => "\\b", + AssertionType::NotWordBoundary => "\\B", + AssertionType::StartLine => "^", + AssertionType::EndLine => "$", + AssertionType::StartText => "\\A", + AssertionType::EndText => "\\z", + AssertionType::EndTextOptionalNewline => "\\Z", + } + .to_string(), + Ast::Flags(flags) => format!("(?{})", format_flags(flags)), + } +} + +fn ast_to_node_with_tracking(ast: &Ast, pattern: &str, offset: usize) -> AstNode { + let node_str = ast_to_string(ast); + let node_len = node_str.len(); + let start = offset.min(pattern.len()); + let end = (offset + node_len).min(pattern.len()).max(start); + + // Note: The end position for groups is corrected in ast_to_node_with_range + + ast_to_node_with_range(ast, pattern, start, end, offset) +} + +fn find_matching_paren(pattern: &str, start: usize) -> Option { + let bytes = pattern.as_bytes(); + if start >= bytes.len() || bytes[start] != b'(' { + return None; + } + + let mut depth = 0; + let mut i = start; + + while i < bytes.len() { + let ch = bytes[i]; + + // Handle escape sequences + if ch == b'\\' && i + 1 < bytes.len() { + i += 2; // Skip the backslash and the next character + continue; + } + + // Check for brackets (character classes) + if ch == b'[' { + i += 1; + // Inside bracket class - find the closing ] + while i < bytes.len() { + if bytes[i] == b'\\' && i + 1 < bytes.len() { + i += 2; // Skip escaped character + continue; + } + if bytes[i] == b']' { + i += 1; + break; + } + i += 1; + } + continue; + } + + // Check for parentheses + match ch { + b'(' => depth += 1, + b')' => { + depth -= 1; + if depth == 0 { + return Some(i + 1); // +1 to include the ')' + } + } + _ => {} + } + + i += 1; + } + + None +} + +fn ast_to_node_with_range( + ast: &Ast, + pattern: &str, + start: usize, + end: usize, + offset: usize, +) -> AstNode { + // SAFETY FIRST: Validate and clamp all positions + let start = start.min(pattern.len()); + let end = end.min(pattern.len()).max(start); + + match ast { + Ast::Empty => AstNode { + node_type: "Empty".to_string(), + description: "matches an empty string".to_string(), + start, + end, + children: None, + properties: None, + }, + + Ast::Literal(lit) => { + let char_display = if lit.c.is_control() || lit.c.is_whitespace() { + lit.c.escape_default().to_string() + } else { + lit.c.to_string() + }; + + let description = if lit.escaped { + format!( + "\\{} matches the character {} literally", + lit.c, char_display + ) + } else { + format!("matches the character {} literally", char_display) + }; + + AstNode { + node_type: "Literal".to_string(), + description, + start, + end, + children: None, + properties: Some(serde_json::json!({ + "character": lit.c.to_string(), + "escaped": lit.escaped, + })), + } + } + + Ast::Concat(items) => { + // For concatenation, track positions of each child + let mut children = Vec::new(); + let mut current_offset = offset; + + for item in items { + let child_str = ast_to_string(item); + let child_len = child_str.len(); + let child_start = current_offset.min(pattern.len()); + let child_end = (current_offset + child_len) + .min(pattern.len()) + .max(child_start); + + let child = + ast_to_node_with_range(item, pattern, child_start, child_end, current_offset); + + // Use the actual end position of the child (which may be corrected for groups) + // instead of the calculated length from ast_to_string + current_offset = child.end; + children.push(child); + } + + AstNode { + node_type: "Concatenation".to_string(), + description: format!("Concatenation of {} elements.", items.len()), + start, + end, + children: Some(children), + properties: None, + } + } + + Ast::Alternation(alts) => { + // For alternation, each alternative gets its own position + // Note: positions will overlap since alternatives are mutually exclusive + let mut children = Vec::new(); + let mut current_offset = offset; + + for (i, alt) in alts.iter().enumerate() { + let alt_str = ast_to_string(alt); + let alt_len = alt_str.len(); + let alt_start = current_offset.min(pattern.len()); + let alt_end = (current_offset + alt_len).min(pattern.len()).max(alt_start); + + let child = + ast_to_node_with_range(alt, pattern, alt_start, alt_end, current_offset); + children.push(child); + + // Move past this alternative and the '|' separator + current_offset += alt_len; + if i < alts.len() - 1 { + current_offset += 1; // for the '|' + } + } + + AstNode { + node_type: "Alternation".to_string(), + description: "alternation - matches the expression before or after the |. Acts like a boolean OR".to_string(), + start, + end, + children: Some(children), + properties: None, + } + } + + Ast::Group(group) => { + let (node_type, description, inner, prefix_len) = match group.as_ref() { + Group::Capturing(g) => ( + "Capturing Group", + "capturing group - groups multiple tokens together and creates a capture group for extracting a substring or using a backreference".to_string(), + &g.inner, + 1, // "(" + ), + Group::NonCapturing(g) => { + let flags_str = format_flags(&g.flags); + ( + "Non-Capturing Group", + "non-capturing group - groups multiple tokens together without creating a capture group".to_string(), + &g.inner, + 3 + flags_str.len(), // "(?flags:" + ) + } + Group::NamedCapturing(g) => ( + "Named Capturing Group", + format!("named capturing group '{}' - groups multiple tokens together and creates a capture group that can be referenced by name", g.name), + &g.inner, + 4 + g.name.len(), // "(?" + ), + }; + + // Inner content starts after the opening syntax + let inner_offset = offset + prefix_len; + let inner_str = ast_to_string(inner); + let inner_len = inner_str.len(); + let inner_start = inner_offset.min(pattern.len()); + let inner_end = (inner_offset + inner_len) + .min(pattern.len()) + .max(inner_start); + + let child = + ast_to_node_with_range(inner, pattern, inner_start, inner_end, inner_offset); + + // Fix the end position: ast_to_string doesn't preserve escape sequences, + // so we need to find the actual closing parenthesis + let actual_end = find_matching_paren(pattern, start).unwrap_or(end); + + AstNode { + node_type: node_type.to_string(), + description, + start, + end: actual_end, + children: Some(vec![child]), + properties: None, + } + } + + Ast::Repetition(rep) => { + // Inner element comes first, quantifier follows + let inner_str = ast_to_string(&rep.inner); + let inner_len = inner_str.len(); + let inner_start = offset.min(pattern.len()); + let inner_end = (offset + inner_len).min(pattern.len()).max(inner_start); + + let child = ast_to_node_with_range(&rep.inner, pattern, inner_start, inner_end, offset); + + let greedy_suffix = if rep.quantifier.lazy { + ", as few times as possible, expanding as needed (lazy)" + } else { + ", as many times as possible, giving back as needed (greedy)" + }; + + let (description, properties) = match &rep.quantifier.kind { + QuantifierKind::ZeroOrMore => ( + format!( + "matches the previous token between zero and unlimited times{}", + greedy_suffix + ), + serde_json::json!({"lazy": rep.quantifier.lazy}), + ), + QuantifierKind::OneOrMore => ( + format!( + "matches the previous token between one and unlimited times{}", + greedy_suffix + ), + serde_json::json!({"lazy": rep.quantifier.lazy}), + ), + QuantifierKind::ZeroOrOne => ( + format!( + "matches the previous token between zero and one times{}", + greedy_suffix + ), + serde_json::json!({"lazy": rep.quantifier.lazy}), + ), + QuantifierKind::RangeExact(n) => ( + format!("matches the previous token exactly {} times", n), + serde_json::json!({"min": n, "max": n, "lazy": rep.quantifier.lazy}), + ), + QuantifierKind::RangeMinMax(min, max) => ( + format!( + "matches the previous token between {} and {} times{}", + min, max, greedy_suffix + ), + serde_json::json!({"min": min, "max": max, "lazy": rep.quantifier.lazy}), + ), + QuantifierKind::RangeMin(min) => ( + format!( + "matches the previous token between {} and unlimited times{}", + min, greedy_suffix + ), + serde_json::json!({"min": min, "max": null, "lazy": rep.quantifier.lazy}), + ), + }; + + let full_desc = description; + + AstNode { + node_type: "Repetition".to_string(), + description: full_desc, + start, + end, + children: Some(vec![child]), + properties: Some(properties), + } + } + + Ast::CharacterClass(class) => { + let (description, node_type) = match class { + CharacterClass::Dot => ( + "matches any character (except for line terminators)".to_string(), + "Dot".to_string(), + ), + CharacterClass::Perl(perl) => { + let (name, desc) = describe_perl_character_class(perl); + ( + format!("{} {}", name, desc), + "Perl Character Class".to_string(), + ) + } + CharacterClass::Bracket(bracket) => { + let desc = if bracket.negated { + "match a single character not present in the list" + } else { + "match a single character present in the list" + }; + (desc.to_string(), "Character Class".to_string()) + } + CharacterClass::HorizontalWhitespace => ( + "\\h matches any horizontal whitespace character (spaces and tabs)".to_string(), + "Horizontal Whitespace".to_string(), + ), + CharacterClass::NotHorizontalWhitespace => ( + "\\H matches any character that's not a horizontal whitespace character" + .to_string(), + "Not Horizontal Whitespace".to_string(), + ), + CharacterClass::VerticalWhitespace => ( + "\\v matches any vertical whitespace character (newlines)".to_string(), + "Vertical Whitespace".to_string(), + ), + CharacterClass::NotVerticalWhitespace => ( + "\\V matches any character that's not a vertical whitespace character" + .to_string(), + "Not Vertical Whitespace".to_string(), + ), + CharacterClass::UnicodeProperty(prop) => { + let (prefix, verb) = if prop.negate { + ("\\P", "not in") + } else { + ("\\p", "in") + }; + let desc = format!( + "{}{{{}}} matches any character {} the unicode category '{}'", + prefix, prop.name, verb, prop.name + ); + (desc, "Unicode Property".to_string()) + } + }; + + // For bracket character classes, add child nodes for each item + let children = match class { + CharacterClass::Bracket(bracket) => { + Some(create_bracket_class_children(bracket, pattern, offset + 1)) // +1 for '[' + } + _ => None, + }; + + AstNode { + node_type, + description, + start, + end, + children, + properties: None, + } + } + + Ast::Assertion(assertion) => { + let (symbol, desc) = match assertion { + AssertionType::WordBoundary => ( + "\\b", + "asserts position at a word boundary: (^\\w|\\w$|\\W\\w|\\w\\W)", + ), + AssertionType::NotWordBoundary => { + ("\\B", "asserts position at a non-word boundary") + } + AssertionType::StartLine => ("^", "asserts position at start of a line"), + AssertionType::EndLine => ("$", "asserts position at end of a line"), + AssertionType::StartText => ("\\A", "asserts position at start of the string"), + AssertionType::EndText => ("\\z", "asserts position at end of the string"), + AssertionType::EndTextOptionalNewline => ( + "\\Z", + "asserts position at the end of the string, or before the line terminator right at the end of the string (if any)", + ), + }; + + AstNode { + node_type: "Assertion".to_string(), + description: format!("{} {}", symbol, desc), + start, + end, + children: None, + properties: Some(serde_json::json!({ + "assertion_type": symbol, + })), + } + } + + Ast::Flags(flags) => { + let flags_str = format_flags(flags); + let flag_descriptions = describe_flags(flags); + AstNode { + node_type: "Flags".to_string(), + description: if flag_descriptions.is_empty() { + "match flags".to_string() + } else { + flag_descriptions + }, + start, + end, + children: None, + properties: Some(serde_json::json!({ + "flags": flags_str, + })), + } + } + } +} + +fn describe_perl_character_class(perl: &PerlCharacterClass) -> (&'static str, &'static str) { + match perl { + PerlCharacterClass::Digit => ("\\d", "matches a digit (equivalent to [0-9])"), + PerlCharacterClass::Space => ( + "\\s", + "matches any whitespace character (equivalent to [\\r\\n\\t\\f\\v ])", + ), + PerlCharacterClass::Word => ( + "\\w", + "matches any word character (equivalent to [a-zA-Z0-9_])", + ), + PerlCharacterClass::NonDigit => ( + "\\D", + "matches any character that's not a digit (equivalent to [^0-9])", + ), + PerlCharacterClass::NonSpace => ( + "\\S", + "matches any non-whitespace character (equivalent to [^\\r\\n\\t\\f\\v ])", + ), + PerlCharacterClass::NonWord => ( + "\\W", + "matches any non-word character (equivalent to [^a-zA-Z0-9_])", + ), + } +} + +fn describe_ascii_class_kind(kind: &AsciiClassKind) -> &'static str { + match kind { + AsciiClassKind::Alnum => "matches any alphanumeric character [a-zA-Z0-9]", + AsciiClassKind::Alpha => "matches any alphabetic character [a-zA-Z]", + AsciiClassKind::Ascii => "matches any ASCII character [\\x00-\\x7F]", + AsciiClassKind::Blank => "matches a space or tab [ \\t]", + AsciiClassKind::Cntrl => "matches any control character [\\x00-\\x1F\\x7F]", + AsciiClassKind::Digit => "matches any digit [0-9]", + AsciiClassKind::Graph => "matches any visible character (not whitespace) [!-~]", + AsciiClassKind::Lower => "matches any lowercase letter [a-z]", + AsciiClassKind::Print => "matches any printable character [ -~]", + AsciiClassKind::Punct => "matches any punctuation character", + AsciiClassKind::Space => "matches any whitespace character [ \\t\\r\\n\\v\\f]", + AsciiClassKind::Upper => "matches any uppercase letter [A-Z]", + AsciiClassKind::Word => "matches any word character [a-zA-Z0-9_]", + AsciiClassKind::Xdigit => "matches any hexadecimal digit [0-9A-Fa-f]", + } +} + +fn create_bracket_class_children( + bracket: &BracketCharacterClass, + pattern: &str, + mut offset: usize, +) -> Vec { + let mut children = Vec::new(); + + // Skip negation character if present + if bracket.negated { + offset += 1; // Skip '^' + } + + for (i, item) in bracket.items.iter().enumerate() { + let is_first = i == 0; + let is_last = i == bracket.items.len() - 1; + let item_str = format_bracket_item(item, is_first, is_last); + let item_len = item_str.len(); + + let start = offset.min(pattern.len()); + let end = (offset + item_len).min(pattern.len()).max(start); + + let (node_type, description) = match item { + BracketCharacterClassItem::Literal(c) => { + let char_display = if c.is_control() || c.is_whitespace() { + c.escape_default().to_string() + } else { + c.to_string() + }; + ( + "Literal".to_string(), + format!("matches the character {}", char_display), + ) + } + BracketCharacterClassItem::Range(start_char, end_char) => ( + "Character Range".to_string(), + format!( + "{}-{} matches a single character in the range between {} (index {}) and {} (index {})", + start_char, + end_char, + start_char, + *start_char as u32, + end_char, + *end_char as u32 + ), + ), + BracketCharacterClassItem::PerlCharacterClass(perl) => { + let (name, desc) = describe_perl_character_class(perl); + ( + "Perl Character Class".to_string(), + format!("{} {}", name, desc), + ) + } + BracketCharacterClassItem::UnicodeProperty(prop) => { + let (prefix, verb) = if prop.negate { + ("\\P", "not in") + } else { + ("\\p", "in") + }; + ( + "Unicode Property".to_string(), + format!( + "{}{{{}}} matches any character {} the unicode category '{}'", + prefix, prop.name, verb, prop.name + ), + ) + } + BracketCharacterClassItem::AsciiClass(ascii) => { + let kind_name = describe_ascii_class_kind(&ascii.kind); + let kind_str = format_ascii_class_name(&ascii.kind); + ( + "ASCII Class".to_string(), + if ascii.negated { + format!( + "[:^{}:] {}", + kind_str, + kind_name.replace("matches", "matches any character not matching") + ) + } else { + format!("[:{}:] {}", kind_str, kind_name) + }, + ) + } + BracketCharacterClassItem::HorizontalWhitespace => ( + "Horizontal Whitespace".to_string(), + "\\h matches any horizontal whitespace character (spaces and tabs)".to_string(), + ), + BracketCharacterClassItem::NotHorizontalWhitespace => ( + "Not Horizontal Whitespace".to_string(), + "\\H matches any character that's not a horizontal whitespace character" + .to_string(), + ), + BracketCharacterClassItem::VerticalWhitespace => ( + "Vertical Whitespace".to_string(), + "\\v matches any vertical whitespace character (newlines)".to_string(), + ), + BracketCharacterClassItem::NotVerticalWhitespace => ( + "Not Vertical Whitespace".to_string(), + "\\V matches any character that's not a vertical whitespace character".to_string(), + ), + }; + + children.push(AstNode { + node_type, + description, + start, + end, + children: None, + properties: None, + }); + + offset += item_len; + } + + children +} + +fn format_bracket_character_class(bracket: &BracketCharacterClass) -> String { + let mut result = String::from("["); + if bracket.negated { + result.push('^'); + } + for (i, item) in bracket.items.iter().enumerate() { + let is_first = i == 0; + let is_last = i == bracket.items.len() - 1; + result.push_str(&format_bracket_item(item, is_first, is_last)); + } + result.push(']'); + result +} + +fn format_bracket_item(item: &BracketCharacterClassItem, is_first: bool, is_last: bool) -> String { + match item { + BracketCharacterClassItem::Literal(c) => { + // Need to escape certain characters inside brackets + match c { + '\\' => "\\\\".to_string(), + ']' => "\\]".to_string(), + '^' if is_first => "\\^".to_string(), // Only escape ^ if it's first + '-' if !is_first && !is_last => "\\-".to_string(), // Don't escape - at start or end + _ => c.to_string(), + } + } + BracketCharacterClassItem::Range(start, end) => format!("{}-{}", start, end), + BracketCharacterClassItem::PerlCharacterClass(perl) => match perl { + PerlCharacterClass::Digit => "\\d", + PerlCharacterClass::Space => "\\s", + PerlCharacterClass::Word => "\\w", + PerlCharacterClass::NonDigit => "\\D", + PerlCharacterClass::NonSpace => "\\S", + PerlCharacterClass::NonWord => "\\W", + } + .to_string(), + BracketCharacterClassItem::UnicodeProperty(prop) => format_unicode_property(prop), + BracketCharacterClassItem::AsciiClass(ascii) => format_ascii_class(ascii), + BracketCharacterClassItem::HorizontalWhitespace => "\\h".to_string(), + BracketCharacterClassItem::NotHorizontalWhitespace => "\\H".to_string(), + BracketCharacterClassItem::VerticalWhitespace => "\\v".to_string(), + BracketCharacterClassItem::NotVerticalWhitespace => "\\V".to_string(), + } +} + +fn format_unicode_property(prop: &UnicodePropertyClass) -> String { + if prop.negate { + format!("\\P{{{}}}", prop.name) + } else { + format!("\\p{{{}}}", prop.name) + } +} + +fn format_ascii_class_name(kind: &AsciiClassKind) -> &'static str { + match kind { + AsciiClassKind::Alnum => "alnum", + AsciiClassKind::Alpha => "alpha", + AsciiClassKind::Ascii => "ascii", + AsciiClassKind::Blank => "blank", + AsciiClassKind::Cntrl => "cntrl", + AsciiClassKind::Digit => "digit", + AsciiClassKind::Graph => "graph", + AsciiClassKind::Lower => "lower", + AsciiClassKind::Print => "print", + AsciiClassKind::Punct => "punct", + AsciiClassKind::Space => "space", + AsciiClassKind::Upper => "upper", + AsciiClassKind::Word => "word", + AsciiClassKind::Xdigit => "xdigit", + } +} + +fn format_ascii_class(ascii: &AsciiClass) -> String { + let kind_str = format_ascii_class_name(&ascii.kind); + if ascii.negated { + format!("[:^{}:]", kind_str) + } else { + format!("[:{}:]", kind_str) + } +} + +fn format_flags(flags: &Flags) -> String { + let mut result = String::new(); + for flag in &flags.add { + result.push(match flag { + Flag::CaseInsensitive => 'i', + Flag::MultiLine => 'm', + Flag::DotMatchesNewLine => 's', + Flag::IgnoreWhitespace => 'x', + }); + } + if !flags.remove.is_empty() { + result.push('-'); + for flag in &flags.remove { + result.push(match flag { + Flag::CaseInsensitive => 'i', + Flag::MultiLine => 'm', + Flag::DotMatchesNewLine => 's', + Flag::IgnoreWhitespace => 'x', + }); + } + } + result +} + +fn describe_flags(flags: &Flags) -> String { + let mut descriptions = Vec::new(); + for flag in &flags.add { + descriptions.push(match flag { + Flag::CaseInsensitive => { + "i modifier: case insensitive. Letters match both upper and lower case" + } + Flag::MultiLine => "m modifier: multi line. ^ and $ match start/end of line", + Flag::DotMatchesNewLine => "s modifier: single line. Dot matches newline characters", + Flag::IgnoreWhitespace => "x modifier: extended. Spaces and text after # are ignored", + }); + } + if !flags.remove.is_empty() { + for flag in &flags.remove { + descriptions.push(match flag { + Flag::CaseInsensitive => "disable i modifier (case insensitive)", + Flag::MultiLine => "disable m modifier (multi line)", + Flag::DotMatchesNewLine => "disable s modifier (single line)", + Flag::IgnoreWhitespace => "disable x modifier (extended)", + }); + } + } + descriptions.join(", ") +} + +#[cfg(test)] +mod tests { + use super::*; + + // ==================== Basic Functionality Tests ==================== + + #[test] + fn test_valid_simple_pattern() { + let result = explain_regex("test"); + assert!(result.is_ok()); + let tree = result.unwrap(); + assert_eq!(tree.node_type, "Concatenation"); + } + + #[test] + fn test_invalid_pattern() { + let result = explain_regex("["); + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!(!err.is_empty()); + } + + #[test] + fn test_json_serialization() { + let result = explain_regex("foo+"); + assert!(result.is_ok()); + let tree = result.unwrap(); + let json = serde_json::to_string(&tree).unwrap(); + assert!(json.contains("node_type")); + assert!(json.contains("description")); + } + + // ==================== Position Accuracy Tests ==================== + + #[test] + fn test_literal_positions() { + let pattern = "abc"; + let tree = explain_regex(pattern).unwrap(); + + let children = tree.children.as_ref().unwrap(); + assert_eq!(children.len(), 3); + + // Each literal should have sequential positions + assert_eq!(children[0].start, 0); + assert_eq!(children[0].end, 1); // 'a' + + assert_eq!(children[1].start, 1); + assert_eq!(children[1].end, 2); // 'b' + + assert_eq!(children[2].start, 2); + assert_eq!(children[2].end, 3); // 'c' + } + + #[test] + fn test_alternation_positions() { + let pattern = "foo|bar"; + let tree = explain_regex(pattern).unwrap(); + + assert_eq!(tree.node_type, "Alternation"); + let children = tree.children.as_ref().unwrap(); + assert_eq!(children.len(), 2); + + // 'foo' at 0-3 + assert_eq!(children[0].start, 0); + assert_eq!(children[0].end, 3); + + // 'bar' at 4-7 (after '|') + assert_eq!(children[1].start, 4); + assert_eq!(children[1].end, 7); + } + + #[test] + fn test_group_positions() { + let pattern = "(test)"; + let tree = explain_regex(pattern).unwrap(); + + assert_eq!(tree.node_type, "Capturing Group"); + assert_eq!(tree.start, 0); + assert_eq!(tree.end, 6); // "(test)" + + // Inner content at 1-5 (between parens) + let inner = &tree.children.as_ref().unwrap()[0]; + assert_eq!(inner.start, 1); + assert_eq!(inner.end, 5); + } + + #[test] + fn test_repetition_positions() { + let pattern = r"\w{3}"; + let tree = explain_regex(pattern).unwrap(); + + assert_eq!(tree.node_type, "Repetition"); + assert_eq!(tree.start, 0); + assert_eq!(tree.end, 5); // "\w{3}" + + // Inner \w at 0-2 + let inner = &tree.children.as_ref().unwrap()[0]; + assert_eq!(inner.start, 0); + assert_eq!(inner.end, 2); + } + + #[test] + fn test_assertions_positions() { + let pattern = r"^test$"; + let tree = explain_regex(pattern).unwrap(); + + let children = tree.children.as_ref().unwrap(); + assert_eq!(children.len(), 6); // ^, t, e, s, t, $ + + // ^ at 0-1 + assert_eq!(children[0].node_type, "Assertion"); + assert_eq!(children[0].start, 0); + assert_eq!(children[0].end, 1); + + // $ at 5-6 + assert_eq!(children[5].node_type, "Assertion"); + assert_eq!(children[5].start, 5); + assert_eq!(children[5].end, 6); + } + + #[test] + fn test_nested_groups_positions() { + let pattern = "((x))"; + let tree = explain_regex(pattern).unwrap(); + + assert_eq!(tree.start, 0); + assert_eq!(tree.end, 5); + + let inner1 = &tree.children.as_ref().unwrap()[0]; + assert_eq!(inner1.start, 1); + assert_eq!(inner1.end, 4); + + let inner2 = &inner1.children.as_ref().unwrap()[0]; + assert_eq!(inner2.start, 2); + assert_eq!(inner2.end, 3); + } + + // ==================== Bracket Character Class Tests ==================== + + #[test] + fn test_bracket_class_reconstruction() { + let patterns = vec![ + ("[x-z]", 5), + ("[0-5]", 5), + ("[-xyz._]", 8), + ("[^abc]", 6), + (r"[\w\d]", 6), + ]; + + for (pattern, expected_len) in patterns { + let result = explain_regex(pattern); + assert!(result.is_ok(), "Pattern '{}' should be valid", pattern); + + let tree = result.unwrap(); + assert_eq!( + tree.end, expected_len, + "Pattern '{}' length mismatch", + pattern + ); + } + } + + #[test] + fn test_bracket_class_children() { + let pattern = "[-a-z0-9]"; + let tree = explain_regex(pattern).unwrap(); + + assert_eq!(tree.node_type, "Character Class"); + + // Should have children for each item + let children = tree.children.as_ref().unwrap(); + assert!(children.len() >= 2); // At least '-' and ranges + + // Verify each child has valid positions + for child in children { + assert!(child.start < child.end); + assert!(child.end <= pattern.len()); + } + } + + #[test] + fn test_bracket_class_with_ranges() { + let pattern = "[a-z0-9]"; + let tree = explain_regex(pattern).unwrap(); + let children = tree.children.as_ref().unwrap(); + + // Should have children for each range + for child in children { + validate_node_positions(child, pattern); + } + } + + // ==================== Complex Patterns ==================== + + #[test] + fn test_phone_number_like_pattern() { + // Using explicit pattern, not actual phone number regex + let pattern = r"\d{3}-\d{3}"; + let tree = explain_regex(pattern).unwrap(); + validate_node_positions(&tree, pattern); + } + + #[test] + fn test_alternation_with_groups() { + let pattern = r"(x|y)|z"; + let tree = explain_regex(pattern).unwrap(); + validate_node_positions(&tree, pattern); + } + + #[test] + fn test_nested_repetitions() { + let pattern = r"(x+)+"; + let tree = explain_regex(pattern).unwrap(); + validate_node_positions(&tree, pattern); + } + + #[test] + fn test_multiple_assertions() { + let pattern = r"^\btest\b$"; + let tree = explain_regex(pattern).unwrap(); + validate_node_positions(&tree, pattern); + } + + #[test] + fn test_escaped_slash_in_bracket_class() { + // Test pattern with escaped forward slash in bracket class + let pattern = r"[\w.+\/=-]+"; + let result = explain_regex(pattern); + + assert!(result.is_ok(), "Pattern should be valid"); + let tree = result.unwrap(); + + // The full pattern should be captured + assert_eq!(tree.start, 0); + assert_eq!(tree.end, pattern.len()); + + // Walk through and validate all positions + validate_node_positions(&tree, pattern); + } + + #[test] + fn test_complex_pattern_with_escape_sequences() { + // Test pattern with named capture groups and character classes + // This demonstrates: word boundaries, named groups, and character class ranges + let pattern = "\\b(?example[A-Z0-9]{2})\\b"; + let result = explain_regex(pattern); + + assert!(result.is_ok(), "Pattern should be valid: {:?}", result); + let tree = result.unwrap(); + + // Root should span entire pattern + assert_eq!(tree.start, 0, "Root should start at 0"); + assert_eq!(tree.end, pattern.len(), "Root should end at pattern length"); + + // Validate all positions recursively - this will check that groups end with ')' + validate_node_positions(&tree, pattern); + } + + // ==================== Quantifier Tests ==================== + + #[test] + fn test_greedy_quantifiers() { + let patterns = vec!["x*", "x+", "x?", "x{2,5}", "x{2,}"]; + + for pattern in patterns { + let tree = explain_regex(pattern).unwrap(); + assert_eq!(tree.node_type, "Repetition"); + + // Greedy quantifiers have lazy: false in properties + if let Some(properties) = &tree.properties { + if let Some(is_lazy) = properties.get("lazy") { + assert_eq!( + is_lazy, + &serde_json::json!(false), + "Pattern '{}' should be greedy", + pattern + ); + } + } + } + } + + #[test] + fn test_lazy_quantifiers() { + let patterns = vec!["x*?", "x+?", "x??", "x{2,5}?", "x{2,}?"]; + + for pattern in patterns { + let tree = explain_regex(pattern).unwrap(); + assert!(tree.description.contains("lazy")); + } + } + + // ==================== Flags Tests ==================== + + #[test] + fn test_case_insensitive_flag() { + let pattern = "(?i)test"; + let tree = explain_regex(pattern).unwrap(); + let children = tree.children.as_ref().unwrap(); + + assert_eq!(children[0].node_type, "Flags"); + assert!(children[0].description.contains("case insensitive")); + } + + #[test] + fn test_multiple_flags() { + let pattern = "(?ims)test"; + let tree = explain_regex(pattern).unwrap(); + validate_node_positions(&tree, pattern); + } + + // ==================== Safety and Edge Cases ==================== + + #[test] + fn test_positions_never_exceed_bounds() { + let patterns = vec![ + "x", + "test", + r"\d+", + "foo|bar", + "(test)", + r"x{1,5}", + "[a-z]+", + r"\btest\b", + "(?:foo|bar)+", + r"x+?y*?z{2,5}?", + r"((x|(y|z))+\.){2}", + ]; + + for pattern in patterns { + if let Ok(tree) = explain_regex(pattern) { + validate_node_positions(&tree, pattern); + } + } + } + + #[test] + fn test_highlighting_safety() { + // Patterns that could potentially cause highlighting issues + let patterns = vec![ + "test", + r"\w+", + "x|y", + "(test)", + r"x{3}-y{4}", + "[a-z]+", + r"x+y*z?", + r"\btest\b", + "(?:foo)+", + r"(x|(y|z))", + "((((w))))", + ]; + + for pattern in patterns { + if let Ok(tree) = explain_regex(pattern) { + validate_node_positions(&tree, pattern); + } + } + } + + #[test] + fn test_no_overlapping_children() { + let pattern = "xyz"; + let tree = explain_regex(pattern).unwrap(); + let children = tree.children.unwrap(); + + // Positions should be sequential + for i in 0..children.len() - 1 { + assert!(children[i].end <= children[i + 1].start); + } + } + + #[test] + fn test_deeply_nested_groups() { + let pattern = "((((x))))"; + let tree = explain_regex(pattern).unwrap(); + validate_node_positions(&tree, pattern); + } + + #[test] + fn test_empty_alternation() { + let pattern = "x||y"; + let tree = explain_regex(pattern).unwrap(); + validate_node_positions(&tree, pattern); + } + + #[test] + fn test_unicode_property() { + let pattern = r"\p{L}+"; + let tree = explain_regex(pattern).unwrap(); + validate_node_positions(&tree, pattern); + } + + // ==================== Multi-byte Character Tests ==================== + + #[test] + fn test_emoji_in_literal() { + // Emoji are multi-byte UTF-8 characters + let pattern = "πŸ˜€"; + let tree = explain_regex(pattern).unwrap(); + + assert_eq!(tree.node_type, "Literal"); + assert_eq!(tree.start, 0); + assert_eq!(tree.end, pattern.len()); // Should be 4 bytes + + // Verify we can safely slice + let slice = &pattern[tree.start..tree.end]; + assert_eq!(slice, "πŸ˜€"); + } + + #[test] + fn test_emoji_in_concatenation() { + let pattern = "aπŸ˜€b"; + let tree = explain_regex(pattern).unwrap(); + + assert_eq!(tree.node_type, "Concatenation"); + validate_node_positions(&tree, pattern); + } + + #[test] + fn test_emoji_in_alternation() { + let pattern = "πŸ˜€|😎"; + let tree = explain_regex(pattern).unwrap(); + + assert_eq!(tree.node_type, "Alternation"); + let children = tree.children.as_ref().unwrap(); + assert_eq!(children.len(), 2); + + // Verify positions for each emoji + validate_node_positions(&tree, pattern); + } + + #[test] + fn test_emoji_in_group() { + let pattern = "(πŸ˜€+)"; + let tree = explain_regex(pattern).unwrap(); + + assert_eq!(tree.node_type, "Capturing Group"); + validate_node_positions(&tree, pattern); + } + + #[test] + fn test_mixed_ascii_and_emoji() { + let pattern = "testπŸ˜€abc"; + let tree = explain_regex(pattern).unwrap(); + + validate_node_positions(&tree, pattern); + + // Ensure no child spans beyond pattern length + fn check_bounds(node: &AstNode, pattern_len: usize) { + assert!( + node.end <= pattern_len, + "Node end {} exceeds pattern length {}", + node.end, + pattern_len + ); + if let Some(children) = &node.children { + for child in children { + check_bounds(child, pattern_len); + } + } + } + check_bounds(&tree, pattern.len()); + } + + #[test] + fn test_unicode_combining_characters() { + // Combining characters (e.g., Γ© = e + combining acute) + let pattern = "cafΓ©"; + let tree = explain_regex(pattern).unwrap(); + + validate_node_positions(&tree, pattern); + } + + #[test] + fn test_chinese_characters() { + let pattern = "δ½ ε₯½"; + let tree = explain_regex(pattern).unwrap(); + + assert_eq!(tree.node_type, "Concatenation"); + validate_node_positions(&tree, pattern); + } + + #[test] + fn test_emoji_with_regex_operators() { + let pattern = "πŸ˜€+|😎*"; + let tree = explain_regex(pattern).unwrap(); + + validate_node_positions(&tree, pattern); + } + + // ==================== Helper Functions ==================== + + fn validate_node_positions(node: &AstNode, pattern: &str) { + assert!( + node.start <= node.end, + "Start ({}) should be <= end ({}) for {}", + node.start, + node.end, + node.node_type + ); + assert!( + node.end <= pattern.len(), + "End ({}) should be <= pattern length ({}) for {}", + node.end, + pattern.len(), + node.node_type + ); + + // Validate that group nodes end with ')' + if node.start < pattern.len() && node.end <= pattern.len() && node.start < node.end { + let content = &pattern[node.start..node.end]; + if node.node_type.contains("Group") { + assert!( + content.ends_with(')'), + "{} [{}..{}] should end with ')' but content is: {:?}", + node.node_type, + node.start, + node.end, + content + ); + } + } + + if let Some(children) = &node.children { + for child in children { + validate_node_positions(child, pattern); + } + } + } +} diff --git a/sds/src/parser/mod.rs b/sds/src/parser/mod.rs index b75cbf55..7a473d2a 100644 --- a/sds/src/parser/mod.rs +++ b/sds/src/parser/mod.rs @@ -1,5 +1,6 @@ pub(crate) mod ast; pub(crate) mod error; +pub mod explainer; pub(crate) mod regex_parser; pub(crate) mod unicode_property_names;