Skip to content

Commit 068350e

Browse files
authored
Merge pull request #298 from yfnaji/b-spline
B-Spline Interpolation Implementation
2 parents 0b8e1be + 0dc1977 commit 068350e

File tree

3 files changed

+253
-4
lines changed

3 files changed

+253
-4
lines changed

crates/RustQuant_error/src/lib.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,14 @@ pub enum RustQuantError {
115115
/// Outside of interpolation range.
116116
#[error("Outside of interpolation range.")]
117117
OutsideOfRange,
118+
119+
/// Inconsistent B-Spline parameter lengths.
120+
#[error("For {0} control points and degree {1}, we need {0} + {1} + 1 ({2}) knots.")]
121+
BSplineInvalidParameters(usize, usize, usize),
122+
123+
/// Outside of B-Spline interpolation range.
124+
#[error("{0}")]
125+
BSplineOutsideOfRange(String),
118126
}
119127

120128
/// Curve error enum.
Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
2+
// RustQuant: A Rust library for quantitative finance tools.
3+
// Copyright (C) 2023 https://github.com/avhz
4+
// Dual licensed under Apache 2.0 and MIT.
5+
// See:
6+
// - LICENSE-APACHE.md
7+
// - LICENSE-MIT.md
8+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
9+
10+
//! Module containing functionality for interpolation.
11+
12+
use crate::interpolation::{InterpolationIndex, InterpolationValue, Interpolator};
13+
use RustQuant_error::RustQuantError;
14+
15+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
16+
// STRUCTS & ENUMS
17+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
18+
19+
/// B-Spline Interpolator.
20+
pub struct BSplineInterpolator<IndexType, ValueType>
21+
where
22+
IndexType: InterpolationIndex<DeltaDiv = ValueType>,
23+
ValueType: InterpolationValue,
24+
{
25+
/// Knots of the B-Spline.
26+
pub knots: Vec<IndexType>,
27+
28+
/// Control points of the B-Spline.
29+
pub control_points: Vec<ValueType>,
30+
31+
/// Degree of B-Spline.
32+
pub degree: usize,
33+
34+
/// Whether the interpolator has been fitted.
35+
pub fitted: bool,
36+
}
37+
38+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
39+
// IMPLEMENTATIONS, FUNCTIONS, AND MACROS
40+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
41+
42+
impl<IndexType, ValueType> BSplineInterpolator<IndexType, ValueType>
43+
where
44+
IndexType: InterpolationIndex<DeltaDiv = ValueType>,
45+
ValueType: InterpolationValue,
46+
{
47+
/// Create a new BSplineInterpolator.
48+
///
49+
/// # Errors
50+
/// - `RustQuantError::UnequalLength` if ```xs.length() != ys.length()```.
51+
///
52+
/// # Panics
53+
/// Panics if NaN is in the index.
54+
pub fn new(
55+
mut knots: Vec<IndexType>,
56+
control_points: Vec<ValueType>,
57+
degree: usize
58+
) -> Result<BSplineInterpolator<IndexType, ValueType>, RustQuantError> {
59+
60+
if knots.len() != control_points.len() + degree + 1 {
61+
return Err(RustQuantError::BSplineInvalidParameters(
62+
control_points.len(), degree, control_points.len() + degree + 1,
63+
));
64+
}
65+
66+
knots.sort_by(|a, b| a.partial_cmp(b).unwrap());
67+
68+
Ok(Self {
69+
knots,
70+
control_points,
71+
degree,
72+
fitted: false,
73+
})
74+
}
75+
76+
/// Cox de Boor algorithm to evalute the spline curves.
77+
fn cox_de_boor(&self, point: IndexType, index: usize, degree: usize) -> ValueType {
78+
if degree == 0 {
79+
return if point.ge(&self.knots[index]) && point.lt(&self.knots[index + 1]) {
80+
ValueType::one()
81+
} else {
82+
ValueType::zero()
83+
}
84+
}
85+
86+
let mut left_term: ValueType = ValueType::zero();
87+
let mut right_term: ValueType = ValueType::zero();
88+
89+
if self.knots[index + degree] != self.knots[index] {
90+
left_term = ((point - self.knots[index]) / (self.knots[index + degree] - self.knots[index]))
91+
* self.cox_de_boor(point, index, degree - 1);
92+
}
93+
94+
if self.knots[index + degree + 1] != self.knots[index + 1] {
95+
right_term = ((self.knots[index + degree + 1] - point) / (self.knots[index + degree + 1] - self.knots[index + 1]))
96+
* self.cox_de_boor(point, index + 1, degree - 1);
97+
}
98+
left_term + right_term
99+
}
100+
}
101+
102+
impl<IndexType, ValueType> Interpolator<IndexType, ValueType>
103+
for BSplineInterpolator<IndexType, ValueType>
104+
where
105+
IndexType: InterpolationIndex<DeltaDiv = ValueType>,
106+
ValueType: InterpolationValue,
107+
{
108+
fn fit(&mut self) -> Result<(), RustQuantError> {
109+
110+
self.fitted = true;
111+
Ok(())
112+
}
113+
114+
fn range(&self) -> (IndexType, IndexType) {
115+
(*self.knots.first().unwrap(), *self.knots.last().unwrap())
116+
}
117+
118+
fn add_point(&mut self, point: (IndexType, ValueType)) {
119+
let idx = self.knots.partition_point(|&x| x < point.0);
120+
self.knots.insert(idx, point.0);
121+
self.control_points.insert(self.control_points.len(), point.1);
122+
}
123+
124+
125+
fn interpolate(&self, point: IndexType) -> Result<ValueType, RustQuantError> {
126+
if !(point.ge(&self.knots[self.degree]) && point.le(&self.knots[self.knots.len() - self.degree - 1])) {
127+
128+
let error_message: String = format!(
129+
"Point {} is outside of the interpolation range [{}, {}]",
130+
point,
131+
self.knots[self.degree],
132+
self.knots[self.knots.len() - self.degree - 1]
133+
);
134+
return Err(RustQuantError::BSplineOutsideOfRange(error_message));
135+
}
136+
137+
let mut value = ValueType::zero();
138+
for (index, control_point) in self.control_points.iter().enumerate() {
139+
value += self.cox_de_boor(point, index, self.degree) * (*control_point);
140+
}
141+
142+
Ok(value)
143+
}
144+
}
145+
146+
#[cfg(test)]
147+
mod tests_b_splines {
148+
use super::*;
149+
use RustQuant_utils::{assert_approx_equal, RUSTQUANT_EPSILON};
150+
151+
#[test]
152+
fn test_b_spline_uniform_knots() {
153+
let knots = vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
154+
let control_points = vec![-1.0, 2.0, 0.0, -1.0];
155+
156+
let mut interpolator = BSplineInterpolator::new(knots, control_points, 2).unwrap();
157+
let _ = interpolator.fit();
158+
159+
assert_approx_equal!(
160+
1.375,
161+
interpolator.interpolate(2.5).unwrap(),
162+
RUSTQUANT_EPSILON
163+
);
164+
}
165+
166+
#[test]
167+
fn test_b_spline_non_uniform_knots() {
168+
let knots = vec![0.0, 1.0, 3.0, 4.0, 6.0, 7.0, 8.0, 10.0, 11.0];
169+
let control_points = vec![2.0, -1.0, 1.0, 0.0, 1.0];
170+
171+
let mut interpolator = BSplineInterpolator::new(knots, control_points, 3).unwrap();
172+
let _ = interpolator.fit();
173+
174+
assert_approx_equal!(
175+
0.058333333333333,
176+
interpolator.interpolate(5.0).unwrap(),
177+
RUSTQUANT_EPSILON
178+
);
179+
}
180+
181+
#[test]
182+
fn test_b_spline_dates() {
183+
let now = time::OffsetDateTime::now_utc();
184+
let knots: Vec<time::OffsetDateTime> = vec![
185+
now,
186+
now + time::Duration::days(1),
187+
now + time::Duration::days(2),
188+
now + time::Duration::days(3),
189+
now + time::Duration::days(4),
190+
now + time::Duration::days(5),
191+
now + time::Duration::days(6),
192+
];
193+
let control_points = vec![-1.0, 2.0, 0.0, -1.0];
194+
195+
let mut interpolator = BSplineInterpolator::new(
196+
knots.clone(), control_points, 2
197+
).unwrap();
198+
let _ = interpolator.fit();
199+
200+
assert_approx_equal!(
201+
1.375,
202+
interpolator
203+
.interpolate(knots[2] + time::Duration::hours(12))
204+
.unwrap(),
205+
RUSTQUANT_EPSILON
206+
);
207+
}
208+
209+
#[test]
210+
fn test_b_spline_inconsistent_parameters() {
211+
let knots = vec![0.0, 1.0, 2.0, 3.0, 4.0,];
212+
let control_points = vec![-1.0, 2.0, 0.0, -1.0];
213+
214+
match BSplineInterpolator::new(knots.clone(), control_points.clone(), 2) {
215+
Ok(_) => panic!("Constructor did not throw an error!"),
216+
Err(e) => assert_eq!(
217+
e.to_string(),
218+
"For 4 control points and degree 2, we need 4 + 2 + 1 (7) knots."
219+
)
220+
}
221+
}
222+
223+
#[test]
224+
fn test_b_spline_out_of_range() {
225+
let knots = vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
226+
let control_points = vec![-1.0, 2.0, 0.0, -1.0];
227+
let mut interpolator = BSplineInterpolator::new(knots, control_points, 2).unwrap();
228+
let _ = interpolator.fit();
229+
230+
match interpolator.interpolate(5.5) {
231+
Ok(_) => panic!("Interpolation should have failed!"),
232+
Err(e) => assert_eq!(
233+
e.to_string(),
234+
"Point 5.5 is outside of the interpolation range [2, 4]"
235+
)
236+
}
237+
}
238+
}

