Skip to content

Commit 5938b5e

Browse files
committed
Added CATE example.
Signed-off-by: Marvin Hansen <[email protected]>
1 parent 1e02eff commit 5938b5e

File tree

3 files changed

+220
-0
lines changed

3 files changed

+220
-0
lines changed

Cargo.lock

Lines changed: 7 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

examples/epp_cate/Cargo.toml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# SPDX-License-Identifier: MIT
2+
# Copyright (c) "2025" . The DeepCausality Authors and Contributors. All Rights Reserved.
3+
4+
[package]
5+
name = "example-cate"
6+
version = "0.1.0"
7+
edition = "2021"
8+
rust-version = "1.80"
9+
publish = false
10+
11+
[dependencies]
12+
deep_causality = { path = "../../deep_causality" }

examples/epp_cate/src/main.rs

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
/*
2+
* SPDX-License-Identifier: MIT
3+
* Copyright (c) "2025" . The DeepCausality Authors and Contributors. All Rights Reserved.
4+
*/
5+
6+
use deep_causality::*;
7+
use std::sync::Arc;
8+
9+
// Define IDs for different data types within the context
10+
const AGE_ID: IdentificationValue = 1;
11+
const INITIAL_BP_ID: IdentificationValue = 2;
12+
const DRUG_ADMINISTERED_ID: IdentificationValue = 3;
13+
14+
// Define ID for the causaloid
15+
const DRUG_EFFECT_CAUSALOID_ID: IdentificationValue = 10;
16+
17+
fn main() {
18+
println!("\n--- CATE Example: Effect of Medication on Blood Pressure for Patients > 65 ---");
19+
20+
// 1. Define the population of patients
21+
let patient_population = create_patient_population();
22+
println!(
23+
"Created a population of {} patients.",
24+
patient_population.len()
25+
);
26+
27+
// 2. Select the subgroup of interest (patients over 65)
28+
let subgroup: Vec<&BaseContext> = patient_population
29+
.iter()
30+
.filter(|ctx| {
31+
for i in 0..ctx.number_of_nodes() {
32+
if let Some(node) = ctx.get_node(i) {
33+
if let ContextoidType::Datoid(data_node) = node.vertex_type() {
34+
if data_node.id() == AGE_ID && data_node.get_data() > 65.0 {
35+
return true;
36+
}
37+
}
38+
}
39+
}
40+
false
41+
})
42+
.collect();
43+
println!(
44+
"Found {} patients in the subgroup (age > 65).",
45+
subgroup.len()
46+
);
47+
48+
// 3. Run parallel counterfactuals for the subgroup
49+
let mut ites: Vec<f64> = Vec::new(); // To store Individual Treatment Effects
50+
51+
for patient_context in subgroup {
52+
let initial_bp = get_patient_bp(patient_context).unwrap_or(140.0);
53+
54+
// --- Create Counterfactual Contexts ---
55+
let mut treatment_context = patient_context.clone();
56+
let drug_datoid = Contextoid::new(
57+
DRUG_ADMINISTERED_ID,
58+
ContextoidType::Datoid(Data::new(DRUG_ADMINISTERED_ID, 1.0)), // drug_administered = true
59+
);
60+
treatment_context.add_node(drug_datoid).unwrap();
61+
62+
let mut control_context = patient_context.clone();
63+
let no_drug_datoid = Contextoid::new(
64+
DRUG_ADMINISTERED_ID,
65+
ContextoidType::Datoid(Data::new(DRUG_ADMINISTERED_ID, 0.0)), // drug_administered = false
66+
);
67+
control_context.add_node(no_drug_datoid).unwrap();
68+
69+
// --- Instantiate Causaloids for each scenario ---
70+
let treatment_causaloid = Causaloid::new_with_context(
71+
DRUG_EFFECT_CAUSALOID_ID,
72+
drug_effect_logic,
73+
Arc::new(treatment_context),
74+
"Drug effect under treatment",
75+
);
76+
77+
let control_causaloid = Causaloid::new_with_context(
78+
DRUG_EFFECT_CAUSALOID_ID,
79+
drug_effect_logic,
80+
Arc::new(control_context),
81+
"Drug effect under control",
82+
);
83+
84+
// --- Evaluate Potential Outcomes ---
85+
// The input effect is the patient's initial BP.
86+
let input_effect = PropagatingEffect::Numerical(initial_bp);
87+
88+
let y1_effect = treatment_causaloid
89+
.evaluate(&input_effect)
90+
.unwrap()
91+
.as_numerical()
92+
.unwrap();
93+
let y0_effect = control_causaloid
94+
.evaluate(&input_effect)
95+
.unwrap()
96+
.as_numerical()
97+
.unwrap();
98+
99+
let y1 = initial_bp + y1_effect; // Potential outcome if treated
100+
let y0 = initial_bp + y0_effect; // Potential outcome if not treated
101+
102+
// --- Calculate and Store ITE ---
103+
let ite = y1 - y0;
104+
ites.push(ite);
105+
}
106+
107+
// 4. Aggregate and Conclude
108+
if !ites.is_empty() {
109+
let cate: f64 = ites.iter().sum::<f64>() / ites.len() as f64;
110+
println!("\n--- CATE Calculation Result ---");
111+
println!(
112+
"The Conditional Average Treatment Effect (CATE) for patients over 65 is: {:.2}",
113+
cate
114+
);
115+
} else {
116+
println!("\nNo patients found in the subgroup to calculate CATE.");
117+
}
118+
}
119+
120+
/// The causal logic for the drug's effect.
121+
/// This function checks the context to see if the drug was administered and returns the effect on blood pressure.
122+
fn drug_effect_logic(
123+
_effect: &PropagatingEffect, // We don't need the incoming effect for this simple model
124+
context: &Arc<BaseContext>,
125+
) -> Result<PropagatingEffect, CausalityError> {
126+
let mut drug_administered = false;
127+
128+
// Search the context for the DRUG_ADMINISTERED_ID flag.
129+
for i in 0..context.number_of_nodes() {
130+
if let Some(node) = context.get_node(i) {
131+
if let ContextoidType::Datoid(data_node) = node.vertex_type() {
132+
if data_node.id() == DRUG_ADMINISTERED_ID && data_node.get_data() == 1.0 {
133+
drug_administered = true;
134+
break;
135+
}
136+
}
137+
}
138+
}
139+
140+
if drug_administered {
141+
// If the drug was given, it causes a 10-point drop in blood pressure.
142+
Ok(PropagatingEffect::Numerical(-10.0))
143+
} else {
144+
// If no drug was given, there is no effect.
145+
Ok(PropagatingEffect::Numerical(0.0))
146+
}
147+
}
148+
149+
/// Creates a sample population of patients with different ages and blood pressures.
150+
fn create_patient_population() -> Vec<BaseContext> {
151+
let mut population = Vec::new();
152+
let mut patient_id_counter = 1;
153+
154+
// Tuples of (age, initial_bp)
155+
let patient_data = vec![
156+
(55.0, 145.0),
157+
(70.0, 150.0),
158+
(68.0, 155.0),
159+
(45.0, 130.0),
160+
(80.0, 160.0),
161+
(72.0, 148.0),
162+
(60.0, 140.0),
163+
];
164+
165+
for (age, bp) in patient_data {
166+
let mut context = BaseContext::with_capacity(patient_id_counter, "Patient", 5);
167+
patient_id_counter += 1;
168+
169+
let age_datoid = Contextoid::new(
170+
patient_id_counter,
171+
ContextoidType::Datoid(Data::new(AGE_ID, age)),
172+
);
173+
context.add_node(age_datoid).unwrap();
174+
patient_id_counter += 1;
175+
176+
let bp_datoid = Contextoid::new(
177+
patient_id_counter,
178+
ContextoidType::Datoid(Data::new(INITIAL_BP_ID, bp)),
179+
);
180+
context.add_node(bp_datoid).unwrap();
181+
patient_id_counter += 1;
182+
183+
population.push(context);
184+
}
185+
186+
population
187+
}
188+
189+
/// Helper to extract the initial blood pressure from a patient's context.
190+
fn get_patient_bp(context: &BaseContext) -> Option<f64> {
191+
for i in 0..context.number_of_nodes() {
192+
if let Some(node) = context.get_node(i) {
193+
if let ContextoidType::Datoid(data_node) = node.vertex_type() {
194+
if data_node.id() == INITIAL_BP_ID {
195+
return Some(data_node.get_data());
196+
}
197+
}
198+
}
199+
}
200+
None
201+
}

0 commit comments

Comments
 (0)