Skip to content

Commit f7fd1d7

Browse files
committed
feat: tco for other extensions
1 parent 0c93530 commit f7fd1d7

File tree

31 files changed

+683
-27
lines changed

31 files changed

+683
-27
lines changed

extensions/algebra/circuit/Cargo.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@ edition.workspace = true
77
homepage.workspace = true
88
repository.workspace = true
99

10+
[features]
11+
default = ["jemalloc", "tco"]
12+
tco = ["openvm-circuit/tco"]
13+
jemalloc = ["openvm-circuit/jemalloc"]
14+
1015
[dependencies]
1116
openvm-circuit-primitives = { workspace = true }
1217
openvm-circuit-primitives-derive = { workspace = true }

extensions/algebra/circuit/src/execution.rs

Lines changed: 90 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,7 @@ use std::{
55

66
use num_bigint::BigUint;
77
use openvm_algebra_transpiler::{Fp2Opcode, Rv32ModularArithmeticOpcode};
8-
use openvm_circuit::{
9-
arch::*,
10-
system::memory::{online::GuestMemory, POINTER_MAX_BITS},
11-
};
8+
use openvm_circuit::arch::*;
129
use openvm_circuit_primitives::AlignedBytesBorrow;
1310
use openvm_instructions::{
1411
instruction::Instruction,
@@ -177,6 +174,94 @@ impl<'a, const BLOCKS: usize, const BLOCK_SIZE: usize, const IS_FP2: bool>
177174
impl<F: PrimeField32, const BLOCKS: usize, const BLOCK_SIZE: usize, const IS_FP2: bool> Executor<F>
178175
for FieldExprVecHeapExecutor<BLOCKS, BLOCK_SIZE, IS_FP2>
179176
{
177+
#[cfg(feature = "tco")]
178+
fn handler<Ctx>(
179+
&self,
180+
pc: u32,
181+
inst: &Instruction<F>,
182+
data: &mut [u8],
183+
) -> Result<Handler<F, Ctx>, StaticProgramError>
184+
where
185+
Ctx: ExecutionCtxTrait,
186+
{
187+
let pre_compute: &mut FieldExpressionPreCompute = data.borrow_mut();
188+
189+
let op = self.pre_compute_impl(pc, inst, pre_compute)?;
190+
191+
if let Some(op) = op {
192+
let modulus = &pre_compute.expr.prime;
193+
if IS_FP2 {
194+
if let Some(field_type) = get_fp2_field_type(modulus) {
195+
generate_fp2_dispatch!(
196+
field_type,
197+
op,
198+
BLOCKS,
199+
BLOCK_SIZE,
200+
execute_e1_tco_handler,
201+
[
202+
(BN254Coordinate, Add),
203+
(BN254Coordinate, Sub),
204+
(BN254Coordinate, Mul),
205+
(BN254Coordinate, Div),
206+
(BLS12_381Coordinate, Add),
207+
(BLS12_381Coordinate, Sub),
208+
(BLS12_381Coordinate, Mul),
209+
(BLS12_381Coordinate, Div),
210+
]
211+
)
212+
} else {
213+
Ok(execute_e1_generic_tco_handler::<_, _, BLOCKS, BLOCK_SIZE, IS_FP2>)
214+
}
215+
} else if let Some(field_type) = get_field_type(modulus) {
216+
generate_field_dispatch!(
217+
field_type,
218+
op,
219+
BLOCKS,
220+
BLOCK_SIZE,
221+
execute_e1_tco_handler,
222+
[
223+
(K256Coordinate, Add),
224+
(K256Coordinate, Sub),
225+
(K256Coordinate, Mul),
226+
(K256Coordinate, Div),
227+
(K256Scalar, Add),
228+
(K256Scalar, Sub),
229+
(K256Scalar, Mul),
230+
(K256Scalar, Div),
231+
(P256Coordinate, Add),
232+
(P256Coordinate, Sub),
233+
(P256Coordinate, Mul),
234+
(P256Coordinate, Div),
235+
(P256Scalar, Add),
236+
(P256Scalar, Sub),
237+
(P256Scalar, Mul),
238+
(P256Scalar, Div),
239+
(BN254Coordinate, Add),
240+
(BN254Coordinate, Sub),
241+
(BN254Coordinate, Mul),
242+
(BN254Coordinate, Div),
243+
(BN254Scalar, Add),
244+
(BN254Scalar, Sub),
245+
(BN254Scalar, Mul),
246+
(BN254Scalar, Div),
247+
(BLS12_381Coordinate, Add),
248+
(BLS12_381Coordinate, Sub),
249+
(BLS12_381Coordinate, Mul),
250+
(BLS12_381Coordinate, Div),
251+
(BLS12_381Scalar, Add),
252+
(BLS12_381Scalar, Sub),
253+
(BLS12_381Scalar, Mul),
254+
(BLS12_381Scalar, Div),
255+
]
256+
)
257+
} else {
258+
Ok(execute_e1_generic_tco_handler::<_, _, BLOCKS, BLOCK_SIZE, IS_FP2>)
259+
}
260+
} else {
261+
Ok(execute_e1_setup_tco_handler::<_, _, BLOCKS, BLOCK_SIZE, IS_FP2>)
262+
}
263+
}
264+
180265
#[inline(always)]
181266
fn pre_compute_size(&self) -> usize {
182267
std::mem::size_of::<FieldExpressionPreCompute>()
@@ -527,6 +612,7 @@ unsafe fn execute_e2_setup_impl<
527612
execute_e12_setup_impl::<_, _, BLOCKS, BLOCK_SIZE, IS_FP2>(&pre_compute.data, vm_state);
528613
}
529614

615+
#[create_tco_handler]
530616
unsafe fn execute_e1_impl<
531617
F: PrimeField32,
532618
CTX: ExecutionCtxTrait,

extensions/algebra/circuit/src/lib.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
#![cfg_attr(feature = "tco", allow(incomplete_features))]
2+
#![cfg_attr(feature = "tco", feature(likely_unlikely))]
3+
#![cfg_attr(feature = "tco", feature(explicit_tail_calls))]
4+
15
use derive_more::derive::{Deref, DerefMut};
26
use openvm_circuit_derive::PreflightExecutor;
37
use openvm_mod_circuit_builder::FieldExpressionExecutor;

extensions/algebra/circuit/src/modular_chip/is_eq.rs

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,7 @@ use num_bigint::BigUint;
77
use openvm_algebra_transpiler::Rv32ModularArithmeticOpcode;
88
use openvm_circuit::{
99
arch::*,
10-
system::memory::{
11-
online::{GuestMemory, TracingMemory},
12-
MemoryAuxColsFactory, POINTER_MAX_BITS,
13-
},
10+
system::memory::{online::TracingMemory, MemoryAuxColsFactory},
1411
};
1512
use openvm_circuit_primitives::{
1613
bigint::utils::big_uint_to_limbs,
@@ -550,6 +547,34 @@ where
550547

551548
Ok(fn_ptr)
552549
}
550+
551+
#[cfg(feature = "tco")]
552+
fn handler<Ctx>(
553+
&self,
554+
_opcode: u32,
555+
_instruction: &Instruction<F>,
556+
_data: &mut [u8],
557+
) -> Result<
558+
for<'a, 'b, 'c> unsafe fn(
559+
&'a InterpretedInstance<'b, F, Ctx>,
560+
&'c mut VmExecState<F, GuestMemory, Ctx>,
561+
) -> Result<(), ExecutionError>,
562+
StaticProgramError,
563+
>
564+
where
565+
Ctx: ExecutionCtxTrait,
566+
{
567+
let pre_compute: &mut ModularIsEqualPreCompute<TOTAL_READ_SIZE> = data.borrow_mut();
568+
569+
let is_setup = self.pre_compute_impl(pc, inst, pre_compute)?;
570+
let fn_ptr = if is_setup {
571+
execute_e1_tco_handler::<_, _, NUM_LANES, LANE_SIZE, TOTAL_READ_SIZE, true>
572+
} else {
573+
execute_e1_tco_handler::<_, _, NUM_LANES, LANE_SIZE, TOTAL_READ_SIZE, false>
574+
};
575+
576+
Ok(fn_ptr)
577+
}
553578
}
554579

555580
impl<F, const NUM_LANES: usize, const LANE_SIZE: usize, const TOTAL_READ_SIZE: usize>
@@ -584,6 +609,7 @@ where
584609
}
585610
}
586611

