Skip to content

Commit 030859d

Browse files
committed
reworked evaluate_probabilistic_propagation and evaluate_mixed_propagation in causable_reasoning.rs
Fixed downstream tests. Signed-off-by: Marvin Hansen <[email protected]>
1 parent 56fe027 commit 030859d

File tree

8 files changed

+108
-102
lines changed

8 files changed

+108
-102
lines changed

deep_causality/src/traits/causable/causable_reasoning.rs

Lines changed: 31 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -121,79 +121,61 @@ where
121121
}
122122

123123
// Convert final probability to a deterministic outcome based on a threshold.
124-
if cumulative_prob > 0.5 {
125-
Ok(PropagatingEffect::Deterministic(true))
126-
} else {
127-
Ok(PropagatingEffect::Deterministic(false))
128-
}
124+
Ok(PropagatingEffect::Probabilistic(cumulative_prob))
129125
}
130126

131127
/// Evaluates a linear chain of causes that may contain a mix of deterministic and
132-
/// probabilistic effects, aggregating them into a final effect.
128+
/// probabilistic effects, aggregating them into a final deterministic outcome.
129+
///
130+
/// This method converts all effects (`Deterministic`, `Probabilistic`, `Numerical`)
131+
/// into a numerical value (where true=1.0, false=0.0) and aggregates them by
132+
/// multiplication. The final cumulative probability is then compared against a
133+
/// threshold (0.5) to produce a final `Deterministic(true)` or `Deterministic(false)`.
134+
///
135+
/// This approach is robust, order-independent, and provides a consistent result.
133136
///
134137
/// # Arguments
135138
/// * `effect` - A single `PropagatingEffect` object that all causes will use.
136139
///
137140
/// # Errors
138-
/// Returns a `CausalityError` if a `ContextualLink` is encountered.
141+
/// Returns a `CausalityError` if a `ContextualLink` is encountered, as it cannot be
142+
/// converted to a numerical probability.
139143
fn evaluate_mixed_propagation(
140144
&self,
141145
effect: &PropagatingEffect,
142146
_logic: &AggregateLogic,
143147
) -> Result<PropagatingEffect, CausalityError> {
144-
// The chain starts as deterministically true. It can transition to probabilistic.
145-
let mut aggregated_effect = PropagatingEffect::Deterministic(true);
148+
// Start with 1.0, the multiplicative identity, to aggregate all effects numerically.
149+
let mut cumulative_prob: NumericalValue = 1.0;
146150

147151
for cause in self.get_all_items() {
148152
let current_effect = cause.evaluate(effect)?;
149153

150-
// Update the aggregated effect based on the current effect.
151-
aggregated_effect = match (aggregated_effect, current_effect) {
152-
// Deterministic false breaks the chain.
153-
(_, PropagatingEffect::Deterministic(false)) => {
154-
return Ok(PropagatingEffect::Deterministic(false));
155-
}
156-
157-
// ContextualLink is invalid in this context.
158-
(_, PropagatingEffect::ContextualLink(_, _)) => {
159-
return Err(CausalityError(
160-
"Encountered a ContextualLink in a mixed-chain evaluation.".into(),
161-
));
162-
}
163-
164-
// If the chain is deterministic and the new effect is true, it remains deterministic true.
165-
(
166-
PropagatingEffect::Deterministic(true),
167-
PropagatingEffect::Deterministic(true),
168-
) => PropagatingEffect::Deterministic(true),
169-
170-
// If the chain is deterministic and the new effect is probabilistic, the chain becomes probabilistic.
171-
(PropagatingEffect::Deterministic(true), PropagatingEffect::Probabilistic(p)) => {
172-
PropagatingEffect::Probabilistic(p)
173-
}
174-
175-
// If the chain is already probabilistic and the new effect is true, the probability is unchanged.
176-
(PropagatingEffect::Probabilistic(p), PropagatingEffect::Deterministic(true)) => {
177-
PropagatingEffect::Probabilistic(p)
178-
}
179-
180-
// If the chain is probabilistic and the new effect is also probabilistic, multiply them.
181-
(PropagatingEffect::Probabilistic(p1), PropagatingEffect::Probabilistic(p2)) => {
182-
PropagatingEffect::Probabilistic(p1 * p2)
183-
}
184-
185-
// Other combinations should not be possible due to the guards above.
186-
(agg, curr) => {
154+
// Convert every effect to a numerical probability to ensure consistent, order-independent aggregation.
155+
let current_prob = match current_effect {
156+
PropagatingEffect::Deterministic(true) => 1.0,
157+
PropagatingEffect::Deterministic(false) => 0.0,
158+
PropagatingEffect::Probabilistic(p) => p,
159+
PropagatingEffect::Numerical(p) => p,
160+
// .
161+
_ => {
162+
// Other variants are not handled in this mode.
187163
return Err(CausalityError(format!(
188-
"Unhandled effect combination in mixed chain: Agg: {agg:?}, Curr: {curr:?}"
164+
"evaluate_mixed_propagation encountered an unsupported effect: {effect:?}. Only probabilistic, deterministic, or numerical effects are allowed."
189165
)));
190166
}
191167
};
168+
169+
cumulative_prob *= current_prob;
192170
}
193171

194-
Ok(aggregated_effect)
172+
// Convert the final aggregated probability to a deterministic outcome based on a standard threshold.
173+
if cumulative_prob > 0.5 {
174+
Ok(PropagatingEffect::Deterministic(true))
175+
} else {
176+
Ok(PropagatingEffect::Deterministic(false))
177+
}
195178
}
196-
197179
/// Generates an explanation by concatenating the `explain()` text of all causes.
198180
///
199181
/// Each explanation is formatted and separated by newlines.

deep_causality/src/utils_test/test_utils.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,18 @@ pub fn get_test_inf_vec() -> Vec<Inference> {
3434
Vec::from_iter([i1, i2])
3535
}
3636

37-
pub fn get_test_causality_vec() -> BaseCausaloidVec {
37+
pub fn get_deterministic_test_causality_vec() -> BaseCausaloidVec {
3838
let q1 = get_test_causaloid_deterministic();
3939
let q2 = get_test_causaloid_deterministic();
4040
let q3 = get_test_causaloid_deterministic();
4141
Vec::from_iter([q1, q2, q3])
4242
}
43+
pub fn get_probabilistic_test_causality_vec() -> BaseCausaloidVec {
44+
let q1 = get_test_causaloid_probabilistic();
45+
let q2 = get_test_causaloid_probabilistic();
46+
let q3 = get_test_causaloid_probabilistic();
47+
Vec::from_iter([q1, q2, q3])
48+
}
4349

4450
pub fn get_test_single_data(val: NumericalValue) -> PropagatingEffect {
4551
PropagatingEffect::Numerical(val)
@@ -87,7 +93,7 @@ pub fn get_test_causaloid_probabilistic() -> BaseCausaloid {
8793

8894
// If it's the Probabilistic, extract the inner value.
8995
PropagatingEffect::Probabilistic(val) => *val,
90-
96+
9197
// For any other type of effect, this function cannot proceed, so return an error.
9298
_ => return Err(CausalityError(
9399
"Causal function expected Numerical effect but received a different variant."

deep_causality/tests/extensions/causable/causable_arr_tests.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ pub fn get_test_causality_array_mixed() -> [BaseCausaloid; 20] {
2424

2525
// Combine a1 and a2
2626
a1.into_iter()
27-
.chain(a2.into_iter())
27+
.chain(a2)
2828
.collect::<Vec<_>>()
2929
.try_into()
3030
.unwrap()

deep_causality/tests/extensions/causable/causable_btree_map_tests.rs

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,22 @@ use deep_causality::*;
1212
type TestBTreeMap = BTreeMap<i8, BaseCausaloid>;
1313

1414
// Helper function to create a standard test BTreeMap.
15-
fn get_test_causality_btree_map() -> TestBTreeMap {
15+
fn get_test_causality_btree_map_deterministic() -> TestBTreeMap {
1616
BTreeMap::from([
1717
(1, get_test_causaloid_deterministic()),
1818
(2, get_test_causaloid_deterministic()),
1919
(3, get_test_causaloid_deterministic()),
2020
])
2121
}
2222

23+
fn get_test_causality_btree_map_probabilistic() -> TestBTreeMap {
24+
BTreeMap::from([
25+
(1, get_test_causaloid_probabilistic()),
26+
(2, get_test_causaloid_probabilistic()),
27+
(3, get_test_causaloid_probabilistic()),
28+
])
29+
}
30+
2331
// Helper to activate all causes in a collection for testing purposes.
2432
fn activate_all_causes(map: &TestBTreeMap) {
2533
// A value that ensures the default test causaloid (threshold 0.55) becomes active.
@@ -32,7 +40,7 @@ fn activate_all_causes(map: &TestBTreeMap) {
3240

3341
#[test]
3442
fn test_add() {
35-
let mut map = get_test_causality_btree_map();
43+
let mut map = get_test_causality_btree_map_deterministic();
3644
assert_eq!(3, map.len());
3745

3846
let q = get_test_causaloid_deterministic();
@@ -42,7 +50,7 @@ fn test_add() {
4250

4351
#[test]
4452
fn test_contains() {
45-
let mut map = get_test_causality_btree_map();
53+
let mut map = get_test_causality_btree_map_deterministic();
4654
assert_eq!(3, map.len());
4755
assert!(map.contains_key(&1));
4856

@@ -54,7 +62,7 @@ fn test_contains() {
5462

5563
#[test]
5664
fn test_remove() {
57-
let mut map = get_test_causality_btree_map();
65+
let mut map = get_test_causality_btree_map_deterministic();
5866
assert_eq!(3, map.len());
5967
assert!(map.contains_key(&1));
6068

@@ -65,7 +73,7 @@ fn test_remove() {
6573

6674
#[test]
6775
fn test_get_all_items() {
68-
let col = get_test_causality_btree_map();
76+
let col = get_test_causality_btree_map_deterministic();
6977
let all_items = col.get_all_items();
7078

7179
let exp_len = col.len();
@@ -75,7 +83,7 @@ fn test_get_all_items() {
7583

7684
#[test]
7785
fn test_evaluate_deterministic_propagation() {
78-
let map = get_test_causality_btree_map();
86+
let map = get_test_causality_btree_map_deterministic();
7987

8088
// Case 1: All succeed, chain should be deterministically true.
8189
let effect_success = PropagatingEffect::Numerical(0.99);
@@ -94,7 +102,7 @@ fn test_evaluate_deterministic_propagation() {
94102

95103
#[test]
96104
fn test_evaluate_probabilistic_propagation() {
97-
let map = get_test_causality_btree_map();
105+
let map = get_test_causality_btree_map_probabilistic();
98106

99107
// Case 1: All succeed (Deterministic(true) is treated as probability 1.0).
100108
// The cumulative probability should be 1.0.
@@ -115,7 +123,7 @@ fn test_evaluate_probabilistic_propagation() {
115123

116124
#[test]
117125
fn test_evaluate_mixed_propagation() {
118-
let map = get_test_causality_btree_map();
126+
let map = get_test_causality_btree_map_deterministic();
119127

120128
// Case 1: All succeed, chain remains deterministically true.
121129
let effect_success = PropagatingEffect::Numerical(0.99);
@@ -134,7 +142,7 @@ fn test_evaluate_mixed_propagation() {
134142

135143
#[test]
136144
fn test_explain() {
137-
let map = get_test_causality_btree_map();
145+
let map = get_test_causality_btree_map_deterministic();
138146
activate_all_causes(&map);
139147

140148
let single_explanation = "Causaloid: 1 'tests whether data exceeds threshold of 0.55' evaluated to: PropagatingEffect::Deterministic(true)";
@@ -148,18 +156,18 @@ fn test_explain() {
148156

149157
#[test]
150158
fn test_len() {
151-
let map = get_test_causality_btree_map();
159+
let map = get_test_causality_btree_map_deterministic();
152160
assert_eq!(3, map.len());
153161
}
154162

155163
#[test]
156164
fn test_is_empty() {
157-
let map = get_test_causality_btree_map();
165+
let map = get_test_causality_btree_map_deterministic();
158166
assert!(!map.is_empty());
159167
}
160168

161169
#[test]
162170
fn test_to_vec() {
163-
let map = get_test_causality_btree_map();
171+
let map = get_test_causality_btree_map_deterministic();
164172
assert_eq!(3, map.to_vec().len());
165173
}

deep_causality/tests/extensions/causable/causable_map_tests.rs

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,22 @@ use deep_causality::*;
1212
type TestHashMap = HashMap<i8, BaseCausaloid>;
1313

1414
// Helper function to create a standard test HashMap.
15-
fn get_test_causality_map() -> TestHashMap {
15+
fn get_deterministic_test_causality_map() -> TestHashMap {
1616
HashMap::from([
1717
(1, get_test_causaloid_deterministic()),
1818
(2, get_test_causaloid_deterministic()),
1919
(3, get_test_causaloid_deterministic()),
2020
])
2121
}
2222

23+
fn get_probabilistic_test_causality_map() -> TestHashMap {
24+
HashMap::from([
25+
(1, get_test_causaloid_probabilistic()),
26+
(2, get_test_causaloid_probabilistic()),
27+
(3, get_test_causaloid_probabilistic()),
28+
])
29+
}
30+
2331
fn get_mixed_test_causality_map() -> TestHashMap {
2432
HashMap::from([
2533
(1, get_test_causaloid_deterministic_true()),
@@ -48,7 +56,7 @@ fn activate_all_causes(map: &TestHashMap) {
4856

4957
#[test]
5058
fn test_add() {
51-
let mut map = get_test_causality_map();
59+
let mut map = get_deterministic_test_causality_map();
5260
assert_eq!(3, map.len());
5361

5462
let q = get_test_causaloid_deterministic();
@@ -58,7 +66,7 @@ fn test_add() {
5866

5967
#[test]
6068
fn test_contains() {
61-
let mut map = get_test_causality_map();
69+
let mut map = get_deterministic_test_causality_map();
6270
assert_eq!(3, map.len());
6371
assert!(map.contains_key(&1));
6472

@@ -70,7 +78,7 @@ fn test_contains() {
7078

7179
#[test]
7280
fn test_remove() {
73-
let mut map = get_test_causality_map();
81+
let mut map = get_deterministic_test_causality_map();
7482
assert_eq!(3, map.len());
7583
assert!(map.contains_key(&1));
7684

@@ -81,7 +89,7 @@ fn test_remove() {
8189

8290
#[test]
8391
fn test_get_all_items() {
84-
let col = get_test_causality_map();
92+
let col = get_deterministic_test_causality_map();
8593
let all_items = col.get_all_items();
8694

8795
let exp_len = col.len();
@@ -91,7 +99,7 @@ fn test_get_all_items() {
9199

92100
#[test]
93101
fn test_evaluate_deterministic_propagation() {
94-
let map = get_test_causality_map();
102+
let map = get_deterministic_test_causality_map();
95103

96104
// Case 1: All succeed, chain should be deterministically true.
97105
let effect_success = PropagatingEffect::Numerical(0.99);
@@ -110,7 +118,7 @@ fn test_evaluate_deterministic_propagation() {
110118

111119
#[test]
112120
fn test_evaluate_probabilistic_propagation() {
113-
let map = get_test_causality_map();
121+
let map = get_probabilistic_test_causality_map();
114122

115123
// Case 1: All succeed (Deterministic(true) is treated as probability 1.0).
116124
// The cumulative probability should be 1.0.
@@ -138,8 +146,8 @@ fn test_evaluate_mixed_propagation() {
138146
let res_success = map
139147
.evaluate_mixed_propagation(&effect_success, &AggregateLogic::All)
140148
.unwrap();
141-
// This is false b/c the AggregateLogic fails i.e. not all causes evaluate
142-
assert_eq!(res_success, PropagatingEffect::Probabilistic(0.0));
149+
// All mixed cased evaluate
150+
assert_eq!(res_success, PropagatingEffect::Deterministic(true));
143151
}
144152

145153
#[test]
@@ -154,7 +162,7 @@ fn test_evaluate_mixed_propagation_err() {
154162

155163
#[test]
156164
fn test_explain() {
157-
let map = get_test_causality_map();
165+
let map = get_deterministic_test_causality_map();
158166
activate_all_causes(&map);
159167

160168
let single_explanation = "Causaloid: 1 'tests whether data exceeds threshold of 0.55' evaluated to: PropagatingEffect::Deterministic(true)";
@@ -167,18 +175,18 @@ fn test_explain() {
167175

168176
#[test]
169177
fn test_len() {
170-
let map = get_test_causality_map();
178+
let map = get_deterministic_test_causality_map();
171179
assert_eq!(3, map.len());
172180
}
173181

174182
#[test]
175183
fn test_is_empty() {
176-
let map = get_test_causality_map();
184+
let map = get_deterministic_test_causality_map();
177185
assert!(!map.is_empty());
178186
}
179187

180188
#[test]
181189
fn test_to_vec() {
182-
let map = get_test_causality_map();
190+
let map = get_deterministic_test_causality_map();
183191
assert_eq!(3, map.to_vec().len());
184192
}

0 commit comments

Comments
 (0)