Skip to content

Commit bb5876d

Browse files
committed
feat(cubesql): Add XIRR aggregate function
Signed-off-by: Alex Qyoun-ae <[email protected]>
1 parent 1d15182 commit bb5876d

File tree

6 files changed

+382
-0
lines changed

6 files changed

+382
-0
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
mod xirr;
2+
3+
pub use xirr::*;
Lines changed: 336 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,336 @@
1+
use std::sync::Arc;
2+
3+
use datafusion::{
4+
arrow::{
5+
array::{ArrayRef, Date32Array, Float64Array, ListArray},
6+
compute::cast,
7+
datatypes::{DataType, Field, TimeUnit},
8+
},
9+
error::{DataFusionError, Result},
10+
logical_expr::{
11+
Accumulator, AccumulatorFunctionImplementation, AggregateUDF, ReturnTypeFunction,
12+
Signature, StateTypeFunction, TypeSignature, Volatility,
13+
},
14+
scalar::ScalarValue,
15+
};
16+
17+
pub const XIRR_UDAF_NAME: &str = "xirr";
18+
19+
/// Creates a XIRR Aggregate UDF.
20+
///
21+
/// Syntax:
22+
/// ```sql
23+
/// XIRR(<payment>, <date> [, <initial_guess> [, <on_error>]])
24+
/// ```
25+
///
26+
/// This function calculates internal rate of return for a series of cash flows (payments)
27+
/// that occur at irregular intervals.
28+
///
29+
/// The function takes two arguments:
30+
/// - `payment` (numeric): The cash flow amount. NULL values are considered 0.
31+
/// - `date` (datetime): The date of the payment. Time is ignored. Must never be NULL.
32+
/// - (optional) `initial_guess` (numeric): An initial guess for the rate of return. Must be
33+
/// greater than -1.0 and consistent across all rows. If NULL or omitted, a default value
34+
/// of 0.1 is used.
35+
/// - (optional) `on_error` (numeric): A value to return if the function cannot find a solution.
36+
/// If omitted, the function will yield an error when it cannot find a solution. Must be
37+
/// consistent across all rows.
38+
///
39+
/// The function always yields an error if:
40+
/// - There are no rows.
41+
/// - The `date` argument contains a NULL value.
42+
/// - The `initial_guess` argument is less than or equal to -1.0, or inconsistent across all rows.
43+
/// - The `on_error` argument is inconsistent across all rows.
44+
///
45+
/// The function returns `on_error` value (or yields an error if omitted) if:
46+
/// - The function cannot find a solution after a set number of iterations.
47+
/// - The calculation failed due to internal division by 0.
48+
pub fn create_xirr_udaf() -> AggregateUDF {
49+
let name = XIRR_UDAF_NAME;
50+
let type_signatures = {
51+
// Only types actually used by cubesql are included
52+
const NUMERIC_TYPES: &[DataType] = &[DataType::Float64, DataType::Int64, DataType::Int32];
53+
const DATETIME_TYPES: &[DataType] = &[
54+
DataType::Date32,
55+
DataType::Timestamp(TimeUnit::Nanosecond, None),
56+
DataType::Timestamp(TimeUnit::Millisecond, None),
57+
];
58+
let mut type_signatures = Vec::with_capacity(45);
59+
for payment_type in NUMERIC_TYPES {
60+
for date_type in DATETIME_TYPES {
61+
// Base signatures without `initial_guess` and `on_error` arguments
62+
type_signatures.push(TypeSignature::Exact(vec![
63+
payment_type.clone(),
64+
date_type.clone(),
65+
]));
66+
// Signatures with `initial_guess` argument; only [`DataType::Float64`] is accepted
67+
const INITIAL_GUESS_TYPE: DataType = DataType::Float64;
68+
type_signatures.push(TypeSignature::Exact(vec![
69+
payment_type.clone(),
70+
date_type.clone(),
71+
INITIAL_GUESS_TYPE,
72+
]));
73+
// Signatures with `initial_guess` and `on_error` arguments
74+
for on_error_type in NUMERIC_TYPES {
75+
type_signatures.push(TypeSignature::Exact(vec![
76+
payment_type.clone(),
77+
date_type.clone(),
78+
INITIAL_GUESS_TYPE,
79+
on_error_type.clone(),
80+
]));
81+
}
82+
}
83+
}
84+
type_signatures
85+
};
86+
let signature = Signature::one_of(
87+
type_signatures,
88+
Volatility::Volatile, // due to the usage of [`f64::powf`]
89+
);
90+
let return_type: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Float64)));
91+
let accumulator: AccumulatorFunctionImplementation =
92+
Arc::new(|| Ok(Box::new(XirrAccumulator::new())));
93+
let state_type: StateTypeFunction = Arc::new(|_| {
94+
Ok(Arc::new(vec![
95+
DataType::List(Box::new(Field::new("item", DataType::Float64, true))),
96+
DataType::List(Box::new(Field::new("item", DataType::Date32, true))),
97+
DataType::List(Box::new(Field::new("item", DataType::Float64, true))),
98+
DataType::List(Box::new(Field::new("item", DataType::Float64, true))),
99+
]))
100+
});
101+
AggregateUDF::new(name, &signature, &return_type, &accumulator, &state_type)
102+
}
103+
104+
#[derive(Debug)]
105+
struct XirrAccumulator {
106+
/// Pairs of (payment, date).
107+
pairs: Vec<(f64, i32)>,
108+
initial_guess: ValueState<f64>,
109+
on_error: ValueState<f64>,
110+
}
111+
112+
impl XirrAccumulator {
113+
fn new() -> Self {
114+
XirrAccumulator {
115+
pairs: vec![],
116+
initial_guess: ValueState::Unset,
117+
on_error: ValueState::Unset,
118+
}
119+
}
120+
121+
fn add_pair(&mut self, payment: Option<f64>, date: Option<i32>) -> Result<()> {
122+
let Some(date) = date else {
123+
return Err(DataFusionError::Execution(
124+
"One or more values for the `date` argument passed to XIRR is null".to_string(),
125+
));
126+
};
127+
// NULL payment value is treated as 0
128+
let payment = payment.unwrap_or(0.0);
129+
self.pairs.push((payment, date));
130+
Ok(())
131+
}
132+
133+
fn set_initial_guess(&mut self, initial_guess: Option<f64>) -> Result<()> {
134+
let ValueState::Set(current_initial_guess) = self.initial_guess else {
135+
self.initial_guess = ValueState::Set(initial_guess);
136+
return Ok(());
137+
};
138+
if current_initial_guess != initial_guess {
139+
return Err(DataFusionError::Execution(
140+
"The `initial_guess` argument passed to XIRR is inconsistent".to_string(),
141+
));
142+
}
143+
Ok(())
144+
}
145+
146+
fn set_on_error(&mut self, on_error: Option<f64>) -> Result<()> {
147+
let ValueState::Set(current_on_error) = self.on_error else {
148+
self.on_error = ValueState::Set(on_error);
149+
return Ok(());
150+
};
151+
if current_on_error != on_error {
152+
return Err(DataFusionError::Execution(
153+
"The `on_error` argument passed to XIRR is inconsistent".to_string(),
154+
));
155+
}
156+
Ok(())
157+
}
158+
159+
fn yield_no_solution(&self) -> Result<ScalarValue> {
160+
match self.on_error {
161+
ValueState::Unset => Err(DataFusionError::Execution(
162+
"The XIRR function couldn't find a solution".to_string(),
163+
)),
164+
ValueState::Set(on_error) => Ok(ScalarValue::Float64(on_error)),
165+
}
166+
}
167+
}
168+
169+
impl Accumulator for XirrAccumulator {
170+
fn state(&self) -> Result<Vec<ScalarValue>> {
171+
let (payments, dates): (Vec<_>, Vec<_>) = self
172+
.pairs
173+
.iter()
174+
.map(|(payment, date)| {
175+
let payment = ScalarValue::Float64(Some(*payment));
176+
let date = ScalarValue::Date32(Some(*date));
177+
(payment, date)
178+
})
179+
.unzip();
180+
let initial_guess = match self.initial_guess {
181+
ValueState::Unset => vec![],
182+
ValueState::Set(initial_guess) => vec![ScalarValue::Float64(initial_guess)],
183+
};
184+
let on_error = match self.on_error {
185+
ValueState::Unset => vec![],
186+
ValueState::Set(on_error) => vec![ScalarValue::Float64(on_error)],
187+
};
188+
Ok(vec![
189+
ScalarValue::List(Some(Box::new(payments)), Box::new(DataType::Float64)),
190+
ScalarValue::List(Some(Box::new(dates)), Box::new(DataType::Date32)),
191+
ScalarValue::List(Some(Box::new(initial_guess)), Box::new(DataType::Float64)),
192+
ScalarValue::List(Some(Box::new(on_error)), Box::new(DataType::Float64)),
193+
])
194+
}
195+
196+
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
197+
let payments = cast(&values[0], &DataType::Float64)?;
198+
let payments = payments.as_any().downcast_ref::<Float64Array>().unwrap();
199+
let dates = cast(&values[1], &DataType::Date32)?;
200+
let dates = dates.as_any().downcast_ref::<Date32Array>().unwrap();
201+
for (payment, date) in payments.into_iter().zip(dates) {
202+
self.add_pair(payment, date)?;
203+
}
204+
let values_len = values.len();
205+
if values_len < 3 {
206+
return Ok(());
207+
}
208+
let initial_guesses = values[2].as_any().downcast_ref::<Float64Array>().unwrap();
209+
for initial_guess in initial_guesses {
210+
self.set_initial_guess(initial_guess)?;
211+
}
212+
if values_len < 4 {
213+
return Ok(());
214+
}
215+
let on_errors = cast(&values[3], &DataType::Float64)?;
216+
let on_errors = on_errors.as_any().downcast_ref::<Float64Array>().unwrap();
217+
for on_error in on_errors {
218+
self.set_on_error(on_error)?;
219+
}
220+
Ok(())
221+
}
222+
223+
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
224+
let payments = states[0]
225+
.as_any()
226+
.downcast_ref::<ListArray>()
227+
.unwrap()
228+
.values();
229+
let payments = payments.as_any().downcast_ref::<Float64Array>().unwrap();
230+
let dates = states[1]
231+
.as_any()
232+
.downcast_ref::<ListArray>()
233+
.unwrap()
234+
.values();
235+
let dates = dates.as_any().downcast_ref::<Date32Array>().unwrap();
236+
for (payment, date) in payments.into_iter().zip(dates) {
237+
self.add_pair(payment, date)?;
238+
}
239+
let states_len = states.len();
240+
if states_len < 3 {
241+
return Ok(());
242+
}
243+
let initial_guesses = states[2]
244+
.as_any()
245+
.downcast_ref::<ListArray>()
246+
.unwrap()
247+
.values();
248+
let initial_guesses = initial_guesses
249+
.as_any()
250+
.downcast_ref::<Float64Array>()
251+
.unwrap();
252+
for initial_guess in initial_guesses {
253+
self.set_initial_guess(initial_guess)?;
254+
}
255+
if states_len < 4 {
256+
return Ok(());
257+
}
258+
let on_errors = states[3]
259+
.as_any()
260+
.downcast_ref::<ListArray>()
261+
.unwrap()
262+
.values();
263+
let on_errors = on_errors.as_any().downcast_ref::<Float64Array>().unwrap();
264+
for on_error in on_errors {
265+
self.set_on_error(on_error)?;
266+
}
267+
Ok(())
268+
}
269+
270+
fn evaluate(&self) -> Result<ScalarValue> {
271+
const MAX_ITERATIONS: usize = 100;
272+
const TOLERANCE: f64 = 1e-6;
273+
const DEFAULT_INITIAL_GUESS: f64 = 0.1;
274+
let Some(min_date) = self.pairs.iter().map(|(_, date)| *date).min() else {
275+
return Err(DataFusionError::Execution(
276+
"A result for XIRR couldn't be determined because the arguments are empty"
277+
.to_string(),
278+
));
279+
};
280+
let pairs = self
281+
.pairs
282+
.iter()
283+
.map(|(payment, date)| {
284+
let year_difference = (*date - min_date) as f64 / 365.0;
285+
(*payment, year_difference)
286+
})
287+
.collect::<Vec<_>>();
288+
let mut rate_of_return = self
289+
.initial_guess
290+
.to_value()
291+
.unwrap_or(DEFAULT_INITIAL_GUESS);
292+
if rate_of_return <= -1.0 {
293+
return Err(DataFusionError::Execution(
294+
"The `initial_guess` argument passed to the XIRR function must be greater than -1"
295+
.to_string(),
296+
));
297+
}
298+
for _ in 0..MAX_ITERATIONS {
299+
let mut net_present_value = 0.0;
300+
let mut derivative_value = 0.0;
301+
for (payment, year_difference) in &pairs {
302+
if *payment == 0.0 {
303+
continue;
304+
}
305+
let rate_positive = 1.0 + rate_of_return;
306+
let denominator = rate_positive.powf(*year_difference);
307+
net_present_value += *payment / denominator;
308+
derivative_value -= *year_difference * *payment / denominator / rate_positive;
309+
}
310+
if net_present_value.abs() < TOLERANCE {
311+
return Ok(ScalarValue::Float64(Some(rate_of_return)));
312+
}
313+
let rate_reduction = net_present_value / derivative_value;
314+
if rate_reduction.is_nan() {
315+
return self.yield_no_solution();
316+
}
317+
rate_of_return -= rate_reduction;
318+
}
319+
self.yield_no_solution()
320+
}
321+
}
322+
323+
#[derive(Debug)]
324+
enum ValueState<T: Copy> {
325+
Unset,
326+
Set(Option<T>),
327+
}
328+
329+
impl<T: Copy> ValueState<T> {
330+
fn to_value(&self) -> Option<T> {
331+
match self {
332+
ValueState::Unset => None,
333+
ValueState::Set(value) => *value,
334+
}
335+
}
336+
}
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
mod common;
2+
mod extension;
23
mod pg_catalog;
34
mod redshift;
45
mod utils;
56