612+
#[create_tco_handler]
587613
unsafe fn execute_e1_impl<
588614
F: PrimeField32,
589615
CTX: ExecutionCtxTrait,

extensions/bigint/circuit/Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,10 @@ test-case.workspace = true
3333
alloy-primitives = { version = "1.2.1" }
3434

3535
[features]
36-
default = ["parallel", "jemalloc"]
36+
default = ["parallel", "jemalloc", "tco"]
3737
parallel = ["openvm-circuit/parallel"]
3838
test-utils = ["openvm-circuit/test-utils"]
39+
tco = ["openvm-circuit/tco"]
3940
# performance features:
4041
mimalloc = ["openvm-circuit/mimalloc"]
4142
jemalloc = ["openvm-circuit/jemalloc"]

extensions/bigint/circuit/src/base_alu.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,28 @@ impl<F: PrimeField32> Executor<F> for Rv32BaseAlu256Executor {
5959
};
6060
Ok(fn_ptr)
6161
}
62+
63+
#[cfg(feature = "tco")]
64+
fn handler<Ctx>(
65+
&self,
66+
pc: u32,
67+
inst: &Instruction<F>,
68+
data: &mut [u8],
69+
) -> Result<Handler<F, Ctx>, StaticProgramError>
70+
where
71+
Ctx: ExecutionCtxTrait,
72+
{
73+
let data: &mut BaseAluPreCompute = data.borrow_mut();
74+
let local_opcode = self.pre_compute_impl(pc, inst, data)?;
75+
let fn_ptr = match local_opcode {
76+
BaseAluOpcode::ADD => execute_e1_tco_handler::<_, _, AddOp>,
77+
BaseAluOpcode::SUB => execute_e1_tco_handler::<_, _, SubOp>,
78+
BaseAluOpcode::XOR => execute_e1_tco_handler::<_, _, XorOp>,
79+
BaseAluOpcode::OR => execute_e1_tco_handler::<_, _, OrOp>,
80+
BaseAluOpcode::AND => execute_e1_tco_handler::<_, _, AndOp>,
81+
};
82+
Ok(fn_ptr)
83+
}
6284
}
6385

