@@ -6,6 +6,13 @@ use std::hash::Hash;
66use itertools:: Itertools ;
77use serde:: { Deserialize , Serialize } ;
88
9+ use crate :: {
10+ adapter:: Adapter ,
11+ blocks:: BasicBlock ,
12+ expression:: { AlgebraicExpression , AlgebraicReference } ,
13+ SymbolicConstraint ,
14+ } ;
15+
916/// "Constraints" that were inferred from execution statistics. They hold empirically
1017/// (most of the time), but are not guaranteed to hold in all cases.
1118#[ derive( Serialize , Deserialize , Clone , Default , Debug ) ]
@@ -16,6 +23,9 @@ pub struct EmpiricalConstraints {
1623 /// For each basic block (identified by its starting PC), the equivalence classes of columns.
1724 /// Each equivalence class is a list of (instruction index in block, column index).
1825 pub equivalence_classes_by_block : BTreeMap < u64 , BTreeSet < BTreeSet < ( usize , usize ) > > > ,
26+ /// Count of how many times each program counter was executed in the sampled executions.
27+ /// This can be used to set a threshold for applying constraints only to frequently executed PCs.
28+ pub pc_counts : BTreeMap < u32 , u64 > ,
1929}
2030
2131/// Debug information mapping AIR ids to program counters and column names.
@@ -61,6 +71,32 @@ impl EmpiricalConstraints {
6171 } )
6272 . or_insert ( classes) ;
6373 }
74+
75+ // Combine pc counts
76+ for ( pc, count) in other. pc_counts {
77+ * self . pc_counts . entry ( pc) . or_insert ( 0 ) += count;
78+ }
79+ }
80+
81+ pub fn apply_pc_threshold ( & self , threshold : u64 ) -> Self {
82+ EmpiricalConstraints {
83+ column_ranges_by_pc : self
84+ . column_ranges_by_pc
85+ . iter ( )
86+ . filter ( |( pc, _) | self . pc_counts . get ( pc) . cloned ( ) . unwrap_or ( 0 ) >= threshold)
87+ . map ( |( pc, ranges) | ( * pc, ranges. clone ( ) ) )
88+ . collect ( ) ,
89+ equivalence_classes_by_block : self
90+ . equivalence_classes_by_block
91+ . iter ( )
92+ . filter ( |( & block_pc, _) | {
93+ // For equivalence classes, we check the pc_counts of the first instruction in the block
94+ self . pc_counts . get ( & ( block_pc as u32 ) ) . cloned ( ) . unwrap_or ( 0 ) >= threshold
95+ } )
96+ . map ( |( block_pc, classes) | ( * block_pc, classes. clone ( ) ) )
97+ . collect ( ) ,
98+ pc_counts : self . pc_counts . clone ( ) ,
99+ }
64100 }
65101}
66102
@@ -88,6 +124,104 @@ fn merge_maps<K: Ord, V: Eq + Debug>(map1: &mut BTreeMap<K, V>, map2: BTreeMap<K
88124 }
89125}
90126
127+ /// For any program line that was not executed at least this many times in the traces,
128+ /// discard any empirical constraints associated with it.
129+ const EXECUTION_COUNT_THRESHOLD : u64 = 100 ;
130+
131+ pub struct ConstraintGenerator < ' a , A : Adapter > {
132+ empirical_constraints : EmpiricalConstraints ,
133+ algebraic_references : BTreeMap < ( usize , usize ) , AlgebraicReference > ,
134+ block : & ' a BasicBlock < A :: Instruction > ,
135+ }
136+
137+ impl < ' a , A : Adapter > ConstraintGenerator < ' a , A > {
138+ pub fn new (
139+ empirical_constraints : & EmpiricalConstraints ,
140+ subs : & [ Vec < u64 > ] ,
141+ columns : impl Iterator < Item = AlgebraicReference > ,
142+ block : & ' a BasicBlock < A :: Instruction > ,
143+ ) -> Self {
144+ let reverse_subs = subs
145+ . iter ( )
146+ . enumerate ( )
147+ . flat_map ( |( instr_index, subs) | {
148+ subs. iter ( )
149+ . enumerate ( )
150+ . map ( move |( col_index, & poly_id) | ( poly_id, ( instr_index, col_index) ) )
151+ } )
152+ . collect :: < BTreeMap < _ , _ > > ( ) ;
153+ let algebraic_references = columns
154+ . map ( |r| ( * reverse_subs. get ( & r. id ) . unwrap ( ) , r. clone ( ) ) )
155+ . collect :: < BTreeMap < _ , _ > > ( ) ;
156+
157+ Self {
158+ empirical_constraints : empirical_constraints
159+ . apply_pc_threshold ( EXECUTION_COUNT_THRESHOLD ) ,
160+ algebraic_references,
161+ block,
162+ }
163+ }
164+
165+ fn get_algebraic_reference ( & self , instr_index : usize , col_index : usize ) -> AlgebraicReference {
166+ self . algebraic_references
167+ . get ( & ( instr_index, col_index) )
168+ . cloned ( )
169+ . unwrap_or_else ( || {
170+ panic ! (
171+ "Missing reference for (i: {}, col_index: {}, block_id: {})" ,
172+ instr_index, col_index, self . block. start_pc
173+ )
174+ } )
175+ }
176+
177+ pub fn range_constraints ( & self ) -> Vec < SymbolicConstraint < <A as Adapter >:: PowdrField > > {
178+ let mut constraints = Vec :: new ( ) ;
179+
180+ for i in 0 ..self . block . statements . len ( ) {
181+ let pc = ( self . block . start_pc + ( i * 4 ) as u64 ) as u32 ;
182+ let Some ( range_constraints) = self . empirical_constraints . column_ranges_by_pc . get ( & pc)
183+ else {
184+ continue ;
185+ } ;
186+ for ( col_index, range) in range_constraints. iter ( ) . enumerate ( ) {
187+ if range. 0 == range. 1 {
188+ let value = A :: PowdrField :: from ( range. 0 as u64 ) ;
189+ let reference = self . get_algebraic_reference ( i, col_index) ;
190+ let constraint = AlgebraicExpression :: Reference ( reference)
191+ - AlgebraicExpression :: Number ( value) ;
192+
193+ constraints. push ( SymbolicConstraint { expr : constraint } ) ;
194+ }
195+ }
196+ }
197+
198+ constraints
199+ }
200+
201+ pub fn equivalence_constraints ( & self ) -> Vec < SymbolicConstraint < <A as Adapter >:: PowdrField > > {
202+ let mut constraints = Vec :: new ( ) ;
203+
204+ if let Some ( equivalence_classes) = self
205+ . empirical_constraints
206+ . equivalence_classes_by_block
207+ . get ( & self . block . start_pc )
208+ {
209+ for equivalence_class in equivalence_classes {
210+ let first = equivalence_class. first ( ) . unwrap ( ) ;
211+ let first_ref = self . get_algebraic_reference ( first. 0 , first. 1 ) ;
212+ for other in equivalence_class. iter ( ) . skip ( 1 ) {
213+ let other_ref = self . get_algebraic_reference ( other. 0 , other. 1 ) ;
214+ let constraint = AlgebraicExpression :: Reference ( first_ref. clone ( ) )
215+ - AlgebraicExpression :: Reference ( other_ref. clone ( ) ) ;
216+ constraints. push ( SymbolicConstraint { expr : constraint } ) ;
217+ }
218+ }
219+ }
220+
221+ constraints
222+ }
223+ }
224+
91225/// Intersects multiple partitions of the same universe into a single partition.
92226/// In other words, two elements are in the same equivalence class in the resulting partition
93227/// if and only if they are in the same equivalence class in all input partitions.
0 commit comments