crates/RustQuant_math/src/interpolation/mod.rs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
// - LICENSE-MIT.md
88
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
99

10-
use std::ops::{Div, Mul, Sub};
10+
use std::ops::{Div, Mul, Sub, AddAssign};
1111
use RustQuant_error::RustQuantError;
1212

1313
pub mod linear_interpolator;
@@ -16,15 +16,18 @@ pub use linear_interpolator::*;
1616
pub mod exponential_interpolator;
1717
pub use exponential_interpolator::*;
1818

19+
pub mod b_splines;
20+
pub use b_splines::*;
21+
1922
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
2023
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
2124

2225
/// Trait describing requirements to be interpolated.
23-
pub trait InterpolationValue: num::Num + std::fmt::Debug + Copy + Clone + Sized {}
26+
pub trait InterpolationValue: num::Num + AddAssign + std::fmt::Debug + Copy + Clone + Sized {}
2427

2528
/// Trait describing requirements to be an index of interpolation.
2629
pub trait InterpolationIndex:
27-
Sub<Self, Output = Self::Delta> + PartialOrd + Copy + Clone + Sized
30+
Sub<Self, Output = Self::Delta> + PartialOrd + Copy + Clone + Sized + std::fmt::Display
2831
{
2932
/// Type of the difference of `Self` - `Self`
3033
type Delta: Div<Self::Delta, Output = Self::DeltaDiv>
@@ -60,7 +63,7 @@ where
6063
fn add_point(&mut self, point: (IndexType, ValueType));
6164
}
6265

63-
impl<T> InterpolationValue for T where T: num::Num + std::fmt::Debug + Copy + Clone + Sized {}
66+
impl<T> InterpolationValue for T where T: num::Num + AddAssign + std::fmt::Debug + Copy + Clone + Sized {}
6467

6568
macro_rules! impl_interpolation_index {
6669
($a:ty, $b:ty, $c:ty) => {

0 commit comments

Comments
 (0)