|
1 | 1 | use ordered_float::OrderedFloat; |
2 | | -use polars::prelude::*; |
3 | | -use std::collections::HashMap; |
4 | | -use std::time::Instant; |
5 | | - |
6 | | -use sde_sim_rs::filtration::Filtration; |
7 | 2 | use sde_sim_rs::proc::util::parse_equations; |
8 | 3 | use sde_sim_rs::sim::simulate; |
| 4 | +use std::collections::HashMap; |
| 5 | +use std::time::Instant; |
9 | 6 |
|
10 | 7 | fn main() { |
11 | | - // Simulation Parameters |
12 | | - let dt: f64 = 0.1; |
13 | | - let t_start: f64 = 0.0; |
14 | | - let t_end: f64 = 100.0; |
15 | | - let scenarios: i32 = 10000; |
16 | | - let initial_values = HashMap::from([("X1".to_string(), 1.0), ("X2".to_string(), 1.0)]); |
17 | | - let equations = [ |
18 | | - "dX1 = (0.005 * X1) * dt + (0.02 * X1) * dW1".to_string(), |
19 | | - "dX2 = (0.005 * X2) * dt + (0.02 * X1) * dW1 + (0.01 * X2) * dW2 + (1) * dJ1(0.5)" |
20 | | - .to_string(), |
| 8 | + // ────── configuration ────── |
| 9 | + let initial_values = HashMap::from([("X1".to_string(), 100.0), ("X2".to_string(), 0.0)]); |
| 10 | + |
| 11 | + let processes_equations = vec![ |
| 12 | + "dX1 = ( sin(t) ) * dt + (0.01 * X1) * dW1 + (0.001 * X1) * dN1(0.5 * cos(t))".to_string(), |
| 13 | + "X2 = max(X1 - 100.0, 0.0)".to_string(), |
21 | 14 | ]; |
22 | | - let scheme = "runge-kutta"; // "euler" or "runge-kutta" |
23 | | - let rng_method = "sobol"; // "pseudo" or "sobol" |
24 | 15 |
|
25 | | - // 1. Prepare Time Steps |
26 | | - let time_steps: Vec<OrderedFloat<f64>> = (0..) |
27 | | - .map(|i| OrderedFloat(t_start + i as f64 * dt)) |
28 | | - .take_while(|t| t.0 <= t_end) |
| 16 | + let scheme = "euler"; // other valid value: "runge-kutta" |
| 17 | + let rng_method = "pseudo"; // other valid value: "sobol" |
| 18 | + let scenarios: u64 = 10_000; |
| 19 | + |
| 20 | + // build a uniformly spaced time vector, identical to what the Python |
| 21 | + // wrapper accepts as `time_steps: Vec<f64>`. |
| 22 | + let dt = 0.1; |
| 23 | + let t_start = 0.0; |
| 24 | + let t_end = 100.0; |
| 25 | + let time_steps: Vec<f64> = (0..) |
| 26 | + .map(|i| t_start + i as f64 * dt) |
| 27 | + .take_while(|t| *t <= t_end) |
29 | 28 | .collect(); |
30 | 29 |
|
31 | | - // 2. Parse equations |
32 | | - let universe = |
33 | | - parse_equations(&equations, time_steps.clone()).expect("Failed to parse equations"); |
34 | | - |
35 | | - // 3. Initialize Filtration |
36 | | - let mut filtration = Filtration::new( |
37 | | - universe, |
38 | | - time_steps.clone(), |
39 | | - (1..=scenarios).collect(), |
40 | | - Some(initial_values), |
41 | | - ); |
42 | | - |
43 | | - // Run Simulation |
44 | | - let before = Instant::now(); |
45 | | - println!("Starting simulation with {} RNG...", rng_method); |
46 | | - simulate(&mut filtration, scheme, rng_method); |
47 | | - |
48 | | - let duration = before.elapsed(); |
49 | | - println!( |
50 | | - "Simulation completed in {:.4} seconds.\n", |
51 | | - duration.as_secs_f64() |
52 | | - ); |
53 | | - |
54 | | - let df: DataFrame = filtration.to_dataframe(); |
55 | | - println!("{}", df); |
56 | | - |
57 | | - assert!(duration.as_secs_f64() > 0.0); |
| 30 | + // convert the floats to `OrderedFloat` for internal use |
| 31 | + let ordered_steps: Vec<OrderedFloat<f64>> = |
| 32 | + time_steps.iter().copied().map(OrderedFloat).collect(); |
| 33 | + |
| 34 | + // parse the equations into a ProcessUniverse (same work done in Python) |
| 35 | + let universe = parse_equations(&processes_equations, ordered_steps.clone()) |
| 36 | + .expect("failed to parse process equations"); |
| 37 | + |
| 38 | + // run the actual simulation; this mirrors the body of `simulate_py` |
| 39 | + let start = Instant::now(); |
| 40 | + println!("running {} scenarios with {} rng...", scenarios, rng_method); |
| 41 | + let lf = simulate( |
| 42 | + &universe, |
| 43 | + ordered_steps.clone(), |
| 44 | + initial_values.clone(), |
| 45 | + scenarios, |
| 46 | + scheme, |
| 47 | + rng_method, |
| 48 | + ) |
| 49 | + .expect("failed to run simulation"); |
| 50 | + let df = lf.collect().expect("failed to collect results"); |
| 51 | + let elapsed = start.elapsed(); |
| 52 | + println!("completed in {:.3}s", elapsed.as_secs_f64()); |
| 53 | + |
| 54 | + // print a small portion of the output frame |
| 55 | + println!("{:#?}", df.head(Some(10))); |
58 | 56 | } |
0 commit comments