33 * Copyright (c) "2025" . The DeepCausality Authors and Contributors. All Rights Reserved.
44 */
55
6- use crate :: { ComputationNode , LogicalOperator , SampledValue , Sampler , UncertainError } ;
6+ use crate :: ComputationNode ;
7+ use crate :: types:: computation:: node:: NodeId ;
8+ use crate :: { LogicalOperator , SampledValue , Sampler , UncertainError } ;
79use std:: collections:: HashMap ;
810use std:: sync:: Arc ;
911
@@ -29,10 +31,9 @@ impl Sampler for SequentialSampler {
2931 /// - `Ok(SampledValue)` containing the sampled value if the sampling is successful.
3032 /// - `Err(UncertainError)` if an error occurs during sampling (e.g., type mismatch, distribution error).
3133 fn sample ( & self , root_node : & Arc < ComputationNode > ) -> Result < SampledValue , UncertainError > {
32- let mut context: HashMap < * const ComputationNode , SampledValue > = HashMap :: new ( ) ;
34+ let mut context: HashMap < NodeId , SampledValue > = HashMap :: new ( ) ; // Changed key type
3335 // Call the internal method.
34- let mut rng = rand:: rng ( ) ;
35- self . evaluate_node ( root_node, & mut context, & mut rng)
36+ self . evaluate_node ( root_node, & mut context, & mut rand:: rng ( ) ) // Changed rng
3637 }
3738}
3839
@@ -63,44 +64,55 @@ impl SequentialSampler {
6364 fn evaluate_node (
6465 & self ,
6566 node : & ComputationNode ,
66- context : & mut HashMap < * const ComputationNode , SampledValue > ,
67+ context : & mut HashMap < NodeId , SampledValue > , // Changed key type
6768 rng : & mut impl rand:: Rng ,
6869 ) -> Result < SampledValue , UncertainError > {
69- // Use raw pointer as key for memoization
70- let ptr: * const ComputationNode = node as * const ComputationNode ;
70+ // Extract node_id from the current node
71+ let current_node_id = match node {
72+ ComputationNode :: LeafF64 { node_id, .. } => * node_id,
73+ ComputationNode :: LeafBool { node_id, .. } => * node_id,
74+ ComputationNode :: ArithmeticOp { node_id, .. } => * node_id,
75+ ComputationNode :: ComparisonOp { node_id, .. } => * node_id,
76+ ComputationNode :: LogicalOp { node_id, .. } => * node_id,
77+ ComputationNode :: FunctionOp { node_id, .. } => * node_id,
78+ ComputationNode :: FunctionOpBool { node_id, .. } => * node_id,
79+ ComputationNode :: NegationOp { node_id, .. } => * node_id,
80+ ComputationNode :: ConditionalOp { node_id, .. } => * node_id,
81+ } ;
7182
72- if let Some ( value) = context. get ( & ptr) {
83+ if let Some ( value) = context. get ( & current_node_id) {
84+ // Use node_id
7385 return Ok ( * value) ;
7486 }
7587
7688 let result = match node {
77- ComputationNode :: LeafF64 ( dist) => SampledValue :: Float ( dist. sample ( rng) ?) ,
78- ComputationNode :: LeafBool ( dist) => SampledValue :: Bool ( dist. sample ( rng) ?) ,
89+ ComputationNode :: LeafF64 { dist, .. } => SampledValue :: Float ( dist. sample ( rng) ?) ,
90+ ComputationNode :: LeafBool { dist, .. } => SampledValue :: Bool ( dist. sample ( rng) ?) ,
7991
80- ComputationNode :: ArithmeticOp { op, lhs, rhs } => {
92+ ComputationNode :: ArithmeticOp { op, lhs, rhs, .. } => { // Extract op, lhs, rhs
8193 let lhs_val = self . evaluate_node ( lhs, context, rng) ?;
8294 let rhs_val = self . evaluate_node ( rhs, context, rng) ?;
8395 match ( lhs_val, rhs_val) {
8496 ( SampledValue :: Float ( l) , SampledValue :: Float ( r) ) => {
8597 SampledValue :: Float ( op. apply ( l, r) )
8698 }
87- _ => panic ! ( "Type error: Arithmetic op requires float inputs" ) ,
99+ _ => return Err ( UncertainError :: UnsupportedTypeError (
100+ "Arithmetic op requires float inputs" . into ( ) ,
101+ ) ) ,
88102 }
89103 }
90104
91- ComputationNode :: ComparisonOp {
92- op,
93- threshold,
94- operand,
95- } => {
105+ ComputationNode :: ComparisonOp { op, threshold, operand, .. } => { // Extract op, threshold, operand
96106 let operand_val = self . evaluate_node ( operand, context, rng) ?;
97107 match operand_val {
98108 SampledValue :: Float ( o) => SampledValue :: Bool ( op. apply ( o, * threshold) ) ,
99- _ => panic ! ( "Type error: Comparison op requires float input" ) ,
109+ _ => return Err ( UncertainError :: UnsupportedTypeError (
110+ "Comparison op requires float input" . into ( ) ,
111+ ) ) ,
100112 }
101113 }
102114
103- ComputationNode :: LogicalOp { op, operands } => {
115+ ComputationNode :: LogicalOp { op, operands, .. } => { // Extract op, operands
104116 let mut vals = Vec :: with_capacity ( operands. len ( ) ) ;
105117 for operand_node in operands {
106118 match self . evaluate_node ( operand_node, context, rng) ? {
@@ -142,38 +154,47 @@ impl SequentialSampler {
142154 SampledValue :: Bool ( result)
143155 }
144156
145- ComputationNode :: FunctionOp { func, operand } => {
157+ ComputationNode :: FunctionOp { func, operand, .. } => { // Extract func, operand
146158 let operand_val = self . evaluate_node ( operand, context, rng) ?;
147159 match operand_val {
148160 SampledValue :: Float ( o) => SampledValue :: Float ( func ( o) ) ,
149- _ => panic ! ( "Type error: Function op requires float input" ) ,
161+ _ => return Err ( UncertainError :: UnsupportedTypeError (
162+ "Function op requires float input" . into ( ) ,
163+ ) ) ,
150164 }
151165 }
152166
153- ComputationNode :: NegationOp { operand } => {
167+ ComputationNode :: NegationOp { operand, .. } => { // Extract operand
154168 let operand_val = self . evaluate_node ( operand, context, rng) ?;
155169 match operand_val {
156170 SampledValue :: Float ( o) => SampledValue :: Float ( -o) ,
157- _ => panic ! ( "Type error: Negation op requires float input" ) ,
171+ _ => return Err ( UncertainError :: UnsupportedTypeError (
172+ "Negation op requires float input" . into ( ) ,
173+ ) ) ,
158174 }
159175 }
160176
161- ComputationNode :: FunctionOpBool { func, operand } => {
177+ ComputationNode :: FunctionOpBool { func, operand, .. } => { // Extract func, operand
162178 let operand_val = self . evaluate_node ( operand, context, rng) ?;
163179 match operand_val {
164180 SampledValue :: Float ( o) => SampledValue :: Bool ( func ( o) ) ,
165- _ => panic ! ( "Type error: FunctionOpBool requires float input" ) ,
181+ _ => return Err ( UncertainError :: UnsupportedTypeError (
182+ "FunctionOpBool requires float input" . into ( ) ,
183+ ) ) ,
166184 }
167185 }
168186
169187 ComputationNode :: ConditionalOp {
170188 condition,
171189 if_true,
172190 if_false,
191+ .. // Extract condition, if_true, if_false
173192 } => {
174193 let condition_val = match self . evaluate_node ( condition, context, rng) ? {
175194 SampledValue :: Bool ( b) => b,
176- _ => panic ! ( "Type error: Conditional condition must be boolean" ) ,
195+ _ => return Err ( UncertainError :: UnsupportedTypeError (
196+ "Conditional condition must be boolean" . into ( ) ,
197+ ) ) ,
177198 } ;
178199
179200 if condition_val {
@@ -184,7 +205,7 @@ impl SequentialSampler {
184205 }
185206 } ;
186207
187- context. insert ( ptr , result) ;
208+ context. insert ( current_node_id , result) ; // Use node_id
188209 Ok ( result)
189210 }
190211}
0 commit comments