Skip to content

Commit dc48faf

Browse files
authored
fix: Reduce size of FnBounds (#1592)
- Extract guest symbol data into a string table and store it in `guest.syms` to reduce the size of metrics.json - Update `flamegraph.py` to load guest symbols from `guest.syms` - Fix stack trace unwinding in `update_current_fn()` - Update `openvm-prof` to use `mmap` for faster reading of large JSON files
1 parent d712b9d commit dc48faf

File tree

9 files changed

+98
-16
lines changed

9 files changed

+98
-16
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ rustc-*
2323
.bench_metrics/
2424
__pycache__/
2525
metrics.json
26+
guest.syms
2627

2728
# KZG trusted setup
2829
**/params

Cargo.lock

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

ci/scripts/metric_unify/flamegraph.py

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,19 @@
66

77
from utils import FLAMEGRAPHS_DIR, get_git_root
88

9-
def get_stack_lines(metrics_dict, group_by_kvs, stack_keys, metric_name, sum_metrics=None):
9+
def get_function_symbol(string_table, offset_str):
10+
try:
11+
offset_int = int(offset_str)
12+
end = string_table.find(b'\0', offset_int)
13+
if end == -1:
14+
print(f"Invalid symbol offset: {offset_int}")
15+
return None
16+
return string_table[offset_int:end].decode()
17+
except ValueError:
18+
return offset_str
19+
20+
21+
def get_stack_lines(metrics_dict, group_by_kvs, stack_keys, metric_name, sum_metrics=None, string_table=None):
1022
"""
1123
Filters a metrics_dict obtained from json for entries that look like:
1224
[ { labels: [["key1", "span1;span2"], ["key2", "span3"]], "metric": metric_name, "value": 2 } ]
@@ -40,7 +52,15 @@ def get_stack_lines(metrics_dict, group_by_kvs, stack_keys, metric_name, sum_met
4052
if key not in labels:
4153
filter = True
4254
break
43-
stack_values.append(labels[key])
55+
if key == 'cycle_tracker_span':
56+
if labels[key] == '' or string_table is None:
57+
stack_values.append(labels[key])
58+
else:
59+
symbol_offsets = labels[key].split(';')
60+
function_symbols = [get_function_symbol(string_table, offset) for offset in symbol_offsets]
61+
stack_values.extend(function_symbols)
62+
else:
63+
stack_values.append(labels[key])
4464
if filter:
4565
continue
4666

@@ -57,8 +77,8 @@ def get_stack_lines(metrics_dict, group_by_kvs, stack_keys, metric_name, sum_met
5777
return lines if non_zero else []
5878

5979

60-
def create_flamegraph(fname, metrics_dict, group_by_kvs, stack_keys, metric_name, sum_metrics=None, reverse=False):
61-
lines = get_stack_lines(metrics_dict, group_by_kvs, stack_keys, metric_name, sum_metrics)
80+
def create_flamegraph(fname, metrics_dict, group_by_kvs, stack_keys, metric_name, sum_metrics=None, reverse=False, string_table=None):
81+
lines = get_stack_lines(metrics_dict, group_by_kvs, stack_keys, metric_name, sum_metrics, string_table)
6282
if not lines:
6383
return
6484

@@ -86,7 +106,7 @@ def create_flamegraph(fname, metrics_dict, group_by_kvs, stack_keys, metric_name
86106
print(f"Created flamegraph at {flamegraph_path}")
87107

88108

89-
def create_flamegraphs(metrics_file, group_by, stack_keys, metric_name, sum_metrics=None, reverse=False):
109+
def create_flamegraphs(metrics_file, group_by, stack_keys, metric_name, sum_metrics=None, reverse=False, string_table=None):
90110
fname_prefix = os.path.splitext(os.path.basename(metrics_file))[0]
91111

92112
with open(metrics_file, 'r') as f:
@@ -104,18 +124,18 @@ def create_flamegraphs(metrics_file, group_by, stack_keys, metric_name, sum_metr
104124
for group_by_values in group_by_values_list:
105125
group_by_kvs = list(zip(group_by, group_by_values))
106126
fname = fname_prefix + '-' + '-'.join(group_by_values)
107-
create_flamegraph(fname, metrics_dict, group_by_kvs, stack_keys, metric_name, sum_metrics, reverse=reverse)
127+
create_flamegraph(fname, metrics_dict, group_by_kvs, stack_keys, metric_name, sum_metrics, reverse=reverse, string_table=string_table)
108128

109129

110-
def create_custom_flamegraphs(metrics_file, group_by=["group"]):
130+
def create_custom_flamegraphs(metrics_file, group_by=["group"], string_table=None):
111131
for reverse in [False, True]:
112132
create_flamegraphs(metrics_file, group_by, ["cycle_tracker_span", "dsl_ir", "opcode"], "frequency",
113-
reverse=reverse)
133+
reverse=reverse, string_table=string_table)
114134
create_flamegraphs(metrics_file, group_by, ["cycle_tracker_span", "dsl_ir", "opcode", "air_name"], "cells_used",
115-
reverse=reverse)
135+
reverse=reverse, string_table=string_table)
116136
create_flamegraphs(metrics_file, group_by, ["cell_tracker_span"], "cells_used",
117137
sum_metrics=["simple_advice_cells", "fixed_cells", "lookup_advice_cells"],
118-
reverse=reverse)
138+
reverse=reverse, string_table=string_table)
119139

120140

121141
def main():
@@ -127,9 +147,16 @@ def main():
127147

128148
argparser = argparse.ArgumentParser()
129149
argparser.add_argument('metrics_json', type=str, help="Path to the metrics JSON")
150+
argparser.add_argument('--guest-symbols', type=str, help="Path to the guest symbols file", default=None, required=False)
130151
args = argparser.parse_args()
131152

132-
create_custom_flamegraphs(args.metrics_json)
153+
if args.guest_symbols:
154+
with open(args.guest_symbols, 'rb') as f:
155+
string_table = f.read()
156+
else:
157+
string_table = None
158+
159+
create_custom_flamegraphs(args.metrics_json, string_table=string_table)
133160

134161

135162
if __name__ == '__main__':

crates/prof/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@ clap = { workspace = true, features = ["derive"] }
1616
eyre.workspace = true
1717
itertools = { workspace = true, features = ["use_std"] }
1818
num-format = "0.4"
19+
memmap2 = "0.9"

crates/prof/src/lib.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use aggregate::{
44
EXECUTE_TIME_LABEL, PROOF_TIME_LABEL, PROVE_EXCL_TRACE_TIME_LABEL, TRACE_GEN_TIME_LABEL,
55
};
66
use eyre::Result;
7+
use memmap2::Mmap;
78

89
use crate::types::{Labels, Metric, MetricDb, MetricsFile};
910

@@ -14,7 +15,8 @@ pub mod types;
1415
impl MetricDb {
1516
pub fn new(metrics_file: impl AsRef<Path>) -> Result<Self> {
1617
let file = File::open(metrics_file)?;
17-
let metrics: MetricsFile = serde_json::from_reader(file)?;
18+
let mmap = unsafe { Mmap::map(&file)? };
19+
let metrics: MetricsFile = serde_json::from_slice(&mmap)?;
1820

1921
let mut db = MetricDb::default();
2022

crates/toolchain/transpiler/Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ eyre.workspace = true
1515
thiserror.workspace = true
1616
elf = "0.7.4"
1717
rrs-lib.workspace = true
18+
rustc-demangle = { version = "0.1.24", optional = true }
1819

1920
[features]
20-
function-span = []
21+
function-span = ["dep:rustc-demangle"]

crates/toolchain/transpiler/src/elf.rs

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
// Initial version taken from https://github.com/succinctlabs/sp1/blob/v2.0.0/crates/core/executor/src/disassembler/elf.rs under MIT License
22
// and https://github.com/risc0/risc0/blob/f61379bf69b24d56e49d6af96a3b284961dcc498/risc0/binfmt/src/elf.rs#L34 under Apache License
33
use std::{cmp::min, collections::BTreeMap, fmt::Debug};
4+
#[cfg(feature = "function-span")]
5+
use std::{
6+
collections::{hash_map::Entry, HashMap},
7+
io::Write,
8+
};
49

510
use elf::{
611
abi::{EM_RISCV, ET_EXEC, PF_X, PT_LOAD},
@@ -93,18 +98,48 @@ impl Elf {
9398
#[cfg(feature = "function-span")]
9499
{
95100
if let Some((symtab, stringtab)) = elf.symbol_table()? {
101+
let mut fn_names = Vec::new();
102+
for symbol in symtab.iter() {
103+
if symbol.st_symtype() == elf::abi::STT_FUNC {
104+
let raw_name = stringtab.get(symbol.st_name as usize).unwrap().to_string();
105+
let demangled_name = rustc_demangle::demangle(&raw_name).to_string();
106+
fn_names.push((demangled_name, symbol.st_name));
107+
}
108+
}
109+
110+
let mut buf = Vec::new();
111+
let mut offsets = HashMap::new();
112+
buf.push(0);
113+
for (name, st_name) in fn_names {
114+
if let Entry::Vacant(e) = offsets.entry(st_name) {
115+
let offset = buf.len();
116+
e.insert(offset);
117+
buf.extend_from_slice(name.as_bytes());
118+
buf.push(0);
119+
}
120+
}
121+
96122
for symbol in symtab.iter() {
97123
if symbol.st_symtype() == elf::abi::STT_FUNC {
98124
fn_bounds.insert(
99125
symbol.st_value as u32,
100126
FnBound {
101127
start: symbol.st_value as u32,
102128
end: (symbol.st_value + symbol.st_size - (WORD_SIZE as u64)) as u32,
103-
name: stringtab.get(symbol.st_name as usize).unwrap().to_string(),
129+
name: offsets[&symbol.st_name].to_string(),
104130
},
105131
);
106132
}
107133
}
134+
135+
let guest_symbols_path = std::env::var("GUEST_SYMBOLS_PATH")?;
136+
let mut guest_symbols_file =
137+
std::fs::File::create(&guest_symbols_path).map_err(|e| {
138+
eyre::eyre!(
139+
"Failed to create guest symbols file at {guest_symbols_path}: {e}"
140+
)
141+
})?;
142+
guest_symbols_file.write_all(buf.as_slice())?;
108143
} else {
109144
println!("No symbol table found");
110145
}

crates/vm/src/metrics/cycle_tracker/mod.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@ impl CycleTracker {
99
Self::default()
1010
}
1111

12+
pub fn top(&self) -> Option<&String> {
13+
self.stack.last()
14+
}
15+
1216
/// Starts a new cycle tracker span for the given name.
1317
/// If a span already exists for the given name, it ends the existing span and pushes a new one
1418
/// to the vec.

crates/vm/src/metrics/mod.rs

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,10 @@ impl VmMetrics {
106106

107107
#[cfg(feature = "function-span")]
108108
fn update_current_fn(&mut self, pc: u32) {
109-
if !self.fn_bounds.is_empty() && (pc < self.current_fn.start || pc > self.current_fn.end) {
109+
if self.fn_bounds.is_empty() {
110+
return;
111+
}
112+
if pc < self.current_fn.start || pc > self.current_fn.end {
110113
self.current_fn = self
111114
.fn_bounds
112115
.range(..=pc)
@@ -116,10 +119,16 @@ impl VmMetrics {
116119
if pc == self.current_fn.start {
117120
self.cycle_tracker.start(self.current_fn.name.clone());
118121
} else {
119-
self.cycle_tracker.force_end();
122+
while let Some(name) = self.cycle_tracker.top() {
123+
if name == &self.current_fn.name {
124+
break;
125+
}
126+
self.cycle_tracker.force_end();
127+
}
120128
}
121129
};
122130
}
131+
123132
pub fn emit(&self) {
124133
for (name, value) in self.chip_heights.iter() {
125134
let labels = [("chip_name", name.clone())];

0 commit comments

Comments
 (0)