Skip to content

Commit fb140d9

Browse files
committed
fix(deep_causality_uncertain): Added NodeID.
Signed-off-by: Marvin Hansen <[email protected]>
1 parent 4458b71 commit fb140d9

File tree

15 files changed

+472
-176
lines changed

15 files changed

+472
-176
lines changed

deep_causality_uncertain/src/lib.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
mod algos;
99
mod errors;
1010
mod traits;
11-
pub mod types;
11+
mod types;
1212

1313
// types
1414
pub use crate::algos::hypothesis::sprt_eval;
@@ -17,9 +17,8 @@ pub use crate::errors::UncertainError;
1717
// Traits
1818
pub use crate::traits::sampler::Sampler;
1919
pub use crate::types::cache::{GlobalSampleCache, SampledValue, with_global_cache};
20-
pub use crate::types::computation::{
21-
ArithmeticOperator, ComparisonOperator, ComputationNode, LogicalOperator,
22-
};
20+
pub use crate::types::computation::node::{ComputationNode, NodeId};
21+
pub use crate::types::computation::{ArithmeticOperator, ComparisonOperator, LogicalOperator};
2322
pub use crate::types::distribution::DistributionEnum;
2423
pub use crate::types::distribution_parameters::BernoulliParams;
2524
pub use crate::types::distribution_parameters::NormalDistributionParams;

deep_causality_uncertain/src/types/computation/mod.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
pub mod node;
66
pub mod operator;
77

8-
pub use node::ComputationNode;
98
pub use operator::arithmetic_operator::ArithmeticOperator;
109
pub use operator::comparison_operator::ComparisonOperator;
1110
pub use operator::logical_operator::LogicalOperator;
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
/*
2+
* SPDX-License-Identifier: MIT
3+
* Copyright (c) "2025" . The DeepCausality Authors and Contributors. All Rights Reserved.
4+
*/
5+
6+
use crate::{ArithmeticOperator, ComparisonOperator, LogicalOperator, NodeId};
7+
use std::sync::Arc;
8+
use std::sync::atomic::AtomicUsize;
9+
10+
pub static NEXT_NODE_ID: AtomicUsize = AtomicUsize::new(0);
11+
12+
/// Represents a node in the computation graph. This is now a single, non-generic enum.
13+
#[derive(Clone)]
14+
pub enum ComputationNode {
15+
// Leaf nodes now contain the specific distribution type directly.
16+
LeafF64 {
17+
node_id: NodeId,
18+
dist: crate::DistributionEnum<f64>,
19+
},
20+
LeafBool {
21+
node_id: NodeId,
22+
dist: crate::DistributionEnum<bool>,
23+
},
24+
25+
ArithmeticOp {
26+
node_id: NodeId,
27+
op: ArithmeticOperator,
28+
lhs: Box<ComputationNode>,
29+
rhs: Box<ComputationNode>,
30+
},
31+
ComparisonOp {
32+
node_id: NodeId,
33+
op: ComparisonOperator,
34+
threshold: f64,
35+
operand: Box<ComputationNode>,
36+
},
37+
LogicalOp {
38+
node_id: NodeId,
39+
op: LogicalOperator,
40+
operands: Vec<Box<ComputationNode>>,
41+
},
42+
FunctionOp {
43+
node_id: NodeId,
44+
func: Arc<dyn Fn(f64) -> f64 + Send + Sync>,
45+
operand: Box<ComputationNode>,
46+
},
47+
FunctionOpBool {
48+
node_id: NodeId,
49+
func: Arc<dyn Fn(f64) -> bool + Send + Sync>,
50+
operand: Box<ComputationNode>,
51+
},
52+
NegationOp {
53+
node_id: NodeId,
54+
operand: Box<ComputationNode>,
55+
},
56+
ConditionalOp {
57+
node_id: NodeId,
58+
condition: Box<ComputationNode>,
59+
if_true: Box<ComputationNode>,
60+
if_false: Box<ComputationNode>,
61+
},
62+
}

deep_causality_uncertain/src/types/computation/node/mod.rs

Lines changed: 4 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -2,46 +2,8 @@
22
* SPDX-License-Identifier: MIT
33
* Copyright (c) "2025" . The DeepCausality Authors and Contributors. All Rights Reserved.
44
*/
5-
use crate::ArithmeticOperator;
6-
use crate::ComparisonOperator;
7-
use crate::LogicalOperator;
8-
use std::sync::Arc;
5+
mod computation_node;
6+
mod node_id;
97

