Skip to content

Commit a2f5876

Browse files
feat: update segmentation weights for hardware memory (#2341)
segment_ctx.rs: - DEFAULT_MAX_CELLS → DEFAULT_MAX_MEMORY = 15gb - max_cells → max_memory in SegmentationLimits - set_max_cells → set_max_memory ctx.rs: - with_max_cells → with_max_memory metered_cost.rs: - Updated import to use DEFAULT_MAX_MEMORY cli/src/commands/prove.rs: - Updated import to DEFAULT_MAX_MEMORY - segment_max_cells → segment_max_memory - with_max_cells → with_max_memory benchmarks/prove/src/bin/async_regex.rs: - segment_max_cells → segment_max_memory - set_max_cells → set_max_memory benchmarks/prove/src/util.rs: - segment_max_cells → segment_max_memory - set_max_cells → set_max_memory --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 2530545 commit a2f5876

File tree

9 files changed

+123
-46
lines changed

9 files changed

+123
-46
lines changed

.github/workflows/base-tests.cuda.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,4 +74,4 @@ jobs:
7474
# - name: Run async concurrency test
7575
# working-directory: benchmarks/prove
7676
# run: |
77-
# CUDA_OPT_LEVEL=3 MAX_CONCURRENCY=5 cargo run --bin async_regex --features cuda,async -- --max-segment-length $((1<<20)) --segment-max-cells 150000000
77+
# CUDA_OPT_LEVEL=3 MAX_CONCURRENCY=5 cargo run --bin async_regex --features cuda,async -- --max-segment-length $((1<<20)) --segment-max-memory 16106127360

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@ libloading = "0.8"
233233
tracing-subscriber = { version = "0.3.20", features = ["std", "env-filter"] }
234234
tokio = "1" # >=1.0.0 to allow downstream flexibility
235235
abi_stable = "0.11.3"
236+
bytesize = "2.0"
236237

237238
# default-features = false for no_std for use in guest programs
238239
itertools = { version = "0.14.0", default-features = false }

benchmarks/prove/src/bin/async_regex.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,13 @@ async fn main() -> eyre::Result<()> {
2323
.limits
2424
.set_max_trace_height(max_height);
2525
}
26-
if let Some(max_cells) = args.segment_max_cells {
26+
if let Some(max_memory) = args.segment_max_memory {
2727
config
2828
.app_vm_config
2929
.as_mut()
3030
.segmentation_config
3131
.limits
32-
.set_max_cells(max_cells);
32+
.set_max_memory(max_memory);
3333
}
3434

3535
let sdk = Sdk::new(config)?;

benchmarks/prove/src/util.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,9 @@ pub struct BenchmarkCli {
6666
#[arg(long, alias = "max_segment_length")]
6767
pub max_segment_length: Option<u32>,
6868

69-
/// Total cells used in all chips in segment for continuations
69+
/// Total memory in bytes used in all chips in segment for continuations
7070
#[arg(long)]
71-
pub segment_max_cells: Option<usize>,
71+
pub segment_max_memory: Option<usize>,
7272

7373
/// Controls the arity (num_children) of the aggregation tree
7474
#[command(flatten)]
@@ -96,12 +96,12 @@ impl BenchmarkCli {
9696
.limits
9797
.set_max_trace_height(max_height);
9898
}
99-
if let Some(max_cells) = self.segment_max_cells {
99+
if let Some(max_memory) = self.segment_max_memory {
100100
app_vm_config
101101
.as_mut()
102102
.segmentation_config
103103
.limits
104-
.set_max_cells(max_cells);
104+
.set_max_memory(max_memory);
105105
}
106106
AppConfig {
107107
app_fri_params: FriParameters::standard_with_100_bits_conjectured_security(

crates/cli/src/commands/prove.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use clap::Parser;
44
use eyre::Result;
55
use openvm_circuit::arch::{
66
execution_mode::metered::segment_ctx::{
7-
SegmentationConfig, SegmentationLimits, DEFAULT_MAX_CELLS, DEFAULT_MAX_TRACE_HEIGHT_BITS,
7+
SegmentationConfig, SegmentationLimits, DEFAULT_MAX_MEMORY, DEFAULT_MAX_TRACE_HEIGHT_BITS,
88
},
99
instructions::exe::VmExe,
1010
};
@@ -133,14 +133,14 @@ pub struct SegmentationArgs {
133133
help_heading = "OpenVM Options"
134134
)]
135135
pub segment_max_height_bits: u8,
136-
/// Total cells used across all chips for triggering segmentation for continuations in the app
137-
/// proof. These thresholds are not exceeded except when they are too small.
136+
/// Total memory in bytes used across all chips for triggering segmentation for continuations
137+
/// in the app proof. These thresholds are not exceeded except when they are too small.
138138
#[arg(
139139
long,
140-
default_value_t = DEFAULT_MAX_CELLS,
140+
default_value_t = DEFAULT_MAX_MEMORY,
141141
help_heading = "OpenVM Options"
142142
)]
143-
pub segment_max_cells: usize,
143+
pub segment_max_memory: usize,
144144
}
145145

146146
impl ProveCmd {
@@ -318,7 +318,7 @@ impl From<SegmentationArgs> for SegmentationConfig {
318318
1u32.checked_shl(args.segment_max_height_bits as u32)
319319
.expect("segment_max_height_bits too large"),
320320
)
321-
.with_max_cells(args.segment_max_cells),
321+
.with_max_memory(args.segment_max_memory),
322322
..Default::default()
323323
}
324324
}

crates/vm/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ cfg-if.workspace = true
4646
libloading = { workspace = true, optional = true }
4747
tempfile = { workspace = true, optional = true }
4848
abi_stable.workspace = true
49+
bytesize.workspace = true
4950

5051
[build-dependencies]
5152
openvm-cuda-builder = { workspace = true, optional = true }

crates/vm/src/arch/execution_mode/metered/ctx.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,8 @@ impl<const PAGE_BITS: usize> MeteredCtx<PAGE_BITS> {
103103
self
104104
}
105105

106-
pub fn with_max_cells(mut self, max_cells: usize) -> Self {
107-
self.segmentation_ctx.set_max_cells(max_cells);
106+
pub fn with_max_memory(mut self, max_memory: usize) -> Self {
107+
self.segmentation_ctx.set_max_memory(max_memory);
108108
self
109109
}
110110

@@ -127,11 +127,21 @@ impl<const PAGE_BITS: usize> MeteredCtx<PAGE_BITS> {
127127
self
128128
}
129129

130+
pub fn with_main_cell_weight(mut self, weight: usize) -> Self {
131+
self.segmentation_ctx.set_main_cell_weight(weight);
132+
self
133+
}
134+
130135
pub fn with_interaction_cell_weight(mut self, weight: usize) -> Self {
131136
self.segmentation_ctx.set_interaction_cell_weight(weight);
132137
self
133138
}
134139

140+
pub fn with_base_field_size(mut self, base_field_size: usize) -> Self {
141+
self.segmentation_ctx.set_base_field_size(base_field_size);
142+
self
143+
}
144+
135145
pub fn segments(&self) -> &[Segment] {
136146
&self.segmentation_ctx.segments
137147
}

crates/vm/src/arch/execution_mode/metered/segment_ctx.rs

Lines changed: 94 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
use std::mem::size_of;
2+
3+
use bytesize::ByteSize;
14
use getset::{Setters, WithSetters};
25
use openvm_stark_backend::p3_field::PrimeField32;
36
use p3_baby_bear::BabyBear;
@@ -7,8 +10,10 @@ pub const DEFAULT_SEGMENT_CHECK_INSNS: u64 = 1000;
710

811
pub const DEFAULT_MAX_TRACE_HEIGHT_BITS: u8 = 22;
912
pub const DEFAULT_MAX_TRACE_HEIGHT: u32 = 1 << DEFAULT_MAX_TRACE_HEIGHT_BITS;
10-
pub const DEFAULT_MAX_CELLS: usize = 1_200_000_000; // 1.2B
13+
pub const DEFAULT_MAX_MEMORY: usize = 15 << 30; // 15GiB
1114
const DEFAULT_MAX_INTERACTIONS: usize = BabyBear::ORDER_U32 as usize;
15+
const DEFAULT_MAIN_CELL_WEIGHT: usize = 3; // 1 + 2^{log_blowup=1}
16+
const DEFAULT_INTERACTION_CELL_WEIGHT: usize = 8; // 2 * D_EF
1217

1318
#[derive(derive_new::new, Clone, Debug, Serialize, Deserialize)]
1419
pub struct Segment {
@@ -17,19 +22,36 @@ pub struct Segment {
1722
pub trace_heights: Vec<u32>,
1823
}
1924

20-
#[derive(Clone, Debug, Default, WithSetters)]
25+
#[derive(Clone, Debug, WithSetters)]
2126
pub struct SegmentationConfig {
2227
pub limits: SegmentationLimits,
23-
/// Cells per row contributed by each interaction in cell count.
28+
/// Weight multiplier for main trace cells in memory calculation.
29+
#[getset(set_with = "pub")]
30+
pub main_cell_weight: usize,
31+
/// Weight multiplier for interaction cells in memory calculation.
2432
#[getset(set_with = "pub")]
2533
pub interaction_cell_weight: usize,
34+
/// Size of the base field in bytes. Used to convert cell count to memory bytes.
35+
#[getset(set_with = "pub")]
36+
pub base_field_size: usize,
37+
}
38+
39+
impl Default for SegmentationConfig {
40+
fn default() -> Self {
41+
Self {
42+
limits: SegmentationLimits::default(),
43+
main_cell_weight: DEFAULT_MAIN_CELL_WEIGHT,
44+
interaction_cell_weight: DEFAULT_INTERACTION_CELL_WEIGHT,
45+
base_field_size: size_of::<u32>(),
46+
}
47+
}
2648
}
2749

2850
#[derive(Clone, Debug, WithSetters, Setters)]
2951
pub struct SegmentationLimits {
3052
pub max_trace_height: u32,
3153
#[getset(set = "pub", set_with = "pub")]
32-
pub max_cells: usize,
54+
pub max_memory: usize,
3355
#[getset(set_with = "pub")]
3456
pub max_interactions: usize,
3557
}
@@ -38,21 +60,21 @@ impl Default for SegmentationLimits {
3860
fn default() -> Self {
3961
Self {
4062
max_trace_height: DEFAULT_MAX_TRACE_HEIGHT,
41-
max_cells: DEFAULT_MAX_CELLS,
63+
max_memory: DEFAULT_MAX_MEMORY,
4264
max_interactions: DEFAULT_MAX_INTERACTIONS,
4365
}
4466
}
4567
}
4668

4769
impl SegmentationLimits {
48-
pub fn new(max_trace_height: u32, max_cells: usize, max_interactions: usize) -> Self {
70+
pub fn new(max_trace_height: u32, max_memory: usize, max_interactions: usize) -> Self {
4971
debug_assert!(
5072
max_trace_height.is_power_of_two(),
5173
"max_trace_height should be a power of two"
5274
);
5375
Self {
5476
max_trace_height,
55-
max_cells,
77+
max_memory,
5678
max_interactions,
5779
}
5880
}
@@ -120,18 +142,26 @@ impl SegmentationCtx {
120142
self.config.limits.set_max_trace_height(max_trace_height);
121143
}
122144

123-
pub fn set_max_cells(&mut self, max_cells: usize) {
124-
self.config.limits.max_cells = max_cells;
145+
pub fn set_max_memory(&mut self, max_memory: usize) {
146+
self.config.limits.max_memory = max_memory;
125147
}
126148

127149
pub fn set_max_interactions(&mut self, max_interactions: usize) {
128150
self.config.limits.max_interactions = max_interactions;
129151
}
130152

153+
pub fn set_main_cell_weight(&mut self, weight: usize) {
154+
self.config.main_cell_weight = weight;
155+
}
156+
131157
pub fn set_interaction_cell_weight(&mut self, weight: usize) {
132158
self.config.interaction_cell_weight = weight;
133159
}
134160

161+
pub fn set_base_field_size(&mut self, base_field_size: usize) {
162+
self.config.base_field_size = base_field_size;
163+
}
164+
135165
/// Calculate the maximum trace height and corresponding air name
136166
#[inline(always)]
137167
fn calculate_max_trace_height_with_name(&self, trace_heights: &[u32]) -> (u32, &str) {
@@ -144,22 +174,44 @@ impl SegmentationCtx {
144174
.unwrap_or((0, "unknown"))
145175
}
146176

147-
/// Calculate the total cells used based on trace heights and widths,
148-
/// including weighted contribution from interactions if `interaction_cell_weight > 0`.
177+
/// Calculate total memory in bytes based on trace heights and widths.
178+
/// Formula: base_field_size * (main_cell_weight * main_cells + interaction_cell_weight *
179+
/// interaction_cells)
149180
#[inline(always)]
150-
fn calculate_total_cells(&self, trace_heights: &[u32]) -> usize {
181+
fn calculate_total_memory(
182+
&self,
183+
trace_heights: &[u32],
184+
) -> (
185+
usize, /* memory */
186+
usize, /* main */
187+
usize, /* interaction */
188+
) {
151189
debug_assert_eq!(trace_heights.len(), self.widths.len());
152190

191+
let main_weight = self.config.main_cell_weight;
153192
let interaction_weight = self.config.interaction_cell_weight;
154-
trace_heights
193+
let base_field_size = self.config.base_field_size;
194+
195+
let mut main_cnt = 0;
196+
let mut interaction_cnt = 0;
197+
for ((&height, &width), &interactions) in trace_heights
155198
.iter()
156199
.zip(self.widths.iter())
157200
.zip(self.interactions.iter())
158-
.map(|((&height, &width), &interactions)| {
159-
let padded_height = height.next_power_of_two() as usize;
160-
padded_height * (width + interactions * interaction_weight)
161-
})
162-
.sum()
201+
{
202+
let padded_height = height.next_power_of_two() as usize;
203+
main_cnt += padded_height * width;
204+
interaction_cnt += padded_height * interactions;
205+
}
206+
207+
let main_memory = main_cnt * main_weight * base_field_size;
208+
let interaction_memory =
209+
(interaction_cnt + 1).next_power_of_two() * interaction_weight * base_field_size;
210+
(
211+
main_memory + interaction_memory,
212+
main_memory,
213+
interaction_memory,
214+
)
163215
}
164216

165217
/// Calculate the total interactions based on trace heights
@@ -199,8 +251,11 @@ impl SegmentationCtx {
199251
return false;
200252
}
201253

254+
let main_weight = self.config.main_cell_weight;
202255
let interaction_weight = self.config.interaction_cell_weight;
203-
let mut total_cells = 0;
256+
let base_field_size = self.config.base_field_size;
257+
let mut main_cnt = 0usize;
258+
let mut interaction_cnt = 0usize;
204259
for (i, (((padded_height, width), interactions), is_constant)) in trace_heights
205260
.iter()
206261
.map(|&height| height.next_power_of_two())
@@ -214,7 +269,7 @@ impl SegmentationCtx {
214269
if !is_constant && padded_height > self.config.limits.max_trace_height {
215270
let air_name = unsafe { self.air_names.get_unchecked(i) };
216271
tracing::info!(
217-
"instret {:10} | height ({:8}) > max ({:8}) | chip {:3} ({}) ",
272+
"overshoot: instret {:10} | height ({:8}) > max ({:8}) | chip {:3} ({}) ",
218273
instret,
219274
padded_height,
220275
self.config.limits.max_trace_height,
@@ -223,23 +278,30 @@ impl SegmentationCtx {
223278
);
224279
return true;
225280
}
226-
total_cells += padded_height as usize * (width + interactions * interaction_weight);
281+
main_cnt += padded_height as usize * width;
282+
interaction_cnt += padded_height as usize * interactions;
227283
}
284+
// interaction rounding to match n_logup calculation
285+
let total_memory = (main_cnt * main_weight
286+
+ (interaction_cnt + 1).next_power_of_two() * interaction_weight)
287+
* base_field_size;
228288

229-
if total_cells > self.config.limits.max_cells {
289+
if total_memory > self.config.limits.max_memory {
230290
tracing::info!(
231-
"instret {:10} | total cells ({:10}) > max ({:10})",
291+
"overshoot: instret {:10} | total memory ({:10}) > max ({:10}) | main ({:10}) | interaction ({:10})",
232292
instret,
233-
total_cells,
234-
self.config.limits.max_cells
293+
total_memory,
294+
self.config.limits.max_memory,
295+
main_cnt,
296+
interaction_cnt
235297
);
236298
return true;
237299
}
238300

239301
let total_interactions = self.calculate_total_interactions(trace_heights);
240302
if total_interactions > self.config.limits.max_interactions {
241303
tracing::info!(
242-
"instret {:10} | total interactions ({:10}) > max ({:10})",
304+
"overshoot: instret {:10} | total interactions ({:10}) > max ({:10})",
243305
instret,
244306
total_interactions,
245307
self.config.limits.max_interactions
@@ -403,18 +465,21 @@ impl SegmentationCtx {
403465
trace_heights: &[u32],
404466
) {
405467
let (max_trace_height, air_name) = self.calculate_max_trace_height_with_name(trace_heights);
406-
let total_cells = self.calculate_total_cells(trace_heights);
468+
let (total_memory, main_memory, interaction_memory) =
469+
self.calculate_total_memory(trace_heights);
407470
let total_interactions = self.calculate_total_interactions(trace_heights);
408471
let utilization = self.calculate_trace_utilization(trace_heights);
409472

410473
let final_marker = if IS_FINAL { " [TERMINATED]" } else { "" };
411474

412475
tracing::info!(
413-
"Segment {:3} | instret {:10} | {:8} instructions | {:10} cells | {:10} interactions | {:8} max height ({}) | {:.2}% utilization{}",
476+
"Segment {:3} | instret {:10} | {:8} instructions | {:5} memory ({:5}, {:5}) | {:10} interactions | {:8} max height ({}) | {:.2}% utilization{}",
414477
self.segments.len(),
415478
instret_start,
416479
num_insns,
417-
total_cells,
480+
ByteSize::b(total_memory as u64),
481+
ByteSize::b(main_memory as u64),
482+
ByteSize::b(interaction_memory as u64),
418483
total_interactions,
419484
max_trace_height,
420485
air_name,

0 commit comments

Comments
 (0)