Skip to content

Commit c8e61a0

Browse files
committed
Fix pattern matching to enforce consistent variable bindings (#65)
Pattern variables (e.g. x_) appearing multiple times in a pattern must bind to the same value. Previously, bindings were merged with extend() which silently allowed conflicting values, causing expressions like Sqrt[x] to incorrectly match x_ when it was already bound to x.
1 parent 986c16a commit c8e61a0

File tree

2 files changed

+73
-7
lines changed

2 files changed

+73
-7
lines changed

src/evaluator/pattern_matching.rs

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,30 @@
11
#[allow(unused_imports)]
22
use super::*;
33

4+
/// Merge new bindings into existing bindings, checking for consistency.
5+
/// If a variable name already has a binding, the new value must be
6+
/// structurally equal. Returns false if there is a conflict.
7+
fn merge_bindings(
8+
existing: &mut Vec<(String, Expr)>,
9+
new: Vec<(String, Expr)>,
10+
) -> bool {
11+
for (name, value) in new {
12+
if name.is_empty() {
13+
continue;
14+
}
15+
if let Some((_, existing_value)) = existing.iter().find(|(n, _)| *n == name)
16+
{
17+
if !expr_equal(existing_value, &value) {
18+
return false;
19+
}
20+
// Already bound to the same value, skip duplicate
21+
} else {
22+
existing.push((name, value));
23+
}
24+
}
25+
true
26+
}
27+
428
/// Perform nested access on an association: assoc["a", "b"] -> assoc["a"]["b"]
529
pub fn association_nested_access(
630
var_name: &str,
@@ -332,8 +356,9 @@ pub fn try_one_identity_match(
332356
// try matching the expression against it
333357
if required_indices.len() == 1 {
334358
let req_pat = &pat_args[required_indices[0]];
335-
if let Some(mut req_bindings) = match_pattern(expr, req_pat) {
336-
req_bindings.extend(bindings);
359+
if let Some(mut req_bindings) = match_pattern(expr, req_pat)
360+
&& merge_bindings(&mut req_bindings, bindings)
361+
{
337362
return Some(req_bindings);
338363
}
339364
}
@@ -1412,8 +1437,8 @@ fn match_args_with_sequences(
14121437
if let Some(mut bindings) = match_pattern(&expr_args[0], pat)
14131438
&& let Some(rest_bindings) =
14141439
match_args_with_sequences(&expr_args[1..], rest_pats)
1440+
&& merge_bindings(&mut bindings, rest_bindings)
14151441
{
1416-
bindings.extend(rest_bindings);
14171442
return Some(bindings);
14181443
}
14191444
None
@@ -1592,7 +1617,9 @@ pub fn match_pattern(
15921617
let mut bindings = Vec::new();
15931618
for (p, e) in pat_items.iter().zip(expr_items.iter()) {
15941619
if let Some(b) = match_pattern(e, p) {
1595-
bindings.extend(b);
1620+
if !merge_bindings(&mut bindings, b) {
1621+
return None;
1622+
}
15961623
} else {
15971624
return None;
15981625
}
@@ -1639,7 +1666,9 @@ pub fn match_pattern(
16391666
let mut bindings = Vec::new();
16401667
for (p, e) in pat_args.iter().zip(expr_args.iter()) {
16411668
if let Some(b) = match_pattern(e, p) {
1642-
bindings.extend(b);
1669+
if !merge_bindings(&mut bindings, b) {
1670+
return None;
1671+
}
16431672
} else {
16441673
return None;
16451674
}
@@ -1679,12 +1708,16 @@ pub fn match_pattern(
16791708
}
16801709
let mut bindings = Vec::new();
16811710
if let Some(b) = match_pattern(expr_left, pat_left) {
1682-
bindings.extend(b);
1711+
if !merge_bindings(&mut bindings, b) {
1712+
return None;
1713+
}
16831714
} else {
16841715
return None;
16851716
}
16861717
if let Some(b) = match_pattern(expr_right, pat_right) {
1687-
bindings.extend(b);
1718+
if !merge_bindings(&mut bindings, b) {
1719+
return None;
1720+
}
16881721
} else {
16891722
return None;
16901723
}

tests/interpreter_tests/syntax.rs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -811,6 +811,39 @@ mod pattern_function {
811811
"f[a, b]^2"
812812
);
813813
}
814+
815+
#[test]
816+
fn pattern_variable_binding_consistency() {
817+
// Same named pattern variable must bind to the same value
818+
// f[x_, x_] should match f[a, a] but not f[a, b]
819+
assert_eq!(interpret("f[a, a] /. f[x_, x_] -> yes").unwrap(), "yes");
820+
assert_eq!(interpret("f[a, b] /. f[x_, x_] -> yes").unwrap(), "f[a, b]");
821+
}
822+
823+
#[test]
824+
fn pattern_variable_no_match_sqrt_vs_symbol() {
825+
// Regression test for issue #65:
826+
// x_ bound to Symbol x should not match Sqrt[x]
827+
assert_eq!(
828+
interpret(
829+
"Int[1/(x_*(a_+b_.*x_)),x_Symbol] := \
830+
-Log[(a+b*x)/x]/a /; FreeQ[{a,b},x]; \
831+
Int[1/(Sqrt[x]*(a + b*x)), x]"
832+
)
833+
.unwrap(),
834+
"Int[1/(Sqrt[x]*(a + b*x)), x]"
835+
);
836+
// But it should still match when x_ consistently binds to x
837+
assert_eq!(
838+
interpret(
839+
"Int[1/(x_*(a_+b_.*x_)),x_Symbol] := \
840+
-Log[(a+b*x)/x]/a /; FreeQ[{a,b},x]; \
841+
Int[1/(x*(a + b*x)), x]"
842+
)
843+
.unwrap(),
844+
"-(Log[(a + b*x)/x]/a)"
845+
);
846+
}
814847
}
815848

816849
mod none_symbol {

0 commit comments

Comments
 (0)