10-
/// Represents a node in the computation graph. This is now a single, non-generic enum.
11-
#[derive(Clone)]
12-
pub enum ComputationNode {
13-
// Leaf nodes now contain the specific distribution type directly.
14-
LeafF64(crate::DistributionEnum<f64>),
15-
LeafBool(crate::DistributionEnum<bool>),
16-
17-
ArithmeticOp {
18-
op: ArithmeticOperator,
19-
lhs: Box<ComputationNode>,
20-
rhs: Box<ComputationNode>,
21-
},
22-
ComparisonOp {
23-
op: ComparisonOperator,
24-
threshold: f64,
25-
operand: Box<ComputationNode>,
26-
},
27-
LogicalOp {
28-
op: LogicalOperator,
29-
operands: Vec<Box<ComputationNode>>,
30-
},
31-
FunctionOp {
32-
func: Arc<dyn Fn(f64) -> f64 + Send + Sync>,
33-
operand: Box<ComputationNode>,
34-
},
35-
FunctionOpBool {
36-
func: Arc<dyn Fn(f64) -> bool + Send + Sync>,
37-
operand: Box<ComputationNode>,
38-
},
39-
NegationOp {
40-
operand: Box<ComputationNode>,
41-
},
42-
ConditionalOp {
43-
condition: Box<ComputationNode>,
44-
if_true: Box<ComputationNode>,
45-
if_false: Box<ComputationNode>,
46-
},
47-
}
8+
pub use computation_node::ComputationNode;
9+
pub use node_id::NodeId;
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
/*
2+
* SPDX-License-Identifier: MIT
3+
* Copyright (c) "2025" . The DeepCausality Authors and Contributors. All Rights Reserved.
4+
*/
5+
6+
use crate::types::computation::node::computation_node::NEXT_NODE_ID;
7+
use std::sync::atomic::Ordering;
8+
9+
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
10+
pub struct NodeId(usize);
11+
12+
impl NodeId {
13+
pub fn new() -> Self {
14+
NodeId(NEXT_NODE_ID.fetch_add(1, Ordering::Relaxed))
15+
}
16+
}
17+
18+
impl Default for NodeId {
19+
fn default() -> Self {
20+
Self::new()
21+
}
22+
}

deep_causality_uncertain/src/types/sampler/sequential_sampler.rs

Lines changed: 48 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
* Copyright (c) "2025" . The DeepCausality Authors and Contributors. All Rights Reserved.
44
*/
55

6-
use crate::{ComputationNode, LogicalOperator, SampledValue, Sampler, UncertainError};
6+
use crate::ComputationNode;
7+
use crate::types::computation::node::NodeId;
8+
use crate::{LogicalOperator, SampledValue, Sampler, UncertainError};
79
use std::collections::HashMap;
810
use std::sync::Arc;
911

@@ -29,10 +31,9 @@ impl Sampler for SequentialSampler {
2931
/// - `Ok(SampledValue)` containing the sampled value if the sampling is successful.
3032
/// - `Err(UncertainError)` if an error occurs during sampling (e.g., type mismatch, distribution error).
3133
fn sample(&self, root_node: &Arc<ComputationNode>) -> Result<SampledValue, UncertainError> {
32-
let mut context: HashMap<*const ComputationNode, SampledValue> = HashMap::new();
34+
let mut context: HashMap<NodeId, SampledValue> = HashMap::new(); // Changed key type
3335
// Call the internal method.
34-
let mut rng = rand::rng();
35-
self.evaluate_node(root_node, &mut context, &mut rng)
36+
self.evaluate_node(root_node, &mut context, &mut rand::rng()) // Changed rng
3637
}
3738
}
3839

