Skip to content

Commit 21ebf2b

Browse files
authored
[feat] support custom segmentation (#1254)
* support custom segmentation * address comments
1 parent 40dc754 commit 21ebf2b

File tree

2 files changed

+108
-18
lines changed

2 files changed

+108
-18
lines changed

crates/vm/src/arch/config.rs

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use std::sync::Arc;
2+
13
use derive_new::new;
24
use openvm_circuit::system::memory::MemoryTraceHeights;
35
use openvm_instructions::program::DEFAULT_MAX_NUM_PUBLIC_VALUES;
@@ -13,12 +15,12 @@ pub use super::testing::{
1315
POSEIDON2_DIRECT_BUS, RANGE_TUPLE_CHECKER_BUS, READ_INSTRUCTION_BUS,
1416
};
1517
use super::{
18+
segment::{DefaultSegmentationStrategy, SegmentationStrategy},
1619
AnyEnum, InstructionExecutor, SystemComplex, SystemExecutor, SystemPeriphery, VmChipComplex,
1720
VmInventoryError, PUBLIC_VALUES_AIR_ID,
1821
};
1922
use crate::system::memory::BOUNDARY_AIR_OFFSET;
2023

21-
const DEFAULT_MAX_SEGMENT_LEN: usize = (1 << 22) - 100;
2224
// sbox is decomposed to have this max degree for Poseidon2. We set to 3 so quotient_degree = 2
2325
// allows log_blowup = 1
2426
const DEFAULT_POSEIDON2_MAX_CONSTRAINT_DEGREE: usize = 3;
@@ -86,11 +88,18 @@ pub struct SystemConfig {
8688
/// cannot read public values directly, but they can decommit the public values from the memory
8789
/// merkle root.
8890
pub num_public_values: usize,
89-
/// When continuations are enabled, a heuristic used to determine when to segment execution.
90-
pub max_segment_len: usize,
9191
/// Whether to collect detailed profiling metrics.
9292
/// **Warning**: this slows down the runtime.
9393
pub profiling: bool,
94+
/// Segmentation strategy
95+
/// This field is skipped in serde as it's only used in execution and
96+
/// not needed after any serialize/deserialize.
97+
#[serde(skip, default = "get_default_segmentation_strategy")]
98+
pub segmentation_strategy: Arc<dyn SegmentationStrategy>,
99+
}
100+
101+
pub fn get_default_segmentation_strategy() -> Arc<dyn SegmentationStrategy> {
102+
Arc::new(DefaultSegmentationStrategy::default())
94103
}
95104

96105
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
@@ -105,12 +114,13 @@ impl SystemConfig {
105114
memory_config: MemoryConfig,
106115
num_public_values: usize,
107116
) -> Self {
117+
let segmentation_strategy = get_default_segmentation_strategy();
108118
Self {
109119
max_constraint_degree,
110120
continuation_enabled: false,
111121
memory_config,
112122
num_public_values,
113-
max_segment_len: DEFAULT_MAX_SEGMENT_LEN,
123+
segmentation_strategy,
114124
profiling: false,
115125
}
116126
}
@@ -136,10 +146,16 @@ impl SystemConfig {
136146
}
137147

138148
pub fn with_max_segment_len(mut self, max_segment_len: usize) -> Self {
139-
self.max_segment_len = max_segment_len;
149+
self.segmentation_strategy = Arc::new(
150+
DefaultSegmentationStrategy::new_with_max_segment_len(max_segment_len),
151+
);
140152
self
141153
}
142154

155+
pub fn set_segmentation_strategy<S: SegmentationStrategy + 'static>(&mut self, strategy: S) {
156+
self.segmentation_strategy = Arc::new(strategy);
157+
}
158+
143159
pub fn with_profiling(mut self) -> Self {
144160
self.profiling = true;
145161
self

crates/vm/src/arch/segment.rs

Lines changed: 87 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,84 @@ use crate::{
2727
/// Check segment every 100 instructions.
2828
const SEGMENT_CHECK_INTERVAL: usize = 100;
2929

30+
const DEFAULT_MAX_SEGMENT_LEN: usize = (1 << 22) - 100;
31+
// a heuristic number for the maximum number of cells per chip in a segment
32+
// a few reasons for this number:
33+
// 1. `VmAirWrapper<Rv32BaseAluAdapterAir, BaseAluCoreAir<4, 8>` is
34+
// the chip with the most cells in a segment from the reth-benchmark.
35+
// 2. `VmAirWrapper<Rv32BaseAluAdapterAir, BaseAluCoreAir<4, 8>`:
36+
// its trace width is 36 and its after challenge trace width is 80.
37+
const DEFAULT_MAX_CELLS_PER_CHIP_IN_SEGMENT: usize = DEFAULT_MAX_SEGMENT_LEN * 120;
38+
39+
pub trait SegmentationStrategy:
40+
std::fmt::Debug + Send + Sync + std::panic::UnwindSafe + std::panic::RefUnwindSafe
41+
{
42+
fn should_segment(
43+
&self,
44+
air_names: &[String],
45+
trace_heights: &[usize],
46+
trace_cells: &[usize],
47+
) -> bool;
48+
}
49+
50+
/// Default segmentation strategy: segment if any chip's height or cells exceed the limits.
51+
#[derive(Debug)]
52+
pub struct DefaultSegmentationStrategy {
53+
max_segment_len: usize,
54+
max_cells_per_chip_in_segment: usize,
55+
}
56+
57+
impl Default for DefaultSegmentationStrategy {
58+
fn default() -> Self {
59+
Self {
60+
max_segment_len: DEFAULT_MAX_SEGMENT_LEN,
61+
max_cells_per_chip_in_segment: DEFAULT_MAX_CELLS_PER_CHIP_IN_SEGMENT,
62+
}
63+
}
64+
}
65+
66+
impl DefaultSegmentationStrategy {
67+
pub fn new_with_max_segment_len(max_segment_len: usize) -> Self {
68+
Self {
69+
max_segment_len,
70+
max_cells_per_chip_in_segment: max_segment_len * 120,
71+
}
72+
}
73+
}
74+
75+
impl SegmentationStrategy for DefaultSegmentationStrategy {
76+
fn should_segment(
77+
&self,
78+
air_names: &[String],
79+
trace_heights: &[usize],
80+
trace_cells: &[usize],
81+
) -> bool {
82+
for (i, &height) in trace_heights.iter().enumerate() {
83+
if height > self.max_segment_len {
84+
tracing::info!(
85+
"Should segment because chip {} (name: {}) has height {}",
86+
i,
87+
air_names[i],
88+
height
89+
);
90+
return true;
91+
}
92+
}
93+
for (i, &num_cells) in trace_cells.iter().enumerate() {
94+
if num_cells > self.max_cells_per_chip_in_segment {
95+
tracing::info!(
96+
"Should segment because chip {} (name: {}) has {} cells",
97+
i,
98+
air_names[i],
99+
num_cells
100+
);
101+
return true;
102+
}
103+
}
104+
false
105+
}
106+
}
107+
30108
pub struct ExecutionSegment<F, VC>
31109
where
32110
F: PrimeField32,
@@ -276,19 +354,15 @@ impl<F: PrimeField32, VC: VmConfig<F>> ExecutionSegment<F, VC> {
276354
return false;
277355
}
278356
self.since_last_segment_check = 0;
279-
let heights = self.chip_complex.dynamic_trace_heights();
280-
for (i, height) in heights.enumerate() {
281-
if height > self.system_config().max_segment_len {
282-
tracing::info!(
283-
"Should segment because chip {} has height {}",
284-
self.air_names[i],
285-
height
286-
);
287-
return true;
288-
}
289-
}
290-
291-
false
357+
let segmentation_strategy = self.system_config().segmentation_strategy.clone();
358+
segmentation_strategy.should_segment(
359+
&self.air_names,
360+
&self
361+
.chip_complex
362+
.dynamic_trace_heights()
363+
.collect::<Vec<_>>(),
364+
&self.chip_complex.current_trace_cells(),
365+
)
292366
}
293367

294368
pub fn current_trace_cells(&self) -> Vec<usize> {

0 commit comments

Comments
 (0)