Skip to content
Draft
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 36 additions & 36 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -114,40 +114,40 @@ iter_over_hash_type = "deny"

# Uncomment both patches below for local stark-backend and openvm.
# The local openvm also needs to have stark-backend patched so all types match.
# [patch."https://github.com/powdr-labs/stark-backend.git"]
# openvm-stark-sdk = { path = "../stark-backend/crates/stark-sdk", default-features = false }
# openvm-stark-backend = { path = "../stark-backend/crates/stark-backend", default-features = false }
# openvm-cuda-backend = { path = "../stark-backend/crates/cuda-backend", default-features = false }
# openvm-cuda-builder = { path = "../stark-backend/crates/cuda-builder", default-features = false }
# openvm-cuda-common = { path = "../stark-backend/crates/cuda-common", default-features = false }
[patch."https://github.com/powdr-labs/stark-backend.git"]
openvm-stark-sdk = { path = "../stark-backend/crates/stark-sdk", default-features = false }
openvm-stark-backend = { path = "../stark-backend/crates/stark-backend", default-features = false }
openvm-cuda-backend = { path = "../stark-backend/crates/cuda-backend", default-features = false }
openvm-cuda-builder = { path = "../stark-backend/crates/cuda-builder", default-features = false }
openvm-cuda-common = { path = "../stark-backend/crates/cuda-common", default-features = false }

# [patch."https://github.com/powdr-labs/openvm.git"]
# openvm = { path = "../openvm/crates/toolchain/openvm" }
# openvm-build = { path = "../openvm/crates/toolchain/build" }
# openvm-rv32im-circuit = { path = "../openvm/extensions/rv32im/circuit/" }
# openvm-rv32im-transpiler = { path = "../openvm/extensions/rv32im/transpiler" }
# openvm-rv32im-guest = { path = "../openvm/extensions/rv32im/guest" }
# openvm-transpiler = { path = "../openvm/crates/toolchain/transpiler" }
# openvm-circuit = { path = "../openvm/crates/vm" }
# openvm-circuit-derive = { path = "../openvm/crates/vm/derive" }
# openvm-circuit-primitives = { path = "../openvm/crates/circuits/primitives" }
# openvm-circuit-primitives-derive = { path = "../openvm/crates/circuits/primitives/derive" }
# openvm-instructions = { path = "../openvm/crates/toolchain/instructions" }
# openvm-instructions-derive = { path = "../openvm/crates/toolchain/instructions/derive" }
# openvm-sdk = { path = "../openvm/crates/sdk" }
# openvm-ecc-circuit = { path = "../openvm/extensions/ecc/circuit" }
# openvm-ecc-transpiler = { path = "../openvm/extensions/ecc/transpiler" }
# openvm-keccak256-circuit = { path = "../openvm/extensions/keccak256/circuit" }
# openvm-keccak256-transpiler = { path = "../openvm/extensions/keccak256/transpiler" }
# openvm-sha256-circuit = { path = "../openvm/extensions/sha256/circuit" }
# openvm-sha256-transpiler = { path = "../openvm/extensions/sha256/transpiler" }
# openvm-algebra-circuit = { path = "../openvm/extensions/algebra/circuit" }
# openvm-algebra-transpiler = { path = "../openvm/extensions/algebra/transpiler" }
# openvm-bigint-circuit = { path = "../openvm/extensions/bigint/circuit" }
# openvm-bigint-transpiler = { path = "../openvm/extensions/bigint/transpiler" }
# openvm-pairing-circuit = { path = "../openvm/extensions/pairing/circuit" }
# openvm-pairing-transpiler = { path = "../openvm/extensions/pairing/transpiler" }
# openvm-native-circuit = { path = "../openvm/extensions/native/circuit" }
# openvm-native-recursion = { path = "../openvm/extensions/native/recursion" }
# openvm-platform = { path = "../openvm/crates/toolchain/platform" }
# openvm-custom-insn = { path = "../openvm/crates/toolchain/custom_insn" }
[patch."https://github.com/powdr-labs/openvm.git"]
openvm = { path = "../openvm/crates/toolchain/openvm" }
openvm-build = { path = "../openvm/crates/toolchain/build" }
openvm-rv32im-circuit = { path = "../openvm/extensions/rv32im/circuit/" }
openvm-rv32im-transpiler = { path = "../openvm/extensions/rv32im/transpiler" }
openvm-rv32im-guest = { path = "../openvm/extensions/rv32im/guest" }
openvm-transpiler = { path = "../openvm/crates/toolchain/transpiler" }
openvm-circuit = { path = "../openvm/crates/vm" }
openvm-circuit-derive = { path = "../openvm/crates/vm/derive" }
openvm-circuit-primitives = { path = "../openvm/crates/circuits/primitives" }
openvm-circuit-primitives-derive = { path = "../openvm/crates/circuits/primitives/derive" }
openvm-instructions = { path = "../openvm/crates/toolchain/instructions" }
openvm-instructions-derive = { path = "../openvm/crates/toolchain/instructions/derive" }
openvm-sdk = { path = "../openvm/crates/sdk" }
openvm-ecc-circuit = { path = "../openvm/extensions/ecc/circuit" }
openvm-ecc-transpiler = { path = "../openvm/extensions/ecc/transpiler" }
openvm-keccak256-circuit = { path = "../openvm/extensions/keccak256/circuit" }
openvm-keccak256-transpiler = { path = "../openvm/extensions/keccak256/transpiler" }
openvm-sha256-circuit = { path = "../openvm/extensions/sha256/circuit" }
openvm-sha256-transpiler = { path = "../openvm/extensions/sha256/transpiler" }
openvm-algebra-circuit = { path = "../openvm/extensions/algebra/circuit" }
openvm-algebra-transpiler = { path = "../openvm/extensions/algebra/transpiler" }
openvm-bigint-circuit = { path = "../openvm/extensions/bigint/circuit" }
openvm-bigint-transpiler = { path = "../openvm/extensions/bigint/transpiler" }
openvm-pairing-circuit = { path = "../openvm/extensions/pairing/circuit" }
openvm-pairing-transpiler = { path = "../openvm/extensions/pairing/transpiler" }
openvm-native-circuit = { path = "../openvm/extensions/native/circuit" }
openvm-native-recursion = { path = "../openvm/extensions/native/recursion" }
openvm-platform = { path = "../openvm/crates/toolchain/platform" }
openvm-custom-insn = { path = "../openvm/crates/toolchain/custom_insn" }
24 changes: 16 additions & 8 deletions autoprecompiles/src/trace_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,22 @@ use std::{cmp::Eq, hash::Hash};
use crate::expression::{AlgebraicExpression, AlgebraicReference};
use crate::{Apc, InstructionHandler};

