Skip to content

Commit cf720fa

Browse files
committed
Refactored Granger causality example to improve dynamic updates.
Signed-off-by: Marvin Hansen <[email protected]>
1 parent 016a8b4 commit cf720fa

File tree

1 file changed

+170
-154
lines changed

1 file changed

+170
-154
lines changed

examples/epp_granger/src/main.rs

Lines changed: 170 additions & 154 deletions
Original file line numberDiff line numberDiff line change
@@ -6,123 +6,188 @@
66
use deep_causality::*;
77
use std::sync::Arc;
88

9-
// Contextoid IDs
9+
// Define IDs for different data types within the context
1010
const OIL_PRICE_ID: IdentificationValue = 0;
1111
const SHIPPING_ACTIVITY_ID: IdentificationValue = 1;
1212
const TIME_ID: IdentificationValue = 2;
1313

14-
// Causaloid IDs
14+
// Define ID for the causaloid
1515
const PREDICTOR_CAUSALOID_ID: IdentificationValue = 1;
1616

1717
fn main() {
18-
println!("\n--- Granger Causality Example: Oil Prices and Shipping Activity ---");
18+
println!(
19+
"
20+
--- Granger Causality Example: Oil Prices and Shipping Activity ---"
21+
);
1922

20-
// 1. Setup the Contexts (Factual and Counterfactual)
21-
let factual_context = get_context_with_data();
22-
let control_context = get_counterfactual_context(&factual_context);
23+
// 1. Define the Predictive Causaloid (using the contextual function)
24+
let predictor_id = PREDICTOR_CAUSALOID_ID;
2325

24-
// 2. Define the Predictive Causaloid
25-
let shipping_predictor_causaloid = get_shipping_predictor_causaloid();
26+
// Create two instances of the causaloid, one for each context.
27+
let factual_causaloid = get_factual_causaloid(predictor_id);
28+
let counterfactual_causaloid = get_counterfactual_causaloid(predictor_id);
2629

27-
// Create the CausaloidGraph
28-
let mut causaloid_graph = CausaloidGraph::new(0);
29-
let predictor_idx = causaloid_graph
30-
.add_causaloid(shipping_predictor_causaloid)
31-
.unwrap();
32-
causaloid_graph.freeze();
33-
let causaloid_graph_arc = Arc::new(causaloid_graph);
34-
35-
// Simulate prediction for a future time step (e.g., Q5)
36-
let prediction_time_step = 4.0; // Q5 (after Q1, Q2, Q3, Q4)
37-
38-
// 3. Execute the Granger Test
39-
40-
// Factual Evaluation
41-
println!("\n--- Factual Evaluation (with Oil Prices) ---");
42-
let mut factual_input_map = PropagatingEffect::new_map();
43-
factual_input_map.insert(TIME_ID, PropagatingEffect::Numerical(prediction_time_step));
44-
// Pass the factual context to the causaloid graph for evaluation
45-
// The causaloid's internal logic will query the context it's associated with.
46-
// For this example, we'll pass the context directly to the causaloid's evaluate function
47-
// by associating the causaloid with the context before evaluation.
48-
49-
// Temporarily associate the causaloid with the factual context for evaluation
50-
let mut temp_predictor_causaloid_factual = causaloid_graph_arc
51-
.get_causaloid(predictor_idx)
52-
.unwrap()
53-
.clone();
54-
let factual_context_arc = Arc::new(factual_context);
55-
temp_predictor_causaloid_factual.set_context(Some(Arc::clone(&factual_context_arc)));
30+
// Define the input for the evaluation. This tells the causaloid what time to predict for.
31+
let prediction_time_step = 4.0; // Predict for Q5, given data for Q1-Q4
32+
let mut input_map = PropagatingEffect::new_map();
33+
input_map.insert(TIME_ID, PropagatingEffect::Numerical(prediction_time_step));
34+
35+
// 2. Execute the Granger Test
5636

57-
let factual_prediction_res = temp_predictor_causaloid_factual.evaluate(&factual_input_map);
58-
let factual_prediction = factual_prediction_res.unwrap().as_numerical().unwrap();
37+
// Factual Evaluation (with oil price history)
5938
println!(
60-
"Factual Prediction for Q{:.0} Shipping Activity: {:.2}",
61-
prediction_time_step + 1.0,
62-
factual_prediction
39+
"
40+
--- Factual Evaluation (with Oil Prices) ---"
6341
);
64-
65-
// Assuming a known actual value for Q5 for error calculation
66-
let actual_q5_shipping = 105.0; // Example actual value
67-
let error_factual = (factual_prediction - actual_q5_shipping).abs();
68-
println!("Factual Prediction Error: {:.2}", error_factual);
69-
70-
// Counterfactual Evaluation
71-
println!("\n--- Counterfactual Evaluation (without Oil Prices) ---");
72-
let mut counterfactual_input_map = PropagatingEffect::new_map();
73-
counterfactual_input_map.insert(TIME_ID, PropagatingEffect::Numerical(prediction_time_step));
74-
75-
// Temporarily associate the causaloid with the counterfactual context for evaluation
76-
let mut temp_predictor_causaloid_control = causaloid_graph_arc
77-
.get_causaloid(predictor_idx)
42+
let factual_prediction = factual_causaloid
43+
.evaluate(&input_map)
7844
.unwrap()
79-
.clone();
80-
let control_context_arc = Arc::new(control_context);
81-
temp_predictor_causaloid_control.set_context(Some(Arc::clone(&control_context_arc)));
45+
.as_numerical()
46+
.unwrap();
47+
println!(
48+
"Factual Prediction for Q5 Shipping Activity: {:.2}",
49+
factual_prediction
50+
);
8251

83-
let counterfactual_prediction_res =
84-
temp_predictor_causaloid_control.evaluate(&counterfactual_input_map);
85-
let counterfactual_prediction = counterfactual_prediction_res
52+
// Counterfactual Evaluation (without oil price history)
53+
println!(
54+
"
55+
--- Counterfactual Evaluation (without Oil Prices) ---"
56+
);
57+
let counterfactual_prediction = counterfactual_causaloid
58+
.evaluate(&input_map)
8659
.unwrap()
8760
.as_numerical()
8861
.unwrap();
8962
println!(
90-
"Counterfactual Prediction for Q{:.0} Shipping Activity: {:.2}",
91-
prediction_time_step + 1.0,
63+
"Counterfactual Prediction for Q5 Shipping Activity: {:.2}",
9264
counterfactual_prediction
9365
);
9466

67+
// 3. Compare and Conclude
68+
// This is the hypothetical "true" value for Q5, used to measure prediction error.
69+
let actual_q5_shipping = 105.0;
70+
let error_factual = (factual_prediction - actual_q5_shipping).abs();
9571
let error_counterfactual = (counterfactual_prediction - actual_q5_shipping).abs();
72+
73+
println!(
74+
"
75+
--- Granger Causality Conclusion ---"
76+
);
77+
println!("Actual Q5 Shipping Activity: {:.2}", actual_q5_shipping);
9678
println!(
97-
"Counterfactual Prediction Error: {:.2}",
79+
"Factual Prediction Error (with oil data): {:.2}",
80+
error_factual
81+
);
82+
println!(
83+
"Counterfactual Prediction Error (no oil data): {:.2}",
9884
error_counterfactual
9985
);
10086

101-
// 4. Compare and Conclude
102-
println!("\n--- Granger Causality Conclusion ---");
10387
if error_factual < error_counterfactual {
104-
println!("Conclusion: Past oil prices DO Granger-cause future shipping activity.");
10588
println!(
106-
"Factual error ({:.2}) < Counterfactual error ({:.2})",
107-
error_factual, error_counterfactual
89+
"
90+
Conclusion: Past oil prices DO Granger-cause future shipping activity."
10891
);
92+
println!("Because the error is lower when oil price history is included.");
10993
} else {
110-
println!("Conclusion: Past oil prices DO NOT Granger-cause future shipping activity.");
11194
println!(
112-
"Factual error ({:.2}) >= Counterfactual error ({:.2})",
113-
error_factual, error_counterfactual
95+
"
96+
Conclusion: Past oil prices DO NOT Granger-cause future shipping activity."
11497
);
98+
println!("Because including oil price history did not improve the prediction.");
99+
}
100+
}
101+
102+
fn get_factual_causaloid(predictor_id: IdentificationValue) -> BaseCausaloid {
103+
let predictor_description = "Predicts shipping activity based on factual historical data";
104+
let factual_context = Arc::new(get_context_with_data());
105+
106+
// `new_with_context` is used to create a causaloid that has access to its context.
107+
Causaloid::new_with_context(
108+
predictor_id,
109+
shipping_predictor_logic,
110+
Arc::clone(&factual_context),
111+
predictor_description,
112+
)
113+
}
114+
115+
fn get_counterfactual_causaloid(predictor_id: IdentificationValue) -> BaseCausaloid {
116+
let factual_context = Arc::new(get_context_with_data());
117+
let counterfactual_context = Arc::new(get_counterfactual_context(&factual_context));
118+
let predictor_description =
119+
"Predicts shipping activity based on counterfactual historical data";
120+
121+
// `new_with_context` is used to create a causaloid that has access to its context.
122+
Causaloid::new_with_context(
123+
predictor_id,
124+
shipping_predictor_logic,
125+
Arc::clone(&counterfactual_context),
126+
predictor_description,
127+
)
128+
}
129+
130+
/// The main logic for the predictive causaloid.
131+
/// This function has access to the context and performs a prediction based on its contents.
132+
fn shipping_predictor_logic(
133+
effect: &PropagatingEffect,
134+
context: &Arc<BaseContext>,
135+
) -> Result<PropagatingEffect, CausalityError> {
136+
// Extract the target prediction time from the input effect map.
137+
let _target_time = effect.get_numerical_from_map(TIME_ID)?;
138+
139+
let mut oil_prices: Vec<f64> = Vec::new();
140+
let mut shipping_activities: Vec<f64> = Vec::new();
141+
142+
// Iterate through all nodes in the context graph to gather historical data.
143+
// A real implementation would likely use more sophisticated queries, but this
144+
// demonstrates accessing the full context.
145+
for i in 0..context.number_of_nodes() {
146+
if let Some(node) = context.get_node(i) {
147+
if let ContextoidType::Datoid(data_node) = node.vertex_type() {
148+
match data_node.id() {
149+
OIL_PRICE_ID => oil_prices.push(data_node.get_data()),
150+
SHIPPING_ACTIVITY_ID => shipping_activities.push(data_node.get_data()),
151+
_ => (),
152+
}
153+
}
154+
}
115155
}
156+
157+
// --- Simple Prediction Model ---
158+
// Predicts next shipping activity based on the average of past activity,
159+
// plus an adjustment based on the average oil price.
160+
161+
if shipping_activities.is_empty() {
162+
return Ok(PropagatingEffect::Numerical(100.0)); // Baseline prediction
163+
}
164+
165+
let avg_shipping: f64 =
166+
shipping_activities.iter().sum::<f64>() / shipping_activities.len() as f64;
167+
168+
let mut oil_price_effect = 0.0;
169+
if !oil_prices.is_empty() {
170+
let avg_oil = oil_prices.iter().sum::<f64>() / oil_prices.len() as f64;
171+
// Simple model: higher avg oil price slightly decreases the next shipping activity value.
172+
// The numbers are chosen to make the factual error smaller.
173+
oil_price_effect = (avg_oil - 50.0) * 0.5; // 50 is a baseline oil price
174+
}
175+
176+
// Predict the next value by taking the average and adding a trend factor,
177+
// adjusted by the oil price effect.
178+
let prediction = avg_shipping + 3.0 - oil_price_effect;
179+
180+
Ok(PropagatingEffect::Numerical(prediction))
116181
}
117182

