-
Notifications
You must be signed in to change notification settings - Fork 279
Expand file tree
/
Copy pathcircuit.rs
More file actions
161 lines (146 loc) · 5.42 KB
/
circuit.rs
File metadata and controls
161 lines (146 loc) · 5.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
//! ACE circuit emission for the DAG IR.
//!
//! The emitted circuit is a flat list of inputs, constants, and arithmetic
//! ops that matches the ACE chiplet execution model.
use std::collections::HashMap;
use miden_crypto::field::Field;
use crate::{
AceError, InputLayout,
dag::{AceDag, NodeId, NodeKind},
};
/// Arithmetic operations supported by the ACE circuit.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum AceOp {
Add,
Sub,
Mul,
}
/// Nodes in the emitted ACE circuit.
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub(crate) enum AceNode {
Input(usize),
Constant(usize),
Operation(usize),
}
/// Operation node in the ACE circuit.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) struct AceOpNode {
pub op: AceOp,
pub lhs: AceNode,
pub rhs: AceNode,
}
/// Emitted ACE circuit with layout and operation list.
///
/// This is the off-VM representation used by tests and tools.
#[derive(Debug, Clone)]
pub struct AceCircuit<EF> {
pub(crate) layout: InputLayout,
pub(crate) constants: Vec<EF>,
pub(crate) operations: Vec<AceOpNode>,
pub(crate) root: AceNode,
}
impl<EF: Field> AceCircuit<EF> {
/// Return the input layout for this circuit.
pub fn layout(&self) -> &InputLayout {
&self.layout
}
/// Evaluate the circuit against the provided input vector.
pub fn eval(&self, inputs: &[EF]) -> Result<EF, AceError> {
if inputs.len() != self.layout.total_inputs {
return Err(AceError::InvalidInputLength {
expected: self.layout.total_inputs,
got: inputs.len(),
});
}
let mut op_values = vec![EF::ZERO; self.operations.len()];
for (idx, op) in self.operations.iter().enumerate() {
let lhs = self.node_value(op.lhs, inputs, &op_values);
let rhs = self.node_value(op.rhs, inputs, &op_values);
op_values[idx] = match op.op {
AceOp::Add => lhs + rhs,
AceOp::Sub => lhs - rhs,
AceOp::Mul => lhs * rhs,
};
}
Ok(self.node_value(self.root, inputs, &op_values))
}
/// Total number of nodes (inputs + constants + ops).
pub fn num_nodes(&self) -> usize {
self.layout.total_inputs + self.constants.len() + self.operations.len()
}
fn node_value(&self, node: AceNode, inputs: &[EF], op_values: &[EF]) -> EF {
match node {
AceNode::Input(index) => inputs[index],
AceNode::Constant(index) => self.constants[index],
AceNode::Operation(index) => op_values[index],
}
}
}
/// Emit an ACE circuit from the DAG and input layout.
pub fn emit_circuit<EF>(dag: &AceDag<EF>, layout: InputLayout) -> Result<AceCircuit<EF>, AceError>
where
EF: Field,
{
let mut constants = Vec::new();
let mut constant_map = HashMap::<EF, usize>::new();
let mut operations = Vec::new();
let mut node_map: Vec<Option<AceNode>> = vec![None; dag.nodes().len()];
for (idx, node) in dag.nodes().iter().enumerate() {
let ace_node = match node {
NodeKind::Input(key) => {
let input_idx = layout.index(*key).ok_or_else(|| AceError::InvalidInputLayout {
message: format!("missing input key in layout: {key:?}"),
})?;
AceNode::Input(input_idx)
},
NodeKind::Constant(value) => {
let const_idx = *constant_map.entry(*value).or_insert_with(|| {
constants.push(*value);
constants.len() - 1
});
AceNode::Constant(const_idx)
},
NodeKind::Add(a, b) => {
let lhs = lookup_node(&node_map, *a);
let rhs = lookup_node(&node_map, *b);
let op_idx = operations.len();
operations.push(AceOpNode { op: AceOp::Add, lhs, rhs });
AceNode::Operation(op_idx)
},
NodeKind::Sub(a, b) => {
let lhs = lookup_node(&node_map, *a);
let rhs = lookup_node(&node_map, *b);
let op_idx = operations.len();
operations.push(AceOpNode { op: AceOp::Sub, lhs, rhs });
AceNode::Operation(op_idx)
},
NodeKind::Mul(a, b) => {
let lhs = lookup_node(&node_map, *a);
let rhs = lookup_node(&node_map, *b);
let op_idx = operations.len();
operations.push(AceOpNode { op: AceOp::Mul, lhs, rhs });
AceNode::Operation(op_idx)
},
NodeKind::Neg(a) => {
let rhs = lookup_node(&node_map, *a);
let zero = *constant_map.entry(EF::ZERO).or_insert_with(|| {
constants.push(EF::ZERO);
constants.len() - 1
});
let op_idx = operations.len();
operations.push(AceOpNode {
op: AceOp::Sub,
lhs: AceNode::Constant(zero),
rhs,
});
AceNode::Operation(op_idx)
},
};
node_map[idx] = Some(ace_node);
}
let root = lookup_node(&node_map, dag.root());
Ok(AceCircuit { layout, constants, operations, root })
}
fn lookup_node(map: &[Option<AceNode>], id: NodeId) -> AceNode {
map[id.index()].expect("ACE DAG nodes must be topologically ordered")
}