pub struct OriginalRowReference<'a, D> {
pub struct OriginalRowReference<'a, D, I> {
pub air_id: &'a I,
pub row_index: usize,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added this because when an apc row fails tracegen, we need to know which original rows to add to the software tables.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Original rows aka rejected rows of the APC dummy traces?

Software tables aka non-APC traces?

pub data: &'a D,
pub start: usize,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed this because it can be derived from the rest

pub length: usize,
}

pub struct TraceData<'a, F, D> {
impl<'a, D, I> OriginalRowReference<'a, D, I> {
pub fn start(&self) -> usize {
self.row_index * self.length
}
}

pub struct TraceData<'a, F, D, I> {
/// For each call of the apc, the values of each original instruction's dummy trace.
pub dummy_values: Vec<Vec<OriginalRowReference<'a, D>>>,
pub dummy_values: Vec<Vec<OriginalRowReference<'a, D, I>>>,
/// The mapping from dummy trace index to APC index for each instruction.
pub dummy_trace_index_to_apc_index_by_instruction: Vec<Vec<(usize, usize)>>,
/// The mapping from poly_id to the index in the list of apc columns.
Expand All @@ -43,7 +50,7 @@ pub fn generate_trace<'a, IH, M: TraceTrait<IH::Field>>(
instruction_handler: &'a IH,
apc_call_count: usize,
apc: &'a Apc<IH::Field, IH::Instruction>,
) -> TraceData<'a, IH::Field, M::Values>
) -> TraceData<'a, IH::Field, M::Values, IH::AirId>
where
IH: InstructionHandler,
IH::Field: Display + Clone + Send + Sync,
Expand Down Expand Up @@ -104,15 +111,16 @@ where
.iter()
.zip_eq(original_instruction_table_offsets.iter())
.map(|(air_id, dummy_table_offset)| {
let trace = air_id_to_dummy_trace.get(air_id).unwrap();
let (air_id, trace) = air_id_to_dummy_trace.get_key_value(air_id).unwrap();
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting case here, air_id has the same value, but a different lifetime! We need the one with the longer lifetime.

let values = trace.values();
let width = trace.width();
let occurrences_per_record = air_id_occurrences.get(air_id).unwrap();
let start = (trace_row * occurrences_per_record + dummy_table_offset) * width;
let row_index = trace_row * occurrences_per_record + dummy_table_offset;
OriginalRowReference {
data: values,
start,
length: width,
air_id,
row_index,
}
})
.collect_vec()
Expand Down
110 changes: 82 additions & 28 deletions openvm/src/powdr_extension/trace_generator/cpu/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@ use openvm_circuit::{
use openvm_stark_backend::{
p3_field::FieldAlgebra,
p3_matrix::dense::{DenseMatrix, RowMajorMatrix},
prover::{hal::ProverBackend, types::AirProvingContext},
prover::{
hal::ProverBackend,
types::{AirProvingContext, AirProvingContexts},
},
Chip,
};
use openvm_stark_sdk::p3_baby_bear::BabyBear;
Expand Down Expand Up @@ -59,14 +62,26 @@ impl<F> From<Arc<RowMajorMatrix<F>>> for SharedCpuTrace<F> {
}

impl<R, PB: ProverBackend<Matrix = Arc<RowMajorMatrix<BabyBear>>>> Chip<R, PB> for PowdrChipCpu {
fn generate_proving_ctx(&self, _: R) -> AirProvingContext<PB> {
fn generate_proving_ctx(&self, records: R) -> AirProvingContext<PB> {
unreachable!()
}

fn generate_proving_ctxs(&self, _: R) -> AirProvingContexts<PB> {
tracing::trace!("Generating air proof input for PowdrChip {}", self.name);

let trace = self
let (trace, rejected) = self
.trace_generator
.generate_witness(self.record_arena_by_air_name.take());

AirProvingContext::simple(Arc::new(trace), vec![])
let rejected = rejected
.into_iter()
.map(|(key, (rows, values))| (key, (rows, AirProvingContext::simple(values, vec![]))))
.collect();

AirProvingContexts {
main: AirProvingContext::simple(Arc::new(trace), vec![]),
rejected,
}
}
}

Expand Down Expand Up @@ -95,14 +110,17 @@ impl PowdrTraceGeneratorCpu {
pub fn generate_witness(
&self,
mut original_arenas: OriginalArenas<MatrixRecordArena<BabyBear>>,
) -> DenseMatrix<BabyBear> {
) -> (
DenseMatrix<BabyBear>,
HashMap<String, (Vec<usize>, Arc<DenseMatrix<BabyBear>>)>,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For each air name, a trace and the rows of that trace which were rejected

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice :)

) {
use powdr_autoprecompiles::trace_handler::{generate_trace, TraceData};

let num_apc_calls = original_arenas.number_of_calls();
if num_apc_calls == 0 {
// If the APC isn't called, early return with an empty trace.
let width = self.apc.machine().main_columns().count();
return RowMajorMatrix::new(vec![], width);
return (RowMajorMatrix::new(vec![], width), HashMap::default());
}

let chip_inventory = {
Expand All @@ -119,7 +137,7 @@ impl PowdrTraceGeneratorCpu {
.inventory
};

let dummy_trace_by_air_name: HashMap<String, SharedCpuTrace<BabyBear>> = chip_inventory
let mut dummy_trace_by_air_name: HashMap<String, SharedCpuTrace<BabyBear>> = chip_inventory
.chips()
.iter()
.enumerate()
Expand Down Expand Up @@ -157,6 +175,12 @@ impl PowdrTraceGeneratorCpu {
let height = next_power_of_two_or_zero(num_apc_calls);
let mut values = <BabyBear as FieldAlgebra>::zero_vec(height * width);

let mut rejected: HashMap<String, Vec<usize>> = dummy_trace_by_air_name
.keys()
.cloned()
.map(|key| (key, vec![]))
.collect();

// go through the final table and fill in the values
values
// a record is `width` values
Expand All @@ -169,7 +193,7 @@ impl PowdrTraceGeneratorCpu {
use powdr_autoprecompiles::expression::MappingRowEvaluator;
for (dummy_row, dummy_trace_index_to_apc_index) in dummy_values
.iter()
.map(|r| &r.data[r.start..r.start + r.length])
.map(|r| &r.data[r.start()..r.start() + r.length])
Comment on lines -172 to +198
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Length is the width of the original AIR?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

.zip_eq(&dummy_trace_index_to_apc_index_by_instruction)
{
for (dummy_trace_index, apc_index) in dummy_trace_index_to_apc_index {
Expand Down Expand Up @@ -200,27 +224,57 @@ impl PowdrTraceGeneratorCpu {

let evaluator = MappingRowEvaluator::new(row_slice, &apc_poly_id_to_index);

// replay the side effects of this row on the main periphery
self.apc
.machine()
.bus_interactions
.iter()
.for_each(|interaction| {
use powdr_autoprecompiles::expression::{
AlgebraicEvaluator, ConcreteBusInteraction,
};

let ConcreteBusInteraction { id, mult, args } =
evaluator.eval_bus_interaction(interaction);
self.periphery.real.apply(
id as u16,
mult.as_canonical_u32(),
args.map(|arg| arg.as_canonical_u32()),
&self.periphery.bus_ids,
);
});
// check the constraints and bus interactions
let row_is_valid = true;

if row_is_valid {
// replay the side effects of this row on the main periphery
self.apc
.machine()
.bus_interactions
.iter()
.for_each(|interaction| {
use powdr_autoprecompiles::expression::{
AlgebraicEvaluator, ConcreteBusInteraction,
};

let ConcreteBusInteraction { id, mult, args } =
evaluator.eval_bus_interaction(interaction);
self.periphery.real.apply(
id as u16,
mult.as_canonical_u32(),
args.map(|arg| arg.as_canonical_u32()),
&self.periphery.bus_ids,
);
});
} else {
// for each original row
for original_row_reference in dummy_values {
// TODO replay the side effects of this row on the real periphery

// add the row index to the rejected set
rejected
.get_mut(original_row_reference.air_id)
.unwrap()
.push(original_row_reference.row_index);
}
}
});

RowMajorMatrix::new(values, width)
// merge the rejected indices with the traces
let rejected = rejected
.into_iter()
.map(|(name, indices)| {
(
name.clone(),
(
indices,
dummy_trace_by_air_name.remove(&name).unwrap().matrix,
),
)
})
.collect();

(RowMajorMatrix::new(values, width), rejected)
}
}
Loading