67
pub use common::*;
8+
pub use extension::*;
79
pub use pg_catalog::*;
810
pub use redshift::*;

rust/cubesql/cubesql/src/compile/query_engine.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,7 @@ impl QueryEngine for SqlQueryEngine {
516516
// udaf
517517
ctx.register_udaf(create_measure_udaf());
518518
ctx.register_udaf(create_patch_measure_udaf());
519+
ctx.register_udaf(create_xirr_udaf());
519520

520521
// udtf
521522
ctx.register_udtf(create_generate_series_udtf());
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
---
2+
source: cubesql/src/compile/test/test_udfs.rs
3+
expression: "execute_query(r#\"\n SELECT LEFT(XIRR(payment, date)::text, 10) AS xirr\n FROM (\n SELECT '2014-01-01'::date AS date, -10000.0 AS payment\n UNION ALL\n SELECT '2014-03-01'::date AS date, 2750.0 AS payment\n UNION ALL\n SELECT '2014-10-30'::date AS date, 4250.0 AS payment\n UNION ALL\n SELECT '2015-02-15'::date AS date, 3250.0 AS payment\n UNION ALL\n SELECT '2015-04-01'::date AS date, 2750.0 AS payment\n ) AS \"t\"\n \"#.to_string(),\nDatabaseProtocol::PostgreSQL).await?"
4+
---
5+
+------------+
6+
| xirr |
7+
+------------+
8+
| 0.37485859 |
9+
+------------+

0 commit comments

Comments
 (0)