Skip to content

Commit 69a6bf5

Browse files
Update deep_causality_uncertain/src/types/sampler/sequential_sampler.rs
Co-authored-by: qodo-merge-pro[bot] <151058649+qodo-merge-pro[bot]@users.noreply.github.com>
1 parent d15d405 commit 69a6bf5

File tree

1 file changed

+25
-15
lines changed

1 file changed

+25
-15
lines changed

deep_causality_uncertain/src/types/sampler/sequential_sampler.rs

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -100,22 +100,32 @@ impl SequentialSampler {
100100
}
101101

102102
ComputationNode::LogicalOp { op, operands } => {
103-
let evaluated_operands: Vec<bool> = operands
104-
.iter()
105-
.map(|operand_node| {
106-
match self.evaluate_node(operand_node, context, rng).unwrap() {
107-
SampledValue::Bool(b) => b,
108-
_ => panic!("Type error: Logical op requires boolean inputs"),
109-
}
110-
})
111-
.collect();
112-
103+
let mut vals = Vec::with_capacity(operands.len());
104+
for operand_node in operands {
105+
match self.evaluate_node(operand_node, context, rng)? {
106+
SampledValue::Bool(b) => vals.push(b),
107+
_ => return Err(UncertainError::TypeError("Logical op requires boolean inputs".into())),
108+
}
109+
}
113110
let result = match op {
114-
LogicalOperator::And => evaluated_operands[0] && evaluated_operands[1],
115-
LogicalOperator::Or => evaluated_operands[0] || evaluated_operands[1],
116-
LogicalOperator::Not => !evaluated_operands[0],
117-
LogicalOperator::NOR => !(evaluated_operands[0] || evaluated_operands[1]),
118-
LogicalOperator::XOR => evaluated_operands[0] ^ evaluated_operands[1],
111+
LogicalOperator::Not => {
112+
if vals.len() != 1 {
113+
return Err(UncertainError::TypeError("NOT expects exactly 1 operand".into()));
114+
}
115+
!vals[0]
116+
}
117+
LogicalOperator::And | LogicalOperator::Or | LogicalOperator::NOR | LogicalOperator::XOR => {
118+
if vals.len() != 2 {
119+
return Err(UncertainError::TypeError("Binary logical op expects exactly 2 operands".into()));
120+
}
121+
match op {
122+
LogicalOperator::And => vals[0] && vals[1],
123+
LogicalOperator::Or => vals[0] || vals[1],
124+
LogicalOperator::NOR => !(vals[0] || vals[1]),
125+
LogicalOperator::XOR => vals[0] ^ vals[1],
126+
LogicalOperator::Not => unreachable!(),
127+
}
128+
}
119129
};
120130
SampledValue::Bool(result)
121131
}

0 commit comments

Comments
 (0)