diff --git a/openvm/src/powdr_extension/executor/mod.rs b/openvm/src/powdr_extension/executor/mod.rs index ec2ce674b5..12f4f1c1c2 100644 --- a/openvm/src/powdr_extension/executor/mod.rs +++ b/openvm/src/powdr_extension/executor/mod.rs @@ -69,7 +69,7 @@ impl OriginalArenas { apc_call_count_estimate: impl Fn() -> usize, original_airs: &OriginalAirs, apc: &Arc>>, - ) { + ) -> &mut InitializedOriginalArenas { match self { OriginalArenas::Uninitialized => { *self = OriginalArenas::Initialized(InitializedOriginalArenas::new( @@ -77,43 +77,12 @@ impl OriginalArenas { original_airs, apc, )); + match self { + OriginalArenas::Initialized(i) => i, + _ => unreachable!(), + } } - OriginalArenas::Initialized(_) => {} - } - } - - /// Returns a mutable reference to the arenas. - /// - Panics if the arenas are not initialized. - pub fn arenas_mut(&mut self) -> &mut HashMap { - match self { - OriginalArenas::Uninitialized => panic!("original arenas are uninitialized"), - OriginalArenas::Initialized(initialized) => &mut initialized.arenas, - } - } - - /// Returns a reference to the arenas. - /// - Panics if the arenas are not initialized. - pub fn arenas(&self) -> &HashMap { - match self { - OriginalArenas::Uninitialized => panic!("original arenas are uninitialized"), - OriginalArenas::Initialized(initialized) => &initialized.arenas, - } - } - - /// Returns a mutable reference to the number of calls. - /// - Panics if the arenas are not initialized. - pub fn number_of_calls_mut(&mut self) -> &mut usize { - match self { - OriginalArenas::Uninitialized => panic!("original arenas are uninitialized"), - OriginalArenas::Initialized(initialized) => &mut initialized.number_of_calls, - } - } - - /// Returns the number of calls. If not initialized, `Preflight::execute` is never called, and thus return 0. - pub fn number_of_calls(&self) -> usize { - match self { - OriginalArenas::Uninitialized => 0, - OriginalArenas::Initialized(initialized) => initialized.number_of_calls, + OriginalArenas::Initialized(i) => i, } } } @@ -425,9 +394,10 @@ impl PreflightExecutor> for PowdrExecutor // Recover an estimate of how many times the APC is called in this segment based on the current ctx height and width let apc_call_count = || ctx.trace_buffer.len() / ctx.width; - original_arenas.ensure_initialized(apc_call_count, &self.air_by_opcode_id, &self.apc); + let original_arenas = + original_arenas.ensure_initialized(apc_call_count, &self.air_by_opcode_id, &self.apc); - let arenas = original_arenas.arenas_mut(); + let arenas = &mut original_arenas.arenas; // execute the original instructions one by one for instruction in self.apc.instructions() { @@ -458,7 +428,7 @@ impl PreflightExecutor> for PowdrExecutor } // Update the real number of calls to the APC - *original_arenas.number_of_calls_mut() += 1; + original_arenas.number_of_calls += 1; Ok(()) } @@ -500,9 +470,10 @@ impl PreflightExecutor for PowdrExecutor { buf.len() / bytes_per_row }; - original_arenas.ensure_initialized(apc_call_count, &self.air_by_opcode_id, &self.apc); + let original_arenas = + original_arenas.ensure_initialized(apc_call_count, &self.air_by_opcode_id, &self.apc); - let arenas = original_arenas.arenas_mut(); + let arenas = &mut original_arenas.arenas; // execute the original instructions one by one for instruction in self.apc.instructions() { @@ -533,7 +504,7 @@ impl PreflightExecutor for PowdrExecutor { } // Update the real number of calls to the APC - *original_arenas.number_of_calls_mut() += 1; + original_arenas.number_of_calls += 1; Ok(()) } diff --git a/openvm/src/powdr_extension/trace_generator/cpu/mod.rs b/openvm/src/powdr_extension/trace_generator/cpu/mod.rs index 5583c70a8a..e88869ef9c 100644 --- a/openvm/src/powdr_extension/trace_generator/cpu/mod.rs +++ b/openvm/src/powdr_extension/trace_generator/cpu/mod.rs @@ -94,16 +94,19 @@ impl PowdrTraceGeneratorCpu { pub fn generate_witness( &self, - mut original_arenas: OriginalArenas>, + original_arenas: OriginalArenas>, ) -> DenseMatrix { use powdr_autoprecompiles::trace_handler::{generate_trace, TraceData}; - let num_apc_calls = original_arenas.number_of_calls(); - if num_apc_calls == 0 { - // If the APC isn't called, early return with an empty trace. - let width = self.apc.machine().main_columns().count(); - return RowMajorMatrix::new(vec![], width); - } + let width = self.apc.machine().main_columns().count(); + + let original_arenas = match original_arenas { + OriginalArenas::Initialized(arenas) => arenas, + OriginalArenas::Uninitialized => { + // if the arenas are uninitialized, the apc was not called, so we return an empty trace + return RowMajorMatrix::new(vec![], width); + } + }; let chip_inventory = { let airs: AirInventory = @@ -119,7 +122,7 @@ impl PowdrTraceGeneratorCpu { .inventory }; - let arenas = original_arenas.arenas_mut(); + let (mut arenas, num_apc_calls) = (original_arenas.arenas, original_arenas.number_of_calls); let dummy_trace_by_air_name: HashMap> = chip_inventory .chips() diff --git a/openvm/src/powdr_extension/trace_generator/cuda/mod.rs b/openvm/src/powdr_extension/trace_generator/cuda/mod.rs index 5fd7751d50..7f32040fbc 100644 --- a/openvm/src/powdr_extension/trace_generator/cuda/mod.rs +++ b/openvm/src/powdr_extension/trace_generator/cuda/mod.rs @@ -196,14 +196,15 @@ impl PowdrTraceGeneratorGpu { fn try_generate_witness( &self, - mut original_arenas: OriginalArenas, + original_arenas: OriginalArenas, ) -> Option> { - let num_apc_calls = original_arenas.number_of_calls(); - - if num_apc_calls == 0 { - // If the APC isn't called, early return with an empty trace. - return None; - } + let original_arenas = match original_arenas { + OriginalArenas::Initialized(arenas) => arenas, + OriginalArenas::Uninitialized => { + // if the arenas are uninitialized, the apc was not called, so we return early + return None; + } + }; let chip_inventory = { let airs: AirInventory = @@ -219,7 +220,7 @@ impl PowdrTraceGeneratorGpu { .inventory }; - let arenas = original_arenas.arenas_mut(); + let (mut arenas, num_apc_calls) = (original_arenas.arenas, original_arenas.number_of_calls); let dummy_trace_by_air_name: HashMap> = chip_inventory .chips()