@@ -63,44 +64,55 @@ impl SequentialSampler {
6364
fn evaluate_node(
6465
&self,
6566
node: &ComputationNode,
66-
context: &mut HashMap<*const ComputationNode, SampledValue>,
67+
context: &mut HashMap<NodeId, SampledValue>, // Changed key type
6768
rng: &mut impl rand::Rng,
6869
) -> Result<SampledValue, UncertainError> {
69-
// Use raw pointer as key for memoization
70-
let ptr: *const ComputationNode = node as *const ComputationNode;
70+
// Extract node_id from the current node
71+
let current_node_id = match node {
72+
ComputationNode::LeafF64 { node_id, .. } => *node_id,
73+
ComputationNode::LeafBool { node_id, .. } => *node_id,
74+
ComputationNode::ArithmeticOp { node_id, .. } => *node_id,
75+
ComputationNode::ComparisonOp { node_id, .. } => *node_id,
76+
ComputationNode::LogicalOp { node_id, .. } => *node_id,
77+
ComputationNode::FunctionOp { node_id, .. } => *node_id,
78+
ComputationNode::FunctionOpBool { node_id, .. } => *node_id,
79+
ComputationNode::NegationOp { node_id, .. } => *node_id,
80+
ComputationNode::ConditionalOp { node_id, .. } => *node_id,
81+
};
7182

72-
if let Some(value) = context.get(&ptr) {
83+
if let Some(value) = context.get(&current_node_id) {
84+
// Use node_id
7385
return Ok(*value);
7486
}
7587

7688
let result = match node {
77-
ComputationNode::LeafF64(dist) => SampledValue::Float(dist.sample(rng)?),
78-
ComputationNode::LeafBool(dist) => SampledValue::Bool(dist.sample(rng)?),
89+
ComputationNode::LeafF64 { dist, .. } => SampledValue::Float(dist.sample(rng)?),
90+
ComputationNode::LeafBool { dist, .. } => SampledValue::Bool(dist.sample(rng)?),
7991

80-
ComputationNode::ArithmeticOp { op, lhs, rhs } => {
92+
ComputationNode::ArithmeticOp { op, lhs, rhs, .. } => { // Extract op, lhs, rhs
8193
let lhs_val = self.evaluate_node(lhs, context, rng)?;
8294
let rhs_val = self.evaluate_node(rhs, context, rng)?;
8395
match (lhs_val, rhs_val) {
8496
(SampledValue::Float(l), SampledValue::Float(r)) => {
8597
SampledValue::Float(op.apply(l, r))
8698
}
87-
_ => panic!("Type error: Arithmetic op requires float inputs"),
99+
_ => return Err(UncertainError::UnsupportedTypeError(
100+
"Arithmetic op requires float inputs".into(),
101+
)),
88102
}
89103
}
90104

91-
ComputationNode::ComparisonOp {
92-
op,
93-
threshold,
94-
operand,
95-
} => {
105+
ComputationNode::ComparisonOp { op, threshold, operand, .. } => { // Extract op, threshold, operand
96106
let operand_val = self.evaluate_node(operand, context, rng)?;
97107
match operand_val {
98108
SampledValue::Float(o) => SampledValue::Bool(op.apply(o, *threshold)),
99-
_ => panic!("Type error: Comparison op requires float input"),
109+
_ => return Err(UncertainError::UnsupportedTypeError(
110+
"Comparison op requires float input".into(),
111+
)),
100112
}
101113
}
102114

103-
ComputationNode::LogicalOp { op, operands } => {
115+
ComputationNode::LogicalOp { op, operands, .. } => { // Extract op, operands
104116
let mut vals = Vec::with_capacity(operands.len());
105117
for operand_node in operands {
106118
match self.evaluate_node(operand_node, context, rng)? {
@@ -142,38 +154,47 @@ impl SequentialSampler {
142154
SampledValue::Bool(result)
143155
}
144156

145-
ComputationNode::FunctionOp { func, operand } => {
157+
ComputationNode::FunctionOp { func, operand, .. } => { // Extract func, operand
146158
let operand_val = self.evaluate_node(operand, context, rng)?;
147159
match operand_val {
148160
SampledValue::Float(o) => SampledValue::Float(func(o)),
149-
_ => panic!("Type error: Function op requires float input"),
161+
_ => return Err(UncertainError::UnsupportedTypeError(
162+
"Function op requires float input".into(),
163+
)),
150164
}
151165
}
152166

153-
ComputationNode::NegationOp { operand } => {
167+
ComputationNode::NegationOp { operand, .. } => { // Extract operand
154168
let operand_val = self.evaluate_node(operand, context, rng)?;
155169
match operand_val {
156170
SampledValue::Float(o) => SampledValue::Float(-o),
157-
_ => panic!("Type error: Negation op requires float input"),
171+
_ => return Err(UncertainError::UnsupportedTypeError(
172+
"Negation op requires float input".into(),
173+
)),
158174
}
159175
}
160176

161-
ComputationNode::FunctionOpBool { func, operand } => {
177+
ComputationNode::FunctionOpBool { func, operand, .. } => { // Extract func, operand
162178
let operand_val = self.evaluate_node(operand, context, rng)?;
163179
match operand_val {
164180
SampledValue::Float(o) => SampledValue::Bool(func(o)),
165-
_ => panic!("Type error: FunctionOpBool requires float input"),
181+
_ => return Err(UncertainError::UnsupportedTypeError(
182+
"FunctionOpBool requires float input".into(),
183+
)),
166184
}
167185
}
168186

169187
ComputationNode::ConditionalOp {
170188
condition,
171189
if_true,
172190
if_false,
191+
.. // Extract condition, if_true, if_false
173192
} => {
174193
let condition_val = match self.evaluate_node(condition, context, rng)? {
175194
SampledValue::Bool(b) => b,
176-
_ => panic!("Type error: Conditional condition must be boolean"),
195+
_ => return Err(UncertainError::UnsupportedTypeError(
196+
"Conditional condition must be boolean".into(),
197+
)),
177198
};
178199

179200
if condition_val {
@@ -184,7 +205,7 @@ impl SequentialSampler {
184205
}
185206
};
186207

187-
context.insert(ptr, result);
208+
context.insert(current_node_id, result); // Use node_id
188209
Ok(result)
189210
}
190211
}

0 commit comments

Comments
 (0)