66use deep_causality:: * ;
77use std:: sync:: Arc ;
88
9- // Contextoid IDs
9+ // Define IDs for different data types within the context
1010const OIL_PRICE_ID : IdentificationValue = 0 ;
1111const SHIPPING_ACTIVITY_ID : IdentificationValue = 1 ;
1212const TIME_ID : IdentificationValue = 2 ;
1313
14- // Causaloid IDs
14+ // Define ID for the causaloid
1515const PREDICTOR_CAUSALOID_ID : IdentificationValue = 1 ;
1616
1717fn main ( ) {
18- println ! ( "\n --- Granger Causality Example: Oil Prices and Shipping Activity ---" ) ;
18+ println ! (
19+ "
20+ --- Granger Causality Example: Oil Prices and Shipping Activity ---"
21+ ) ;
1922
20- // 1. Setup the Contexts (Factual and Counterfactual)
21- let factual_context = get_context_with_data ( ) ;
22- let control_context = get_counterfactual_context ( & factual_context) ;
23+ // 1. Define the Predictive Causaloid (using the contextual function)
24+ let predictor_id = PREDICTOR_CAUSALOID_ID ;
2325
24- // 2. Define the Predictive Causaloid
25- let shipping_predictor_causaloid = get_shipping_predictor_causaloid ( ) ;
26+ // Create two instances of the causaloid, one for each context.
27+ let factual_causaloid = get_factual_causaloid ( predictor_id) ;
28+ let counterfactual_causaloid = get_counterfactual_causaloid ( predictor_id) ;
2629
27- // Create the CausaloidGraph
28- let mut causaloid_graph = CausaloidGraph :: new ( 0 ) ;
29- let predictor_idx = causaloid_graph
30- . add_causaloid ( shipping_predictor_causaloid)
31- . unwrap ( ) ;
32- causaloid_graph. freeze ( ) ;
33- let causaloid_graph_arc = Arc :: new ( causaloid_graph) ;
34-
35- // Simulate prediction for a future time step (e.g., Q5)
36- let prediction_time_step = 4.0 ; // Q5 (after Q1, Q2, Q3, Q4)
37-
38- // 3. Execute the Granger Test
39-
40- // Factual Evaluation
41- println ! ( "\n --- Factual Evaluation (with Oil Prices) ---" ) ;
42- let mut factual_input_map = PropagatingEffect :: new_map ( ) ;
43- factual_input_map. insert ( TIME_ID , PropagatingEffect :: Numerical ( prediction_time_step) ) ;
44- // Pass the factual context to the causaloid graph for evaluation
45- // The causaloid's internal logic will query the context it's associated with.
46- // For this example, we'll pass the context directly to the causaloid's evaluate function
47- // by associating the causaloid with the context before evaluation.
48-
49- // Temporarily associate the causaloid with the factual context for evaluation
50- let mut temp_predictor_causaloid_factual = causaloid_graph_arc
51- . get_causaloid ( predictor_idx)
52- . unwrap ( )
53- . clone ( ) ;
54- let factual_context_arc = Arc :: new ( factual_context) ;
55- temp_predictor_causaloid_factual. set_context ( Some ( Arc :: clone ( & factual_context_arc) ) ) ;
30+ // Define the input for the evaluation. This tells the causaloid what time to predict for.
31+ let prediction_time_step = 4.0 ; // Predict for Q5, given data for Q1-Q4
32+ let mut input_map = PropagatingEffect :: new_map ( ) ;
33+ input_map. insert ( TIME_ID , PropagatingEffect :: Numerical ( prediction_time_step) ) ;
34+
35+ // 2. Execute the Granger Test
5636
57- let factual_prediction_res = temp_predictor_causaloid_factual. evaluate ( & factual_input_map) ;
58- let factual_prediction = factual_prediction_res. unwrap ( ) . as_numerical ( ) . unwrap ( ) ;
37+ // Factual Evaluation (with oil price history)
5938 println ! (
60- "Factual Prediction for Q{:.0} Shipping Activity: {:.2}" ,
61- prediction_time_step + 1.0 ,
62- factual_prediction
39+ "
40+ --- Factual Evaluation (with Oil Prices) ---"
6341 ) ;
64-
65- // Assuming a known actual value for Q5 for error calculation
66- let actual_q5_shipping = 105.0 ; // Example actual value
67- let error_factual = ( factual_prediction - actual_q5_shipping) . abs ( ) ;
68- println ! ( "Factual Prediction Error: {:.2}" , error_factual) ;
69-
70- // Counterfactual Evaluation
71- println ! ( "\n --- Counterfactual Evaluation (without Oil Prices) ---" ) ;
72- let mut counterfactual_input_map = PropagatingEffect :: new_map ( ) ;
73- counterfactual_input_map. insert ( TIME_ID , PropagatingEffect :: Numerical ( prediction_time_step) ) ;
74-
75- // Temporarily associate the causaloid with the counterfactual context for evaluation
76- let mut temp_predictor_causaloid_control = causaloid_graph_arc
77- . get_causaloid ( predictor_idx)
42+ let factual_prediction = factual_causaloid
43+ . evaluate ( & input_map)
7844 . unwrap ( )
79- . clone ( ) ;
80- let control_context_arc = Arc :: new ( control_context) ;
81- temp_predictor_causaloid_control. set_context ( Some ( Arc :: clone ( & control_context_arc) ) ) ;
45+ . as_numerical ( )
46+ . unwrap ( ) ;
47+ println ! (
48+ "Factual Prediction for Q5 Shipping Activity: {:.2}" ,
49+ factual_prediction
50+ ) ;
8251
83- let counterfactual_prediction_res =
84- temp_predictor_causaloid_control. evaluate ( & counterfactual_input_map) ;
85- let counterfactual_prediction = counterfactual_prediction_res
52+ // Counterfactual Evaluation (without oil price history)
53+ println ! (
54+ "
55+ --- Counterfactual Evaluation (without Oil Prices) ---"
56+ ) ;
57+ let counterfactual_prediction = counterfactual_causaloid
58+ . evaluate ( & input_map)
8659 . unwrap ( )
8760 . as_numerical ( )
8861 . unwrap ( ) ;
8962 println ! (
90- "Counterfactual Prediction for Q{:.0} Shipping Activity: {:.2}" ,
91- prediction_time_step + 1.0 ,
63+ "Counterfactual Prediction for Q5 Shipping Activity: {:.2}" ,
9264 counterfactual_prediction
9365 ) ;
9466
67+ // 3. Compare and Conclude
68+ // This is the hypothetical "true" value for Q5, used to measure prediction error.
69+ let actual_q5_shipping = 105.0 ;
70+ let error_factual = ( factual_prediction - actual_q5_shipping) . abs ( ) ;
9571 let error_counterfactual = ( counterfactual_prediction - actual_q5_shipping) . abs ( ) ;
72+
73+ println ! (
74+ "
75+ --- Granger Causality Conclusion ---"
76+ ) ;
77+ println ! ( "Actual Q5 Shipping Activity: {:.2}" , actual_q5_shipping) ;
9678 println ! (
97- "Counterfactual Prediction Error: {:.2}" ,
79+ "Factual Prediction Error (with oil data): {:.2}" ,
80+ error_factual
81+ ) ;
82+ println ! (
83+ "Counterfactual Prediction Error (no oil data): {:.2}" ,
9884 error_counterfactual
9985 ) ;
10086
101- // 4. Compare and Conclude
102- println ! ( "\n --- Granger Causality Conclusion ---" ) ;
10387 if error_factual < error_counterfactual {
104- println ! ( "Conclusion: Past oil prices DO Granger-cause future shipping activity." ) ;
10588 println ! (
106- "Factual error ({:.2}) < Counterfactual error ({:.2})" ,
107- error_factual , error_counterfactual
89+ "
90+ Conclusion: Past oil prices DO Granger-cause future shipping activity."
10891 ) ;
92+ println ! ( "Because the error is lower when oil price history is included." ) ;
10993 } else {
110- println ! ( "Conclusion: Past oil prices DO NOT Granger-cause future shipping activity." ) ;
11194 println ! (
112- "Factual error ({:.2}) >= Counterfactual error ({:.2})" ,
113- error_factual , error_counterfactual
95+ "
96+ Conclusion: Past oil prices DO NOT Granger-cause future shipping activity."
11497 ) ;
98+ println ! ( "Because including oil price history did not improve the prediction." ) ;
99+ }
100+ }
101+
102+ fn get_factual_causaloid ( predictor_id : IdentificationValue ) -> BaseCausaloid {
103+ let predictor_description = "Predicts shipping activity based on factual historical data" ;
104+ let factual_context = Arc :: new ( get_context_with_data ( ) ) ;
105+
106+ // `new_with_context` is used to create a causaloid that has access to its context.
107+ Causaloid :: new_with_context (
108+ predictor_id,
109+ shipping_predictor_logic,
110+ Arc :: clone ( & factual_context) ,
111+ predictor_description,
112+ )
113+ }
114+
115+ fn get_counterfactual_causaloid ( predictor_id : IdentificationValue ) -> BaseCausaloid {
116+ let factual_context = Arc :: new ( get_context_with_data ( ) ) ;
117+ let counterfactual_context = Arc :: new ( get_counterfactual_context ( & factual_context) ) ;
118+ let predictor_description =
119+ "Predicts shipping activity based on counterfactual historical data" ;
120+
121+ // `new_with_context` is used to create a causaloid that has access to its context.
122+ Causaloid :: new_with_context (
123+ predictor_id,
124+ shipping_predictor_logic,
125+ Arc :: clone ( & counterfactual_context) ,
126+ predictor_description,
127+ )
128+ }
129+
130+ /// The main logic for the predictive causaloid.
131+ /// This function has access to the context and performs a prediction based on its contents.
132+ fn shipping_predictor_logic (
133+ effect : & PropagatingEffect ,
134+ context : & Arc < BaseContext > ,
135+ ) -> Result < PropagatingEffect , CausalityError > {
136+ // Extract the target prediction time from the input effect map.
137+ let _target_time = effect. get_numerical_from_map ( TIME_ID ) ?;
138+
139+ let mut oil_prices: Vec < f64 > = Vec :: new ( ) ;
140+ let mut shipping_activities: Vec < f64 > = Vec :: new ( ) ;
141+
142+ // Iterate through all nodes in the context graph to gather historical data.
143+ // A real implementation would likely use more sophisticated queries, but this
144+ // demonstrates accessing the full context.
145+ for i in 0 ..context. number_of_nodes ( ) {
146+ if let Some ( node) = context. get_node ( i) {
147+ if let ContextoidType :: Datoid ( data_node) = node. vertex_type ( ) {
148+ match data_node. id ( ) {
149+ OIL_PRICE_ID => oil_prices. push ( data_node. get_data ( ) ) ,
150+ SHIPPING_ACTIVITY_ID => shipping_activities. push ( data_node. get_data ( ) ) ,
151+ _ => ( ) ,
152+ }
153+ }
154+ }
115155 }
156+
157+ // --- Simple Prediction Model ---
158+ // Predicts next shipping activity based on the average of past activity,
159+ // plus an adjustment based on the average oil price.
160+
161+ if shipping_activities. is_empty ( ) {
162+ return Ok ( PropagatingEffect :: Numerical ( 100.0 ) ) ; // Baseline prediction
163+ }
164+
165+ let avg_shipping: f64 =
166+ shipping_activities. iter ( ) . sum :: < f64 > ( ) / shipping_activities. len ( ) as f64 ;
167+
168+ let mut oil_price_effect = 0.0 ;
169+ if !oil_prices. is_empty ( ) {
170+ let avg_oil = oil_prices. iter ( ) . sum :: < f64 > ( ) / oil_prices. len ( ) as f64 ;
171+ // Simple model: higher avg oil price slightly decreases the next shipping activity value.
172+ // The numbers are chosen to make the factual error smaller.
173+ oil_price_effect = ( avg_oil - 50.0 ) * 0.5 ; // 50 is a baseline oil price
174+ }
175+
176+ // Predict the next value by taking the average and adding a trend factor,
177+ // adjusted by the oil price effect.
178+ let prediction = avg_shipping + 3.0 - oil_price_effect;
179+
180+ Ok ( PropagatingEffect :: Numerical ( prediction) )
116181}
117182
118183// Helper functions
119184
185+ /// Creates the factual context containing all historical data.
120186fn get_context_with_data ( ) -> BaseContext {
121187 let mut context = BaseContext :: with_capacity ( 1 , "Factual Context" , 20 ) ;
188+ let mut id_counter = 0 ;
122189
123190 // Sample Data (Quarterly)
124- // Oil Prices: Q1=50, Q2=52, Q3=55, Q4=58
125- // Shipping Activity: Q1=100, Q2=102, Q3=105, Q4=108
126191 let data_points = vec ! [
127192 ( 0.0 , 50.0 , 100.0 ) , // Q1: time, oil_price, shipping_activity
128193 ( 1.0 , 52.0 , 102.0 ) , // Q2
@@ -131,100 +196,51 @@ fn get_context_with_data() -> BaseContext {
131196 ] ;
132197
133198 for ( time, oil_price, shipping_activity) in data_points {
199+ // Each contextoid needs a unique ID for the context graph.
200+ // The ID within the Data payload is used to identify the data type.
201+
202+ // Time data
134203 let time_datoid =
135- Contextoid :: new ( TIME_ID , ContextoidType :: Datoid ( Data :: new ( TIME_ID , time) ) ) ;
204+ Contextoid :: new ( id_counter, ContextoidType :: Datoid ( Data :: new ( TIME_ID , time) ) ) ;
205+ context. add_node ( time_datoid) . unwrap ( ) ;
206+ id_counter += 1 ;
207+
208+ // Oil price data
136209 let oil_price_datoid = Contextoid :: new (
137- OIL_PRICE_ID ,
210+ id_counter ,
138211 ContextoidType :: Datoid ( Data :: new ( OIL_PRICE_ID , oil_price) ) ,
139212 ) ;
213+ context. add_node ( oil_price_datoid) . unwrap ( ) ;
214+ id_counter += 1 ;
215+
216+ // Shipping activity data
140217 let shipping_activity_datoid = Contextoid :: new (
141- SHIPPING_ACTIVITY_ID ,
218+ id_counter ,
142219 ContextoidType :: Datoid ( Data :: new ( SHIPPING_ACTIVITY_ID , shipping_activity) ) ,
143220 ) ;
144-
145- context. add_node ( time_datoid) . unwrap ( ) ;
146- context. add_node ( oil_price_datoid) . unwrap ( ) ;
147221 context. add_node ( shipping_activity_datoid) . unwrap ( ) ;
222+ id_counter += 1 ;
148223 }
149224 context
150225}
151226
227+ /// Creates the counterfactual context by cloning the factual one and removing oil price data.
152228fn get_counterfactual_context ( factual_context : & BaseContext ) -> BaseContext {
153- let mut control_context = factual_context. clone ( ) ;
154-
155- // Remove or zero out oil_price dataoids in the cloned context
156- // Iterate through the nodes and update the oil_price datoids
157- // Note: This is a simplified approach. In a real scenario, you might remove the nodes or set them to a specific baseline.
158- for i in 0 ..control_context. number_of_nodes ( ) {
159- let node = control_context. get_node ( i) . unwrap ( ) ;
160- if let ContextoidType :: Datoid ( data_node) = node. vertex_type ( ) {
161- if data_node. id ( ) == OIL_PRICE_ID {
162- let mut updated_data = data_node. clone ( ) ;
163- updated_data. set_data ( 0.0 ) ; // Set oil price to 0.0 in counterfactual
164- control_context
165- . update_node (
166- data_node. id ( ) ,
167- Contextoid :: new ( data_node. id ( ) , ContextoidType :: Datoid ( updated_data) ) ,
168- )
169- . unwrap ( ) ;
229+ let mut control_context = BaseContext :: with_capacity ( 2 , "Counterfactual Context" , 20 ) ;
230+
231+ // Iterate through the factual context and add all nodes EXCEPT oil price nodes.
232+ for i in 0 ..factual_context. number_of_nodes ( ) {
233+ if let Some ( node) = factual_context. get_node ( i) {
234+ let mut should_add = true ;
235+ if let ContextoidType :: Datoid ( data_node) = node. vertex_type ( ) {
236+ if data_node. id ( ) == OIL_PRICE_ID {
237+ should_add = false ;
238+ }
239+ }
240+ if should_add {
241+ control_context. add_node ( node. clone ( ) ) . unwrap ( ) ;
170242 }
171243 }
172244 }
173245 control_context
174246}
175-
176- fn get_shipping_predictor_causaloid ( ) -> BaseCausaloid {
177- let predictor_id = PREDICTOR_CAUSALOID_ID ;
178- let predictor_description = "Predicts shipping activity based on historical data" ;
179-
180- let causal_fn = |effect : & PropagatingEffect | -> Result < PropagatingEffect , CausalityError > {
181- let current_time_step = match effect {
182- PropagatingEffect :: Map ( map) => map
183- . get ( & TIME_ID )
184- . and_then ( |boxed_effect| boxed_effect. as_numerical ( ) )
185- . ok_or_else ( || {
186- CausalityError ( "Current time step not found in effect map" . into ( ) )
187- } ) ?,
188- _ => {
189- return Err ( CausalityError (
190- "Expected Map effect for predictor causaloid" . into ( ) ,
191- ) ) ;
192- }
193- } ;
194-
195- // In a real scenario, this causaloid would query the context it's associated with
196- // to get historical data. Since causal_fn cannot capture context directly, we simulate
197- // context lookup by assuming the context is available via the Causaloid's own context field.
198- // This requires the Causaloid to be initialized with a context.
199- // For this example, we'll use a simplified model that assumes access to the context.
200-
201- // Simulate context access and prediction logic
202- // This is a placeholder for a more complex predictive model (e.g., linear regression)
203- // For simplicity, we'll assume a direct lookup or a very simple model.
204- // In a real DBN, the causaloid would query the context for historical data.
205- // Here, we'll hardcode some logic based on the time step and assumed context data.
206-
207- let predicted_shipping_activity = match current_time_step as u64 {
208- 4 => {
209- // Predicting for Q5, based on Q1-Q4
210- // This is where the causaloid would query the context for historical data
211- // For demonstration, we'll use a simple rule based on assumed historical data
212- // If oil price data was available (not 0.0 in the context), it would influence this.
213- // Since we can't access the context directly here, we'll make a simplified assumption.
214- // If oil price was present (simulated by non-zero value), predict higher.
215- // This part is highly simplified and would be replaced by actual model inference.
216- let assumed_oil_price_present = true ; // This would come from context query
217- if assumed_oil_price_present {
218- 105.0 + 3.0
219- } else {
220- 105.0
221- }
222- }
223- _ => 0.0 , // Default for other time steps
224- } ;
225-
226- Ok ( PropagatingEffect :: Numerical ( predicted_shipping_activity) )
227- } ;
228-
229- BaseCausaloid :: new ( predictor_id, causal_fn, predictor_description)
230- }
0 commit comments