Skip to content

Commit 016a8b4

Browse files
committed
Added initial Granger causality example for oil prices and shipping activity.
Signed-off-by: Marvin Hansen <[email protected]>
1 parent 9df237b commit 016a8b4

File tree

2 files changed

+242
-0
lines changed

2 files changed

+242
-0
lines changed

examples/epp_granger/Cargo.toml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# SPDX-License-Identifier: MIT
2+
# Copyright (c) "2025" . The DeepCausality Authors and Contributors. All Rights Reserved.
3+
4+
[package]
5+
name = "example-granger"
6+
version = "0.1.0"
7+
edition = "2021"
8+
rust-version = "1.80"
9+
publish = false
10+
11+
[dependencies]
12+
deep_causality = { path = "../../deep_causality" }

examples/epp_granger/src/main.rs

Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
/*
2+
* SPDX-License-Identifier: MIT
3+
* Copyright (c) "2025" . The DeepCausality Authors and Contributors. All Rights Reserved.
4+
*/
5+
6+
use deep_causality::*;
7+
use std::sync::Arc;
8+
9+
// Contextoid IDs
10+
const OIL_PRICE_ID: IdentificationValue = 0;
11+
const SHIPPING_ACTIVITY_ID: IdentificationValue = 1;
12+
const TIME_ID: IdentificationValue = 2;
13+
14+
// Causaloid IDs
15+
const PREDICTOR_CAUSALOID_ID: IdentificationValue = 1;
16+
17+
fn main() {
18+
println!("\n--- Granger Causality Example: Oil Prices and Shipping Activity ---");
19+
20+
// 1. Setup the Contexts (Factual and Counterfactual)
21+
let factual_context = get_context_with_data();
22+
let control_context = get_counterfactual_context(&factual_context);
23+
24+
// 2. Define the Predictive Causaloid
25+
let shipping_predictor_causaloid = get_shipping_predictor_causaloid();
26+
27+
// Create the CausaloidGraph
28+
let mut causaloid_graph = CausaloidGraph::new(0);
29+
let predictor_idx = causaloid_graph
30+
.add_causaloid(shipping_predictor_causaloid)
31+
.unwrap();
32+
causaloid_graph.freeze();
33+
let causaloid_graph_arc = Arc::new(causaloid_graph);
34+
35+
// Simulate prediction for a future time step (e.g., Q5)
36+
let prediction_time_step = 4.0; // Q5 (after Q1, Q2, Q3, Q4)
37+
38+
// 3. Execute the Granger Test
39+
40+
// Factual Evaluation
41+
println!("\n--- Factual Evaluation (with Oil Prices) ---");
42+
let mut factual_input_map = PropagatingEffect::new_map();
43+
factual_input_map.insert(TIME_ID, PropagatingEffect::Numerical(prediction_time_step));
44+
// Pass the factual context to the causaloid graph for evaluation
45+
// The causaloid's internal logic will query the context it's associated with.
46+
// For this example, we'll pass the context directly to the causaloid's evaluate function
47+
// by associating the causaloid with the context before evaluation.
48+
49+
// Temporarily associate the causaloid with the factual context for evaluation
50+
let mut temp_predictor_causaloid_factual = causaloid_graph_arc
51+
.get_causaloid(predictor_idx)
52+
.unwrap()
53+
.clone();
54+
let factual_context_arc = Arc::new(factual_context);
55+
temp_predictor_causaloid_factual.set_context(Some(Arc::clone(&factual_context_arc)));
56+
57+
let factual_prediction_res = temp_predictor_causaloid_factual.evaluate(&factual_input_map);
58+
let factual_prediction = factual_prediction_res.unwrap().as_numerical().unwrap();
59+
println!(
60+
"Factual Prediction for Q{:.0} Shipping Activity: {:.2}",
61+
prediction_time_step + 1.0,
62+
factual_prediction
63+
);
64+
65+
// Assuming a known actual value for Q5 for error calculation
66+
let actual_q5_shipping = 105.0; // Example actual value
67+
let error_factual = (factual_prediction - actual_q5_shipping).abs();
68+
println!("Factual Prediction Error: {:.2}", error_factual);
69+
70+
// Counterfactual Evaluation
71+
println!("\n--- Counterfactual Evaluation (without Oil Prices) ---");
72+
let mut counterfactual_input_map = PropagatingEffect::new_map();
73+
counterfactual_input_map.insert(TIME_ID, PropagatingEffect::Numerical(prediction_time_step));
74+
75+
// Temporarily associate the causaloid with the counterfactual context for evaluation
76+
let mut temp_predictor_causaloid_control = causaloid_graph_arc
77+
.get_causaloid(predictor_idx)
78+
.unwrap()
79+
.clone();
80+
let control_context_arc = Arc::new(control_context);
81+
temp_predictor_causaloid_control.set_context(Some(Arc::clone(&control_context_arc)));
82+
83+
let counterfactual_prediction_res =
84+
temp_predictor_causaloid_control.evaluate(&counterfactual_input_map);
85+
let counterfactual_prediction = counterfactual_prediction_res
86+
.unwrap()
87+
.as_numerical()
88+
.unwrap();
89+
println!(
90+
"Counterfactual Prediction for Q{:.0} Shipping Activity: {:.2}",
91+
prediction_time_step + 1.0,
92+
counterfactual_prediction
93+
);
94+
95+
let error_counterfactual = (counterfactual_prediction - actual_q5_shipping).abs();
96+
println!(
97+
"Counterfactual Prediction Error: {:.2}",
98+
error_counterfactual
99+
);
100+
101+
// 4. Compare and Conclude
102+
println!("\n--- Granger Causality Conclusion ---");
103+
if error_factual < error_counterfactual {
104+
println!("Conclusion: Past oil prices DO Granger-cause future shipping activity.");
105+
println!(
106+
"Factual error ({:.2}) < Counterfactual error ({:.2})",
107+
error_factual, error_counterfactual
108+
);
109+
} else {
110+
println!("Conclusion: Past oil prices DO NOT Granger-cause future shipping activity.");
111+
println!(
112+
"Factual error ({:.2}) >= Counterfactual error ({:.2})",
113+
error_factual, error_counterfactual
114+
);
115+
}
116+
}
117+
118+
// Helper functions
119+
120+
fn get_context_with_data() -> BaseContext {
121+
let mut context = BaseContext::with_capacity(1, "Factual Context", 20);
122+
123+
// Sample Data (Quarterly)
124+
// Oil Prices: Q1=50, Q2=52, Q3=55, Q4=58
125+
// Shipping Activity: Q1=100, Q2=102, Q3=105, Q4=108
126+
let data_points = vec![
127+
(0.0, 50.0, 100.0), // Q1: time, oil_price, shipping_activity
128+
(1.0, 52.0, 102.0), // Q2
129+
(2.0, 55.0, 105.0), // Q3
130+
(3.0, 58.0, 108.0), // Q4
131+
];
132+
133+
for (time, oil_price, shipping_activity) in data_points {
134+
let time_datoid =
135+
Contextoid::new(TIME_ID, ContextoidType::Datoid(Data::new(TIME_ID, time)));
136+
let oil_price_datoid = Contextoid::new(
137+
OIL_PRICE_ID,
138+
ContextoidType::Datoid(Data::new(OIL_PRICE_ID, oil_price)),
139+
);
140+
let shipping_activity_datoid = Contextoid::new(
141+
SHIPPING_ACTIVITY_ID,
142+
ContextoidType::Datoid(Data::new(SHIPPING_ACTIVITY_ID, shipping_activity)),
143+
);
144+
145+
context.add_node(time_datoid).unwrap();
146+
context.add_node(oil_price_datoid).unwrap();
147+
context.add_node(shipping_activity_datoid).unwrap();
148+
}
149+
context
150+
}
151+
152+
fn get_counterfactual_context(factual_context: &BaseContext) -> BaseContext {
153+
let mut control_context = factual_context.clone();
154+
155+
// Remove or zero out oil_price dataoids in the cloned context
156+
// Iterate through the nodes and update the oil_price datoids
157+
// Note: This is a simplified approach. In a real scenario, you might remove the nodes or set them to a specific baseline.
158+
for i in 0..control_context.number_of_nodes() {
159+
let node = control_context.get_node(i).unwrap();
160+
if let ContextoidType::Datoid(data_node) = node.vertex_type() {
161+
if data_node.id() == OIL_PRICE_ID {
162+
let mut updated_data = data_node.clone();
163+
updated_data.set_data(0.0); // Set oil price to 0.0 in counterfactual
164+
control_context
165+
.update_node(
166+
data_node.id(),
167+
Contextoid::new(data_node.id(), ContextoidType::Datoid(updated_data)),
168+
)
169+
.unwrap();
170+
}
171+
}
172+
}
173+
control_context
174+
}
175+
176+
fn get_shipping_predictor_causaloid() -> BaseCausaloid {
177+
let predictor_id = PREDICTOR_CAUSALOID_ID;
178+
let predictor_description = "Predicts shipping activity based on historical data";
179+
180+
let causal_fn = |effect: &PropagatingEffect| -> Result<PropagatingEffect, CausalityError> {
181+
let current_time_step = match effect {
182+
PropagatingEffect::Map(map) => map
183+
.get(&TIME_ID)
184+
.and_then(|boxed_effect| boxed_effect.as_numerical())
185+
.ok_or_else(|| {
186+
CausalityError("Current time step not found in effect map".into())
187+
})?,
188+
_ => {
189+
return Err(CausalityError(
190+
"Expected Map effect for predictor causaloid".into(),
191+
));
192+
}
193+
};
194+
195+
// In a real scenario, this causaloid would query the context it's associated with
196+
// to get historical data. Since causal_fn cannot capture context directly, we simulate
197+
// context lookup by assuming the context is available via the Causaloid's own context field.
198+
// This requires the Causaloid to be initialized with a context.
199+
// For this example, we'll use a simplified model that assumes access to the context.
200+
201+
// Simulate context access and prediction logic
202+
// This is a placeholder for a more complex predictive model (e.g., linear regression)
203+
// For simplicity, we'll assume a direct lookup or a very simple model.
204+
// In a real DBN, the causaloid would query the context for historical data.
205+
// Here, we'll hardcode some logic based on the time step and assumed context data.
206+
207+
let predicted_shipping_activity = match current_time_step as u64 {
208+
4 => {
209+
// Predicting for Q5, based on Q1-Q4
210+
// This is where the causaloid would query the context for historical data
211+
// For demonstration, we'll use a simple rule based on assumed historical data
212+
// If oil price data was available (not 0.0 in the context), it would influence this.
213+
// Since we can't access the context directly here, we'll make a simplified assumption.
214+
// If oil price was present (simulated by non-zero value), predict higher.
215+
// This part is highly simplified and would be replaced by actual model inference.
216+
let assumed_oil_price_present = true; // This would come from context query
217+
if assumed_oil_price_present {
218+
105.0 + 3.0
219+
} else {
220+
105.0
221+
}
222+
}
223+
_ => 0.0, // Default for other time steps
224+
};
225+
226+
Ok(PropagatingEffect::Numerical(predicted_shipping_activity))
227+
};
228+
229+
BaseCausaloid::new(predictor_id, causal_fn, predictor_description)
230+
}

0 commit comments

Comments
 (0)