6486
impl<F: PrimeField32> MeteredExecutor<F> for Rv32BaseAlu256Executor {
@@ -106,6 +128,7 @@ unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait, OP: AluOp>(
106128
vm_state.instret += 1;
107129
}
108130

131+
#[create_tco_handler]
109132
unsafe fn execute_e1_impl<F: PrimeField32, CTX: ExecutionCtxTrait, OP: AluOp>(
110133
pre_compute: &[u8],
111134
vm_state: &mut VmExecState<F, GuestMemory, CTX>,

extensions/bigint/circuit/src/branch_eq.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,25 @@ impl<F: PrimeField32> Executor<F> for Rv32BranchEqual256Executor {
5454
};
5555
Ok(fn_ptr)
5656
}
57+
58+
#[cfg(feature = "tco")]
59+
fn handler<Ctx>(
60+
&self,
61+
pc: u32,
62+
inst: &Instruction<F>,
63+
data: &mut [u8],
64+
) -> Result<Handler<F, Ctx>, StaticProgramError>
65+
where
66+
Ctx: ExecutionCtxTrait,
67+
{
68+
let data: &mut BranchEqPreCompute = data.borrow_mut();
69+
let local_opcode = self.pre_compute_impl(pc, inst, data)?;
70+
let fn_ptr = match local_opcode {
71+
BranchEqualOpcode::BEQ => execute_e1_tco_handler::<_, _, false>,
72+
BranchEqualOpcode::BNE => execute_e1_tco_handler::<_, _, true>,
73+
};
74+
Ok(fn_ptr)
75+
}
5776
}
5877

5978
impl<F: PrimeField32> MeteredExecutor<F> for Rv32BranchEqual256Executor {
@@ -101,6 +120,7 @@ unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait, const IS_NE:
101120
vm_state.instret += 1;
102121
}
103122