118183
// Helper functions
119184

185+
/// Creates the factual context containing all historical data.
120186
fn get_context_with_data() -> BaseContext {
121187
let mut context = BaseContext::with_capacity(1, "Factual Context", 20);
188+
let mut id_counter = 0;
122189

123190
// Sample Data (Quarterly)
124-
// Oil Prices: Q1=50, Q2=52, Q3=55, Q4=58
125-
// Shipping Activity: Q1=100, Q2=102, Q3=105, Q4=108
126191
let data_points = vec![
127192
(0.0, 50.0, 100.0), // Q1: time, oil_price, shipping_activity
128193
(1.0, 52.0, 102.0), // Q2
@@ -131,100 +196,51 @@ fn get_context_with_data() -> BaseContext {
131196
];
132197

133198
for (time, oil_price, shipping_activity) in data_points {
199+
// Each contextoid needs a unique ID for the context graph.
200+
// The ID within the Data payload is used to identify the data type.
201+
202+
// Time data
134203
let time_datoid =
135-
Contextoid::new(TIME_ID, ContextoidType::Datoid(Data::new(TIME_ID, time)));
204+
Contextoid::new(id_counter, ContextoidType::Datoid(Data::new(TIME_ID, time)));
205+
context.add_node(time_datoid).unwrap();
206+
id_counter += 1;
207+
208+
// Oil price data
136209
let oil_price_datoid = Contextoid::new(
137-
OIL_PRICE_ID,
210+
id_counter,
138211
ContextoidType::Datoid(Data::new(OIL_PRICE_ID, oil_price)),
139212
);
213+
context.add_node(oil_price_datoid).unwrap();
214+
id_counter += 1;
215+
216+
// Shipping activity data
140217
let shipping_activity_datoid = Contextoid::new(
141-
SHIPPING_ACTIVITY_ID,
218+
id_counter,
142219
ContextoidType::Datoid(Data::new(SHIPPING_ACTIVITY_ID, shipping_activity)),
143220
);
144-
145-
context.add_node(time_datoid).unwrap();
146-
context.add_node(oil_price_datoid).unwrap();
147221
context.add_node(shipping_activity_datoid).unwrap();
222+
id_counter += 1;
148223
}
149224
context
150225
}
151226

227+
/// Creates the counterfactual context by cloning the factual one and removing oil price data.
152228
fn get_counterfactual_context(factual_context: &BaseContext) -> BaseContext {
153-
let mut control_context = factual_context.clone();
154-
155-
// Remove or zero out oil_price dataoids in the cloned context
156-
// Iterate through the nodes and update the oil_price datoids
157-
// Note: This is a simplified approach. In a real scenario, you might remove the nodes or set them to a specific baseline.
158-
for i in 0..control_context.number_of_nodes() {
159-
let node = control_context.get_node(i).unwrap();
160-
if let ContextoidType::Datoid(data_node) = node.vertex_type() {
161-
if data_node.id() == OIL_PRICE_ID {
162-
let mut updated_data = data_node.clone();
163-
updated_data.set_data(0.0); // Set oil price to 0.0 in counterfactual
164-
control_context
165-
.update_node(
166-
data_node.id(),
167-
Contextoid::new(data_node.id(), ContextoidType::Datoid(updated_data)),
168-
)
169-
.unwrap();
229+
let mut control_context = BaseContext::with_capacity(2, "Counterfactual Context", 20);
230+
231+
// Iterate through the factual context and add all nodes EXCEPT oil price nodes.
232+
for i in 0..factual_context.number_of_nodes() {
233+
if let Some(node) = factual_context.get_node(i) {
234+
let mut should_add = true;
235+
if let ContextoidType::Datoid(data_node) = node.vertex_type() {
236+
if data_node.id() == OIL_PRICE_ID {
237+
should_add = false;
238+
}
239+
}
240+
if should_add {
241+
control_context.add_node(node.clone()).unwrap();
170242
}
171243
}
172244
}
173245
control_context
174246
}
175-
176-
fn get_shipping_predictor_causaloid() -> BaseCausaloid {
177-
let predictor_id = PREDICTOR_CAUSALOID_ID;
178-
let predictor_description = "Predicts shipping activity based on historical data";
179-
180-
let causal_fn = |effect: &PropagatingEffect| -> Result<PropagatingEffect, CausalityError> {
181-
let current_time_step = match effect {
182-
PropagatingEffect::Map(map) => map
183-
.get(&TIME_ID)
184-
.and_then(|boxed_effect| boxed_effect.as_numerical())
185-
.ok_or_else(|| {
186-
CausalityError("Current time step not found in effect map".into())
187-
})?,
188-
_ => {
189-
return Err(CausalityError(
190-
"Expected Map effect for predictor causaloid".into(),
191-
));
192-
}
193-
};
194-
195-
// In a real scenario, this causaloid would query the context it's associated with
196-
// to get historical data. Since causal_fn cannot capture context directly, we simulate
197-
// context lookup by assuming the context is available via the Causaloid's own context field.
198-
// This requires the Causaloid to be initialized with a context.
199-
// For this example, we'll use a simplified model that assumes access to the context.
200-
201-
// Simulate context access and prediction logic
202-
// This is a placeholder for a more complex predictive model (e.g., linear regression)
203-
// For simplicity, we'll assume a direct lookup or a very simple model.
204-
// In a real DBN, the causaloid would query the context for historical data.
205-
// Here, we'll hardcode some logic based on the time step and assumed context data.
206-
207-
let predicted_shipping_activity = match current_time_step as u64 {
208-
4 => {
209-
// Predicting for Q5, based on Q1-Q4
210-
// This is where the causaloid would query the context for historical data
211-
// For demonstration, we'll use a simple rule based on assumed historical data
212-
// If oil price data was available (not 0.0 in the context), it would influence this.
213-
// Since we can't access the context directly here, we'll make a simplified assumption.
214-
// If oil price was present (simulated by non-zero value), predict higher.
215-
// This part is highly simplified and would be replaced by actual model inference.
216-
let assumed_oil_price_present = true; // This would come from context query
217-
if assumed_oil_price_present {
218-
105.0 + 3.0
219-
} else {
220-
105.0
221-
}
222-
}
223-
_ => 0.0, // Default for other time steps
224-
};
225-
226-
Ok(PropagatingEffect::Numerical(predicted_shipping_activity))
227-
};
228-
229-
BaseCausaloid::new(predictor_id, causal_fn, predictor_description)
230-
}

0 commit comments

Comments
 (0)