Skip to content

Commit c7f2b92

Browse files
committed
Fix Part access on parenthesized expressions, E^(x/c) and polynomial power integration
- Add ParenExtended grammar rule so (expr)[[index]] parses correctly - Extend try_match_linear_arg to handle var/const (e.g. x/2) with proper Rationals - Enable closed-form poly × E^(cx) integration for fractional coefficients - Expand polynomial powers (e.g. (x+1)^2) before integrating term-by-term - Add 12 unit tests covering all three fixes
1 parent 8af1bf5 commit c7f2b92

File tree

5 files changed

+345
-62
lines changed

5 files changed

+345
-62
lines changed

src/functions/calculus_ast.rs

Lines changed: 222 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -904,6 +904,40 @@ fn differentiate(expr: &Expr, var: &str) -> Result<Expr, InterpreterError> {
904904
fn make_divided(expr: Expr, divisor: Expr) -> Expr {
905905
match &divisor {
906906
Expr::Integer(1) => expr,
907+
// expr / (a/b) → expr * b/a = (b * expr) / a
908+
Expr::BinaryOp {
909+
op: crate::syntax::BinaryOperator::Divide,
910+
left: num,
911+
right: den,
912+
} => {
913+
let result = Expr::BinaryOp {
914+
op: crate::syntax::BinaryOperator::Divide,
915+
left: Box::new(Expr::BinaryOp {
916+
op: crate::syntax::BinaryOperator::Times,
917+
left: den.clone(),
918+
right: Box::new(expr),
919+
}),
920+
right: num.clone(),
921+
};
922+
simplify(result)
923+
}
924+
// expr / Rational[a, b] → (b * expr) / a
925+
Expr::FunctionCall { name, args }
926+
if name == "Rational" && args.len() == 2 =>
927+
{
928+
let num = &args[0]; // a
929+
let den = &args[1]; // b
930+
let result = Expr::BinaryOp {
931+
op: crate::syntax::BinaryOperator::Divide,
932+
left: Box::new(Expr::BinaryOp {
933+
op: crate::syntax::BinaryOperator::Times,
934+
left: Box::new(den.clone()),
935+
right: Box::new(expr),
936+
}),
937+
right: Box::new(num.clone()),
938+
};
939+
simplify(result)
940+
}
907941
_ => Expr::BinaryOp {
908942
op: crate::syntax::BinaryOperator::Divide,
909943
left: Box::new(expr),
@@ -1047,6 +1081,50 @@ fn try_match_linear_arg(expr: &Expr, var: &str) -> Option<Expr> {
10471081
None
10481082
}
10491083
}
1084+
// x/c form: var/const → coefficient is 1/const (i.e., Rational[1,c] for integer c)
1085+
Expr::BinaryOp {
1086+
op: crate::syntax::BinaryOperator::Divide,
1087+
left,
1088+
right,
1089+
} => {
1090+
if matches!(left.as_ref(), Expr::Identifier(name) if name == var)
1091+
&& is_constant_wrt(right, var)
1092+
{
1093+
// coefficient = 1/right — use division_ast for proper Rational creation
1094+
if let Ok(result) = crate::functions::math_ast::divide_ast(&[
1095+
Expr::Integer(1),
1096+
*right.clone(),
1097+
]) {
1098+
Some(result)
1099+
} else {
1100+
Some(Expr::BinaryOp {
1101+
op: crate::syntax::BinaryOperator::Divide,
1102+
left: Box::new(Expr::Integer(1)),
1103+
right: right.clone(),
1104+
})
1105+
}
1106+
} else if is_constant_wrt(right, var) {
1107+
// (expr)/const where expr might be a*x → coefficient is a/const
1108+
if let Some(inner_coeff) = try_match_linear_arg(left, var) {
1109+
if let Ok(result) = crate::functions::math_ast::divide_ast(&[
1110+
inner_coeff.clone(),
1111+
*right.clone(),
1112+
]) {
1113+
Some(result)
1114+
} else {
1115+
Some(Expr::BinaryOp {
1116+
op: crate::syntax::BinaryOperator::Divide,
1117+
left: Box::new(inner_coeff),
1118+
right: right.clone(),
1119+
})
1120+
}
1121+
} else {
1122+
None
1123+
}
1124+
} else {
1125+
None
1126+
}
1127+
}
10501128
// FunctionCall("Times", [coeff, var]) form
10511129
Expr::FunctionCall { name, args } if name == "Times" => {
10521130
// Find the variable factor and collect the rest as coefficient
@@ -2095,15 +2173,21 @@ fn try_integrate_poly_times_const_exp(
20952173
) -> Option<Expr> {
20962174
use crate::syntax::BinaryOperator::*;
20972175

2098-
let log_base = Expr::FunctionCall {
2099-
name: "Log".to_string(),
2100-
args: vec![base.clone()],
2176+
// For base E, Log[E] = 1, so rate = coeff directly.
2177+
// For other bases, rate = coeff * Log[base].
2178+
let rate = if matches!(base, Expr::Constant(c) if c == "E") {
2179+
simplify(coeff.clone())
2180+
} else {
2181+
let log_base = Expr::FunctionCall {
2182+
name: "Log".to_string(),
2183+
args: vec![base.clone()],
2184+
};
2185+
simplify(Expr::BinaryOp {
2186+
op: Times,
2187+
left: Box::new(coeff.clone()),
2188+
right: Box::new(log_base),
2189+
})
21012190
};
2102-
let rate = simplify(Expr::BinaryOp {
2103-
op: Times,
2104-
left: Box::new(coeff.clone()),
2105-
right: Box::new(log_base),
2106-
});
21072191

21082192
// Collect derivatives of poly until we reach 0
21092193
let mut derivs = vec![poly.clone()];
@@ -2122,74 +2206,139 @@ fn try_integrate_poly_times_const_exp(
21222206
}
21232207
}
21242208

2209+
// For numeric rates (e.g., base E with fractional coeff), compute each term
2210+
// directly with 1/rate^(k+1) to get clean integer coefficients.
2211+
// For symbolic rates (non-E bases involving Log), use the common-denominator
2212+
// form: (exponential * Σ P^(k)(x)*rate^(n-1-k)) / rate^n.
2213+
let is_numeric_rate = matches!(&rate, Expr::Integer(_))
2214+
|| matches!(&rate, Expr::FunctionCall { name, .. } if name == "Rational");
2215+
21252216
let n = derivs.len();
21262217

2127-
// Build numerator: Σ_{k=0}^{n-1} (-1)^k * P^(k)(x) * rate^(n-1-k)
2128-
let mut num_terms = Vec::new();
2129-
for (k, deriv) in derivs.iter().enumerate() {
2130-
let rate_power = n as i128 - 1 - k as i128;
2131-
let rate_factor = if rate_power == 0 {
2132-
Expr::Integer(1)
2133-
} else if rate_power == 1 {
2134-
rate.clone()
2135-
} else {
2136-
Expr::BinaryOp {
2137-
op: Power,
2138-
left: Box::new(rate.clone()),
2139-
right: Box::new(Expr::Integer(rate_power)),
2218+
if is_numeric_rate {
2219+
// Direct approach: Σ (-1)^k * P^(k)(x) / rate^(k+1)
2220+
let inv_rate =
2221+
crate::functions::math_ast::divide_ast(&[Expr::Integer(1), rate.clone()])
2222+
.unwrap_or_else(|_| Expr::BinaryOp {
2223+
op: Divide,
2224+
left: Box::new(Expr::Integer(1)),
2225+
right: Box::new(rate.clone()),
2226+
});
2227+
2228+
let mut num_terms = Vec::new();
2229+
for (k, deriv) in derivs.iter().enumerate() {
2230+
let k1 = k as i128 + 1;
2231+
let inv_rate_factor = if k1 == 1 {
2232+
inv_rate.clone()
2233+
} else {
2234+
crate::functions::math_ast::power_two(&inv_rate, &Expr::Integer(k1))
2235+
.unwrap_or_else(|_| Expr::BinaryOp {
2236+
op: Power,
2237+
left: Box::new(inv_rate.clone()),
2238+
right: Box::new(Expr::Integer(k1)),
2239+
})
2240+
};
2241+
2242+
let mut term = simplify(Expr::BinaryOp {
2243+
op: Times,
2244+
left: Box::new(deriv.clone()),
2245+
right: Box::new(inv_rate_factor),
2246+
});
2247+
2248+
if k % 2 == 1 {
2249+
term = simplify(Expr::BinaryOp {
2250+
op: Times,
2251+
left: Box::new(Expr::Integer(-1)),
2252+
right: Box::new(term),
2253+
});
21402254
}
2255+
2256+
num_terms.push(term);
2257+
}
2258+
2259+
let numerator = if num_terms.len() == 1 {
2260+
num_terms.into_iter().next().unwrap()
2261+
} else {
2262+
let combined = Expr::FunctionCall {
2263+
name: "Plus".to_string(),
2264+
args: num_terms,
2265+
};
2266+
crate::functions::polynomial_ast::expand_and_combine(&combined)
21412267
};
21422268

2143-
let mut term = simplify(Expr::BinaryOp {
2269+
let result = simplify(Expr::BinaryOp {
21442270
op: Times,
2145-
left: Box::new(deriv.clone()),
2146-
right: Box::new(rate_factor),
2271+
left: Box::new(exponential.clone()),
2272+
right: Box::new(numerator),
21472273
});
21482274

2149-
if k % 2 == 1 {
2150-
term = simplify(Expr::BinaryOp {
2275+
Some(result)
2276+
} else {
2277+
// Common-denominator approach: (exponential * Σ P^(k)(x)*rate^(n-1-k)) / rate^n
2278+
let mut num_terms = Vec::new();
2279+
for (k, deriv) in derivs.iter().enumerate() {
2280+
let rate_power = n as i128 - 1 - k as i128;
2281+
let rate_factor = if rate_power == 0 {
2282+
Expr::Integer(1)
2283+
} else if rate_power == 1 {
2284+
rate.clone()
2285+
} else {
2286+
Expr::BinaryOp {
2287+
op: Power,
2288+
left: Box::new(rate.clone()),
2289+
right: Box::new(Expr::Integer(rate_power)),
2290+
}
2291+
};
2292+
2293+
let mut term = simplify(Expr::BinaryOp {
21512294
op: Times,
2152-
left: Box::new(Expr::Integer(-1)),
2153-
right: Box::new(term),
2295+
left: Box::new(deriv.clone()),
2296+
right: Box::new(rate_factor),
21542297
});
2298+
2299+
if k % 2 == 1 {
2300+
term = simplify(Expr::BinaryOp {
2301+
op: Times,
2302+
left: Box::new(Expr::Integer(-1)),
2303+
right: Box::new(term),
2304+
});
2305+
}
2306+
2307+
num_terms.push(term);
21552308
}
21562309

2157-
num_terms.push(term);
2158-
}
2310+
let numerator = if num_terms.len() == 1 {
2311+
num_terms.into_iter().next().unwrap()
2312+
} else {
2313+
let combined = Expr::FunctionCall {
2314+
name: "Plus".to_string(),
2315+
args: num_terms,
2316+
};
2317+
crate::functions::polynomial_ast::expand_and_combine(&combined)
2318+
};
21592319

2160-
let numerator = if num_terms.len() == 1 {
2161-
num_terms.into_iter().next().unwrap()
2162-
} else {
2163-
let combined = Expr::FunctionCall {
2164-
name: "Plus".to_string(),
2165-
args: num_terms,
2320+
let denom = if n == 1 {
2321+
rate
2322+
} else {
2323+
Expr::BinaryOp {
2324+
op: Power,
2325+
left: Box::new(rate),
2326+
right: Box::new(Expr::Integer(n as i128)),
2327+
}
21662328
};
2167-
crate::functions::polynomial_ast::expand_and_combine(&combined)
2168-
};
21692329

2170-
// Denominator: rate^n
2171-
let denom = if n == 1 {
2172-
rate
2173-
} else {
2174-
Expr::BinaryOp {
2175-
op: Power,
2176-
left: Box::new(rate),
2177-
right: Box::new(Expr::Integer(n as i128)),
2178-
}
2179-
};
2330+
let result_num = simplify(Expr::BinaryOp {
2331+
op: Times,
2332+
left: Box::new(exponential.clone()),
2333+
right: Box::new(numerator),
2334+
});
21802335

2181-
// Result: (exponential * numerator) / denom
2182-
let result_num = simplify(Expr::BinaryOp {
2183-
op: Times,
2184-
left: Box::new(exponential.clone()),
2185-
right: Box::new(numerator),
2186-
});
2187-
2188-
Some(Expr::BinaryOp {
2189-
op: Divide,
2190-
left: Box::new(result_num),
2191-
right: Box::new(denom),
2192-
})
2336+
Some(Expr::BinaryOp {
2337+
op: Divide,
2338+
left: Box::new(result_num),
2339+
right: Box::new(denom),
2340+
})
2341+
}
21932342
}
21942343

21952344
/// Try integration by parts: ∫ u dv = u*v - ∫ v du
@@ -2227,8 +2376,9 @@ fn try_integration_by_parts(factors: &[&Expr], var: &str) -> Option<Expr> {
22272376
.map(|(_, f)| *f)
22282377
.collect();
22292378

2230-
// Special case: polynomial × constant-base exponential (non-E base)
2379+
// Special case: polynomial × constant-base exponential (including E base)
22312380
// Use closed-form formula: ∫ P(x)*a^(cx) dx = a^(cx) * Σ (-1)^k P^(k)(x) / (c*Log[a])^(k+1)
2381+
// For a=E, rate = c*Log[E] = c, giving direct polynomial-times-E^(cx) integration.
22322382
if dv_factors.len() == 1
22332383
&& let Expr::BinaryOp {
22342384
op: crate::syntax::BinaryOperator::Power,
@@ -2237,7 +2387,6 @@ fn try_integration_by_parts(factors: &[&Expr], var: &str) -> Option<Expr> {
22372387
} = dv_factors[0]
22382388
&& is_constant_wrt(base, var)
22392389
&& !is_constant_wrt(exp, var)
2240-
&& !matches!(base.as_ref(), Expr::Constant(c) if c == "E")
22412390
&& let Some(coeff) = try_match_linear_arg(exp, var)
22422391
&& is_polynomial_in(u, var)
22432392
{
@@ -2608,6 +2757,17 @@ fn integrate(expr: &Expr, var: &str) -> Option<Expr> {
26082757
{
26092758
return Some(result);
26102759
}
2760+
// ∫ f(x)^n dx where n is a positive integer: try expanding
2761+
if let Expr::Integer(n) = right.as_ref()
2762+
&& *n >= 2
2763+
&& !is_constant_wrt(left, var)
2764+
{
2765+
let expanded =
2766+
crate::functions::polynomial_ast::expand_and_combine(expr);
2767+
if !expr_str_eq(&expanded, expr) {
2768+
return integrate(&expanded, var);
2769+
}
2770+
}
26112771
None
26122772
}
26132773
_ => None,

src/syntax.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1958,6 +1958,25 @@ pub fn pair_to_expr(pair: Pair<Rule>) -> Expr {
19581958
}
19591959
result
19601960
}
1961+
Rule::ParenExtended => {
1962+
// (expr)[[index]] -> Part[expr, index]
1963+
let inner_pairs: Vec<_> = pair.into_inner().collect();
1964+
// First inner pair is the expression, then PartIndexSuffix elements
1965+
let base_expr = pair_to_expr(inner_pairs[0].clone());
1966+
let part_indices: Vec<Expr> = inner_pairs
1967+
.iter()
1968+
.filter(|p| matches!(p.as_rule(), Rule::PartIndexSuffix))
1969+
.flat_map(|p| p.clone().into_inner().map(pair_to_expr))
1970+
.collect();
1971+
let mut result = base_expr;
1972+
for idx in &part_indices {
1973+
result = Expr::Part {
1974+
expr: Box::new(result),
1975+
index: Box::new(idx.clone()),
1976+
};
1977+
}
1978+
result
1979+
}
19611980
Rule::Increment => {
19621981
// x++ -> Increment[x]
19631982
let inner = pair.into_inner().next().unwrap();

0 commit comments

Comments
 (0)