123+
#[create_tco_handler]
104124
unsafe fn execute_e1_impl<F: PrimeField32, CTX: ExecutionCtxTrait, const IS_NE: bool>(
105125
pre_compute: &[u8],
106126
vm_state: &mut VmExecState<F, GuestMemory, CTX>,

extensions/bigint/circuit/src/branch_lt.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,27 @@ impl<F: PrimeField32> Executor<F> for Rv32BranchLessThan256Executor {
5959
};
6060
Ok(fn_ptr)
6161
}
62+
63+
#[cfg(feature = "tco")]
64+
fn handler<Ctx>(
65+
&self,
66+
pc: u32,
67+
inst: &Instruction<F>,
68+
data: &mut [u8],
69+
) -> Result<Handler<F, Ctx>, StaticProgramError>
70+
where
71+
Ctx: ExecutionCtxTrait,
72+
{
73+
let data: &mut BranchLtPreCompute = data.borrow_mut();
74+
let local_opcode = self.pre_compute_impl(pc, inst, data)?;
75+
let fn_ptr = match local_opcode {
76+
BranchLessThanOpcode::BLT => execute_e1_tco_handler::<_, _, BltOp>,
77+
BranchLessThanOpcode::BLTU => execute_e1_tco_handler::<_, _, BltuOp>,
78+
BranchLessThanOpcode::BGE => execute_e1_tco_handler::<_, _, BgeOp>,
79+
BranchLessThanOpcode::BGEU => execute_e1_tco_handler::<_, _, BgeuOp>,
80+
};
81+
Ok(fn_ptr)
82+
}
6283
}
6384

6485
impl<F: PrimeField32> MeteredExecutor<F> for Rv32BranchLessThan256Executor {
@@ -107,6 +128,7 @@ unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait, OP: BranchLe
107128
vm_state.instret += 1;
108129
}
109130

131+
#[create_tco_handler]
110132
unsafe fn execute_e1_impl<F: PrimeField32, CTX: ExecutionCtxTrait, OP: BranchLessThanOp>(
111133
pre_compute: &[u8],
112134
vm_state: &mut VmExecState<F, GuestMemory, CTX>,

extensions/bigint/circuit/src/less_than.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,25 @@ impl<F: PrimeField32> Executor<F> for Rv32LessThan256Executor {
5454
};
5555
Ok(fn_ptr)
5656
}
57+
58+
#[cfg(feature = "tco")]
59+
fn handler<Ctx>(
60+
&self,
61+
pc: u32,
62+
inst: &Instruction<F>,
63+
data: &mut [u8],
64+
) -> Result<Handler<F, Ctx>, StaticProgramError>
65+
where
66+
Ctx: ExecutionCtxTrait,
67+
{
68+
let data: &mut LessThanPreCompute = data.borrow_mut();
69+
let local_opcode = self.pre_compute_impl(pc, inst, data)?;
70+
let fn_ptr = match local_opcode {
71+
LessThanOpcode::SLT => execute_e1_tco_handler::<_, _, false>,
72+
LessThanOpcode::SLTU => execute_e1_tco_handler::<_, _, true>,
73+
};
74+
Ok(fn_ptr)
75+
}
5776
}
5877

5978
impl<F: PrimeField32> MeteredExecutor<F> for Rv32LessThan256Executor {
@@ -105,6 +124,7 @@ unsafe fn execute_e12_impl<F: PrimeField32, CTX: ExecutionCtxTrait, const IS_U25
105124
vm_state.instret += 1;
106125
}
107126

127+
#[create_tco_handler]
108128
unsafe fn execute_e1_impl<F: PrimeField32, CTX: ExecutionCtxTrait, const IS_U256: bool>(
109129
pre_compute: &[u8],
110130
vm_state: &mut VmExecState<F, GuestMemory, CTX>,

extensions/bigint/circuit/src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
#![cfg_attr(feature = "tco", allow(incomplete_features))]
2+
#![cfg_attr(feature = "tco", feature(likely_unlikely))]
3+
#![cfg_attr(feature = "tco", feature(explicit_tail_calls))]
14
use openvm_circuit::{
25
self,
36
arch::{

0 commit comments

Comments
 (0)