diff --git a/.codespellignore b/.codespellignore index c91d0f7707..4b1c229c68 100644 --- a/.codespellignore +++ b/.codespellignore @@ -2,4 +2,5 @@ InOut inout LoadE SelectE -ser \ No newline at end of file +ser +te \ No newline at end of file diff --git a/.github/workflows/extension-tests.yml b/.github/workflows/extension-tests.yml index 2d07bdf1f6..2ac189374a 100644 --- a/.github/workflows/extension-tests.yml +++ b/.github/workflows/extension-tests.yml @@ -29,7 +29,7 @@ jobs: - { name: "rv32im", path: "rv32im" } - { name: "native", path: "native" } - { name: "keccak256", path: "keccak256" } - - { name: "sha256", path: "sha256" } + - { name: "sha2", path: "sha2" } - { name: "bigint", path: "bigint" } - { name: "algebra", path: "algebra" } - { name: "ecc", path: "ecc" } diff --git a/.github/workflows/primitives.yml b/.github/workflows/primitives.yml index 2d86155ab2..4385b1ba5a 100644 --- a/.github/workflows/primitives.yml +++ b/.github/workflows/primitives.yml @@ -8,7 +8,7 @@ on: paths: - "crates/circuits/primitives/**" - "crates/circuits/poseidon2-air/**" - - "crates/circuits/sha256-air/**" + - "crates/circuits/sha2-air/**" - "crates/circuits/mod-builder/**" - "Cargo.toml" - ".github/workflows/primitives.yml" @@ -48,8 +48,8 @@ jobs: run: | cargo nextest run --cargo-profile fast --features parallel - - name: Run tests for sha256-air - working-directory: crates/circuits/sha256-air + - name: Run tests for sha2-air + working-directory: crates/circuits/sha2-air run: | cargo nextest run --cargo-profile fast --features parallel diff --git a/Cargo.lock b/Cargo.lock index a2674489a9..f62b172b8a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4277,8 +4277,8 @@ dependencies = [ "openvm-ecc-transpiler", "openvm-rv32im-circuit", "openvm-rv32im-transpiler", - "openvm-sha256-circuit", - "openvm-sha256-transpiler", + "openvm-sha2-circuit", + "openvm-sha2-transpiler", "openvm-stark-backend", "openvm-stark-sdk", "openvm-toolchain-tests", @@ -4552,6 +4552,16 @@ dependencies = [ "regex-automata 0.1.10", ] +[[package]] +name = "matrixmultiply" +version = "0.3.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a06de3016e9fae57a36fd14dba131fccf49f74b40b7fbdb472f96e361ec71a08" +dependencies = [ + "autocfg", + "rawpointer", +] + [[package]] name = "maybe-rayon" version = "0.1.1" @@ -4689,6 +4699,21 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "ndarray" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "882ed72dce9365842bf196bdeedf5055305f11fc8c03dee7bb0194a6cad34841" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "portable-atomic", + "portable-atomic-util", + "rawpointer", +] + [[package]] name = "new_debug_unreachable" version = "1.0.6" @@ -5125,8 +5150,8 @@ dependencies = [ "openvm-pairing-transpiler", "openvm-rv32im-circuit", "openvm-rv32im-transpiler", - "openvm-sha256-circuit", - "openvm-sha256-transpiler", + "openvm-sha2-circuit", + "openvm-sha2-transpiler", "openvm-stark-sdk", "openvm-transpiler", "rand 0.8.5", @@ -5309,6 +5334,8 @@ name = "openvm-circuit-primitives-derive" version = "1.4.0-rc.0" dependencies = [ "itertools 0.14.0", + "ndarray", + "proc-macro2", "quote", "syn 2.0.104", ] @@ -5345,6 +5372,7 @@ dependencies = [ "hex-literal 0.4.1", "lazy_static", "num-bigint 0.4.6", + "num-integer", "num-traits", "once_cell", "openvm-algebra-circuit", @@ -5352,11 +5380,13 @@ dependencies = [ "openvm-circuit-derive", "openvm-circuit-primitives", "openvm-circuit-primitives-derive", + "openvm-ecc-guest", "openvm-ecc-transpiler", "openvm-instructions", "openvm-mod-circuit-builder", "openvm-rv32-adapters", "openvm-rv32im-circuit", + "openvm-sha2-circuit", "openvm-stark-backend", "openvm-stark-sdk", "rand 0.8.5", @@ -5373,12 +5403,18 @@ dependencies = [ "elliptic-curve 0.13.8", "group 0.13.0", "halo2curves-axiom", + "hex-literal 0.4.1", + "lazy_static", + "num-bigint 0.4.6", "once_cell", "openvm", "openvm-algebra-guest", + "openvm-algebra-moduli-macros", "openvm-custom-insn", "openvm-ecc-sw-macros", + "openvm-ecc-te-macros", "openvm-rv32im-guest", + "openvm-sha2", "serde", "strum_macros", ] @@ -5397,6 +5433,7 @@ dependencies = [ "openvm-ecc-transpiler", "openvm-rv32im-transpiler", "openvm-sdk", + "openvm-sha2-transpiler", "openvm-stark-sdk", "openvm-toolchain-tests", "openvm-transpiler", @@ -5414,6 +5451,15 @@ dependencies = [ "syn 2.0.104", ] +[[package]] +name = "openvm-ecc-te-macros" +version = "1.4.0-rc.0" +dependencies = [ + "openvm-macros-common", + "quote", + "syn 2.0.104", +] + [[package]] name = "openvm-ecc-transpiler" version = "1.4.0-rc.0" @@ -5924,8 +5970,8 @@ dependencies = [ "openvm-pairing-transpiler", "openvm-rv32im-circuit", "openvm-rv32im-transpiler", - "openvm-sha256-circuit", - "openvm-sha256-transpiler", + "openvm-sha2-circuit", + "openvm-sha2-transpiler", "openvm-stark-backend", "openvm-stark-sdk", "openvm-transpiler", @@ -5950,9 +5996,9 @@ dependencies = [ "openvm-circuit", "openvm-instructions", "openvm-rv32im-transpiler", - "openvm-sha256-circuit", - "openvm-sha256-guest", - "openvm-sha256-transpiler", + "openvm-sha2-circuit", + "openvm-sha2-guest", + "openvm-sha2-transpiler", "openvm-stark-sdk", "openvm-toolchain-tests", "openvm-transpiler", @@ -5960,11 +6006,14 @@ dependencies = [ ] [[package]] -name = "openvm-sha256-air" +name = "openvm-sha2-air" version = "1.4.0-rc.0" dependencies = [ + "ndarray", + "num_enum", "openvm-circuit", "openvm-circuit-primitives", + "openvm-circuit-primitives-derive", "openvm-stark-backend", "openvm-stark-sdk", "rand 0.8.5", @@ -5972,41 +6021,41 @@ dependencies = [ ] [[package]] -name = "openvm-sha256-circuit" +name = "openvm-sha2-circuit" version = "1.4.0-rc.0" dependencies = [ "derive-new 0.6.0", "derive_more 1.0.0", + "ndarray", "openvm-circuit", "openvm-circuit-derive", "openvm-circuit-primitives", "openvm-circuit-primitives-derive", "openvm-instructions", "openvm-rv32im-circuit", - "openvm-sha256-air", - "openvm-sha256-transpiler", + "openvm-sha2-air", + "openvm-sha2-transpiler", "openvm-stark-backend", "openvm-stark-sdk", "rand 0.8.5", "serde", "sha2 0.10.9", - "strum", ] [[package]] -name = "openvm-sha256-guest" +name = "openvm-sha2-guest" version = "1.4.0-rc.0" dependencies = [ "openvm-platform", ] [[package]] -name = "openvm-sha256-transpiler" +name = "openvm-sha2-transpiler" version = "1.4.0-rc.0" dependencies = [ "openvm-instructions", "openvm-instructions-derive", - "openvm-sha256-guest", + "openvm-sha2-guest", "openvm-stark-backend", "openvm-transpiler", "rrs-lib", @@ -6193,8 +6242,8 @@ dependencies = [ "openvm-ecc-transpiler", "openvm-rv32im-circuit", "openvm-rv32im-transpiler", - "openvm-sha256-circuit", - "openvm-sha256-transpiler", + "openvm-sha2-circuit", + "openvm-sha2-transpiler", "openvm-stark-backend", "openvm-stark-sdk", "openvm-toolchain-tests", @@ -6869,6 +6918,15 @@ version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f84267b20a16ea918e43c6a88433c2d54fa145c92a811b5b047ccbe153674483" +[[package]] +name = "portable-atomic-util" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8a2f0d8d040d7848a709caf78912debcc3f33ee4b3cac47d73d1e1069e83507" +dependencies = [ + "portable-atomic", +] + [[package]] name = "poseidon-primitives" version = "0.2.0" @@ -7247,6 +7305,12 @@ dependencies = [ "bitflags 2.9.1", ] +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + [[package]] name = "rayon" version = "1.10.0" diff --git a/Cargo.toml b/Cargo.toml index 2fe6c8dc62..4e275d6766 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -51,13 +51,14 @@ members = [ "extensions/keccak256/circuit", "extensions/keccak256/transpiler", "extensions/keccak256/guest", - "extensions/sha256/circuit", - "extensions/sha256/transpiler", - "extensions/sha256/guest", + "extensions/sha2/circuit", + "extensions/sha2/transpiler", + "extensions/sha2/guest", "extensions/ecc/circuit", "extensions/ecc/transpiler", "extensions/ecc/guest", "extensions/ecc/sw-macros", + "extensions/ecc/te-macros", "extensions/ecc/tests", "extensions/pairing/circuit", "extensions/pairing/guest", @@ -119,7 +120,7 @@ openvm-stark-sdk = { git = "https://github.com/openvm-org/stark-backend.git", ta openvm-sdk = { path = "crates/sdk", default-features = false } openvm-mod-circuit-builder = { path = "crates/circuits/mod-builder", default-features = false } openvm-poseidon2-air = { path = "crates/circuits/poseidon2-air", default-features = false } -openvm-sha256-air = { path = "crates/circuits/sha256-air", default-features = false } +openvm-sha2-air = { path = "crates/circuits/sha2-air", default-features = false } openvm-circuit-primitives = { path = "crates/circuits/primitives", default-features = false } openvm-circuit-primitives-derive = { path = "crates/circuits/primitives/derive", default-features = false } openvm = { path = "crates/toolchain/openvm", default-features = false } @@ -149,9 +150,9 @@ openvm-native-transpiler = { path = "extensions/native/transpiler", default-feat openvm-keccak256-circuit = { path = "extensions/keccak256/circuit", default-features = false } openvm-keccak256-transpiler = { path = "extensions/keccak256/transpiler", default-features = false } openvm-keccak256-guest = { path = "extensions/keccak256/guest", default-features = false } -openvm-sha256-circuit = { path = "extensions/sha256/circuit", default-features = false } -openvm-sha256-transpiler = { path = "extensions/sha256/transpiler", default-features = false } -openvm-sha256-guest = { path = "extensions/sha256/guest", default-features = false } +openvm-sha2-circuit = { path = "extensions/sha2/circuit", default-features = false } +openvm-sha2-transpiler = { path = "extensions/sha2/transpiler", default-features = false } +openvm-sha2-guest = { path = "extensions/sha2/guest", default-features = false } openvm-bigint-circuit = { path = "extensions/bigint/circuit", default-features = false } openvm-bigint-transpiler = { path = "extensions/bigint/transpiler", default-features = false } openvm-bigint-guest = { path = "extensions/bigint/guest", default-features = false } @@ -164,11 +165,15 @@ openvm-ecc-circuit = { path = "extensions/ecc/circuit", default-features = false openvm-ecc-transpiler = { path = "extensions/ecc/transpiler", default-features = false } openvm-ecc-guest = { path = "extensions/ecc/guest", default-features = false } openvm-ecc-sw-macros = { path = "extensions/ecc/sw-macros", default-features = false } +openvm-ecc-te-macros = { path = "extensions/ecc/te-macros", default-features = false } openvm-pairing-circuit = { path = "extensions/pairing/circuit", default-features = false } openvm-pairing-transpiler = { path = "extensions/pairing/transpiler", default-features = false } openvm-pairing-guest = { path = "extensions/pairing/guest", default-features = false } openvm-verify-stark = { path = "guest-libs/verify_stark", default-features = false } +# Guest Libraries +openvm-sha2 = { path = "guest-libs/sha2", default-features = false } + # Benchmarking openvm-benchmarks-utils = { path = "benchmarks/utils", default-features = false } @@ -228,6 +233,8 @@ hex = { version = "0.4.3", default-features = false } serde-big-array = "0.5.1" dashmap = "6.1.0" memmap2 = "0.9.5" +ndarray = { version = "0.16.1", default-features = false } +num_enum = { version = "0.7.4", default-features = false } # default-features = false for no_std for use in guest programs itertools = { version = "0.14.0", default-features = false } diff --git a/benchmarks/execute/Cargo.toml b/benchmarks/execute/Cargo.toml index a6979c5e89..f943f73dcb 100644 --- a/benchmarks/execute/Cargo.toml +++ b/benchmarks/execute/Cargo.toml @@ -25,8 +25,8 @@ openvm-keccak256-circuit.workspace = true openvm-keccak256-transpiler.workspace = true openvm-rv32im-circuit.workspace = true openvm-rv32im-transpiler.workspace = true -openvm-sha256-circuit.workspace = true -openvm-sha256-transpiler.workspace = true +openvm-sha2-circuit.workspace = true +openvm-sha2-transpiler.workspace = true eyre.workspace = true derive_more = { workspace = true, features = ["from"] } rand.workspace = true diff --git a/benchmarks/execute/benches/execute.rs b/benchmarks/execute/benches/execute.rs index dca805d826..80f23a5290 100644 --- a/benchmarks/execute/benches/execute.rs +++ b/benchmarks/execute/benches/execute.rs @@ -22,9 +22,7 @@ use openvm_circuit::{ }, derive::VmConfig, }; -use openvm_ecc_circuit::{ - WeierstrassExtension, WeierstrassExtensionExecutor, WeierstrassExtensionPeriphery, -}; +use openvm_ecc_circuit::{EccExtension, EccExtensionExecutor, EccExtensionPeriphery}; use openvm_ecc_transpiler::EccTranspilerExtension; use openvm_keccak256_circuit::{Keccak256, Keccak256Executor, Keccak256Periphery}; use openvm_keccak256_transpiler::Keccak256TranspilerExtension; @@ -40,8 +38,8 @@ use openvm_rv32im_circuit::{ use openvm_rv32im_transpiler::{ Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, }; -use openvm_sha256_circuit::{Sha256, Sha256Executor, Sha256Periphery}; -use openvm_sha256_transpiler::Sha256TranspilerExtension; +use openvm_sha2_circuit::{Sha2, Sha2Executor, Sha2Periphery}; +use openvm_sha2_transpiler::Sha2TranspilerExtension; use openvm_stark_sdk::{ config::baby_bear_poseidon2::{ default_engine, BabyBearPoseidon2Config, BabyBearPoseidon2Engine, @@ -84,13 +82,13 @@ pub struct ExecuteConfig { #[extension] pub keccak: Keccak256, #[extension] - pub sha256: Sha256, + pub sha2: Sha2, #[extension] pub modular: ModularExtension, #[extension] pub fp2: Fp2Extension, #[extension] - pub weierstrass: WeierstrassExtension, + pub ecc: EccExtension, #[extension] pub pairing: PairingExtension, } @@ -105,7 +103,7 @@ impl Default for ExecuteConfig { io: Rv32Io, bigint: Int256::default(), keccak: Keccak256, - sha256: Sha256, + sha2: Sha2, modular: ModularExtension::new(vec![ bn_config.modulus.clone(), bn_config.scalar.clone(), @@ -114,7 +112,7 @@ impl Default for ExecuteConfig { BN254_COMPLEX_STRUCT_NAME.to_string(), bn_config.modulus.clone(), )]), - weierstrass: WeierstrassExtension::new(vec![bn_config.clone()]), + ecc: EccExtension::new(vec![bn_config.clone()], vec![]), pairing: PairingExtension::new(vec![PairingCurve::Bn254]), } } @@ -147,7 +145,7 @@ fn create_default_transpiler() -> Transpiler { .with_extension(Rv32MTranspilerExtension) .with_extension(Int256TranspilerExtension) .with_extension(Keccak256TranspilerExtension) - .with_extension(Sha256TranspilerExtension) + .with_extension(Sha2TranspilerExtension) .with_extension(ModularTranspilerExtension) .with_extension(Fp2TranspilerExtension) .with_extension(EccTranspilerExtension) diff --git a/benchmarks/guest/kitchen-sink/openvm.toml b/benchmarks/guest/kitchen-sink/openvm.toml index 2d1b307eef..e6cafcf57f 100644 --- a/benchmarks/guest/kitchen-sink/openvm.toml +++ b/benchmarks/guest/kitchen-sink/openvm.toml @@ -2,7 +2,7 @@ [app_vm_config.rv32m] [app_vm_config.io] [app_vm_config.keccak] -[app_vm_config.sha256] +[app_vm_config.sha2] [app_vm_config.bigint] [app_vm_config.modular] diff --git a/benchmarks/guest/sha256/openvm.toml b/benchmarks/guest/sha256/openvm.toml index 656bf52414..35f92b7195 100644 --- a/benchmarks/guest/sha256/openvm.toml +++ b/benchmarks/guest/sha256/openvm.toml @@ -1,4 +1,4 @@ [app_vm_config.rv32i] [app_vm_config.rv32m] [app_vm_config.io] -[app_vm_config.sha256] +[app_vm_config.sha2] diff --git a/benchmarks/guest/sha256_iter/openvm.toml b/benchmarks/guest/sha256_iter/openvm.toml index 656bf52414..35f92b7195 100644 --- a/benchmarks/guest/sha256_iter/openvm.toml +++ b/benchmarks/guest/sha256_iter/openvm.toml @@ -1,4 +1,4 @@ [app_vm_config.rv32i] [app_vm_config.rv32m] [app_vm_config.io] -[app_vm_config.sha256] +[app_vm_config.sha2] diff --git a/benchmarks/prove/src/bin/ecrecover.rs b/benchmarks/prove/src/bin/ecrecover.rs index 23fe2c82af..54fca3e3c2 100644 --- a/benchmarks/prove/src/bin/ecrecover.rs +++ b/benchmarks/prove/src/bin/ecrecover.rs @@ -12,8 +12,8 @@ use openvm_circuit::{ derive::VmConfig, }; use openvm_ecc_circuit::{ - CurveConfig, WeierstrassExtension, WeierstrassExtensionExecutor, WeierstrassExtensionPeriphery, - SECP256K1_CONFIG, + CurveConfig, EccExtension, EccExtensionExecutor, EccExtensionPeriphery, SwCurveCoeffs, + TeCurveCoeffs, SECP256K1_CONFIG, }; use openvm_ecc_transpiler::EccTranspilerExtension; use openvm_keccak256_circuit::{Keccak256, Keccak256Executor, Keccak256Periphery}; @@ -63,7 +63,7 @@ pub struct Rv32ImEcRecoverConfig { #[extension] pub keccak: Keccak256, #[extension] - pub weierstrass: WeierstrassExtension, + pub ecc: EccExtension, } impl InitFileGenerator for Rv32ImEcRecoverConfig { @@ -71,17 +71,25 @@ impl InitFileGenerator for Rv32ImEcRecoverConfig { Some(format!( "// This file is automatically generated by cargo openvm. Do not rename or edit.\n{}\n{}\n", self.modular.generate_moduli_init(), - self.weierstrass.generate_sw_init() + self.ecc.generate_ecc_init() )) } } impl Rv32ImEcRecoverConfig { - pub fn for_curves(curves: Vec) -> Self { - let primes: Vec = curves + pub fn for_curves( + sw_curves: Vec>, + te_curves: Vec>, + ) -> Self { + let sw_primes: Vec = sw_curves .iter() .flat_map(|c| [c.modulus.clone(), c.scalar.clone()]) .collect(); + let te_primes: Vec = te_curves + .iter() + .flat_map(|c| [c.modulus.clone(), c.scalar.clone()]) + .collect(); + let primes = [sw_primes, te_primes].concat(); Self { system: SystemConfig::default().with_continuations(), base: Default::default(), @@ -89,7 +97,7 @@ impl Rv32ImEcRecoverConfig { io: Default::default(), modular: ModularExtension::new(primes), keccak: Default::default(), - weierstrass: WeierstrassExtension::new(curves), + ecc: EccExtension::new(sw_curves, te_curves), } } } @@ -97,7 +105,7 @@ impl Rv32ImEcRecoverConfig { fn main() -> Result<()> { let args = BenchmarkCli::parse(); - let config = Rv32ImEcRecoverConfig::for_curves(vec![SECP256K1_CONFIG.clone()]); + let config = Rv32ImEcRecoverConfig::for_curves(vec![SECP256K1_CONFIG.clone()], vec![]); let elf = args.build_bench_program("ecrecover", &config, None)?; let exe = VmExe::from_elf( diff --git a/benchmarks/prove/src/bin/kitchen_sink.rs b/benchmarks/prove/src/bin/kitchen_sink.rs index 13fd7380d9..be11b096ce 100644 --- a/benchmarks/prove/src/bin/kitchen_sink.rs +++ b/benchmarks/prove/src/bin/kitchen_sink.rs @@ -7,7 +7,7 @@ use openvm_algebra_circuit::{Fp2Extension, ModularExtension}; use openvm_benchmarks_prove::util::BenchmarkCli; use openvm_circuit::arch::{instructions::exe::VmExe, SingleSegmentVmExecutor, SystemConfig}; use openvm_continuations::verifier::leaf::types::LeafVmVerifierInput; -use openvm_ecc_circuit::{WeierstrassExtension, P256_CONFIG, SECP256K1_CONFIG}; +use openvm_ecc_circuit::{EccExtension, P256_CONFIG, SECP256K1_CONFIG}; use openvm_native_circuit::{NativeConfig, NATIVE_MAX_TRACE_HEIGHTS}; use openvm_native_recursion::halo2::utils::{CacheHalo2ParamsReader, DEFAULT_PARAMS_DIR}; use openvm_pairing_circuit::{PairingCurve, PairingExtension}; @@ -94,7 +94,7 @@ fn main() -> Result<()> { .rv32m(Default::default()) .io(Default::default()) .keccak(Default::default()) - .sha256(Default::default()) + .sha2(Default::default()) .bigint(Default::default()) .modular(ModularExtension::new(vec![ BigUint::from_str("1000000000000000003").unwrap(), @@ -119,12 +119,15 @@ fn main() -> Result<()> { bls_config.modulus.clone(), ), ])) - .ecc(WeierstrassExtension::new(vec![ - SECP256K1_CONFIG.clone(), - P256_CONFIG.clone(), - bn_config.clone(), - bls_config.clone(), - ])) + .ecc(EccExtension::new( + vec![ + SECP256K1_CONFIG.clone(), + P256_CONFIG.clone(), + bn_config.clone(), + bls_config.clone(), + ], + vec![], + )) .pairing(PairingExtension::new(vec![ PairingCurve::Bn254, PairingCurve::Bls12_381, diff --git a/benchmarks/prove/src/bin/pairing.rs b/benchmarks/prove/src/bin/pairing.rs index 1db6d1b491..80cc68809c 100644 --- a/benchmarks/prove/src/bin/pairing.rs +++ b/benchmarks/prove/src/bin/pairing.rs @@ -3,7 +3,7 @@ use eyre::Result; use openvm_algebra_circuit::{Fp2Extension, ModularExtension}; use openvm_benchmarks_prove::util::BenchmarkCli; use openvm_circuit::arch::SystemConfig; -use openvm_ecc_circuit::WeierstrassExtension; +use openvm_ecc_circuit::EccExtension; use openvm_pairing_circuit::{PairingCurve, PairingExtension}; use openvm_pairing_guest::bn254::{BN254_COMPLEX_STRUCT_NAME, BN254_MODULUS, BN254_ORDER}; use openvm_sdk::{config::SdkVmConfig, Sdk, StdIn}; @@ -26,9 +26,10 @@ fn main() -> Result<()> { BN254_COMPLEX_STRUCT_NAME.to_string(), BN254_MODULUS.clone(), )])) - .ecc(WeierstrassExtension::new(vec![ - PairingCurve::Bn254.curve_config() - ])) + .ecc(EccExtension::new( + vec![PairingCurve::Bn254.curve_config()], + vec![], + )) .pairing(PairingExtension::new(vec![PairingCurve::Bn254])) .build(); let elf = args.build_bench_program("pairing", &vm_config, None)?; diff --git a/book/src/SUMMARY.md b/book/src/SUMMARY.md index 08c9faefb1..953d1fbe2d 100644 --- a/book/src/SUMMARY.md +++ b/book/src/SUMMARY.md @@ -21,7 +21,7 @@ - [Overview](./custom-extensions/overview.md) - [Keccak](./custom-extensions/keccak.md) -- [SHA-256](./custom-extensions/sha256.md) +- [SHA-2](./custom-extensions/sha2.md) - [Big Integer](./custom-extensions/bigint.md) - [Algebra (Modular Arithmetic)](./custom-extensions/algebra.md) - [Elliptic Curve Cryptography](./custom-extensions/ecc.md) diff --git a/book/src/custom-extensions/ecc.md b/book/src/custom-extensions/ecc.md index ba0fff90d6..ff3199dc73 100644 --- a/book/src/custom-extensions/ecc.md +++ b/book/src/custom-extensions/ecc.md @@ -17,12 +17,20 @@ Developers can enable arbitrary Weierstrass curves by configuring this extension - `WeierstrassPoint` trait: It represents an affine point on a Weierstrass elliptic curve and it extends `Group`. - - `Coordinate` type is the type of the coordinates of the point, and it implements `IntMod`. - - `x()`, `y()` are used to get the affine coordinates + - `Coordinate` type is the type of the coordinates of the point, and it implements `Field`. + - `x()`, `y()` are used to get the affine coordinates. - `from_xy` is a constructor for the point, which checks if the point is either identity or on the affine curve. - The point supports elliptic curve operations through intrinsic functions `add_ne_nonidentity` and `double_nonidentity`. - `decompress`: Sometimes an elliptic curve point is compressed and represented by its `x` coordinate and the odd/even parity of the `y` coordinate. `decompress` is used to decompress the point back to `(x, y)`. +- `TwistedEdwardsPoint` trait: + It represents an affine point on a twisted Edwards elliptic curve and it extends `Group`. + + - `Coordinate` type is the type of the coordinates of the point, and it implements `Field`. + - `x()`, `y()` are used to get the affine coordinates. + - `from_xy` is a constructor for the point, which checks if the point is on the affine curve. + - The point supports elliptic curve addition through the `add_impl` method. + - `msm`: for multi-scalar multiplication. - `ecdsa`: for doing ECDSA signature verification and public key recovery from signature. @@ -31,17 +39,20 @@ Developers can enable arbitrary Weierstrass curves by configuring this extension For elliptic curve cryptography, the `openvm-ecc-guest` crate provides macros similar to those in [`openvm-algebra-guest`](./algebra.md): -1. **Declare**: Use `sw_declare!` to define elliptic curves over the previously declared moduli. For example: +1. **Declare**: Use `sw_declare!` or `te_declare!` to define short Weierstrass or twisted Edwards elliptic curves, respectively, over the previously declared moduli. For example: ```rust sw_declare! { Bls12_381G1Affine { mod_type = Bls12_381Fp, b = BLS12_381_B }, P256Affine { mod_type = P256Coord, a = P256_A, b = P256_B }, } +te_declare! { + Edwards25519 { mod_type = Edwards25519Coord, a = CURVE_A, d = CURVE_D }, +} ``` +This creates `Bls12_381G1Affine` and `P256Affine` structs which implement the `Group` and `WeierstrassPoint` traits, and the `Edwards25519` struct which implements the `Group` and `TwistedEdwardsPoint` traits. The underlying memory layout of the structs uses the memory layout of the `Bls12_381Fp`, `P256Coord`, and `Edwards25519Coord` structs, respectively. -Each declared curve must specify the `mod_type` (implementing `IntMod`) and a constant `b` for the Weierstrass curve equation \\(y^2 = x^3 + ax + b\\). `a` is optional and defaults to 0 for short Weierstrass curves. -This creates `Bls12_381G1Affine` and `P256Affine` structs which implement the `Group` and `WeierstrassPoint` traits. The underlying memory layout of the structs uses the memory layout of the `Bls12_381Fp` and `P256Coord` structs, respectively. +Each declared curve must specify the `mod_type` (implementing `Field`) and a constant `b` for the Weierstrass curve equation \\(y^2 = x^3 + ax + b\\) or `a` and `d` for the twisted Edwards curve equation \\(ax^2 + y^2 = 1 + dx^2y^2\\). For short Weierstrass curves, `a` is optional and defaults to 0. 2. **Init**: Called once, the [`openvm::init!` macro](./overview.md#automating-the-init-step) produces a call to `sw_init!` that enumerates these curves and allows the compiler to produce optimized instructions: @@ -51,17 +62,21 @@ openvm::init!(); sw_init! { Bls12_381G1Affine, P256Affine, } +te_init! { + Edwards25519, +} */ ``` **Summary**: -- `sw_declare!`: Declares elliptic curve structures. +- `sw_declare!`: Declares short Weierstrass elliptic curve structures. +- `te_declare!`: Declares twisted Edwards elliptic curve structures. - `init!`: Initializes them once, linking them to the underlying moduli. -To use elliptic curve operations on a struct defined with `sw_declare!`, it is expected that the struct for the curve's coordinate field was defined using `moduli_declare!`. In particular, the coordinate field needs to be initialized and set up as described in the [algebra extension](./algebra.md) chapter. +To use elliptic curve operations on a struct defined with `sw_declare!` or `te_declare!`, it is expected that the struct for the curve's coordinate field was defined using `moduli_declare!`. In particular, the coordinate field needs to be initialized and set up as described in the [algebra extension](./algebra.md) chapter. -For the basic operations provided by the `WeierstrassPoint` trait, the scalar field is not needed. For the ECDSA functions in the `ecdsa` module, the scalar field must also be declared, initialized, and set up. +For the basic operations provided by the `WeierstrassPoint` or `TwistedEdwardsPoint` traits, the scalar field is not needed. For the ECDSA functions in the `ecdsa` module, the scalar field must also be declared, initialized, and set up. ## ECDSA diff --git a/book/src/custom-extensions/overview.md b/book/src/custom-extensions/overview.md index 2b07a73ec4..9ccfe35f3f 100644 --- a/book/src/custom-extensions/overview.md +++ b/book/src/custom-extensions/overview.md @@ -3,7 +3,7 @@ OpenVM ships with a set of pre-built extensions maintained by the OpenVM team. Below, we highlight six of these extensions designed to accelerate common arithmetic and cryptographic operations that are notoriously expensive to execute. Some of these extensions have corresponding guest libraries which provide convenient, high-level interfaces for your guest program to interact with the extension. - [`openvm-keccak-guest`](./keccak.md) - Keccak256 hash function. See the [Keccak256 guest library](../guest-libs/keccak256.md) for usage details. -- [`openvm-sha256-guest`](./sha256.md) - SHA-256 hash function. See the [SHA-2 guest library](../guest-libs/sha2.md) for usage details. +- [`openvm-sha2-guest`](./sha2.md) - SHA-2 family of hash functions. See the [SHA-2 guest library](../guest-libs/sha2.md) for usage details. - [`openvm-bigint-guest`](./bigint.md) - Big integer arithmetic for 256-bit signed and unsigned integers. See the [ruint guest library](../guest-libs/ruint.md) for using accelerated 256-bit integer ops in rust. - [`openvm-algebra-guest`](./algebra.md) - Modular arithmetic and complex field extensions. - [`openvm-ecc-guest`](./ecc.md) - Elliptic curve cryptography. See the [k256](../guest-libs/k256.md) and [p256](../guest-libs/p256.md) guest libraries for using this extension over the respective curves. @@ -43,9 +43,7 @@ range_tuple_checker_sizes = [256, 8192] [app_vm_config.io] [app_vm_config.keccak] - -[app_vm_config.sha256] - +[app_vm_config.sha2] [app_vm_config.native] [app_vm_config.bigint] diff --git a/book/src/custom-extensions/sha256.md b/book/src/custom-extensions/sha2.md similarity index 52% rename from book/src/custom-extensions/sha256.md rename to book/src/custom-extensions/sha2.md index a4a7f46261..de845fe25f 100644 --- a/book/src/custom-extensions/sha256.md +++ b/book/src/custom-extensions/sha2.md @@ -1,8 +1,8 @@ -# SHA-256 +# SHA-2 -The SHA-256 extension guest provides a function that is meant to be linked to other external libraries. The external libraries can use this function as a hook for the SHA-256 intrinsic. This is enabled only when the target is `zkvm`. +The SHA-2 extension guest provides functions that are meant to be linked to other external libraries. The external libraries can use these functions as a hook for SHA-2 intrinsics. This is enabled only when the target is `zkvm`. We support the SHA-256, SHA-512, and SHA-384 hash functions. -- `zkvm_sha256_impl(input: *const u8, len: usize, output: *mut u8)`: This function has `C` ABI. It takes in a pointer to the input, the length of the input, and a pointer to the output buffer. +- `zkvm_shaXXX_impl(input: *const u8, len: usize, output: *mut u8)` where XXX is one of `256`, `512`, or `384`. These functions have `C` ABI. They take in a pointer to the input, the length of the input, and a pointer to the output buffer. In the external library, you can do the following: @@ -31,5 +31,5 @@ fn sha256(input: &[u8]) -> [u8; 32] { For the guest program to build successfully add the following to your `.toml` file: ```toml -[app_vm_config.sha256] +[app_vm_config.sha2] ``` diff --git a/book/src/guest-libs/sha2.md b/book/src/guest-libs/sha2.md index cd35cf2e02..1ce0f46a89 100644 --- a/book/src/guest-libs/sha2.md +++ b/book/src/guest-libs/sha2.md @@ -3,26 +3,33 @@ The OpenVM SHA-2 guest library provides access to a set of accelerated SHA-2 family hash functions. Currently, it supports the following: - SHA-256 +- SHA-512 +- SHA-384 -## SHA-256 - -Refer [here](https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.180-4.pdf) for more details on SHA-256. +Refer [here](https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.180-4.pdf) for more details on the SHA-2 family of hash functions. For SHA-256, the SHA2 guest library provides two functions for use in your guest code: - - `sha256(input: &[u8]) -> [u8; 32]`: Computes the SHA-256 hash of the input data and returns it as an array of 32 bytes. - `set_sha256(input: &[u8], output: &mut [u8; 32])`: Sets the output to the SHA-256 hash of the input data into the provided output buffer. -See the full example [here](https://github.com/openvm-org/openvm/blob/main/examples/sha256/src/main.rs). +For SHA-512, we provide: +- `sha512(input: &[u8]) -> [u8; 46]`: Computes the SHA-512 hash of the input data and returns it as an array of 64 bytes. +- `set_sha512(input: &[u8], output: &mut [u8; 64])`: Sets the output to the SHA-512 hash of the input data into the provided output buffer. + +For SHA-384, we provide: +- `sha384(input: &[u8]) -> [u8; 48]`: Computes the SHA-384 hash of the input data and returns it as an array of 48 bytes. +- `set_sha384(input: &[u8], output: &mut [u8; 48])`: Sets the output to the SHA-384 hash of the input data into the provided output buffer. + +See the full example [here](https://github.com/openvm-org/openvm/blob/feat/sha-512-new-execution/examples/sha2/src/main.rs). ### Example ```rust,no_run,noplayground -{{ #include ../../../examples/sha256/src/main.rs:imports }} -{{ #include ../../../examples/sha256/src/main.rs:main }} +{{ #include ../../../examples/sha2/src/main.rs:imports }} +{{ #include ../../../examples/sha2/src/main.rs:main }} ``` -To be able to import the `sha256` function, add the following to your `Cargo.toml` file: +To be able to import the `shaXXX` functions and run the example, add the following to your `Cargo.toml` file: ```toml openvm-sha2 = { git = "https://github.com/openvm-org/openvm.git" } @@ -34,4 +41,4 @@ hex = { version = "0.4.3" } For the guest program to build successfully add the following to your `.toml` file: ```toml -[app_vm_config.sha256] \ No newline at end of file +[app_vm_config.sha2] \ No newline at end of file diff --git a/book/src/introduction.md b/book/src/introduction.md index ed39cbe33a..833151d66e 100644 --- a/book/src/introduction.md +++ b/book/src/introduction.md @@ -12,7 +12,7 @@ OpenVM is an open-source zero-knowledge virtual machine (zkVM) framework focused - RISC-V support via RV32IM - A native field arithmetic extension for proof recursion and aggregation - - The Keccak-256 and SHA2-256 hash functions + - The Keccak-256, SHA-256, SHA-512, and SHA-384 hash functions - Int256 arithmetic - Modular arithmetic over arbitrary fields - Elliptic curve operations, including multi-scalar multiplication and ECDSA signature verification, including for the secp256k1 and secp256r1 curves diff --git a/crates/circuits/poseidon2-air/src/babybear.rs b/crates/circuits/poseidon2-air/src/babybear.rs index e12b60bfb4..6989f992c7 100644 --- a/crates/circuits/poseidon2-air/src/babybear.rs +++ b/crates/circuits/poseidon2-air/src/babybear.rs @@ -18,7 +18,7 @@ pub(crate) fn horizen_to_p3_babybear(horizen_babybear: HorizenBabyBear) -> BabyB } pub(crate) fn horizen_round_consts() -> Poseidon2Constants { - let p3_rc16: Vec> = RC16 + let p3_rc16: Vec> = RC16 .iter() .map(|round| { round @@ -29,18 +29,10 @@ pub(crate) fn horizen_round_consts() -> Poseidon2Constants { .collect(); let p_end = BABY_BEAR_POSEIDON2_HALF_FULL_ROUNDS + BABY_BEAR_POSEIDON2_PARTIAL_ROUNDS; - let beginning_full_round_constants: [[BabyBear; POSEIDON2_WIDTH]; - BABY_BEAR_POSEIDON2_HALF_FULL_ROUNDS] = from_fn(|i| p3_rc16[i].clone().try_into().unwrap()); - let partial_round_constants: [BabyBear; BABY_BEAR_POSEIDON2_PARTIAL_ROUNDS] = - from_fn(|i| p3_rc16[i + BABY_BEAR_POSEIDON2_HALF_FULL_ROUNDS][0]); - let ending_full_round_constants: [[BabyBear; POSEIDON2_WIDTH]; - BABY_BEAR_POSEIDON2_HALF_FULL_ROUNDS] = - from_fn(|i| p3_rc16[i + p_end].clone().try_into().unwrap()); - Poseidon2Constants { - beginning_full_round_constants, - partial_round_constants, - ending_full_round_constants, + beginning_full_round_constants: from_fn(|i| p3_rc16[i].clone().try_into().unwrap()), + partial_round_constants: from_fn(|i| p3_rc16[i + BABY_BEAR_POSEIDON2_HALF_FULL_ROUNDS][0]), + ending_full_round_constants: from_fn(|i| p3_rc16[i + p_end].clone().try_into().unwrap()), } } diff --git a/crates/circuits/primitives/derive/Cargo.toml b/crates/circuits/primitives/derive/Cargo.toml index 06d4c00aed..2e91772fd5 100644 --- a/crates/circuits/primitives/derive/Cargo.toml +++ b/crates/circuits/primitives/derive/Cargo.toml @@ -12,6 +12,13 @@ license.workspace = true proc-macro = true [dependencies] -syn = { version = "2.0", features = ["parsing"] } +syn = { version = "2.0", features = ["parsing", "extra-traits"] } quote = "1.0" -itertools = { workspace = true } +itertools = { workspace = true, default-features = true } +proc-macro2 = "1.0" + +[dev-dependencies] +ndarray.workspace = true + +[package.metadata.cargo-shear] +ignored = ["ndarray"] diff --git a/crates/circuits/primitives/derive/src/cols_ref/README.md b/crates/circuits/primitives/derive/src/cols_ref/README.md new file mode 100644 index 0000000000..82812f7b90 --- /dev/null +++ b/crates/circuits/primitives/derive/src/cols_ref/README.md @@ -0,0 +1,113 @@ +# ColsRef macro + +The `ColsRef` procedural macro is used in constraint generation to create column structs that have dynamic sizes. + +Note: this macro was originally created for use in the SHA-2 VM extension, where we reuse the same constraint generation code for three different circuits (SHA-256, SHA-512, and SHA-384). +See the [SHA-2 VM extension](../../../../../../extensions/sha2/circuit/src/sha2_chip/air.rs) for an example of how to use the `ColsRef` macro to reuse constraint generation code over multiple circuits. + +## Overview + +As an illustrative example, consider the following columns struct: +```rust +struct ExampleCols { + arr: [T; N], + sum: T, +} +``` +Let's say we want to constrain `sum` to be the sum of the elements of `arr`, and `N` can be either 5 or 10. +We can define a trait that stores the config parameters. +```rust +pub trait ExampleConfig { + const N: usize; +} +``` +and then implement it for the two different configs. +```rust +pub struct ExampleConfigImplA; +impl ExampleConfig for ExampleConfigImplA { + const N: usize = 5; +} +pub struct ExampleConfigImplB; +impl ExampleConfig for ExampleConfigImplB { + const N: usize = 10; +} +``` +Then we can use the `ColsRef` macro like this +```rust +#[derive(ColsRef)] +#[config(ExampleConfig)] +struct ExampleCols { + arr: [T; N], + sum: T, +} +``` +which will generate a columns struct that uses references to the fields. +```rust +struct ExampleColsRef<'a, T, const N: usize> { + arr: ndarray::ArrayView1<'a, T>, // an n-dimensional view into the input slice (ArrayView2 for 2D arrays, etc.) + sum: &'a T, +} +``` +The `ColsRef` macro will also generate a `from` method that takes a slice of the correct length and returns an instance of the columns struct. +The `from` method is parameterized by a struct that implements the `ExampleConfig` trait, and it uses the associated constants to determine how to split the input slice into the fields of the columns struct. + +So, the constraint generation code can be written as +```rust +impl Air for ExampleAir { + fn eval(&self, builder: &mut AB) { + let main = builder.main(); + let (local, _) = (main.row_slice(0), main.row_slice(1)); + let local_cols = ExampleColsRef::::from::(&local[..C::N + 1]); + let sum = local_cols.arr.iter().sum(); + builder.assert_eq(local_cols.sum, sum); + } +} +``` +Notes: +- the `arr` and `sum` fields of `ExampleColsRef` are references to the elements of the `local` slice. +- the name, `N`, of the const generic parameter must match the name of the associated constant `N` in the `ExampleConfig` trait. + +The `ColsRef` macro also generates a `ExampleColsRefMut` struct that stores mutable references to the fields, for use in trace generation. + +The `ColsRef` macro supports more than just variable-length array fields. +The field types can also be: +- any type that derives `AlignedBorrow` via `#[derive(AlignedBorrow)]` +- any type that derives `ColsRef` via `#[derive(ColsRef)]` +- (possibly nested) arrays of `T` or (possibly nested) arrays of a type that derives `AlignedBorrow` + +Note that we currently do not support arrays of types that derive `ColsRef`. + +## Specification + +Annotating a struct named `ExampleCols` with `#[derive(ColsRef)]` and `#[config(ExampleConfig)]` produces two structs, `ExampleColsRef` and `ExampleColsRefMut`. +- we assume `ExampleCols` has exactly one generic type parameter, typically named `T`, and any number of const generic parameters. Each const generic parameter must have a name that matches an associated constant in the `ExampleConfig` trait + +The fields of `ExampleColsRef` have the same names as the fields of `ExampleCols`, but their types are transformed as follows: +- type `T` becomes `&T` +- type `[T; LEN]` becomes `&ArrayView1` (see [ndarray](https://docs.rs/ndarray/latest/ndarray/index.html)) where `LEN` is an associated constant in `ExampleConfig` + - the `ExampleColsRef::from` method will correctly infer the length of the array from the config +- fields with names that end in `Cols` are assumed to be a columns struct that derives `ColsRef` and are transformed into the appropriate `ColsRef` type recursively + - one restriction is that any nested `ColsRef` type must have the same config as the outer `ColsRef` type +- fields that are annotated with `#[aligned_borrow]` are assumed to derive `AlignedBorrow` and are borrowed from the input slice. The new type is a reference to the `AlignedBorrow` type + - if a field whose name ends in `Cols` is annotated with `#[aligned_borrow]`, then the aligned borrow takes precedence, and the field is not transformed into an `ArrayView` +- nested arrays of `U` become `&ArrayViewX` where `X` is the number of dimensions in the nested array type + - `U` can be either the generic type `T` or a type that derives `AlignedBorrow`. In the latter case, the field must be annotated with `#[aligned_borrow]` + - the `ArrayViewX` type provides a `X`-dimensional view into the row slice + +The fields of `ExampleColsRefMut` are almost the same as the fields of `ExampleColsRef`, but they are mutable references. +- the `ArrayViewMutX` type is used instead of `ArrayViewX` for the array fields. +- fields that derive `ColsRef` are transformed into the appropriate `ColsRefMut` type recursively. + +Each of the `ExampleColsRef` and `ExampleColsRefMut` types has the following methods implemented: +```rust +// Takes a slice of the correct length and returns an instance of the columns struct. +pub const fn from(slice: &[T]) -> Self; +// Returns the number of cells in the struct +pub const fn width() -> usize; +``` +Note that the `width` method on both structs returns the same value. + +Additionally, the `ExampleColsRef` struct has a `from_mut` method that takes a `ExampleColsRefMut` and returns a `ExampleColsRef`. +This may be useful in trace generation to pass a `ExampleColsRefMut` to a function that expects a `ExampleColsRef`. + +See the [tests](../../tests/test_cols_ref.rs) for concrete examples of how the `ColsRef` macro handles each of the supported field types. \ No newline at end of file diff --git a/crates/circuits/primitives/derive/src/cols_ref/mod.rs b/crates/circuits/primitives/derive/src/cols_ref/mod.rs new file mode 100644 index 0000000000..63289ec5eb --- /dev/null +++ b/crates/circuits/primitives/derive/src/cols_ref/mod.rs @@ -0,0 +1,697 @@ +extern crate proc_macro; + +use itertools::Itertools; +use quote::{format_ident, quote}; +use syn::{parse_quote, DeriveInput}; + +pub fn cols_ref_impl( + derive_input: DeriveInput, + config: proc_macro2::Ident, +) -> proc_macro2::TokenStream { + let DeriveInput { + ident, + generics, + data, + vis, + .. + } = derive_input; + + let generic_types = generics + .params + .iter() + .filter_map(|p| { + if let syn::GenericParam::Type(type_param) = p { + Some(type_param) + } else { + None + } + }) + .collect::>(); + + if generic_types.len() != 1 { + panic!("Struct must have exactly one generic type parameter"); + } + + let generic_type = generic_types[0]; + + let const_generics = generics.const_params().map(|p| &p.ident).collect_vec(); + + match data { + syn::Data::Struct(data_struct) => { + // Process the fields of the struct, transforming the types for use in ColsRef struct + let const_field_infos: Vec = data_struct + .fields + .iter() + .map(|f| get_const_cols_ref_fields(f, generic_type, &const_generics)) + .collect_vec(); + + // The ColsRef struct is named by appending `Ref` to the struct name + let const_cols_ref_name = syn::Ident::new(&format!("{}Ref", ident), ident.span()); + + // the args to the `from` method will be different for the ColsRef and ColsRefMut + // structs + let from_args = quote! { slice: &'a [#generic_type] }; + + // Package all the necessary information to generate the ColsRef struct + let struct_info = StructInfo { + name: const_cols_ref_name, + vis: vis.clone(), + generic_type: generic_type.clone(), + field_infos: const_field_infos, + fields: data_struct.fields.clone(), + from_args, + derive_clone: true, + }; + + // Generate the ColsRef struct + let const_cols_ref_struct = make_struct(struct_info.clone(), &config); + + // Generate the `from_mut` method for the ColsRef struct + let from_mut_impl = make_from_mut(struct_info, &config); + + // Process the fields of the struct, transforming the types for use in ColsRefMut struct + let mut_field_infos: Vec = data_struct + .fields + .iter() + .map(|f| get_mut_cols_ref_fields(f, generic_type, &const_generics)) + .collect_vec(); + + // The ColsRefMut struct is named by appending `RefMut` to the struct name + let mut_cols_ref_name = syn::Ident::new(&format!("{}RefMut", ident), ident.span()); + + // the args to the `from` method will be different for the ColsRef and ColsRefMut + // structs + let from_args = quote! { slice: &'a mut [#generic_type] }; + + // Package all the necessary information to generate the ColsRefMut struct + let struct_info = StructInfo { + name: mut_cols_ref_name, + vis, + generic_type: generic_type.clone(), + field_infos: mut_field_infos, + fields: data_struct.fields, + from_args, + derive_clone: false, + }; + + // Generate the ColsRefMut struct + let mut_cols_ref_struct = make_struct(struct_info, &config); + + quote! { + #const_cols_ref_struct + #from_mut_impl + #mut_cols_ref_struct + } + } + _ => panic!("ColsRef can only be derived for structs"), + } +} + +#[derive(Debug, Clone)] +struct StructInfo { + name: syn::Ident, + vis: syn::Visibility, + generic_type: syn::TypeParam, + field_infos: Vec, + fields: syn::Fields, + from_args: proc_macro2::TokenStream, + derive_clone: bool, +} + +// Generate the ColsRef and ColsRefMut structs, depending on the value of `struct_info` +// This function is meant to reduce code duplication between the code needed to generate the two +// structs Notable differences between the two structs are: +// - the types of the fields +// - ColsRef derives Clone, but ColsRefMut cannot (since it stores mutable references) +// - the `from` method parameter is a reference to a slice for ColsRef and a mutable reference to +// a slice for ColsRefMut +fn make_struct(struct_info: StructInfo, config: &proc_macro2::Ident) -> proc_macro2::TokenStream { + let StructInfo { + name, + vis, + generic_type, + field_infos, + fields, + from_args, + derive_clone, + } = struct_info; + + let field_types = field_infos.iter().map(|f| &f.ty).collect_vec(); + let length_exprs = field_infos.iter().map(|f| &f.length_expr).collect_vec(); + let prepare_subslices = field_infos + .iter() + .map(|f| &f.prepare_subslice) + .collect_vec(); + let initializers = field_infos.iter().map(|f| &f.initializer).collect_vec(); + + let idents = fields.iter().map(|f| &f.ident).collect_vec(); + + let clone_impl = if derive_clone { + quote! { + #[derive(Clone)] + } + } else { + quote! {} + }; + + quote! { + #clone_impl + #[derive(Debug)] + #vis struct #name <'a, #generic_type> { + #( pub #idents: #field_types ),* + } + + impl<'a, #generic_type> #name<'a, #generic_type> { + pub fn from(#from_args) -> Self { + #( #prepare_subslices )* + Self { + #( #idents: #initializers ),* + } + } + + // returns number of cells in the struct (where each cell has type T) + pub const fn width() -> usize { + 0 #( + #length_exprs )* + } + } + } +} + +// Generate the `from_mut` method for the ColsRef struct +fn make_from_mut(struct_info: StructInfo, config: &proc_macro2::Ident) -> proc_macro2::TokenStream { + let StructInfo { + name, + vis: _, + generic_type, + field_infos: _, + fields, + from_args: _, + derive_clone: _, + } = struct_info; + + let from_mut_impl = fields + .iter() + .map(|f| { + let ident = f.ident.clone().unwrap(); + + let derives_aligned_borrow = f + .attrs + .iter() + .any(|attr| attr.path().is_ident("aligned_borrow")); + + let is_array = matches!(f.ty, syn::Type::Array(_)); + + if is_array { + // calling view() on ArrayViewMut returns an ArrayView + quote! { + other.#ident.view() + } + } else if derives_aligned_borrow { + // implicitly converts a mutable reference to an immutable reference, so leave the + // field value unchanged + quote! { + other.#ident + } + } else if is_columns_struct(&f.ty) { + // lifetime 'b is used in from_mut to allow more flexible lifetime of return value + let cols_ref_type = + get_const_cols_ref_type(&f.ty, &generic_type, parse_quote! { 'b }); + // Recursively call `from_mut` on the ColsRef field + quote! { + <#cols_ref_type>::from_mut::(&other.#ident) + } + } else if is_generic_type(&f.ty, &generic_type) { + // implicitly converts a mutable reference to an immutable reference, so leave the + // field value unchanged + quote! { + &other.#ident + } + } else { + panic!("Unsupported field type: {:?}", f.ty); + } + }) + .collect_vec(); + + let field_idents = fields + .iter() + .map(|f| f.ident.clone().unwrap()) + .collect_vec(); + + let mut_struct_ident = format_ident!("{}Mut", name.to_string()); + let mut_struct_type: syn::Type = parse_quote! { + #mut_struct_ident<'a, #generic_type> + }; + + parse_quote! { + // lifetime 'b is used in from_mut to allow more flexible lifetime of return value + impl<'b, #generic_type> #name<'b, #generic_type> { + pub fn from_mut<'a, C: #config>(other: &'b #mut_struct_type) -> Self + { + Self { + #( #field_idents: #from_mut_impl ),* + } + } + } + } +} + +// Information about a field that is used to generate the ColsRef and ColsRefMut structs +// See the `make_struct` function to see how this information is used +#[derive(Debug, Clone)] +struct FieldInfo { + // type for struct definition + ty: syn::Type, + // an expr calculating the length of the field + length_expr: proc_macro2::TokenStream, + // prepare a subslice of the slice to be used in the 'from' method + prepare_subslice: proc_macro2::TokenStream, + // an expr used in the Self initializer in the 'from' method + // may refer to the subslice declared in prepare_subslice + initializer: proc_macro2::TokenStream, +} + +// Prepare the fields for the const ColsRef struct +fn get_const_cols_ref_fields( + f: &syn::Field, + generic_type: &syn::TypeParam, + const_generics: &[&syn::Ident], +) -> FieldInfo { + let length_var = format_ident!("{}_length", f.ident.clone().unwrap()); + let slice_var = format_ident!("{}_slice", f.ident.clone().unwrap()); + + let derives_aligned_borrow = f + .attrs + .iter() + .any(|attr| attr.path().is_ident("aligned_borrow")); + + let is_array = matches!(f.ty, syn::Type::Array(_)); + + if is_array { + let ArrayInfo { dims, elem_type } = get_array_info(&f.ty, const_generics); + debug_assert!( + !dims.is_empty(), + "Array field must have at least one dimension" + ); + + let ndarray_ident: syn::Ident = format_ident!("ArrayView{}", dims.len()); + let ndarray_type: syn::Type = parse_quote! { + ndarray::#ndarray_ident<'a, #elem_type> + }; + + // dimensions of the array in terms of number of cells + let dim_exprs = dims + .iter() + .map(|d| match d { + // need to prepend C:: for const generic array dimensions + Dimension::ConstGeneric(expr) => quote! { C::#expr }, + Dimension::Other(expr) => quote! { #expr }, + }) + .collect_vec(); + + if derives_aligned_borrow { + let length_expr = quote! { + <#elem_type>::width() #(* #dim_exprs)* + }; + + FieldInfo { + ty: parse_quote! { + #ndarray_type + }, + length_expr: length_expr.clone(), + prepare_subslice: quote! { + let (#slice_var, slice) = slice.split_at(#length_expr); + let #slice_var: &[#elem_type] = unsafe { &*(#slice_var as *const [T] as *const [#elem_type]) }; + let #slice_var = ndarray::#ndarray_ident::from_shape( ( #(#dim_exprs),* ) , #slice_var).unwrap(); + }, + initializer: quote! { + #slice_var + }, + } + } else if is_columns_struct(&elem_type) { + panic!("Arrays of columns structs are currently not supported"); + } else if is_generic_type(&elem_type, generic_type) { + let length_expr = quote! { + 1 #(* #dim_exprs)* + }; + FieldInfo { + ty: parse_quote! { + #ndarray_type + }, + length_expr: length_expr.clone(), + prepare_subslice: quote! { + let (#slice_var, slice) = slice.split_at(#length_expr); + let #slice_var = ndarray::#ndarray_ident::from_shape( ( #(#dim_exprs),* ) , #slice_var).unwrap(); + }, + initializer: quote! { + #slice_var + }, + } + } else { + panic!("Unsupported field type: {:?}", f.ty); + } + } else if derives_aligned_borrow { + // treat the field as a struct that derives AlignedBorrow (and doesn't depend on the config) + let f_ty = &f.ty; + FieldInfo { + ty: parse_quote! { + &'a #f_ty + }, + length_expr: quote! { + <#f_ty>::width() + }, + prepare_subslice: quote! { + let #length_var = <#f_ty>::width(); + let (#slice_var, slice) = slice.split_at(#length_var); + }, + initializer: quote! { + { + use core::borrow::Borrow; + #slice_var.borrow() + } + }, + } + } else if is_columns_struct(&f.ty) { + let const_cols_ref_type = get_const_cols_ref_type(&f.ty, generic_type, parse_quote! { 'a }); + FieldInfo { + ty: parse_quote! { + #const_cols_ref_type + }, + length_expr: quote! { + <#const_cols_ref_type>::width::() + }, + prepare_subslice: quote! { + let #length_var = <#const_cols_ref_type>::width::(); + let (#slice_var, slice) = slice.split_at(#length_var); + let #slice_var = <#const_cols_ref_type>::from::(#slice_var); + }, + initializer: quote! { + #slice_var + }, + } + } else if is_generic_type(&f.ty, generic_type) { + FieldInfo { + ty: parse_quote! { + &'a #generic_type + }, + length_expr: quote! { + 1 + }, + prepare_subslice: quote! { + let #length_var = 1; + let (#slice_var, slice) = slice.split_at(#length_var); + }, + initializer: quote! { + &#slice_var[0] + }, + } + } else { + panic!("Unsupported field type: {:?}", f.ty); + } +} + +// Prepare the fields for the mut ColsRef struct +fn get_mut_cols_ref_fields( + f: &syn::Field, + generic_type: &syn::TypeParam, + const_generics: &[&syn::Ident], +) -> FieldInfo { + let length_var = format_ident!("{}_length", f.ident.clone().unwrap()); + let slice_var = format_ident!("{}_slice", f.ident.clone().unwrap()); + + let derives_aligned_borrow = f + .attrs + .iter() + .any(|attr| attr.path().is_ident("aligned_borrow")); + + let is_array = matches!(f.ty, syn::Type::Array(_)); + + if is_array { + let ArrayInfo { dims, elem_type } = get_array_info(&f.ty, const_generics); + debug_assert!( + !dims.is_empty(), + "Array field must have at least one dimension" + ); + + let ndarray_ident: syn::Ident = format_ident!("ArrayViewMut{}", dims.len()); + let ndarray_type: syn::Type = parse_quote! { + ndarray::#ndarray_ident<'a, #elem_type> + }; + + // dimensions of the array in terms of number of cells + let dim_exprs = dims + .iter() + .map(|d| match d { + // need to prepend C:: for const generic array dimensions + Dimension::ConstGeneric(expr) => quote! { C::#expr }, + Dimension::Other(expr) => quote! { #expr }, + }) + .collect_vec(); + + if derives_aligned_borrow { + let length_expr = quote! { + <#elem_type>::width() #(* #dim_exprs)* + }; + + FieldInfo { + ty: parse_quote! { + #ndarray_type + }, + length_expr: length_expr.clone(), + prepare_subslice: quote! { + let (#slice_var, slice) = slice.split_at_mut (#length_expr); + let #slice_var: &mut [#elem_type] = unsafe { &mut *(#slice_var as *mut [T] as *mut [#elem_type]) }; + let #slice_var = ndarray::#ndarray_ident::from_shape( ( #(#dim_exprs),* ) , #slice_var).unwrap(); + }, + initializer: quote! { + #slice_var + }, + } + } else if is_columns_struct(&elem_type) { + panic!("Arrays of columns structs are currently not supported"); + } else if is_generic_type(&elem_type, generic_type) { + let length_expr = quote! { + 1 #(* #dim_exprs)* + }; + FieldInfo { + ty: parse_quote! { + #ndarray_type + }, + length_expr: length_expr.clone(), + prepare_subslice: quote! { + let (#slice_var, slice) = slice.split_at_mut(#length_expr); + let #slice_var = ndarray::#ndarray_ident::from_shape( ( #(#dim_exprs),* ) , #slice_var).unwrap(); + }, + initializer: quote! { + #slice_var + }, + } + } else { + panic!("Unsupported field type: {:?}", f.ty); + } + } else if derives_aligned_borrow { + // treat the field as a struct that derives AlignedBorrow (and doesn't depend on the config) + let f_ty = &f.ty; + FieldInfo { + ty: parse_quote! { + &'a mut #f_ty + }, + length_expr: quote! { + <#f_ty>::width() + }, + prepare_subslice: quote! { + let #length_var = <#f_ty>::width(); + let (#slice_var, slice) = slice.split_at_mut(#length_var); + }, + initializer: quote! { + { + use core::borrow::BorrowMut; + #slice_var.borrow_mut() + } + }, + } + } else if is_columns_struct(&f.ty) { + let mut_cols_ref_type = get_mut_cols_ref_type(&f.ty, generic_type); + FieldInfo { + ty: parse_quote! { + #mut_cols_ref_type + }, + length_expr: quote! { + <#mut_cols_ref_type>::width::() + }, + prepare_subslice: quote! { + let #length_var = <#mut_cols_ref_type>::width::(); + let (#slice_var, slice) = slice.split_at_mut(#length_var); + let #slice_var = <#mut_cols_ref_type>::from::(#slice_var); + }, + initializer: quote! { + #slice_var + }, + } + } else if is_generic_type(&f.ty, generic_type) { + FieldInfo { + ty: parse_quote! { + &'a mut #generic_type + }, + length_expr: quote! { + 1 + }, + prepare_subslice: quote! { + let #length_var = 1; + let (#slice_var, slice) = slice.split_at_mut(#length_var); + }, + initializer: quote! { + &mut #slice_var[0] + }, + } + } else { + panic!("Unsupported field type: {:?}", f.ty); + } +} + +// Helper functions + +fn is_columns_struct(ty: &syn::Type) -> bool { + if let syn::Type::Path(type_path) = ty { + type_path + .path + .segments + .iter() + .last() + .map(|s| s.ident.to_string().ends_with("Cols")) + .unwrap_or(false) + } else { + false + } +} + +// If 'ty' is a struct that derives ColsRef, return the ColsRef struct type +// Otherwise, return None +fn get_const_cols_ref_type( + ty: &syn::Type, + generic_type: &syn::TypeParam, + lifetime: syn::Lifetime, +) -> syn::TypePath { + if !is_columns_struct(ty) { + panic!("Expected a columns struct, got {:?}", ty); + } + + if let syn::Type::Path(type_path) = ty { + let s = type_path.path.segments.iter().last().unwrap(); + if s.ident.to_string().ends_with("Cols") { + let const_cols_ref_ident = format_ident!("{}Ref", s.ident); + let const_cols_ref_type = parse_quote! { + #const_cols_ref_ident<#lifetime, #generic_type> + }; + const_cols_ref_type + } else { + panic!("is_columns_struct returned true for type {:?} but the last segment is not a columns struct", ty); + } + } else { + panic!( + "is_columns_struct returned true but the type {:?} is not a path", + ty + ); + } +} + +// If 'ty' is a struct that derives ColsRef, return the ColsRefMut struct type +// Otherwise, return None +fn get_mut_cols_ref_type(ty: &syn::Type, generic_type: &syn::TypeParam) -> syn::TypePath { + if !is_columns_struct(ty) { + panic!("Expected a columns struct, got {:?}", ty); + } + + if let syn::Type::Path(type_path) = ty { + let s = type_path.path.segments.iter().last().unwrap(); + if s.ident.to_string().ends_with("Cols") { + let mut_cols_ref_ident = format_ident!("{}RefMut", s.ident); + let mut_cols_ref_type = parse_quote! { + #mut_cols_ref_ident<'a, #generic_type> + }; + mut_cols_ref_type + } else { + panic!("is_columns_struct returned true for type {:?} but the last segment is not a columns struct", ty); + } + } else { + panic!( + "is_columns_struct returned true but the type {:?} is not a path", + ty + ); + } +} + +fn is_generic_type(ty: &syn::Type, generic_type: &syn::TypeParam) -> bool { + if let syn::Type::Path(type_path) = ty { + if type_path.path.segments.len() == 1 { + type_path + .path + .segments + .iter() + .last() + .map(|s| s.ident == generic_type.ident) + .unwrap_or(false) + } else { + false + } + } else { + false + } +} + +// Type of array dimension +enum Dimension { + ConstGeneric(syn::Expr), + Other(syn::Expr), +} + +// Describes a nested array +struct ArrayInfo { + dims: Vec, + elem_type: syn::Type, +} + +fn get_array_info(ty: &syn::Type, const_generics: &[&syn::Ident]) -> ArrayInfo { + let dims = get_dims(ty, const_generics); + let elem_type = get_elem_type(ty); + ArrayInfo { dims, elem_type } +} + +fn get_elem_type(ty: &syn::Type) -> syn::Type { + match ty { + syn::Type::Array(array) => get_elem_type(array.elem.as_ref()), + syn::Type::Path(_) => ty.clone(), + _ => panic!("Unsupported type: {:?}", ty), + } +} + +// Get a vector of the dimensions of the array +// Each dimension is either a constant generic or a literal integer value +fn get_dims(ty: &syn::Type, const_generics: &[&syn::Ident]) -> Vec { + get_dims_impl(ty, const_generics) + .into_iter() + .rev() + .collect() +} + +fn get_dims_impl(ty: &syn::Type, const_generics: &[&syn::Ident]) -> Vec { + match ty { + syn::Type::Array(array) => { + let mut dims = get_dims_impl(array.elem.as_ref(), const_generics); + match &array.len { + syn::Expr::Path(syn::ExprPath { path, .. }) => { + let len_ident = path.get_ident(); + if len_ident.is_some() && const_generics.contains(&len_ident.unwrap()) { + dims.push(Dimension::ConstGeneric(array.len.clone())); + } else { + dims.push(Dimension::Other(array.len.clone())); + } + } + syn::Expr::Lit(expr_lit) => dims.push(Dimension::Other(expr_lit.clone().into())), + _ => panic!("Unsupported array length type"), + } + dims + } + syn::Type::Path(_) => Vec::new(), + _ => panic!("Unsupported field type"), + } +} diff --git a/crates/circuits/primitives/derive/src/lib.rs b/crates/circuits/primitives/derive/src/lib.rs index f9e384290f..2f5dab0a4c 100644 --- a/crates/circuits/primitives/derive/src/lib.rs +++ b/crates/circuits/primitives/derive/src/lib.rs @@ -7,6 +7,9 @@ use proc_macro::TokenStream; use quote::quote; use syn::{parse_macro_input, Data, DeriveInput, Fields, GenericParam, LitStr, Meta}; +mod cols_ref; +use cols_ref::cols_ref_impl; + #[proc_macro_derive(AlignedBorrow)] pub fn aligned_borrow_derive(input: TokenStream) -> TokenStream { let ast = parse_macro_input!(input as DeriveInput); @@ -443,3 +446,25 @@ pub fn bytes_stateful_derive(input: TokenStream) -> TokenStream { _ => unimplemented!(), } } + +#[proc_macro_derive(ColsRef, attributes(aligned_borrow, config))] +pub fn cols_ref_derive(input: TokenStream) -> TokenStream { + let derive_input: DeriveInput = parse_macro_input!(input as DeriveInput); + + let config = derive_input + .attrs + .iter() + .find(|attr| attr.path().is_ident("config")); + if config.is_none() { + return syn::Error::new(derive_input.ident.span(), "Config attribute is required") + .to_compile_error() + .into(); + } + let config: proc_macro2::Ident = config + .unwrap() + .parse_args() + .expect("Failed to parse config"); + + let res = cols_ref_impl(derive_input, config); + res.into() +} diff --git a/crates/circuits/primitives/derive/tests/example.rs b/crates/circuits/primitives/derive/tests/example.rs new file mode 100644 index 0000000000..58bac9e26c --- /dev/null +++ b/crates/circuits/primitives/derive/tests/example.rs @@ -0,0 +1,87 @@ +use openvm_circuit_primitives_derive::ColsRef; + +pub trait ExampleConfig { + const N: usize; +} +pub struct ExampleConfigImplA; +impl ExampleConfig for ExampleConfigImplA { + const N: usize = 5; +} +pub struct ExampleConfigImplB; +impl ExampleConfig for ExampleConfigImplB { + const N: usize = 10; +} + +#[allow(dead_code)] +#[derive(ColsRef)] +#[config(ExampleConfig)] +struct ExampleCols { + arr: [T; N], + sum: T, +} + +#[test] +fn example() { + let input = [1, 2, 3, 4, 5, 15]; + let test: ExampleColsRef = ExampleColsRef::from::(&input); + println!("{}, {}", test.arr, test.sum); +} + +/* + * For reference, this is what the ColsRef macro expands to. + * The `cargo expand` tool is helpful for understanding how the ColsRef macro works. + * See https://github.com/dtolnay/cargo-expand + +#[derive(Debug, Clone)] +struct ExampleColsRef<'a, T> { + pub arr: ndarray::ArrayView1<'a, T>, + pub sum: &'a T, +} + +impl<'a, T> ExampleColsRef<'a, T> { + pub fn from(slice: &'a [T]) -> Self { + let (arr_slice, slice) = slice.split_at(1 * C::N); + let arr_slice = ndarray::ArrayView1::from_shape((C::N), arr_slice).unwrap(); + let sum_length = 1; + let (sum_slice, slice) = slice.split_at(sum_length); + Self { + arr: arr_slice, + sum: &sum_slice[0], + } + } + pub const fn width() -> usize { + 0 + 1 * C::N + 1 + } +} + +impl<'b, T> ExampleColsRef<'b, T> { + pub fn from_mut<'a, C: ExampleConfig>(other: &'b ExampleColsRefMut<'a, T>) -> Self { + Self { + arr: other.arr.view(), + sum: &other.sum, + } + } +} + +#[derive(Debug)] +struct ExampleColsRefMut<'a, T> { + pub arr: ndarray::ArrayViewMut1<'a, T>, + pub sum: &'a mut T, +} + +impl<'a, T> ExampleColsRefMut<'a, T> { + pub fn from(slice: &'a mut [T]) -> Self { + let (arr_slice, slice) = slice.split_at_mut(1 * C::N); + let arr_slice = ndarray::ArrayViewMut1::from_shape((C::N), arr_slice).unwrap(); + let sum_length = 1; + let (sum_slice, slice) = slice.split_at_mut(sum_length); + Self { + arr: arr_slice, + sum: &mut sum_slice[0], + } + } + pub const fn width() -> usize { + 0 + 1 * C::N + 1 + } +} +*/ diff --git a/crates/circuits/primitives/derive/tests/test_cols_ref.rs b/crates/circuits/primitives/derive/tests/test_cols_ref.rs new file mode 100644 index 0000000000..6bad0c4f9f --- /dev/null +++ b/crates/circuits/primitives/derive/tests/test_cols_ref.rs @@ -0,0 +1,299 @@ +use openvm_circuit_primitives_derive::{AlignedBorrow, ColsRef}; + +pub trait TestConfig { + const N: usize; + const M: usize; +} +pub struct TestConfigImpl; +impl TestConfig for TestConfigImpl { + const N: usize = 5; + const M: usize = 2; +} + +#[allow(dead_code)] // TestCols isn't actually used in the code. silence clippy warning +#[derive(ColsRef)] +#[config(TestConfig)] +struct TestCols { + single_field_element: T, + array_of_t: [T; N], + nested_array_of_t: [[T; N]; N], + cols_struct: TestSubCols, + #[aligned_borrow] + array_of_aligned_borrow: [TestAlignedBorrow; N], + #[aligned_borrow] + nested_array_of_aligned_borrow: [[TestAlignedBorrow; N]; N], +} + +#[allow(dead_code)] // TestSubCols isn't actually used in the code. silence clippy warning +#[derive(ColsRef, Debug)] +#[config(TestConfig)] +struct TestSubCols { + // TestSubCols can have fields of any type that TestCols can have + a: T, + b: [T; M], + #[aligned_borrow] + c: TestAlignedBorrow, +} + +#[derive(AlignedBorrow, Debug)] +struct TestAlignedBorrow { + a: T, + b: [T; 5], +} + +#[test] +fn test_cols_ref() { + assert_eq!( + TestColsRef::::width::(), + TestColsRefMut::::width::() + ); + const WIDTH: usize = TestColsRef::::width::(); + let mut input = vec![0; WIDTH]; + let mut cols: TestColsRefMut = TestColsRefMut::from::(&mut input); + + *cols.single_field_element = 1; + cols.array_of_t[0] = 2; + cols.nested_array_of_t[[0, 0]] = 3; + *cols.cols_struct.a = 4; + cols.cols_struct.b[0] = 5; + cols.cols_struct.c.a = 6; + cols.cols_struct.c.b[0] = 7; + cols.array_of_aligned_borrow[0].a = 8; + cols.array_of_aligned_borrow[0].b[0] = 9; + cols.nested_array_of_aligned_borrow[[0, 0]].a = 10; + cols.nested_array_of_aligned_borrow[[0, 0]].b[0] = 11; + + let cols: TestColsRef = TestColsRef::from::(&input); + println!("{:?}", cols); + assert_eq!(*cols.single_field_element, 1); + assert_eq!(cols.array_of_t[0], 2); + assert_eq!(cols.nested_array_of_t[[0, 0]], 3); + assert_eq!(*cols.cols_struct.a, 4); + assert_eq!(cols.cols_struct.b[0], 5); + assert_eq!(cols.cols_struct.c.a, 6); + assert_eq!(cols.cols_struct.c.b[0], 7); + assert_eq!(cols.array_of_aligned_borrow[0].a, 8); + assert_eq!(cols.array_of_aligned_borrow[0].b[0], 9); + assert_eq!(cols.nested_array_of_aligned_borrow[[0, 0]].a, 10); + assert_eq!(cols.nested_array_of_aligned_borrow[[0, 0]].b[0], 11); +} + +/* + * For reference, this is what the ColsRef macro expands to. + * The `cargo expand` tool is helpful for understanding how the ColsRef macro works. + * See https://github.com/dtolnay/cargo-expand + +#[derive(Debug, Clone)] +struct TestColsRef<'a, T> { + pub single_field_element: &'a T, + pub array_of_t: ndarray::ArrayView1<'a, T>, + pub nested_array_of_t: ndarray::ArrayView2<'a, T>, + pub cols_struct: TestSubColsRef<'a, T>, + pub array_of_aligned_borrow: ndarray::ArrayView1<'a, TestAlignedBorrow>, + pub nested_array_of_aligned_borrow: ndarray::ArrayView2<'a, TestAlignedBorrow>, +} + +impl<'a, T> TestColsRef<'a, T> { + pub fn from(slice: &'a [T]) -> Self { + let single_field_element_length = 1; + let (single_field_element_slice, slice) = slice + .split_at(single_field_element_length); + let (array_of_t_slice, slice) = slice.split_at(1 * C::N); + let array_of_t_slice = ndarray::ArrayView1::from_shape((C::N), array_of_t_slice) + .unwrap(); + let (nested_array_of_t_slice, slice) = slice.split_at(1 * C::N * C::N); + let nested_array_of_t_slice = ndarray::ArrayView2::from_shape( + (C::N, C::N), + nested_array_of_t_slice, + ) + .unwrap(); + let cols_struct_length = >::width::(); + let (cols_struct_slice, slice) = slice.split_at(cols_struct_length); + let cols_struct_slice = >::from::(cols_struct_slice); + let (array_of_aligned_borrow_slice, slice) = slice + .split_at(>::width() * C::N); + let array_of_aligned_borrow_slice: &[TestAlignedBorrow] = unsafe { + &*(array_of_aligned_borrow_slice as *const [T] + as *const [TestAlignedBorrow]) + }; + let array_of_aligned_borrow_slice = ndarray::ArrayView1::from_shape( + (C::N), + array_of_aligned_borrow_slice, + ) + .unwrap(); + let (nested_array_of_aligned_borrow_slice, slice) = slice + .split_at(>::width() * C::N * C::N); + let nested_array_of_aligned_borrow_slice: &[TestAlignedBorrow] = unsafe { + &*(nested_array_of_aligned_borrow_slice as *const [T] + as *const [TestAlignedBorrow]) + }; + let nested_array_of_aligned_borrow_slice = ndarray::ArrayView2::from_shape( + (C::N, C::N), + nested_array_of_aligned_borrow_slice, + ) + .unwrap(); + Self { + single_field_element: &single_field_element_slice[0], + array_of_t: array_of_t_slice, + nested_array_of_t: nested_array_of_t_slice, + cols_struct: cols_struct_slice, + array_of_aligned_borrow: array_of_aligned_borrow_slice, + nested_array_of_aligned_borrow: nested_array_of_aligned_borrow_slice, + } + } + pub const fn width() -> usize { + 0 + 1 + 1 * C::N + 1 * C::N * C::N + >::width::() + + >::width() * C::N + + >::width() * C::N * C::N + } +} + +impl<'b, T> TestColsRef<'b, T> { + pub fn from_mut<'a, C: TestConfig>(other: &'b TestColsRefMut<'a, T>) -> Self { + Self { + single_field_element: &other.single_field_element, + array_of_t: other.array_of_t.view(), + nested_array_of_t: other.nested_array_of_t.view(), + cols_struct: >::from_mut::(&other.cols_struct), + array_of_aligned_borrow: other.array_of_aligned_borrow.view(), + nested_array_of_aligned_borrow: other.nested_array_of_aligned_borrow.view(), + } + } +} + +#[derive(Debug)] +struct TestColsRefMut<'a, T> { + pub single_field_element: &'a mut T, + pub array_of_t: ndarray::ArrayViewMut1<'a, T>, + pub nested_array_of_t: ndarray::ArrayViewMut2<'a, T>, + pub cols_struct: TestSubColsRefMut<'a, T>, + pub array_of_aligned_borrow: ndarray::ArrayViewMut1<'a, TestAlignedBorrow>, + pub nested_array_of_aligned_borrow: ndarray::ArrayViewMut2<'a, TestAlignedBorrow>, +} + +impl<'a, T> TestColsRefMut<'a, T> { + pub fn from(slice: &'a mut [T]) -> Self { + let single_field_element_length = 1; + let (single_field_element_slice, slice) = slice + .split_at_mut(single_field_element_length); + let (array_of_t_slice, slice) = slice.split_at_mut(1 * C::N); + let array_of_t_slice = ndarray::ArrayViewMut1::from_shape( + (C::N), + array_of_t_slice, + ) + .unwrap(); + let (nested_array_of_t_slice, slice) = slice.split_at_mut(1 * C::N * C::N); + let nested_array_of_t_slice = ndarray::ArrayViewMut2::from_shape( + (C::N, C::N), + nested_array_of_t_slice, + ) + .unwrap(); + let cols_struct_length = >::width::(); + let (cols_struct_slice, slice) = slice.split_at_mut(cols_struct_length); + let cols_struct_slice = >::from::(cols_struct_slice); + let (array_of_aligned_borrow_slice, slice) = slice + .split_at_mut(>::width() * C::N); + let array_of_aligned_borrow_slice: &mut [TestAlignedBorrow] = unsafe { + &mut *(array_of_aligned_borrow_slice as *mut [T] + as *mut [TestAlignedBorrow]) + }; + let array_of_aligned_borrow_slice = ndarray::ArrayViewMut1::from_shape( + (C::N), + array_of_aligned_borrow_slice, + ) + .unwrap(); + let (nested_array_of_aligned_borrow_slice, slice) = slice + .split_at_mut(>::width() * C::N * C::N); + let nested_array_of_aligned_borrow_slice: &mut [TestAlignedBorrow] = unsafe { + &mut *(nested_array_of_aligned_borrow_slice as *mut [T] + as *mut [TestAlignedBorrow]) + }; + let nested_array_of_aligned_borrow_slice = ndarray::ArrayViewMut2::from_shape( + (C::N, C::N), + nested_array_of_aligned_borrow_slice, + ) + .unwrap(); + Self { + single_field_element: &mut single_field_element_slice[0], + array_of_t: array_of_t_slice, + nested_array_of_t: nested_array_of_t_slice, + cols_struct: cols_struct_slice, + array_of_aligned_borrow: array_of_aligned_borrow_slice, + nested_array_of_aligned_borrow: nested_array_of_aligned_borrow_slice, + } + } + pub const fn width() -> usize { + 0 + 1 + 1 * C::N + 1 * C::N * C::N + >::width::() + + >::width() * C::N + + >::width() * C::N * C::N + } +} + +#[derive(Debug, Clone)] +struct TestSubColsRef<'a, T> { + pub a: &'a T, + pub b: ndarray::ArrayView1<'a, T>, + pub c: &'a TestAlignedBorrow, +} + +impl<'a, T> TestSubColsRef<'a, T> { + pub fn from(slice: &'a [T]) -> Self { + let a_length = 1; + let (a_slice, slice) = slice.split_at(a_length); + let (b_slice, slice) = slice.split_at(1 * C::M); + let b_slice = ndarray::ArrayView1::from_shape((C::M), b_slice).unwrap(); + let c_length = >::width(); + let (c_slice, slice) = slice.split_at(c_length); + Self { + a: &a_slice[0], + b: b_slice, + c: { + use core::borrow::Borrow; + c_slice.borrow() + }, + } + } + pub const fn width() -> usize { + 0 + 1 + 1 * C::M + >::width() + } +} + +impl<'b, T> TestSubColsRef<'b, T> { + pub fn from_mut<'a, C: TestConfig>(other: &'b TestSubColsRefMut<'a, T>) -> Self { + Self { + a: &other.a, + b: other.b.view(), + c: other.c, + } + } +} + +#[derive(Debug)] +struct TestSubColsRefMut<'a, T> { + pub a: &'a mut T, + pub b: ndarray::ArrayViewMut1<'a, T>, + pub c: &'a mut TestAlignedBorrow, +} + +impl<'a, T> TestSubColsRefMut<'a, T> { + pub fn from(slice: &'a mut [T]) -> Self { + let a_length = 1; + let (a_slice, slice) = slice.split_at_mut(a_length); + let (b_slice, slice) = slice.split_at_mut(1 * C::M); + let b_slice = ndarray::ArrayViewMut1::from_shape((C::M), b_slice).unwrap(); + let c_length = >::width(); + let (c_slice, slice) = slice.split_at_mut(c_length); + Self { + a: &mut a_slice[0], + b: b_slice, + c: { + use core::borrow::BorrowMut; + c_slice.borrow_mut() + }, + } + } + pub const fn width() -> usize { + 0 + 1 + 1 * C::M + >::width() + } +} +*/ diff --git a/crates/circuits/sha256-air/Cargo.toml b/crates/circuits/sha2-air/Cargo.toml similarity index 77% rename from crates/circuits/sha256-air/Cargo.toml rename to crates/circuits/sha2-air/Cargo.toml index c376a1ffdd..9758a10e6e 100644 --- a/crates/circuits/sha256-air/Cargo.toml +++ b/crates/circuits/sha2-air/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "openvm-sha256-air" +name = "openvm-sha2-air" version.workspace = true authors.workspace = true edition.workspace = true @@ -7,8 +7,11 @@ edition.workspace = true [dependencies] openvm-circuit-primitives = { workspace = true } openvm-stark-backend = { workspace = true } +openvm-circuit-primitives-derive = { workspace = true } sha2 = { version = "0.10", features = ["compress"] } rand.workspace = true +ndarray.workspace = true +num_enum = { workspace = true } [dev-dependencies] openvm-stark-sdk = { workspace = true } diff --git a/crates/circuits/sha2-air/src/air.rs b/crates/circuits/sha2-air/src/air.rs new file mode 100644 index 0000000000..9f110480fd --- /dev/null +++ b/crates/circuits/sha2-air/src/air.rs @@ -0,0 +1,694 @@ +use std::{cmp::max, iter::once, marker::PhantomData}; + +use ndarray::s; +use openvm_circuit_primitives::{ + bitwise_op_lookup::BitwiseOperationLookupBus, + encoder::Encoder, + utils::{not, select}, + SubAir, +}; +use openvm_stark_backend::{ + interaction::{BusIndex, InteractionBuilder, PermutationCheckBus}, + p3_air::{AirBuilder, BaseAir}, + p3_field::{Field, FieldAlgebra}, + p3_matrix::Matrix, +}; + +use super::{ + big_sig0_field, big_sig1_field, ch_field, compose, maj_field, small_sig0_field, + small_sig1_field, +}; +use crate::{ + constraint_word_addition, word_into_u16_limbs, Sha2Config, ShaDigestColsRef, ShaRoundColsRef, +}; + +/// Expects the message to be padded to a multiple of C::BLOCK_WORDS * C::WORD_BITS bits +#[derive(Clone, Debug)] +pub struct Sha2Air { + pub bitwise_lookup_bus: BitwiseOperationLookupBus, + pub row_idx_encoder: Encoder, + /// Internal bus for self-interactions in this AIR. + bus: PermutationCheckBus, + _phantom: PhantomData, +} + +impl Sha2Air { + pub fn new(bitwise_lookup_bus: BitwiseOperationLookupBus, self_bus_idx: BusIndex) -> Self { + Self { + bitwise_lookup_bus, + row_idx_encoder: Encoder::new(C::ROWS_PER_BLOCK + 1, 2, false), /* + 1 for dummy + * (padding) rows */ + bus: PermutationCheckBus::new(self_bus_idx), + _phantom: PhantomData, + } + } +} + +impl BaseAir for Sha2Air { + fn width(&self) -> usize { + max(C::ROUND_WIDTH, C::DIGEST_WIDTH) + } +} + +impl SubAir for Sha2Air { + /// The start column for the sub-air to use + type AirContext<'a> + = usize + where + Self: 'a, + AB: 'a, + ::Var: 'a, + ::Expr: 'a; + + fn eval<'a>(&'a self, builder: &'a mut AB, start_col: Self::AirContext<'a>) + where + ::Var: 'a, + ::Expr: 'a, + { + self.eval_row(builder, start_col); + self.eval_transitions(builder, start_col); + } +} + +impl Sha2Air { + /// Implements the single row constraints (i.e. imposes constraints only on local) + /// Implements some sanity constraints on the row index, flags, and work variables + fn eval_row(&self, builder: &mut AB, start_col: usize) { + let main = builder.main(); + let local = main.row_slice(0); + + // Doesn't matter which column struct we use here as we are only interested in the common + // columns + let local_cols: ShaDigestColsRef = + ShaDigestColsRef::from::(&local[start_col..start_col + C::DIGEST_WIDTH]); + let flags = &local_cols.flags; + builder.assert_bool(*flags.is_round_row); + builder.assert_bool(*flags.is_first_4_rows); + builder.assert_bool(*flags.is_digest_row); + builder.assert_bool(*flags.is_round_row + *flags.is_digest_row); + builder.assert_bool(*flags.is_last_block); + + self.row_idx_encoder + .eval(builder, local_cols.flags.row_idx.to_slice().unwrap()); + builder.assert_one(self.row_idx_encoder.contains_flag_range::( + local_cols.flags.row_idx.to_slice().unwrap(), + 0..=C::ROWS_PER_BLOCK, + )); + builder.assert_eq( + self.row_idx_encoder + .contains_flag_range::(local_cols.flags.row_idx.to_slice().unwrap(), 0..=3), + *flags.is_first_4_rows, + ); + builder.assert_eq( + self.row_idx_encoder.contains_flag_range::( + local_cols.flags.row_idx.to_slice().unwrap(), + 0..=C::ROUND_ROWS - 1, + ), + *flags.is_round_row, + ); + builder.assert_eq( + self.row_idx_encoder.contains_flag::( + local_cols.flags.row_idx.to_slice().unwrap(), + &[C::ROUND_ROWS], + ), + *flags.is_digest_row, + ); + // If padding row we want the row_idx to be C::ROWS_PER_BLOCK + builder.assert_eq( + self.row_idx_encoder.contains_flag::( + local_cols.flags.row_idx.to_slice().unwrap(), + &[C::ROWS_PER_BLOCK], + ), + flags.is_padding_row(), + ); + + // Constrain a, e, being composed of bits: we make sure a and e are always in the same place + // in the trace matrix Note: this has to be true for every row, even padding rows + for i in 0..C::ROUNDS_PER_ROW { + for j in 0..C::WORD_BITS { + builder.assert_bool(local_cols.hash.a[[i, j]]); + builder.assert_bool(local_cols.hash.e[[i, j]]); + } + } + } + + /// Implements constraints for a digest row that ensure proper state transitions between blocks + /// This validates that: + /// The work variables are correctly initialized for the next message block + /// For the last message block, the initial state matches SHA_H constants + fn eval_digest_row( + &self, + builder: &mut AB, + local: ShaRoundColsRef, + next: ShaDigestColsRef, + ) { + // Check that if this is the last row of a message or an inpadding row, the hash should be + // the [SHA_H] + for i in 0..C::ROUNDS_PER_ROW { + let a = next.hash.a.row(i).mapv(|x| x.into()).to_vec(); + let e = next.hash.e.row(i).mapv(|x| x.into()).to_vec(); + + for j in 0..C::WORD_U16S { + let a_limb = compose::(&a[j * 16..(j + 1) * 16], 1); + let e_limb = compose::(&e[j * 16..(j + 1) * 16], 1); + + // If it is a padding row or the last row of a message, the `hash` should be the + // [SHA_H] + builder + .when( + next.flags.is_padding_row() + + *next.flags.is_last_block * *next.flags.is_digest_row, + ) + .assert_eq( + a_limb, + AB::Expr::from_canonical_u32( + word_into_u16_limbs::(C::get_h()[C::ROUNDS_PER_ROW - i - 1])[j], + ), + ); + + builder + .when( + next.flags.is_padding_row() + + *next.flags.is_last_block * *next.flags.is_digest_row, + ) + .assert_eq( + e_limb, + AB::Expr::from_canonical_u32( + word_into_u16_limbs::(C::get_h()[C::ROUNDS_PER_ROW - i + 3])[j], + ), + ); + } + } + + // Check if last row of a non-last block, the `hash` should be equal to the final hash of + // the current block + for i in 0..C::ROUNDS_PER_ROW { + let prev_a = next.hash.a.row(i).mapv(|x| x.into()).to_vec(); + let prev_e = next.hash.e.row(i).mapv(|x| x.into()).to_vec(); + let cur_a = next + .final_hash + .row(C::ROUNDS_PER_ROW - i - 1) + .mapv(|x| x.into()); + + let cur_e = next + .final_hash + .row(C::ROUNDS_PER_ROW - i + 3) + .mapv(|x| x.into()); + for j in 0..C::WORD_U8S { + let prev_a_limb = compose::(&prev_a[j * 8..(j + 1) * 8], 1); + let prev_e_limb = compose::(&prev_e[j * 8..(j + 1) * 8], 1); + + builder + .when(not(*next.flags.is_last_block) * *next.flags.is_digest_row) + .assert_eq(prev_a_limb, cur_a[j].clone()); + + builder + .when(not(*next.flags.is_last_block) * *next.flags.is_digest_row) + .assert_eq(prev_e_limb, cur_e[j].clone()); + } + } + + // Assert that the previous hash + work vars == final hash. + // That is, `next.prev_hash[i] + local.work_vars[i] == next.final_hash[i]` + // where addition is done modulo 2^32 + for i in 0..C::HASH_WORDS { + let mut carry = AB::Expr::ZERO; + for j in 0..C::WORD_U16S { + let work_var_limb = if i < C::ROUNDS_PER_ROW { + compose::( + local + .work_vars + .a + .slice(s![C::ROUNDS_PER_ROW - 1 - i, j * 16..(j + 1) * 16]) + .as_slice() + .unwrap(), + 1, + ) + } else { + compose::( + local + .work_vars + .e + .slice(s![C::ROUNDS_PER_ROW + 3 - i, j * 16..(j + 1) * 16]) + .as_slice() + .unwrap(), + 1, + ) + }; + let final_hash_limb = compose::( + next.final_hash + .slice(s![i, j * 2..(j + 1) * 2]) + .as_slice() + .unwrap(), + 8, + ); + + carry = AB::Expr::from(AB::F::from_canonical_u32(1 << 16).inverse()) + * (next.prev_hash[[i, j]] + work_var_limb + carry - final_hash_limb); + builder + .when(*next.flags.is_digest_row) + .assert_bool(carry.clone()); + } + // constrain the final hash limbs two at a time since we can do two checks per + // interaction + for chunk in next.final_hash.row(i).as_slice().unwrap().chunks(2) { + self.bitwise_lookup_bus + .send_range(chunk[0], chunk[1]) + .eval(builder, *next.flags.is_digest_row); + } + } + } + + fn eval_transitions(&self, builder: &mut AB, start_col: usize) { + let main = builder.main(); + let local = main.row_slice(0); + let next = main.row_slice(1); + + // Doesn't matter what column structs we use here + let local_cols: ShaRoundColsRef = + ShaRoundColsRef::from::(&local[start_col..start_col + C::ROUND_WIDTH]); + let next_cols: ShaRoundColsRef = + ShaRoundColsRef::from::(&next[start_col..start_col + C::ROUND_WIDTH]); + + let local_is_padding_row = local_cols.flags.is_padding_row(); + // Note that there will always be a padding row in the trace since the unpadded height is a + // multiple of 17 (SHA-256) or 21 (SHA-512, SHA-384). So the next row is padding iff the + // current block is the last block in the trace. + let next_is_padding_row = next_cols.flags.is_padding_row(); + + // We check that the very last block has `is_last_block` set to true, which guarantees that + // there is at least one complete message. If other digest rows have `is_last_block` set to + // true, then the trace will be interpreted as containing multiple messages. + builder + .when(next_is_padding_row.clone()) + .when(*next_cols.flags.is_digest_row) + .assert_one(*next_cols.flags.is_last_block); + // If we are in a round row, the next row cannot be a padding row + builder + .when(*local_cols.flags.is_round_row) + .assert_zero(next_is_padding_row.clone()); + // The first row must be a round row + builder + .when_first_row() + .assert_one(*local_cols.flags.is_round_row); + // If we are in a padding row, the next row must also be a padding row + builder + .when_transition() + .when(local_is_padding_row.clone()) + .assert_one(next_is_padding_row.clone()); + // If we are in a digest row, the next row cannot be a digest row + builder + .when(*local_cols.flags.is_digest_row) + .assert_zero(*next_cols.flags.is_digest_row); + // Constrain how much the row index changes by + // round->round: 1 + // round->digest: 1 + // digest->round: -C::ROUND_ROWS + // digest->padding: 1 + // padding->padding: 0 + // Other transitions are not allowed by the above constraints + let delta = *local_cols.flags.is_round_row * AB::Expr::ONE + + *local_cols.flags.is_digest_row + * *next_cols.flags.is_round_row + * AB::Expr::from_canonical_usize(C::ROUND_ROWS) + * AB::Expr::NEG_ONE + + *local_cols.flags.is_digest_row * next_is_padding_row.clone() * AB::Expr::ONE; + + let local_row_idx = self.row_idx_encoder.flag_with_val::( + local_cols.flags.row_idx.to_slice().unwrap(), + &(0..=C::ROWS_PER_BLOCK).map(|i| (i, i)).collect::>(), + ); + let next_row_idx = self.row_idx_encoder.flag_with_val::( + next_cols.flags.row_idx.to_slice().unwrap(), + &(0..=C::ROWS_PER_BLOCK).map(|i| (i, i)).collect::>(), + ); + + builder + .when_transition() + .assert_eq(local_row_idx.clone() + delta, next_row_idx.clone()); + builder.when_first_row().assert_zero(local_row_idx); + + // Constrain the global block index + // We set the global block index to 0 for padding rows + // Starting with 1 so it is not the same as the padding rows + + // Global block index is 1 on first row + builder + .when_first_row() + .assert_one(*local_cols.flags.global_block_idx); + + // Global block index is constant on all rows in a block + builder.when(*local_cols.flags.is_round_row).assert_eq( + *local_cols.flags.global_block_idx, + *next_cols.flags.global_block_idx, + ); + // Global block index increases by 1 between blocks + builder + .when_transition() + .when(*local_cols.flags.is_digest_row) + .when(*next_cols.flags.is_round_row) + .assert_eq( + *local_cols.flags.global_block_idx + AB::Expr::ONE, + *next_cols.flags.global_block_idx, + ); + // Global block index is 0 on padding rows + builder + .when(local_is_padding_row.clone()) + .assert_zero(*local_cols.flags.global_block_idx); + + // Constrain the local block index + // We set the local block index to 0 for padding rows + + // Local block index is constant on all rows in a block + // and its value on padding rows is equal to its value on the first block + builder + .when(not(*local_cols.flags.is_digest_row)) + .assert_eq( + *local_cols.flags.local_block_idx, + *next_cols.flags.local_block_idx, + ); + // Local block index increases by 1 between blocks in the same message + builder + .when(*local_cols.flags.is_digest_row) + .when(not(*local_cols.flags.is_last_block)) + .assert_eq( + *local_cols.flags.local_block_idx + AB::Expr::ONE, + *next_cols.flags.local_block_idx, + ); + // Local block index is 0 on padding rows + // Combined with the above, this means that the local block index is 0 in the first block + builder + .when(*local_cols.flags.is_digest_row) + .when(*local_cols.flags.is_last_block) + .assert_zero(*next_cols.flags.local_block_idx); + + self.eval_message_schedule(builder, local_cols.clone(), next_cols.clone()); + self.eval_work_vars(builder, local_cols.clone(), next_cols); + let next_cols: ShaDigestColsRef = + ShaDigestColsRef::from::(&next[start_col..start_col + C::DIGEST_WIDTH]); + self.eval_digest_row(builder, local_cols, next_cols); + let local_cols: ShaDigestColsRef = + ShaDigestColsRef::from::(&local[start_col..start_col + C::DIGEST_WIDTH]); + self.eval_prev_hash(builder, local_cols, next_is_padding_row); + } + + /// Constrains that the next block's `prev_hash` is equal to the current block's `hash` + /// Note: the constraining is done by interactions with the chip itself on every digest row + fn eval_prev_hash( + &self, + builder: &mut AB, + local: ShaDigestColsRef, + is_last_block_of_trace: AB::Expr, /* note this indicates the last block of the trace, + * not the last block of the message */ + ) { + // Constrain that next block's `prev_hash` is equal to the current block's `hash` + let composed_hash = (0..C::HASH_WORDS) + .map(|i| { + let hash_bits = if i < C::ROUNDS_PER_ROW { + local + .hash + .a + .row(C::ROUNDS_PER_ROW - 1 - i) + .mapv(|x| x.into()) + .to_vec() + } else { + local + .hash + .e + .row(C::ROUNDS_PER_ROW + 3 - i) + .mapv(|x| x.into()) + .to_vec() + }; + (0..C::WORD_U16S) + .map(|j| compose::(&hash_bits[j * 16..(j + 1) * 16], 1)) + .collect::>() + }) + .collect::>(); + // Need to handle the case if this is the very last block of the trace matrix + let next_global_block_idx = select( + is_last_block_of_trace, + AB::Expr::ONE, + *local.flags.global_block_idx + AB::Expr::ONE, + ); + // The following interactions constrain certain values from block to block + self.bus.send( + builder, + composed_hash + .into_iter() + .flatten() + .chain(once(next_global_block_idx)), + *local.flags.is_digest_row, + ); + + self.bus.receive( + builder, + local + .prev_hash + .flatten() + .mapv(|x| x.into()) + .into_iter() + .chain(once((*local.flags.global_block_idx).into())), + *local.flags.is_digest_row, + ); + } + + /// Constrain the message schedule additions for `next` row + /// Note: For every addition we need to constrain the following for each of [WORD_U16S] limbs + /// sig_1(w_{t-2})[i] + w_{t-7}[i] + sig_0(w_{t-15})[i] + w_{t-16}[i] + carry_w[t][i-1] - + /// carry_w[t][i] * 2^16 - w_t[i] == 0 Refer to [https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.180-4.pdf] + fn eval_message_schedule<'a, AB: InteractionBuilder>( + &self, + builder: &mut AB, + local: ShaRoundColsRef<'a, AB::Var>, + next: ShaRoundColsRef<'a, AB::Var>, + ) { + // This `w` array contains 8 message schedule words - w_{idx}, ..., w_{idx+7} for some idx + let w = ndarray::concatenate( + ndarray::Axis(0), + &[local.message_schedule.w, next.message_schedule.w], + ) + .unwrap(); + + // Constrain `w_3` for `next` row + for i in 0..C::ROUNDS_PER_ROW - 1 { + // here we constrain the w_3 of the i_th word of the next row + // w_3 of next is w[i+4-3] = w[i+1] + let w_3 = w.row(i + 1).mapv(|x| x.into()).to_vec(); + let expected_w_3 = next.schedule_helper.w_3.row(i); + for j in 0..C::WORD_U16S { + let w_3_limb = compose::(&w_3[j * 16..(j + 1) * 16], 1); + builder + .when(*local.flags.is_round_row) + .assert_eq(w_3_limb, expected_w_3[j].into()); + } + } + + // Constrain intermed for `next` row + // We will only constrain intermed_12 for rows [3, C::ROUND_ROWS - 2], and let it + // unconstrained for other rows Other rows should put the needed value in + // intermed_12 to make the below summation constraint hold + let is_row_intermed_12 = self.row_idx_encoder.contains_flag_range::( + next.flags.row_idx.to_slice().unwrap(), + 3..=C::ROUND_ROWS - 2, + ); + // We will only constrain intermed_8 for rows [2, C::ROUND_ROWS - 3], and let it + // unconstrained for other rows + let is_row_intermed_8 = self.row_idx_encoder.contains_flag_range::( + next.flags.row_idx.to_slice().unwrap(), + 2..=C::ROUND_ROWS - 3, + ); + for i in 0..C::ROUNDS_PER_ROW { + // w_idx + let w_idx = w.row(i).mapv(|x| x.into()).to_vec(); + // sig_0(w_{idx+1}) + let sig_w = small_sig0_field::(w.row(i + 1).as_slice().unwrap()); + for j in 0..C::WORD_U16S { + let w_idx_limb = compose::(&w_idx[j * 16..(j + 1) * 16], 1); + let sig_w_limb = compose::(&sig_w[j * 16..(j + 1) * 16], 1); + + // We would like to constrain this only on rows 0..16, but we can't do a conditional + // check because the degree is already 3. So we must fill in + // `intermed_4` with dummy values on rows 0 and 16 to ensure the constraint holds on + // these rows. + builder.when_transition().assert_eq( + next.schedule_helper.intermed_4[[i, j]], + w_idx_limb + sig_w_limb, + ); + + builder.when(is_row_intermed_8.clone()).assert_eq( + next.schedule_helper.intermed_8[[i, j]], + local.schedule_helper.intermed_4[[i, j]], + ); + + builder.when(is_row_intermed_12.clone()).assert_eq( + next.schedule_helper.intermed_12[[i, j]], + local.schedule_helper.intermed_8[[i, j]], + ); + } + } + + // Constrain the message schedule additions for `next` row + for i in 0..C::ROUNDS_PER_ROW { + // Note, here by w_{t} we mean the i_th word of the `next` row + // w_{t-7} + let w_7 = if i < 3 { + local.schedule_helper.w_3.row(i).mapv(|x| x.into()).to_vec() + } else { + let w_3 = w.row(i - 3).mapv(|x| x.into()).to_vec(); + (0..C::WORD_U16S) + .map(|j| compose::(&w_3[j * 16..(j + 1) * 16], 1)) + .collect::>() + }; + // sig_0(w_{t-15}) + w_{t-16} + let intermed_16 = local.schedule_helper.intermed_12.row(i).mapv(|x| x.into()); + + let carries = (0..C::WORD_U16S) + .map(|j| { + next.message_schedule.carry_or_buffer[[i, j * 2]] + + AB::Expr::TWO * next.message_schedule.carry_or_buffer[[i, j * 2 + 1]] + }) + .collect::>(); + + // Constrain `W_{idx} = sig_1(W_{idx-2}) + W_{idx-7} + sig_0(W_{idx-15}) + W_{idx-16}` + // We would like to constrain this only on rows 4..C::ROUND_ROWS, but we can't do a + // conditional check because the degree of sum is already 3 So we must fill + // in `intermed_12` with dummy values on rows 0..3 and C::ROUND_ROWS-1 and C::ROUND_ROWS + // to ensure the constraint holds on rows 0..4 and C::ROUND_ROWS. Note that + // the dummy value goes in the previous row to make the current row's constraint hold. + constraint_word_addition::<_, C>( + // Note: here we can't do a conditional check because the degree of sum is already + // 3 + &mut builder.when_transition(), + &[&small_sig1_field::( + w.row(i + 2).as_slice().unwrap(), + )], + &[&w_7, intermed_16.as_slice().unwrap()], + w.row(i + 4).as_slice().unwrap(), + &carries, + ); + + for j in 0..C::WORD_U16S { + // When on rows 4..C::ROUND_ROWS message schedule carries should be 0 or 1 + let is_row_4_or_more = *next.flags.is_round_row - *next.flags.is_first_4_rows; + builder + .when(is_row_4_or_more.clone()) + .assert_bool(next.message_schedule.carry_or_buffer[[i, j * 2]]); + builder + .when(is_row_4_or_more) + .assert_bool(next.message_schedule.carry_or_buffer[[i, j * 2 + 1]]); + } + // Constrain w being composed of bits + for j in 0..C::WORD_BITS { + builder + .when(*next.flags.is_round_row) + .assert_bool(next.message_schedule.w[[i, j]]); + } + } + } + + /// Constrain the work vars on `next` row according to the sha documentation + /// Refer to [https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.180-4.pdf] + fn eval_work_vars<'a, AB: InteractionBuilder>( + &self, + builder: &mut AB, + local: ShaRoundColsRef<'a, AB::Var>, + next: ShaRoundColsRef<'a, AB::Var>, + ) { + let a = + ndarray::concatenate(ndarray::Axis(0), &[local.work_vars.a, next.work_vars.a]).unwrap(); + let e = + ndarray::concatenate(ndarray::Axis(0), &[local.work_vars.e, next.work_vars.e]).unwrap(); + + for i in 0..C::ROUNDS_PER_ROW { + for j in 0..C::WORD_U16S { + // Although we need carry_a <= 6 and carry_e <= 5, constraining carry_a, carry_e in + // [0, 2^8) is enough to prevent overflow and ensure the soundness + // of the addition we want to check + self.bitwise_lookup_bus + .send_range( + local.work_vars.carry_a[[i, j]], + local.work_vars.carry_e[[i, j]], + ) + .eval(builder, *local.flags.is_round_row); + } + + let w_limbs = (0..C::WORD_U16S) + .map(|j| { + compose::( + next.message_schedule + .w + .slice(s![i, j * 16..(j + 1) * 16]) + .as_slice() + .unwrap(), + 1, + ) * *next.flags.is_round_row + }) + .collect::>(); + + let k_limbs = (0..C::WORD_U16S) + .map(|j| { + self.row_idx_encoder.flag_with_val::( + next.flags.row_idx.to_slice().unwrap(), + &(0..C::ROUND_ROWS) + .map(|rw_idx| { + ( + rw_idx, + word_into_u16_limbs::( + C::get_k()[rw_idx * C::ROUNDS_PER_ROW + i], + )[j] as usize, + ) + }) + .collect::>(), + ) + }) + .collect::>(); + + // Constrain `a = h + sig_1(e) + ch(e, f, g) + K + W + sig_0(a) + Maj(a, b, c)` + // We have to enforce this constraint on all rows since the degree of the constraint is + // already 3. So, we must fill in `carry_a` with dummy values on digest rows + // to ensure the constraint holds. + constraint_word_addition::<_, C>( + builder, + &[ + e.row(i).mapv(|x| x.into()).as_slice().unwrap(), // previous `h` + &big_sig1_field::(e.row(i + 3).as_slice().unwrap()), /* sig_1 of previous `e` */ + &ch_field::( + e.row(i + 3).as_slice().unwrap(), + e.row(i + 2).as_slice().unwrap(), + e.row(i + 1).as_slice().unwrap(), + ), /* Ch of previous `e`, `f`, `g` */ + &big_sig0_field::(a.row(i + 3).as_slice().unwrap()), /* sig_0 of previous `a` */ + &maj_field::( + a.row(i + 3).as_slice().unwrap(), + a.row(i + 2).as_slice().unwrap(), + a.row(i + 1).as_slice().unwrap(), + ), /* Maj of previous a, b, c */ + ], + &[&w_limbs, &k_limbs], // K and W + a.row(i + 4).as_slice().unwrap(), // new `a` + next.work_vars.carry_a.row(i).as_slice().unwrap(), // carries of addition + ); + + // Constrain `e = d + h + sig_1(e) + ch(e, f, g) + K + W` + // We have to enforce this constraint on all rows since the degree of the constraint is + // already 3. So, we must fill in `carry_e` with dummy values on digest rows + // to ensure the constraint holds. + constraint_word_addition::<_, C>( + builder, + &[ + a.row(i).mapv(|x| x.into()).as_slice().unwrap(), // previous `d` + e.row(i).mapv(|x| x.into()).as_slice().unwrap(), // previous `h` + &big_sig1_field::(e.row(i + 3).as_slice().unwrap()), /* sig_1 of previous `e` */ + &ch_field::( + e.row(i + 3).as_slice().unwrap(), + e.row(i + 2).as_slice().unwrap(), + e.row(i + 1).as_slice().unwrap(), + ), /* Ch of previous `e`, `f`, `g` */ + ], + &[&w_limbs, &k_limbs], // K and W + e.row(i + 4).as_slice().unwrap(), // new `e` + next.work_vars.carry_e.row(i).as_slice().unwrap(), // carries of addition + ); + } + } +} diff --git a/crates/circuits/sha2-air/src/columns.rs b/crates/circuits/sha2-air/src/columns.rs new file mode 100644 index 0000000000..da1e334e97 --- /dev/null +++ b/crates/circuits/sha2-air/src/columns.rs @@ -0,0 +1,187 @@ +//! WARNING: the order of fields in the structs is important, do not change it + +use openvm_circuit_primitives::utils::not; +use openvm_circuit_primitives_derive::ColsRef; +use openvm_stark_backend::p3_field::FieldAlgebra; + +use crate::Sha2Config; + +/// In each SHA block: +/// - First C::ROUND_ROWS rows use ShaRoundCols +/// - Final row uses ShaDigestCols +/// +/// Note that for soundness, we require that there is always a padding row after the last digest row +/// in the trace. Right now, this is true because the unpadded height is a multiple of 17 (SHA-256) +/// or 21 (SHA-512), and thus not a power of 2. +/// +/// ShaRoundCols and ShaDigestCols share the same first 3 fields: +/// - flags +/// - work_vars/hash (same type, different name) +/// - schedule_helper +/// +/// This design allows for: +/// 1. Common constraints to work on either struct type by accessing these shared fields +/// 2. Specific constraints to use the appropriate struct, with flags helping to do conditional +/// constraints +/// +/// Note that the `ShaWorkVarsCols` field is used for different purposes in the two structs. +#[repr(C)] +#[derive(Clone, Copy, Debug, ColsRef)] +#[config(Sha2Config)] +pub struct ShaRoundCols< + T, + const WORD_BITS: usize, + const WORD_U8S: usize, + const WORD_U16S: usize, + const ROUNDS_PER_ROW: usize, + const ROUNDS_PER_ROW_MINUS_ONE: usize, + const ROW_VAR_CNT: usize, +> { + pub flags: Sha2FlagsCols, + pub work_vars: ShaWorkVarsCols, + pub schedule_helper: + Sha2MessageHelperCols, + pub message_schedule: ShaMessageScheduleCols, +} + +#[repr(C)] +#[derive(Clone, Copy, Debug, ColsRef)] +#[config(Sha2Config)] +pub struct ShaDigestCols< + T, + const WORD_BITS: usize, + const WORD_U8S: usize, + const WORD_U16S: usize, + const HASH_WORDS: usize, + const ROUNDS_PER_ROW: usize, + const ROUNDS_PER_ROW_MINUS_ONE: usize, + const ROW_VAR_CNT: usize, +> { + pub flags: Sha2FlagsCols, + /// Will serve as previous hash values for the next block + pub hash: ShaWorkVarsCols, + pub schedule_helper: + Sha2MessageHelperCols, + /// The actual final hash values of the given block + /// Note: the above `hash` will be equal to `final_hash` unless we are on the last block + pub final_hash: [[T; WORD_U8S]; HASH_WORDS], + /// The final hash of the previous block + /// Note: will be constrained using interactions with the chip itself + pub prev_hash: [[T; WORD_U16S]; HASH_WORDS], +} + +#[repr(C)] +#[derive(Clone, Copy, Debug, ColsRef)] +#[config(Sha2Config)] +pub struct ShaMessageScheduleCols< + T, + const WORD_BITS: usize, + const ROUNDS_PER_ROW: usize, + const WORD_U8S: usize, +> { + /// The message schedule words as bits + /// The first 16 words will be the message data + pub w: [[T; WORD_BITS]; ROUNDS_PER_ROW], + /// Will be message schedule carries for rows 4..C::ROUND_ROWS and a buffer for rows 0..4 to be + /// used freely by wrapper chips Note: carries are 2 bit numbers represented using 2 cells + /// as individual bits + pub carry_or_buffer: [[T; WORD_U8S]; ROUNDS_PER_ROW], +} + +#[repr(C)] +#[derive(Clone, Copy, Debug, ColsRef)] +#[config(Sha2Config)] +pub struct ShaWorkVarsCols< + T, + const WORD_BITS: usize, + const ROUNDS_PER_ROW: usize, + const WORD_U16S: usize, +> { + /// `a` and `e` after each iteration as 32-bits + pub a: [[T; WORD_BITS]; ROUNDS_PER_ROW], + pub e: [[T; WORD_BITS]; ROUNDS_PER_ROW], + /// The carry's used for addition during each iteration when computing `a` and `e` + pub carry_a: [[T; WORD_U16S]; ROUNDS_PER_ROW], + pub carry_e: [[T; WORD_U16S]; ROUNDS_PER_ROW], +} + +/// These are the columns that are used to help with the message schedule additions +/// Note: these need to be correctly assigned for every row even on padding rows +#[repr(C)] +#[derive(Clone, Copy, Debug, ColsRef)] +#[config(Sha2Config)] +pub struct Sha2MessageHelperCols< + T, + const WORD_U16S: usize, + const ROUNDS_PER_ROW: usize, + const ROUNDS_PER_ROW_MINUS_ONE: usize, +> { + /// The following are used to move data forward to constrain the message schedule additions + /// The value of `w` from 3 rounds ago + pub w_3: [[T; WORD_U16S]; ROUNDS_PER_ROW_MINUS_ONE], + /// Here intermediate(i) = w_i + sig_0(w_{i+1}) + /// Intermed_t represents the intermediate t rounds ago + /// This is needed to constrain the message schedule, since we can only constrain on two rows + /// at a time + pub intermed_4: [[T; WORD_U16S]; ROUNDS_PER_ROW], + pub intermed_8: [[T; WORD_U16S]; ROUNDS_PER_ROW], + pub intermed_12: [[T; WORD_U16S]; ROUNDS_PER_ROW], +} + +#[repr(C)] +#[derive(Clone, Copy, Debug, ColsRef)] +#[config(Sha2Config)] +pub struct Sha2FlagsCols { + pub is_round_row: T, + /// A flag that indicates if the current row is among the first 4 rows of a block (the message + /// rows) + pub is_first_4_rows: T, + pub is_digest_row: T, + pub is_last_block: T, + /// We will encode the row index [0..17) using 5 cells + pub row_idx: [T; ROW_VAR_CNT], + /// The global index of the current block + pub global_block_idx: T, + /// Will store the index of the current block in the current message starting from 0 + pub local_block_idx: T, +} + +impl, const ROW_VAR_CNT: usize> + Sha2FlagsCols +{ + // This refers to the padding rows that are added to the air to make the trace length a power of + // 2. Not to be confused with the padding added to messages as part of the SHA hash + // function. + pub fn is_not_padding_row(&self) -> O { + self.is_round_row + self.is_digest_row + } + + // This refers to the padding rows that are added to the air to make the trace length a power of + // 2. Not to be confused with the padding added to messages as part of the SHA hash + // function. + pub fn is_padding_row(&self) -> O + where + O: FieldAlgebra, + { + not(self.is_not_padding_row()) + } +} + +impl> Sha2FlagsColsRef<'_, T> { + // This refers to the padding rows that are added to the air to make the trace length a power of + // 2. Not to be confused with the padding added to messages as part of the SHA hash + // function. + pub fn is_not_padding_row(&self) -> O { + *self.is_round_row + *self.is_digest_row + } + + // This refers to the padding rows that are added to the air to make the trace length a power of + // 2. Not to be confused with the padding added to messages as part of the SHA hash + // function. + pub fn is_padding_row(&self) -> O + where + O: FieldAlgebra, + { + not(self.is_not_padding_row()) + } +} diff --git a/crates/circuits/sha2-air/src/config.rs b/crates/circuits/sha2-air/src/config.rs new file mode 100644 index 0000000000..e6e6b54202 --- /dev/null +++ b/crates/circuits/sha2-air/src/config.rs @@ -0,0 +1,388 @@ +use std::ops::{BitAnd, BitOr, BitXor, Not, Shl, Shr}; + +use crate::{ShaDigestColsRef, ShaRoundColsRef}; + +#[repr(u32)] +#[derive(num_enum::TryFromPrimitive, num_enum::IntoPrimitive)] +pub enum Sha2Variant { + Sha256, + Sha512, + Sha384, +} + +pub trait Sha2Config: Send + Sync + Clone { + type Word: 'static + + Shr + + Shl + + BitAnd + + Not + + BitXor + + BitOr + + RotateRight + + WrappingAdd + + PartialEq + + From + + TryInto + + Copy + + Send + + Sync; + // Differentiate between the SHA-2 variants + const VARIANT: Sha2Variant; + /// Number of bits in a SHA word + const WORD_BITS: usize; + /// Number of 16-bit limbs in a SHA word + const WORD_U16S: usize = Self::WORD_BITS / 16; + /// Number of 8-bit limbs in a SHA word + const WORD_U8S: usize = Self::WORD_BITS / 8; + /// Number of words in a SHA block + const BLOCK_WORDS: usize; + /// Number of cells in a SHA block + const BLOCK_U8S: usize = Self::BLOCK_WORDS * Self::WORD_U8S; + /// Number of bits in a SHA block + const BLOCK_BITS: usize = Self::BLOCK_WORDS * Self::WORD_BITS; + /// Number of rows per block + const ROWS_PER_BLOCK: usize; + /// Number of rounds per row. Must divide Self::ROUNDS_PER_BLOCK + const ROUNDS_PER_ROW: usize; + /// Number of rows used for the sha rounds + const ROUND_ROWS: usize = Self::ROUNDS_PER_BLOCK / Self::ROUNDS_PER_ROW; + /// Number of rows used for the message + const MESSAGE_ROWS: usize = Self::BLOCK_WORDS / Self::ROUNDS_PER_ROW; + /// Number of rounds per row minus one (needed for one of the column structs) + const ROUNDS_PER_ROW_MINUS_ONE: usize = Self::ROUNDS_PER_ROW - 1; + /// Number of rounds per block. Must be a multiple of Self::ROUNDS_PER_ROW + const ROUNDS_PER_BLOCK: usize; + /// Number of words in a SHA hash + const HASH_WORDS: usize; + /// Number of vars needed to encode the row index with [Encoder] + const ROW_VAR_CNT: usize; + /// Width of the ShaRoundCols + const ROUND_WIDTH: usize = ShaRoundColsRef::::width::(); + /// Width of the ShaDigestCols + const DIGEST_WIDTH: usize = ShaDigestColsRef::::width::(); + /// Width of the ShaCols + const WIDTH: usize = if Self::ROUND_WIDTH > Self::DIGEST_WIDTH { + Self::ROUND_WIDTH + } else { + Self::DIGEST_WIDTH + }; + /// Number of cells used in each message row to store the message + const CELLS_PER_ROW: usize = Self::ROUNDS_PER_ROW * Self::WORD_U8S; + + /// To optimize the trace generation of invalid rows, we precompute those values. + // these should be appropriately sized for the config + fn get_invalid_carry_a(round_num: usize) -> &'static [u32]; + fn get_invalid_carry_e(round_num: usize) -> &'static [u32]; + + /// We also store the SHA constants K and H + fn get_k() -> &'static [Self::Word]; + fn get_h() -> &'static [Self::Word]; +} + +#[derive(Clone)] +pub struct Sha256Config; + +impl Sha2Config for Sha256Config { + // ==== Do not change these constants! ==== + const VARIANT: Sha2Variant = Sha2Variant::Sha256; + type Word = u32; + /// Number of bits in a SHA256 word + const WORD_BITS: usize = 32; + /// Number of words in a SHA256 block + const BLOCK_WORDS: usize = 16; + /// Number of rows per block + const ROWS_PER_BLOCK: usize = 17; + /// Number of rounds per row + const ROUNDS_PER_ROW: usize = 4; + /// Number of rounds per block + const ROUNDS_PER_BLOCK: usize = 64; + /// Number of words in a SHA256 hash + const HASH_WORDS: usize = 8; + /// Number of vars needed to encode the row index with [Encoder] + const ROW_VAR_CNT: usize = 5; + + fn get_invalid_carry_a(round_num: usize) -> &'static [u32] { + &SHA256_INVALID_CARRY_A[round_num] + } + fn get_invalid_carry_e(round_num: usize) -> &'static [u32] { + &SHA256_INVALID_CARRY_E[round_num] + } + fn get_k() -> &'static [u32] { + &SHA256_K + } + fn get_h() -> &'static [u32] { + &SHA256_H + } +} + +pub const SHA256_INVALID_CARRY_A: [[u32; Sha256Config::WORD_U16S]; Sha256Config::ROUNDS_PER_ROW] = [ + [1230919683, 1162494304], + [266373122, 1282901987], + [1519718403, 1008990871], + [923381762, 330807052], +]; +pub const SHA256_INVALID_CARRY_E: [[u32; Sha256Config::WORD_U16S]; Sha256Config::ROUNDS_PER_ROW] = [ + [204933122, 1994683449], + [443873282, 1544639095], + [719953922, 1888246508], + [194580482, 1075725211], +]; + +/// SHA256 constant K's +pub const SHA256_K: [u32; 64] = [ + 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5, + 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174, + 0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da, + 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967, + 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, + 0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070, + 0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3, + 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2, +]; +/// SHA256 initial hash values +pub const SHA256_H: [u32; 8] = [ + 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19, +]; + +#[derive(Clone)] +pub struct Sha512Config; + +impl Sha2Config for Sha512Config { + // ==== Do not change these constants! ==== + const VARIANT: Sha2Variant = Sha2Variant::Sha512; + type Word = u64; + /// Number of bits in a SHA512 word + const WORD_BITS: usize = 64; + /// Number of words in a SHA512 block + const BLOCK_WORDS: usize = 16; + /// Number of rows per block + const ROWS_PER_BLOCK: usize = 21; + /// Number of rounds per row + const ROUNDS_PER_ROW: usize = 4; + /// Number of rounds per block + const ROUNDS_PER_BLOCK: usize = 80; + /// Number of words in a SHA512 hash + const HASH_WORDS: usize = 8; + /// Number of vars needed to encode the row index with [Encoder] + const ROW_VAR_CNT: usize = 6; + + fn get_invalid_carry_a(round_num: usize) -> &'static [u32] { + &SHA512_INVALID_CARRY_A[round_num] + } + fn get_invalid_carry_e(round_num: usize) -> &'static [u32] { + &SHA512_INVALID_CARRY_E[round_num] + } + fn get_k() -> &'static [u64] { + &SHA512_K + } + fn get_h() -> &'static [u64] { + &SHA512_H + } +} + +pub(crate) const SHA512_INVALID_CARRY_A: [[u32; Sha512Config::WORD_U16S]; + Sha512Config::ROUNDS_PER_ROW] = [ + [55971842, 827997017, 993005918, 512731953], + [227512322, 1697529235, 1936430385, 940122990], + [1939875843, 1173318562, 826201586, 1513494849], + [891955202, 1732283693, 1736658755, 223514501], +]; + +pub(crate) const SHA512_INVALID_CARRY_E: [[u32; Sha512Config::WORD_U16S]; + Sha512Config::ROUNDS_PER_ROW] = [ + [1384427522, 1509509767, 153131516, 102514978], + [1527552003, 1041677071, 837289497, 843522538], + [775188482, 1620184630, 744892564, 892058728], + [1801267202, 1393118048, 1846108940, 830635531], +]; + +/// SHA512 constant K's +pub const SHA512_K: [u64; 80] = [ + 0x428a2f98d728ae22, + 0x7137449123ef65cd, + 0xb5c0fbcfec4d3b2f, + 0xe9b5dba58189dbbc, + 0x3956c25bf348b538, + 0x59f111f1b605d019, + 0x923f82a4af194f9b, + 0xab1c5ed5da6d8118, + 0xd807aa98a3030242, + 0x12835b0145706fbe, + 0x243185be4ee4b28c, + 0x550c7dc3d5ffb4e2, + 0x72be5d74f27b896f, + 0x80deb1fe3b1696b1, + 0x9bdc06a725c71235, + 0xc19bf174cf692694, + 0xe49b69c19ef14ad2, + 0xefbe4786384f25e3, + 0x0fc19dc68b8cd5b5, + 0x240ca1cc77ac9c65, + 0x2de92c6f592b0275, + 0x4a7484aa6ea6e483, + 0x5cb0a9dcbd41fbd4, + 0x76f988da831153b5, + 0x983e5152ee66dfab, + 0xa831c66d2db43210, + 0xb00327c898fb213f, + 0xbf597fc7beef0ee4, + 0xc6e00bf33da88fc2, + 0xd5a79147930aa725, + 0x06ca6351e003826f, + 0x142929670a0e6e70, + 0x27b70a8546d22ffc, + 0x2e1b21385c26c926, + 0x4d2c6dfc5ac42aed, + 0x53380d139d95b3df, + 0x650a73548baf63de, + 0x766a0abb3c77b2a8, + 0x81c2c92e47edaee6, + 0x92722c851482353b, + 0xa2bfe8a14cf10364, + 0xa81a664bbc423001, + 0xc24b8b70d0f89791, + 0xc76c51a30654be30, + 0xd192e819d6ef5218, + 0xd69906245565a910, + 0xf40e35855771202a, + 0x106aa07032bbd1b8, + 0x19a4c116b8d2d0c8, + 0x1e376c085141ab53, + 0x2748774cdf8eeb99, + 0x34b0bcb5e19b48a8, + 0x391c0cb3c5c95a63, + 0x4ed8aa4ae3418acb, + 0x5b9cca4f7763e373, + 0x682e6ff3d6b2b8a3, + 0x748f82ee5defb2fc, + 0x78a5636f43172f60, + 0x84c87814a1f0ab72, + 0x8cc702081a6439ec, + 0x90befffa23631e28, + 0xa4506cebde82bde9, + 0xbef9a3f7b2c67915, + 0xc67178f2e372532b, + 0xca273eceea26619c, + 0xd186b8c721c0c207, + 0xeada7dd6cde0eb1e, + 0xf57d4f7fee6ed178, + 0x06f067aa72176fba, + 0x0a637dc5a2c898a6, + 0x113f9804bef90dae, + 0x1b710b35131c471b, + 0x28db77f523047d84, + 0x32caab7b40c72493, + 0x3c9ebe0a15c9bebc, + 0x431d67c49c100d4c, + 0x4cc5d4becb3e42b6, + 0x597f299cfc657e2a, + 0x5fcb6fab3ad6faec, + 0x6c44198c4a475817, +]; +/// SHA512 initial hash values +pub const SHA512_H: [u64; 8] = [ + 0x6a09e667f3bcc908, + 0xbb67ae8584caa73b, + 0x3c6ef372fe94f82b, + 0xa54ff53a5f1d36f1, + 0x510e527fade682d1, + 0x9b05688c2b3e6c1f, + 0x1f83d9abfb41bd6b, + 0x5be0cd19137e2179, +]; + +#[derive(Clone)] +pub struct Sha384Config; + +impl Sha2Config for Sha384Config { + // ==== Do not change these constants! ==== + const VARIANT: Sha2Variant = Sha2Variant::Sha384; + type Word = u64; + /// Number of bits in a SHA384 word + const WORD_BITS: usize = 64; + /// Number of words in a SHA384 block + const BLOCK_WORDS: usize = 16; + /// Number of rows per block + const ROWS_PER_BLOCK: usize = 21; + /// Number of rounds per row + const ROUNDS_PER_ROW: usize = 4; + /// Number of rounds per block + const ROUNDS_PER_BLOCK: usize = 80; + /// Number of words in a SHA384 hash + const HASH_WORDS: usize = 8; + /// Number of vars needed to encode the row index with [Encoder] + const ROW_VAR_CNT: usize = 6; + + fn get_invalid_carry_a(round_num: usize) -> &'static [u32] { + &SHA384_INVALID_CARRY_A[round_num] + } + fn get_invalid_carry_e(round_num: usize) -> &'static [u32] { + &SHA384_INVALID_CARRY_E[round_num] + } + fn get_k() -> &'static [u64] { + &SHA384_K + } + fn get_h() -> &'static [u64] { + &SHA384_H + } +} + +pub(crate) const SHA384_INVALID_CARRY_A: [[u32; Sha384Config::WORD_U16S]; + Sha384Config::ROUNDS_PER_ROW] = [ + [1571481603, 1428841901, 1050676523, 793575075], + [1233315842, 1822329223, 112923808, 1874228927], + [1245603842, 927240770, 1579759431, 70557227], + [195532801, 594312107, 1429379950, 220407092], +]; + +pub(crate) const SHA384_INVALID_CARRY_E: [[u32; Sha384Config::WORD_U16S]; + Sha384Config::ROUNDS_PER_ROW] = [ + [1067980802, 1508061099, 1418826213, 1232569491], + [1453086722, 1702524575, 152427899, 238512408], + [1623674882, 701393097, 1002035664, 4776891], + [1888911362, 184963225, 1151849224, 1034237098], +]; + +/// SHA384 constant K's +pub const SHA384_K: [u64; 80] = SHA512_K; + +/// SHA384 initial hash values +pub const SHA384_H: [u64; 8] = [ + 0xcbbb9d5dc1059ed8, + 0x629a292a367cd507, + 0x9159015a3070dd17, + 0x152fecd8f70e5939, + 0x67332667ffc00b31, + 0x8eb44a8768581511, + 0xdb0c2e0d64f98fa7, + 0x47b5481dbefa4fa4, +]; + +// Needed to avoid compile errors in utils.rs +// not sure why this doesn't inf loop +pub trait RotateRight { + fn rotate_right(self, n: u32) -> Self; +} +impl RotateRight for u32 { + fn rotate_right(self, n: u32) -> Self { + self.rotate_right(n) + } +} +impl RotateRight for u64 { + fn rotate_right(self, n: u32) -> Self { + self.rotate_right(n) + } +} +pub trait WrappingAdd { + fn wrapping_add(self, n: Self) -> Self; +} +impl WrappingAdd for u32 { + fn wrapping_add(self, n: u32) -> Self { + self.wrapping_add(n) + } +} +impl WrappingAdd for u64 { + fn wrapping_add(self, n: u64) -> Self { + self.wrapping_add(n) + } +} diff --git a/crates/circuits/sha256-air/src/lib.rs b/crates/circuits/sha2-air/src/lib.rs similarity index 65% rename from crates/circuits/sha256-air/src/lib.rs rename to crates/circuits/sha2-air/src/lib.rs index 48bdaee5f9..7c7d095938 100644 --- a/crates/circuits/sha256-air/src/lib.rs +++ b/crates/circuits/sha2-air/src/lib.rs @@ -1,13 +1,15 @@ -//! Implementation of the SHA256 compression function without padding +//! Implementation of the SHA256/SHA512 compression function without padding //! This this AIR doesn't constrain any of the message padding mod air; mod columns; +mod config; mod trace; mod utils; pub use air::*; pub use columns::*; +pub use config::*; pub use trace::*; pub use utils::*; diff --git a/crates/circuits/sha256-air/src/tests.rs b/crates/circuits/sha2-air/src/tests.rs similarity index 55% rename from crates/circuits/sha256-air/src/tests.rs rename to crates/circuits/sha2-air/src/tests.rs index 5822bfe235..f376b0b246 100644 --- a/crates/circuits/sha256-air/src/tests.rs +++ b/crates/circuits/sha2-air/src/tests.rs @@ -1,4 +1,4 @@ -use std::{array, borrow::BorrowMut, cmp::max, sync::Arc}; +use std::{cmp::max, sync::Arc}; use openvm_circuit::arch::{ instructions::riscv::RV32_CELL_BITS, @@ -24,39 +24,39 @@ use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; use rand::Rng; use crate::{ - Sha256Air, Sha256DigestCols, Sha256StepHelper, SHA256_BLOCK_U8S, SHA256_DIGEST_WIDTH, - SHA256_HASH_WORDS, SHA256_ROUND_WIDTH, SHA256_ROWS_PER_BLOCK, SHA256_WORD_U8S, + Sha256Config, Sha2Air, Sha2Config, Sha2StepHelper, Sha384Config, Sha512Config, + ShaDigestColsRefMut, }; // A wrapper AIR purely for testing purposes #[derive(Clone, Debug)] -pub struct Sha256TestAir { - pub sub_air: Sha256Air, +pub struct Sha2TestAir { + pub sub_air: Sha2Air, } -impl BaseAirWithPublicValues for Sha256TestAir {} -impl PartitionedBaseAir for Sha256TestAir {} -impl BaseAir for Sha256TestAir { +impl BaseAirWithPublicValues for Sha2TestAir {} +impl PartitionedBaseAir for Sha2TestAir {} +impl BaseAir for Sha2TestAir { fn width(&self) -> usize { - >::width(&self.sub_air) + as BaseAir>::width(&self.sub_air) } } -impl Air for Sha256TestAir { +impl Air for Sha2TestAir { fn eval(&self, builder: &mut AB) { self.sub_air.eval(builder, 0); } } // A wrapper Chip purely for testing purposes -pub struct Sha256TestChip { - pub air: Sha256TestAir, - pub step: Sha256StepHelper, +pub struct Sha2TestChip { + pub air: Sha2TestAir, + pub step: Sha2StepHelper, pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<8>, - pub records: Vec<([u8; SHA256_BLOCK_U8S], bool)>, + pub records: Vec<(Vec, bool)>, // length of inner vec is C::BLOCK_U8S } -impl Chip for Sha256TestChip +impl Chip for Sha2TestChip where Val: PrimeField32, { @@ -65,33 +65,34 @@ where } fn generate_air_proof_input(self) -> AirProofInput { - let trace = crate::generate_trace::>( + let trace = crate::generate_trace::, C>( &self.step, - self.bitwise_lookup_chip.as_ref(), - >>::width(&self.air.sub_air), + self.bitwise_lookup_chip.clone(), + as BaseAir>>::width(&self.air.sub_air), self.records, ); AirProofInput::simple_no_pis(trace) } } -impl ChipUsageGetter for Sha256TestChip { +impl ChipUsageGetter for Sha2TestChip { fn air_name(&self) -> String { get_air_name(&self.air) } fn current_trace_height(&self) -> usize { - self.records.len() * SHA256_ROWS_PER_BLOCK + self.records.len() * C::ROWS_PER_BLOCK } fn trace_width(&self) -> usize { - max(SHA256_ROUND_WIDTH, SHA256_DIGEST_WIDTH) + max(C::ROUND_WIDTH, C::DIGEST_WIDTH) } } const SELF_BUS_IDX: BusIndex = 28; type F = BabyBear; -fn create_chip_with_rand_records() -> (Sha256TestChip, SharedBitwiseOperationLookupChip<8>) { +fn create_chip_with_rand_records( +) -> (Sha2TestChip, SharedBitwiseOperationLookupChip<8>) { let mut rng = create_seeded_rng(); let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); @@ -99,44 +100,61 @@ fn create_chip_with_rand_records() -> (Sha256TestChip, SharedBitwiseOperationLoo let random_records: Vec<_> = (0..len) .map(|i| { ( - array::from_fn(|_| rng.gen::()), + (0..C::BLOCK_U8S) + .map(|_| rng.gen::()) + .collect::>(), rng.gen::() || i == len - 1, ) }) .collect(); - let chip = Sha256TestChip { - air: Sha256TestAir { - sub_air: Sha256Air::new(bitwise_bus, SELF_BUS_IDX), + let chip = Sha2TestChip { + air: Sha2TestAir { + sub_air: Sha2Air::::new(bitwise_bus, SELF_BUS_IDX), }, - step: Sha256StepHelper::new(), + step: Sha2StepHelper::::new(), bitwise_lookup_chip: bitwise_chip.clone(), records: random_records, }; + (chip, bitwise_chip) } -#[test] -fn rand_sha256_test() { +fn rand_sha2_test() { let tester = VmChipTestBuilder::default(); - let (chip, bitwise_chip) = create_chip_with_rand_records(); + let (chip, bitwise_chip) = create_chip_with_rand_records::(); let tester = tester.build().load(chip).load(bitwise_chip).finalize(); tester.simple_test().expect("Verification failed"); } #[test] -fn negative_sha256_test_bad_final_hash() { +fn rand_sha256_test() { + rand_sha2_test::(); +} + +#[test] +fn rand_sha512_test() { + rand_sha2_test::(); +} + +#[test] +fn rand_sha384_test() { + rand_sha2_test::(); +} + +fn negative_sha2_test_bad_final_hash() { let tester = VmChipTestBuilder::default(); - let (chip, bitwise_chip) = create_chip_with_rand_records(); + let (chip, bitwise_chip) = create_chip_with_rand_records::(); // Set the final_hash to all zeros let modify_trace = |trace: &mut RowMajorMatrix| { trace.row_chunks_exact_mut(1).for_each(|row| { let mut row_slice = row.row_slice(0).to_vec(); - let cols: &mut Sha256DigestCols = row_slice[..SHA256_DIGEST_WIDTH].borrow_mut(); + let mut cols: ShaDigestColsRefMut = + ShaDigestColsRefMut::from::(&mut row_slice[..C::DIGEST_WIDTH]); if cols.flags.is_last_block.is_one() && cols.flags.is_digest_row.is_one() { - for i in 0..SHA256_HASH_WORDS { - for j in 0..SHA256_WORD_U8S { - cols.final_hash[i][j] = F::ZERO; + for i in 0..C::HASH_WORDS { + for j in 0..C::WORD_U8S { + cols.final_hash[[i, j]] = F::ZERO; } } row.values.copy_from_slice(&row_slice); @@ -152,3 +170,18 @@ fn negative_sha256_test_bad_final_hash() { .finalize(); tester.simple_test_with_expected_error(VerificationError::OodEvaluationMismatch); } + +#[test] +fn negative_sha256_test_bad_final_hash() { + negative_sha2_test_bad_final_hash::(); +} + +#[test] +fn negative_sha512_test_bad_final_hash() { + negative_sha2_test_bad_final_hash::(); +} + +#[test] +fn negative_sha384_test_bad_final_hash() { + negative_sha2_test_bad_final_hash::(); +} diff --git a/crates/circuits/sha2-air/src/trace.rs b/crates/circuits/sha2-air/src/trace.rs new file mode 100644 index 0000000000..d2c8e8f8d8 --- /dev/null +++ b/crates/circuits/sha2-air/src/trace.rs @@ -0,0 +1,864 @@ +use std::{marker::PhantomData, ops::Range}; + +use openvm_circuit_primitives::{ + bitwise_op_lookup::SharedBitwiseOperationLookupChip, encoder::Encoder, + utils::next_power_of_two_or_zero, +}; +use openvm_stark_backend::{ + p3_field::PrimeField32, p3_matrix::dense::RowMajorMatrix, p3_maybe_rayon::prelude::*, +}; +use sha2::{compress256, compress512, digest::generic_array::GenericArray}; + +use super::{ + big_sig0_field, big_sig1_field, ch_field, compose, get_flag_pt_array, maj_field, + small_sig0_field, small_sig1_field, ShaRoundColsRefMut, +}; +use crate::{ + big_sig0, big_sig1, ch, le_limbs_into_word, maj, small_sig0, small_sig1, word_into_bits, + word_into_u16_limbs, word_into_u8_limbs, Sha2Config, Sha2Variant, ShaDigestColsRefMut, + ShaRoundColsRef, WrappingAdd, +}; + +/// A helper struct for the SHA256 trace generation. +/// Also, separates the inner AIR from the trace generation. +pub struct Sha2StepHelper { + pub row_idx_encoder: Encoder, + _phantom: PhantomData, +} + +impl Default for Sha2StepHelper { + fn default() -> Self { + Self::new() + } +} + +/// The trace generation of SHA should be done in two passes. +/// The first pass should do `get_block_trace` for every block and generate the invalid rows through +/// `get_default_row` The second pass should go through all the blocks and call +/// `generate_missing_cells` +impl Sha2StepHelper { + pub fn new() -> Self { + Self { + // +1 for dummy (padding) rows + row_idx_encoder: Encoder::new(C::ROWS_PER_BLOCK + 1, 2, false), + _phantom: PhantomData, + } + } + + /// This function takes the input_message (padding not handled), the previous hash, + /// and returns the new hash after processing the block input + pub fn get_block_hash(prev_hash: &[C::Word], input: Vec) -> Vec { + debug_assert!(prev_hash.len() == C::HASH_WORDS); + debug_assert!(input.len() == C::BLOCK_U8S); + let mut new_hash: [C::Word; 8] = prev_hash.try_into().unwrap(); + match C::VARIANT { + Sha2Variant::Sha256 => { + let input_array = [*GenericArray::::from_slice( + &input, + )]; + let hash_ptr: &mut [u32; 8] = unsafe { std::mem::transmute(&mut new_hash) }; + compress256(hash_ptr, &input_array); + } + Sha2Variant::Sha512 | Sha2Variant::Sha384 => { + let hash_ptr: &mut [u64; 8] = unsafe { std::mem::transmute(&mut new_hash) }; + let input_array = [*GenericArray::::from_slice( + &input, + )]; + compress512(hash_ptr, &input_array); + } + } + new_hash.to_vec() + } + + /// This function takes a C::BLOCK_BITS-bit chunk of the input message (padding not handled), + /// the previous hash, a flag indicating if it's the last block, the global block index, the + /// local block index, and the buffer values that will be put in rows 0..4. + /// Will populate the given `trace` with the trace of the block, where the width of the trace is + /// `trace_width` and the starting column for the `Sha2Air` is `trace_start_col`. + /// **Note**: this function only generates some of the required trace. Another pass is required, + /// refer to [`Self::generate_missing_cells`] for details. + #[allow(clippy::too_many_arguments)] + pub fn generate_block_trace( + &self, + trace: &mut [F], + trace_width: usize, + trace_start_col: usize, + input: &[C::Word], + bitwise_lookup_chip: SharedBitwiseOperationLookupChip<8>, + prev_hash: &[C::Word], + is_last_block: bool, + global_block_idx: u32, + local_block_idx: u32, + ) { + #[cfg(debug_assertions)] + { + assert!(input.len() == C::BLOCK_WORDS); + assert!(prev_hash.len() == C::HASH_WORDS); + assert!(trace.len() == trace_width * C::ROWS_PER_BLOCK); + assert!(trace_start_col + C::WIDTH <= trace_width); + if local_block_idx == 0 { + assert!(*prev_hash == *C::get_h()); + } + } + let get_range = |start: usize, len: usize| -> Range { start..start + len }; + let mut message_schedule = vec![C::Word::from(0); C::ROUNDS_PER_BLOCK]; + message_schedule[..input.len()].copy_from_slice(input); + let mut work_vars = prev_hash.to_vec(); + for (i, row) in trace.chunks_exact_mut(trace_width).enumerate() { + // do the rounds + if i < C::ROUND_ROWS { + let mut cols: ShaRoundColsRefMut = ShaRoundColsRefMut::from::( + &mut row[get_range(trace_start_col, C::ROUND_WIDTH)], + ); + *cols.flags.is_round_row = F::ONE; + *cols.flags.is_first_4_rows = if i < 4 { F::ONE } else { F::ZERO }; + *cols.flags.is_digest_row = F::ZERO; + *cols.flags.is_last_block = F::from_bool(is_last_block); + cols.flags + .row_idx + .iter_mut() + .zip( + get_flag_pt_array(&self.row_idx_encoder, i) + .into_iter() + .map(F::from_canonical_u32), + ) + .for_each(|(x, y)| *x = y); + + *cols.flags.global_block_idx = F::from_canonical_u32(global_block_idx); + *cols.flags.local_block_idx = F::from_canonical_u32(local_block_idx); + + // W_idx = M_idx + if i < C::MESSAGE_ROWS { + for j in 0..C::ROUNDS_PER_ROW { + cols.message_schedule + .w + .row_mut(j) + .iter_mut() + .zip( + word_into_bits::(input[i * C::ROUNDS_PER_ROW + j]) + .into_iter() + .map(F::from_canonical_u32), + ) + .for_each(|(x, y)| *x = y); + } + } + // W_idx = SIG1(W_{idx-2}) + W_{idx-7} + SIG0(W_{idx-15}) + W_{idx-16} + else { + for j in 0..C::ROUNDS_PER_ROW { + let idx = i * C::ROUNDS_PER_ROW + j; + let nums: [C::Word; 4] = [ + small_sig1::(message_schedule[idx - 2]), + message_schedule[idx - 7], + small_sig0::(message_schedule[idx - 15]), + message_schedule[idx - 16], + ]; + let w: C::Word = nums + .iter() + .fold(C::Word::from(0), |acc, &num| acc.wrapping_add(num)); + cols.message_schedule + .w + .row_mut(j) + .iter_mut() + .zip( + word_into_bits::(w) + .into_iter() + .map(F::from_canonical_u32), + ) + .for_each(|(x, y)| *x = y); + + let nums_limbs = nums + .iter() + .map(|x| word_into_u16_limbs::(*x)) + .collect::>(); + let w_limbs = word_into_u16_limbs::(w); + + // fill in the carrys + for k in 0..C::WORD_U16S { + let mut sum = nums_limbs.iter().fold(0, |acc, num| acc + num[k]); + if k > 0 { + sum += (cols.message_schedule.carry_or_buffer[[j, k * 2 - 2]] + + F::TWO + * cols.message_schedule.carry_or_buffer[[j, k * 2 - 1]]) + .as_canonical_u32(); + } + let carry = (sum - w_limbs[k]) >> 16; + cols.message_schedule.carry_or_buffer[[j, k * 2]] = + F::from_canonical_u32(carry & 1); + cols.message_schedule.carry_or_buffer[[j, k * 2 + 1]] = + F::from_canonical_u32(carry >> 1); + } + // update the message schedule + message_schedule[idx] = w; + } + } + // fill in the work variables + for j in 0..C::ROUNDS_PER_ROW { + // t1 = h + SIG1(e) + ch(e, f, g) + K_idx + W_idx + let t1 = [ + work_vars[7], + big_sig1::(work_vars[4]), + ch::(work_vars[4], work_vars[5], work_vars[6]), + C::get_k()[i * C::ROUNDS_PER_ROW + j], + le_limbs_into_word::( + cols.message_schedule + .w + .row(j) + .map(|f| f.as_canonical_u32()) + .as_slice() + .unwrap(), + ), + ]; + let t1_sum: C::Word = t1 + .iter() + .fold(C::Word::from(0), |acc, &num| acc.wrapping_add(num)); + + // t2 = SIG0(a) + maj(a, b, c) + let t2 = [ + big_sig0::(work_vars[0]), + maj::(work_vars[0], work_vars[1], work_vars[2]), + ]; + + let t2_sum: C::Word = t2 + .iter() + .fold(C::Word::from(0), |acc, &num| acc.wrapping_add(num)); + + // e = d + t1 + let e = work_vars[3].wrapping_add(t1_sum); + cols.work_vars + .e + .row_mut(j) + .iter_mut() + .zip( + word_into_bits::(e) + .into_iter() + .map(F::from_canonical_u32), + ) + .for_each(|(x, y)| *x = y); + let e_limbs = word_into_u16_limbs::(e); + // a = t1 + t2 + let a = t1_sum.wrapping_add(t2_sum); + cols.work_vars + .a + .row_mut(j) + .iter_mut() + .zip( + word_into_bits::(a) + .into_iter() + .map(F::from_canonical_u32), + ) + .for_each(|(x, y)| *x = y); + let a_limbs = word_into_u16_limbs::(a); + // fill in the carrys + for k in 0..C::WORD_U16S { + let t1_limb = t1 + .iter() + .fold(0, |acc, &num| acc + word_into_u16_limbs::(num)[k]); + let t2_limb = t2 + .iter() + .fold(0, |acc, &num| acc + word_into_u16_limbs::(num)[k]); + + let mut e_limb = t1_limb + word_into_u16_limbs::(work_vars[3])[k]; + let mut a_limb = t1_limb + t2_limb; + if k > 0 { + a_limb += cols.work_vars.carry_a[[j, k - 1]].as_canonical_u32(); + e_limb += cols.work_vars.carry_e[[j, k - 1]].as_canonical_u32(); + } + let carry_a = (a_limb - a_limbs[k]) >> 16; + let carry_e = (e_limb - e_limbs[k]) >> 16; + cols.work_vars.carry_a[[j, k]] = F::from_canonical_u32(carry_a); + cols.work_vars.carry_e[[j, k]] = F::from_canonical_u32(carry_e); + bitwise_lookup_chip.request_range(carry_a, carry_e); + } + + // update working variables + work_vars[7] = work_vars[6]; + work_vars[6] = work_vars[5]; + work_vars[5] = work_vars[4]; + work_vars[4] = e; + work_vars[3] = work_vars[2]; + work_vars[2] = work_vars[1]; + work_vars[1] = work_vars[0]; + work_vars[0] = a; + } + + // filling w_3 and intermed_4 here and the rest later + if i > 0 { + for j in 0..C::ROUNDS_PER_ROW { + let idx = i * C::ROUNDS_PER_ROW + j; + let w_4 = word_into_u16_limbs::(message_schedule[idx - 4]); + let sig_0_w_3 = + word_into_u16_limbs::(small_sig0::(message_schedule[idx - 3])); + cols.schedule_helper + .intermed_4 + .row_mut(j) + .iter_mut() + .zip( + (0..C::WORD_U16S) + .map(|k| F::from_canonical_u32(w_4[k] + sig_0_w_3[k])) + .collect::>(), + ) + .for_each(|(x, y)| *x = y); + if j < C::ROUNDS_PER_ROW - 1 { + let w_3 = message_schedule[idx - 3]; + cols.schedule_helper + .w_3 + .row_mut(j) + .iter_mut() + .zip( + word_into_u16_limbs::(w_3) + .into_iter() + .map(F::from_canonical_u32), + ) + .for_each(|(x, y)| *x = y); + } + } + } + } + // generate the digest row + else { + let mut cols: ShaDigestColsRefMut = ShaDigestColsRefMut::from::( + &mut row[get_range(trace_start_col, C::DIGEST_WIDTH)], + ); + for j in 0..C::ROUNDS_PER_ROW - 1 { + let w_3 = message_schedule[i * C::ROUNDS_PER_ROW + j - 3]; + cols.schedule_helper + .w_3 + .row_mut(j) + .iter_mut() + .zip( + word_into_u16_limbs::(w_3) + .into_iter() + .map(F::from_canonical_u32) + .collect::>(), + ) + .for_each(|(x, y)| *x = y); + } + *cols.flags.is_round_row = F::ZERO; + *cols.flags.is_first_4_rows = F::ZERO; + *cols.flags.is_digest_row = F::ONE; + *cols.flags.is_last_block = F::from_bool(is_last_block); + cols.flags + .row_idx + .iter_mut() + .zip( + get_flag_pt_array(&self.row_idx_encoder, C::ROUND_ROWS) + .into_iter() + .map(F::from_canonical_u32), + ) + .for_each(|(x, y)| *x = y); + + *cols.flags.global_block_idx = F::from_canonical_u32(global_block_idx); + + *cols.flags.local_block_idx = F::from_canonical_u32(local_block_idx); + let final_hash: Vec = (0..C::HASH_WORDS) + .map(|i| work_vars[i].wrapping_add(prev_hash[i])) + .collect(); + let final_hash_limbs: Vec> = final_hash + .iter() + .map(|word| word_into_u8_limbs::(*word)) + .collect(); + // need to ensure final hash limbs are bytes, in order for + // prev_hash[i] + work_vars[i] == final_hash[i] + // to be constrained correctly + for word in final_hash_limbs.iter() { + for chunk in word.chunks(2) { + bitwise_lookup_chip.request_range(chunk[0], chunk[1]); + } + } + cols.final_hash + .iter_mut() + .zip((0..C::HASH_WORDS).flat_map(|i| { + word_into_u8_limbs::(final_hash[i]) + .into_iter() + .map(F::from_canonical_u32) + .collect::>() + })) + .for_each(|(x, y)| *x = y); + cols.prev_hash + .iter_mut() + .zip(prev_hash.iter().flat_map(|f| { + word_into_u16_limbs::(*f) + .into_iter() + .map(F::from_canonical_u32) + .collect::>() + })) + .for_each(|(x, y)| *x = y); + + let hash = if is_last_block { + C::get_h() + .iter() + .map(|x| word_into_bits::(*x)) + .collect::>() + } else { + cols.final_hash + .rows_mut() + .into_iter() + .map(|f| { + le_limbs_into_word::( + f.map(|x| x.as_canonical_u32()).as_slice().unwrap(), + ) + }) + .map(word_into_bits::) + .collect() + } + .into_iter() + .map(|x| x.into_iter().map(F::from_canonical_u32)) + .collect::>(); + + for i in 0..C::ROUNDS_PER_ROW { + cols.hash + .a + .row_mut(i) + .iter_mut() + .zip(hash[C::ROUNDS_PER_ROW - i - 1].clone()) + .for_each(|(x, y)| *x = y); + cols.hash + .e + .row_mut(i) + .iter_mut() + .zip(hash[C::ROUNDS_PER_ROW - i + 3].clone()) + .for_each(|(x, y)| *x = y); + } + } + } + + for i in 0..C::ROWS_PER_BLOCK - 1 { + let rows = &mut trace[i * trace_width..(i + 2) * trace_width]; + let (local, next) = rows.split_at_mut(trace_width); + let mut local_cols: ShaRoundColsRefMut = ShaRoundColsRefMut::from::( + &mut local[get_range(trace_start_col, C::ROUND_WIDTH)], + ); + let mut next_cols: ShaRoundColsRefMut = ShaRoundColsRefMut::from::( + &mut next[get_range(trace_start_col, C::ROUND_WIDTH)], + ); + if i > 0 { + for j in 0..C::ROUNDS_PER_ROW { + next_cols + .schedule_helper + .intermed_8 + .row_mut(j) + .assign(&local_cols.schedule_helper.intermed_4.row(j)); + if (2..C::ROWS_PER_BLOCK - 3).contains(&i) { + next_cols + .schedule_helper + .intermed_12 + .row_mut(j) + .assign(&local_cols.schedule_helper.intermed_8.row(j)); + } + } + } + if i == C::ROWS_PER_BLOCK - 2 { + // `next` is a digest row. + // Fill in `carry_a` and `carry_e` with dummy values so the constraints on `a` and + // `e` hold. + let const_local_cols = ShaRoundColsRef::::from_mut::(&local_cols); + Self::generate_carry_ae(const_local_cols.clone(), &mut next_cols); + // Fill in row 16's `intermed_4` with dummy values so the message schedule + // constraints holds on that row + Self::generate_intermed_4(const_local_cols, &mut next_cols); + } + if i <= 2 { + // i is in 0..3. + // Fill in `local.intermed_12` with dummy values so the message schedule constraints + // hold on rows 1..4. + Self::generate_intermed_12( + &mut local_cols, + ShaRoundColsRef::::from_mut::(&next_cols), + ); + } + } + } + + /// This function will fill in the cells that we couldn't do during the first pass. + /// This function should be called only after `generate_block_trace` was called for all blocks + /// And [`Self::generate_default_row`] is called for all invalid rows + /// Will populate the missing values of `trace`, where the width of the trace is `trace_width` + /// and the starting column for the `ShaAir` is `trace_start_col`. + /// Note: `trace` needs to be the rows 1..C::ROWS_PER_BLOCK of a block and the first row of the + /// next block + pub fn generate_missing_cells( + &self, + trace: &mut [F], + trace_width: usize, + trace_start_col: usize, + ) { + let rows = &mut trace[(C::ROUND_ROWS - 2) * trace_width..(C::ROUND_ROWS + 1) * trace_width]; + let (last_round_row, rows) = rows.split_at_mut(trace_width); + let (digest_row, next_block_first_row) = rows.split_at_mut(trace_width); + let mut cols_last_round_row: ShaRoundColsRefMut = ShaRoundColsRefMut::from::( + &mut last_round_row[trace_start_col..trace_start_col + C::ROUND_WIDTH], + ); + let mut cols_digest_row: ShaRoundColsRefMut = ShaRoundColsRefMut::from::( + &mut digest_row[trace_start_col..trace_start_col + C::ROUND_WIDTH], + ); + let mut cols_next_block_first_row: ShaRoundColsRefMut = ShaRoundColsRefMut::from::( + &mut next_block_first_row[trace_start_col..trace_start_col + C::ROUND_WIDTH], + ); + // Fill in the last round row's `intermed_12` with dummy values so the message schedule + // constraints holds on row 16 + Self::generate_intermed_12( + &mut cols_last_round_row, + ShaRoundColsRef::from_mut::(&cols_digest_row), + ); + // Fill in the digest row's `intermed_12` with dummy values so the message schedule + // constraints holds on the next block's row 0 + Self::generate_intermed_12( + &mut cols_digest_row, + ShaRoundColsRef::from_mut::(&cols_next_block_first_row), + ); + // Fill in the next block's first row's `intermed_4` with dummy values so the message + // schedule constraints holds on that row + Self::generate_intermed_4( + ShaRoundColsRef::from_mut::(&cols_digest_row), + &mut cols_next_block_first_row, + ); + } + + /// Fills the `cols` as a padding row + /// Note: we still need to correctly fill in the hash values, carries and intermeds + pub fn generate_default_row(&self, mut cols: ShaRoundColsRefMut) { + *cols.flags.is_round_row = F::ZERO; + *cols.flags.is_first_4_rows = F::ZERO; + *cols.flags.is_digest_row = F::ZERO; + + *cols.flags.is_last_block = F::ZERO; + *cols.flags.global_block_idx = F::ZERO; + cols.flags + .row_idx + .iter_mut() + .zip( + get_flag_pt_array(&self.row_idx_encoder, C::ROWS_PER_BLOCK) + .into_iter() + .map(F::from_canonical_u32), + ) + .for_each(|(x, y)| *x = y); + *cols.flags.local_block_idx = F::ZERO; + + cols.message_schedule + .w + .iter_mut() + .for_each(|x| *x = F::ZERO); + cols.message_schedule + .carry_or_buffer + .iter_mut() + .for_each(|x| *x = F::ZERO); + + let hash = C::get_h() + .iter() + .map(|x| word_into_bits::(*x)) + .map(|x| x.into_iter().map(F::from_canonical_u32).collect::>()) + .collect::>(); + + for i in 0..C::ROUNDS_PER_ROW { + cols.work_vars + .a + .row_mut(i) + .iter_mut() + .zip(hash[C::ROUNDS_PER_ROW - i - 1].clone()) + .for_each(|(x, y)| *x = y); + cols.work_vars + .e + .row_mut(i) + .iter_mut() + .zip(hash[C::ROUNDS_PER_ROW - i + 3].clone()) + .for_each(|(x, y)| *x = y); + } + + cols.work_vars + .carry_a + .iter_mut() + .zip((0..C::ROUNDS_PER_ROW).flat_map(|i| { + (0..C::WORD_U16S) + .map(|j| F::from_canonical_u32(C::get_invalid_carry_a(i)[j])) + .collect::>() + })) + .for_each(|(x, y)| *x = y); + cols.work_vars + .carry_e + .iter_mut() + .zip((0..C::ROUNDS_PER_ROW).flat_map(|i| { + (0..C::WORD_U16S) + .map(|j| F::from_canonical_u32(C::get_invalid_carry_e(i)[j])) + .collect::>() + })) + .for_each(|(x, y)| *x = y); + } + + /// The following functions do the calculations in native field since they will be called on + /// padding rows which can overflow and we need to make sure it matches the AIR constraints + /// Puts the correct carries in the `next_row`, the resulting carries can be out of bounds + pub fn generate_carry_ae( + local_cols: ShaRoundColsRef, + next_cols: &mut ShaRoundColsRefMut, + ) { + let a = [ + local_cols + .work_vars + .a + .rows() + .into_iter() + .collect::>(), + next_cols.work_vars.a.rows().into_iter().collect::>(), + ] + .concat(); + let e = [ + local_cols + .work_vars + .e + .rows() + .into_iter() + .collect::>(), + next_cols.work_vars.e.rows().into_iter().collect::>(), + ] + .concat(); + for i in 0..C::ROUNDS_PER_ROW { + let cur_a = a[i + 4]; + let sig_a = big_sig0_field::(a[i + 3].as_slice().unwrap()); + let maj_abc = maj_field::( + a[i + 3].as_slice().unwrap(), + a[i + 2].as_slice().unwrap(), + a[i + 1].as_slice().unwrap(), + ); + let d = a[i]; + let cur_e = e[i + 4]; + let sig_e = big_sig1_field::(e[i + 3].as_slice().unwrap()); + let ch_efg = ch_field::( + e[i + 3].as_slice().unwrap(), + e[i + 2].as_slice().unwrap(), + e[i + 1].as_slice().unwrap(), + ); + let h = e[i]; + + let t1 = [h.to_vec(), sig_e, ch_efg.to_vec()]; + let t2 = [sig_a, maj_abc]; + for j in 0..C::WORD_U16S { + let t1_limb_sum = t1.iter().fold(F::ZERO, |acc, x| { + acc + compose::(&x[j * 16..(j + 1) * 16], 1) + }); + let t2_limb_sum = t2.iter().fold(F::ZERO, |acc, x| { + acc + compose::(&x[j * 16..(j + 1) * 16], 1) + }); + let d_limb = compose::(&d.as_slice().unwrap()[j * 16..(j + 1) * 16], 1); + let cur_a_limb = compose::(&cur_a.as_slice().unwrap()[j * 16..(j + 1) * 16], 1); + let cur_e_limb = compose::(&cur_e.as_slice().unwrap()[j * 16..(j + 1) * 16], 1); + let sum = d_limb + + t1_limb_sum + + if j == 0 { + F::ZERO + } else { + next_cols.work_vars.carry_e[[i, j - 1]] + } + - cur_e_limb; + let carry_e = sum * (F::from_canonical_u32(1 << 16).inverse()); + + let sum = t1_limb_sum + + t2_limb_sum + + if j == 0 { + F::ZERO + } else { + next_cols.work_vars.carry_a[[i, j - 1]] + } + - cur_a_limb; + let carry_a = sum * (F::from_canonical_u32(1 << 16).inverse()); + next_cols.work_vars.carry_e[[i, j]] = carry_e; + next_cols.work_vars.carry_a[[i, j]] = carry_a; + } + } + } + + /// Puts the correct intermed_4 in the `next_row` + fn generate_intermed_4( + local_cols: ShaRoundColsRef, + next_cols: &mut ShaRoundColsRefMut, + ) { + let w = [ + local_cols + .message_schedule + .w + .rows() + .into_iter() + .collect::>(), + next_cols + .message_schedule + .w + .rows() + .into_iter() + .collect::>(), + ] + .concat(); + let w_limbs: Vec> = w + .iter() + .map(|x| { + (0..C::WORD_U16S) + .map(|i| compose::(&x.as_slice().unwrap()[i * 16..(i + 1) * 16], 1)) + .collect::>() + }) + .collect(); + for i in 0..C::ROUNDS_PER_ROW { + let sig_w = small_sig0_field::(w[i + 1].as_slice().unwrap()); + let sig_w_limbs: Vec = (0..C::WORD_U16S) + .map(|j| compose::(&sig_w[j * 16..(j + 1) * 16], 1)) + .collect(); + for (j, sig_w_limb) in sig_w_limbs.iter().enumerate() { + next_cols.schedule_helper.intermed_4[[i, j]] = w_limbs[i][j] + *sig_w_limb; + } + } + } + + /// Puts the needed intermed_12 in the `local_row` + fn generate_intermed_12( + local_cols: &mut ShaRoundColsRefMut, + next_cols: ShaRoundColsRef, + ) { + let w = [ + local_cols + .message_schedule + .w + .rows() + .into_iter() + .collect::>(), + next_cols + .message_schedule + .w + .rows() + .into_iter() + .collect::>(), + ] + .concat(); + let w_limbs: Vec> = w + .iter() + .map(|x| { + (0..C::WORD_U16S) + .map(|i| compose::(&x.as_slice().unwrap()[i * 16..(i + 1) * 16], 1)) + .collect::>() + }) + .collect(); + for i in 0..C::ROUNDS_PER_ROW { + // sig_1(w_{t-2}) + let sig_w_2: Vec = (0..C::WORD_U16S) + .map(|j| { + compose::( + &small_sig1_field::(w[i + 2].as_slice().unwrap()) + [j * 16..(j + 1) * 16], + 1, + ) + }) + .collect(); + // w_{t-7} + let w_7 = if i < 3 { + local_cols.schedule_helper.w_3.row(i).to_slice().unwrap() + } else { + w_limbs[i - 3].as_slice() + }; + // w_t + let w_cur = w_limbs[i + 4].as_slice(); + for j in 0..C::WORD_U16S { + let carry = next_cols.message_schedule.carry_or_buffer[[i, j * 2]] + + F::TWO * next_cols.message_schedule.carry_or_buffer[[i, j * 2 + 1]]; + let sum = sig_w_2[j] + w_7[j] - carry * F::from_canonical_u32(1 << 16) - w_cur[j] + + if j > 0 { + next_cols.message_schedule.carry_or_buffer[[i, j * 2 - 2]] + + F::from_canonical_u32(2) + * next_cols.message_schedule.carry_or_buffer[[i, j * 2 - 1]] + } else { + F::ZERO + }; + local_cols.schedule_helper.intermed_12[[i, j]] = -sum; + } + } + } +} + +/// `records` consists of pairs of `(input_block, is_last_block)`. +pub fn generate_trace( + step: &Sha2StepHelper, + bitwise_lookup_chip: SharedBitwiseOperationLookupChip<8>, + width: usize, + records: Vec<(Vec, bool)>, +) -> RowMajorMatrix { + for (input, _) in &records { + debug_assert!(input.len() == C::BLOCK_U8S); + } + + let non_padded_height = records.len() * C::ROWS_PER_BLOCK; + let height = next_power_of_two_or_zero(non_padded_height); + let mut values = F::zero_vec(height * width); + + struct BlockContext { + prev_hash: Vec, // len is C::HASH_WORDS + local_block_idx: u32, + global_block_idx: u32, + input: Vec, // len is C::BLOCK_U8S + is_last_block: bool, + } + let mut block_ctx: Vec> = Vec::with_capacity(records.len()); + let mut prev_hash = C::get_h().to_vec(); + let mut local_block_idx = 0; + let mut global_block_idx = 1; + for (input, is_last_block) in records { + block_ctx.push(BlockContext { + prev_hash: prev_hash.clone(), + local_block_idx, + global_block_idx, + input: input.clone(), + is_last_block, + }); + global_block_idx += 1; + if is_last_block { + local_block_idx = 0; + prev_hash = C::get_h().to_vec(); + } else { + local_block_idx += 1; + prev_hash = Sha2StepHelper::::get_block_hash(&prev_hash, input); + } + } + // first pass + values + .par_chunks_exact_mut(width * C::ROWS_PER_BLOCK) + .zip(block_ctx) + .for_each(|(block, ctx)| { + let BlockContext { + prev_hash, + local_block_idx, + global_block_idx, + input, + is_last_block, + } = ctx; + let input_words = (0..C::BLOCK_WORDS) + .map(|i| { + le_limbs_into_word::( + &(0..C::WORD_U8S) + .map(|j| input[(i + 1) * C::WORD_U8S - j - 1] as u32) + .collect::>(), + ) + }) + .collect::>(); + step.generate_block_trace( + block, + width, + 0, + &input_words, + bitwise_lookup_chip.clone(), + &prev_hash, + is_last_block, + global_block_idx, + local_block_idx, + ); + }); + // second pass: padding rows + values[width * non_padded_height..] + .par_chunks_mut(width) + .for_each(|row| { + let cols: ShaRoundColsRefMut = ShaRoundColsRefMut::from::(row); + step.generate_default_row(cols); + }); + + // second pass: non-padding rows + values[width..] + .par_chunks_mut(width * C::ROWS_PER_BLOCK) + .take(non_padded_height / C::ROWS_PER_BLOCK) + .for_each(|chunk| { + step.generate_missing_cells(chunk, width, 0); + }); + RowMajorMatrix::new(values, width) +} diff --git a/crates/circuits/sha2-air/src/utils.rs b/crates/circuits/sha2-air/src/utils.rs new file mode 100644 index 0000000000..35d4446318 --- /dev/null +++ b/crates/circuits/sha2-air/src/utils.rs @@ -0,0 +1,289 @@ +pub use openvm_circuit_primitives::utils::compose; +use openvm_circuit_primitives::{ + encoder::Encoder, + utils::{not, select}, +}; +use openvm_stark_backend::{p3_air::AirBuilder, p3_field::FieldAlgebra}; +use rand::{rngs::StdRng, Rng}; + +use crate::{RotateRight, Sha2Config}; + +/// Convert a word into a list of 8-bit limbs in little endian +pub fn word_into_u8_limbs(num: impl Into) -> Vec { + word_into_limbs::(num.into(), C::WORD_U8S) +} + +/// Convert a word into a list of 16-bit limbs in little endian +pub fn word_into_u16_limbs(num: impl Into) -> Vec { + word_into_limbs::(num.into(), C::WORD_U16S) +} + +/// Convert a word into a list of 1-bit limbs in little endian +pub fn word_into_bits(num: impl Into) -> Vec { + word_into_limbs::(num.into(), C::WORD_BITS) +} + +/// Convert a word into a list of limbs in little endian +pub fn word_into_limbs(num: C::Word, num_limbs: usize) -> Vec { + let limb_bits = std::mem::size_of::() * 8 / num_limbs; + (0..num_limbs) + .map(|i| { + let shifted = num >> (limb_bits * i); + let mask: C::Word = ((1u32 << limb_bits) - 1).into(); + let masked = shifted & mask; + masked.try_into().unwrap() + }) + .collect() +} + +/// Convert a u32 into a list of 1-bit limbs in little endian +pub fn u32_into_bits(num: u32) -> Vec { + let limb_bits = 32 / C::WORD_BITS; + (0..C::WORD_BITS) + .map(|i| (num >> (limb_bits * i)) & ((1 << limb_bits) - 1)) + .collect() +} + +/// Convert a list of limbs in little endian into a Word +pub fn le_limbs_into_word(limbs: &[u32]) -> C::Word { + let mut limbs = limbs.to_vec(); + limbs.reverse(); + be_limbs_into_word::(&limbs) +} + +/// Convert a list of limbs in big endian into a Word +pub fn be_limbs_into_word(limbs: &[u32]) -> C::Word { + let limb_bits = C::WORD_BITS / limbs.len(); + limbs.iter().fold(C::Word::from(0), |acc, &limb| { + (acc << limb_bits) | limb.into() + }) +} + +/// Convert a list of limbs in little endian into a u32 +pub fn limbs_into_u32(limbs: &[u32]) -> u32 { + let limb_bits = 32 / limbs.len(); + limbs + .iter() + .rev() + .fold(0, |acc, &limb| (acc << limb_bits) | limb) +} + +/// Rotates `bits` right by `n` bits, assumes `bits` is in little-endian +#[inline] +pub(crate) fn rotr(bits: &[impl Into + Clone], n: usize) -> Vec { + (0..bits.len()) + .map(|i| bits[(i + n) % bits.len()].clone().into()) + .collect() +} + +/// Shifts `bits` right by `n` bits, assumes `bits` is in little-endian +#[inline] +pub(crate) fn shr(bits: &[impl Into + Clone], n: usize) -> Vec { + (0..bits.len()) + .map(|i| { + if i + n < bits.len() { + bits[i + n].clone().into() + } else { + F::ZERO + } + }) + .collect() +} + +/// Computes x ^ y ^ z, where x, y, z are assumed to be boolean +#[inline] +pub(crate) fn xor_bit( + x: impl Into, + y: impl Into, + z: impl Into, +) -> F { + let (x, y, z) = (x.into(), y.into(), z.into()); + (x.clone() * y.clone() * z.clone()) + + (x.clone() * not::(y.clone()) * not::(z.clone())) + + (not::(x.clone()) * y.clone() * not::(z.clone())) + + (not::(x) * not::(y) * z) +} + +/// Computes x ^ y ^ z, where x, y, z are [C::WORD_BITS] bit numbers +#[inline] +pub(crate) fn xor( + x: &[impl Into + Clone], + y: &[impl Into + Clone], + z: &[impl Into + Clone], +) -> Vec { + (0..x.len()) + .map(|i| xor_bit(x[i].clone(), y[i].clone(), z[i].clone())) + .collect() +} + +/// Choose function from the SHA spec +#[inline] +pub fn ch(x: C::Word, y: C::Word, z: C::Word) -> C::Word { + (x & y) ^ ((!x) & z) +} + +/// Computes Ch(x,y,z), where x, y, z are [C::WORD_BITS] bit numbers +#[inline] +pub(crate) fn ch_field( + x: &[impl Into + Clone], + y: &[impl Into + Clone], + z: &[impl Into + Clone], +) -> Vec { + (0..x.len()) + .map(|i| select(x[i].clone(), y[i].clone(), z[i].clone())) + .collect() +} + +/// Majority function from the SHA spec +pub fn maj(x: C::Word, y: C::Word, z: C::Word) -> C::Word { + (x & y) ^ (x & z) ^ (y & z) +} + +/// Computes Maj(x,y,z), where x, y, z are [C::WORD_BITS] bit numbers +#[inline] +pub(crate) fn maj_field( + x: &[impl Into + Clone], + y: &[impl Into + Clone], + z: &[impl Into + Clone], +) -> Vec { + (0..x.len()) + .map(|i| { + let (x, y, z) = ( + x[i].clone().into(), + y[i].clone().into(), + z[i].clone().into(), + ); + x.clone() * y.clone() + x.clone() * z.clone() + y.clone() * z.clone() + - F::TWO * x * y * z + }) + .collect() +} + +/// Big sigma_0 function from the SHA spec +pub fn big_sig0(x: C::Word) -> C::Word { + if C::WORD_BITS == 32 { + x.rotate_right(2) ^ x.rotate_right(13) ^ x.rotate_right(22) + } else { + x.rotate_right(28) ^ x.rotate_right(34) ^ x.rotate_right(39) + } +} + +/// Computes BigSigma0(x), where x is a [C::WORD_BITS] bit number in little-endian +#[inline] +pub(crate) fn big_sig0_field( + x: &[impl Into + Clone], +) -> Vec { + if C::WORD_BITS == 32 { + xor(&rotr::(x, 2), &rotr::(x, 13), &rotr::(x, 22)) + } else { + xor(&rotr::(x, 28), &rotr::(x, 34), &rotr::(x, 39)) + } +} + +/// Big sigma_1 function from the SHA spec +pub fn big_sig1(x: C::Word) -> C::Word { + if C::WORD_BITS == 32 { + x.rotate_right(6) ^ x.rotate_right(11) ^ x.rotate_right(25) + } else { + x.rotate_right(14) ^ x.rotate_right(18) ^ x.rotate_right(41) + } +} + +/// Computes BigSigma1(x), where x is a [C::WORD_BITS] bit number in little-endian +#[inline] +pub(crate) fn big_sig1_field( + x: &[impl Into + Clone], +) -> Vec { + if C::WORD_BITS == 32 { + xor(&rotr::(x, 6), &rotr::(x, 11), &rotr::(x, 25)) + } else { + xor(&rotr::(x, 14), &rotr::(x, 18), &rotr::(x, 41)) + } +} + +/// Small sigma_0 function from the SHA spec +pub fn small_sig0(x: C::Word) -> C::Word { + if C::WORD_BITS == 32 { + x.rotate_right(7) ^ x.rotate_right(18) ^ (x >> 3) + } else { + x.rotate_right(1) ^ x.rotate_right(8) ^ (x >> 7) + } +} + +/// Computes SmallSigma0(x), where x is a [C::WORD_BITS] bit number in little-endian +#[inline] +pub(crate) fn small_sig0_field( + x: &[impl Into + Clone], +) -> Vec { + if C::WORD_BITS == 32 { + xor(&rotr::(x, 7), &rotr::(x, 18), &shr::(x, 3)) + } else { + xor(&rotr::(x, 1), &rotr::(x, 8), &shr::(x, 7)) + } +} + +/// Small sigma_1 function from the SHA spec +pub fn small_sig1(x: C::Word) -> C::Word { + if C::WORD_BITS == 32 { + x.rotate_right(17) ^ x.rotate_right(19) ^ (x >> 10) + } else { + x.rotate_right(19) ^ x.rotate_right(61) ^ (x >> 6) + } +} + +/// Computes SmallSigma1(x), where x is a [C::WORD_BITS] bit number in little-endian +#[inline] +pub(crate) fn small_sig1_field( + x: &[impl Into + Clone], +) -> Vec { + if C::WORD_BITS == 32 { + xor(&rotr::(x, 17), &rotr::(x, 19), &shr::(x, 10)) + } else { + xor(&rotr::(x, 19), &rotr::(x, 61), &shr::(x, 6)) + } +} + +/// Generate a random message of a given length +pub fn get_random_message(rng: &mut StdRng, len: usize) -> Vec { + let mut random_message: Vec = vec![0u8; len]; + rng.fill(&mut random_message[..]); + random_message +} + +/// Wrapper of `get_flag_pt` to get the flag pointer as an array +pub fn get_flag_pt_array(encoder: &Encoder, flag_idx: usize) -> Vec { + encoder.get_flag_pt(flag_idx) +} + +/// Constrain the addition of [C::WORD_BITS] bit words in 16-bit limbs +/// It takes in the terms some in bits some in 16-bit limbs, +/// the expected sum in bits and the carries +pub fn constraint_word_addition( + builder: &mut AB, + terms_bits: &[&[impl Into + Clone]], + terms_limb: &[&[impl Into + Clone]], + expected_sum: &[impl Into + Clone], + carries: &[impl Into + Clone], +) { + debug_assert!(terms_bits.iter().all(|x| x.len() == C::WORD_BITS)); + debug_assert!(terms_limb.iter().all(|x| x.len() == C::WORD_U16S)); + assert_eq!(expected_sum.len(), C::WORD_BITS); + assert_eq!(carries.len(), C::WORD_U16S); + + for i in 0..C::WORD_U16S { + let mut limb_sum = if i == 0 { + AB::Expr::ZERO + } else { + carries[i - 1].clone().into() + }; + for term in terms_bits { + limb_sum += compose::(&term[i * 16..(i + 1) * 16], 1); + } + for term in terms_limb { + limb_sum += term[i].clone().into(); + } + let expected_sum_limb = compose::(&expected_sum[i * 16..(i + 1) * 16], 1) + + carries[i].clone().into() * AB::Expr::from_canonical_u32(1 << 16); + builder.assert_eq(limb_sum, expected_sum_limb); + } +} diff --git a/crates/circuits/sha256-air/src/air.rs b/crates/circuits/sha256-air/src/air.rs deleted file mode 100644 index b27af6ffa9..0000000000 --- a/crates/circuits/sha256-air/src/air.rs +++ /dev/null @@ -1,612 +0,0 @@ -use std::{array, borrow::Borrow, cmp::max, iter::once}; - -use openvm_circuit_primitives::{ - bitwise_op_lookup::BitwiseOperationLookupBus, - encoder::Encoder, - utils::{not, select}, - SubAir, -}; -use openvm_stark_backend::{ - interaction::{BusIndex, InteractionBuilder, PermutationCheckBus}, - p3_air::{AirBuilder, BaseAir}, - p3_field::{Field, FieldAlgebra}, - p3_matrix::Matrix, -}; - -use super::{ - big_sig0_field, big_sig1_field, ch_field, compose, maj_field, small_sig0_field, - small_sig1_field, Sha256DigestCols, Sha256RoundCols, SHA256_DIGEST_WIDTH, SHA256_H, - SHA256_HASH_WORDS, SHA256_K, SHA256_ROUNDS_PER_ROW, SHA256_ROUND_WIDTH, SHA256_WORD_BITS, - SHA256_WORD_U16S, SHA256_WORD_U8S, -}; -use crate::{constraint_word_addition, u32_into_u16s}; - -/// Expects the message to be padded to a multiple of 512 bits -#[derive(Clone, Debug)] -pub struct Sha256Air { - pub bitwise_lookup_bus: BitwiseOperationLookupBus, - pub row_idx_encoder: Encoder, - /// Internal bus for self-interactions in this AIR. - bus: PermutationCheckBus, -} - -impl Sha256Air { - pub fn new(bitwise_lookup_bus: BitwiseOperationLookupBus, self_bus_idx: BusIndex) -> Self { - Self { - bitwise_lookup_bus, - row_idx_encoder: Encoder::new(18, 2, false), - bus: PermutationCheckBus::new(self_bus_idx), - } - } -} - -impl BaseAir for Sha256Air { - fn width(&self) -> usize { - max( - Sha256RoundCols::::width(), - Sha256DigestCols::::width(), - ) - } -} - -impl SubAir for Sha256Air { - /// The start column for the sub-air to use - type AirContext<'a> - = usize - where - Self: 'a, - AB: 'a, - ::Var: 'a, - ::Expr: 'a; - - fn eval<'a>(&'a self, builder: &'a mut AB, start_col: Self::AirContext<'a>) - where - ::Var: 'a, - ::Expr: 'a, - { - self.eval_row(builder, start_col); - self.eval_transitions(builder, start_col); - } -} - -impl Sha256Air { - /// Implements the single row constraints (i.e. imposes constraints only on local) - /// Implements some sanity constraints on the row index, flags, and work variables - fn eval_row(&self, builder: &mut AB, start_col: usize) { - let main = builder.main(); - let local = main.row_slice(0); - - // Doesn't matter which column struct we use here as we are only interested in the common - // columns - let local_cols: &Sha256DigestCols = - local[start_col..start_col + SHA256_DIGEST_WIDTH].borrow(); - let flags = &local_cols.flags; - builder.assert_bool(flags.is_round_row); - builder.assert_bool(flags.is_first_4_rows); - builder.assert_bool(flags.is_digest_row); - builder.assert_bool(flags.is_round_row + flags.is_digest_row); - builder.assert_bool(flags.is_last_block); - - self.row_idx_encoder - .eval(builder, &local_cols.flags.row_idx); - builder.assert_one( - self.row_idx_encoder - .contains_flag_range::(&local_cols.flags.row_idx, 0..=17), - ); - builder.assert_eq( - self.row_idx_encoder - .contains_flag_range::(&local_cols.flags.row_idx, 0..=3), - flags.is_first_4_rows, - ); - builder.assert_eq( - self.row_idx_encoder - .contains_flag_range::(&local_cols.flags.row_idx, 0..=15), - flags.is_round_row, - ); - builder.assert_eq( - self.row_idx_encoder - .contains_flag::(&local_cols.flags.row_idx, &[16]), - flags.is_digest_row, - ); - // If padding row we want the row_idx to be 17 - builder.assert_eq( - self.row_idx_encoder - .contains_flag::(&local_cols.flags.row_idx, &[17]), - flags.is_padding_row(), - ); - - // Constrain a, e, being composed of bits: we make sure a and e are always in the same place - // in the trace matrix Note: this has to be true for every row, even padding rows - for i in 0..SHA256_ROUNDS_PER_ROW { - for j in 0..SHA256_WORD_BITS { - builder.assert_bool(local_cols.hash.a[i][j]); - builder.assert_bool(local_cols.hash.e[i][j]); - } - } - } - - /// Implements constraints for a digest row that ensure proper state transitions between blocks - /// This validates that: - /// The work variables are correctly initialized for the next message block - /// For the last message block, the initial state matches SHA256_H constants - fn eval_digest_row( - &self, - builder: &mut AB, - local: &Sha256RoundCols, - next: &Sha256DigestCols, - ) { - // Check that if this is the last row of a message or an inpadding row, the hash should be - // the [SHA256_H] - for i in 0..SHA256_ROUNDS_PER_ROW { - let a = next.hash.a[i].map(|x| x.into()); - let e = next.hash.e[i].map(|x| x.into()); - for j in 0..SHA256_WORD_U16S { - let a_limb = compose::(&a[j * 16..(j + 1) * 16], 1); - let e_limb = compose::(&e[j * 16..(j + 1) * 16], 1); - - // If it is a padding row or the last row of a message, the `hash` should be the - // [SHA256_H] - builder - .when( - next.flags.is_padding_row() - + next.flags.is_last_block * next.flags.is_digest_row, - ) - .assert_eq( - a_limb, - AB::Expr::from_canonical_u32( - u32_into_u16s(SHA256_H[SHA256_ROUNDS_PER_ROW - i - 1])[j], - ), - ); - - builder - .when( - next.flags.is_padding_row() - + next.flags.is_last_block * next.flags.is_digest_row, - ) - .assert_eq( - e_limb, - AB::Expr::from_canonical_u32( - u32_into_u16s(SHA256_H[SHA256_ROUNDS_PER_ROW - i + 3])[j], - ), - ); - } - } - - // Check if last row of a non-last block, the `hash` should be equal to the final hash of - // the current block - for i in 0..SHA256_ROUNDS_PER_ROW { - let prev_a = next.hash.a[i].map(|x| x.into()); - let prev_e = next.hash.e[i].map(|x| x.into()); - let cur_a = next.final_hash[SHA256_ROUNDS_PER_ROW - i - 1].map(|x| x.into()); - - let cur_e = next.final_hash[SHA256_ROUNDS_PER_ROW - i + 3].map(|x| x.into()); - for j in 0..SHA256_WORD_U8S { - let prev_a_limb = compose::(&prev_a[j * 8..(j + 1) * 8], 1); - let prev_e_limb = compose::(&prev_e[j * 8..(j + 1) * 8], 1); - - builder - .when(not(next.flags.is_last_block) * next.flags.is_digest_row) - .assert_eq(prev_a_limb, cur_a[j].clone()); - - builder - .when(not(next.flags.is_last_block) * next.flags.is_digest_row) - .assert_eq(prev_e_limb, cur_e[j].clone()); - } - } - - // Assert that the previous hash + work vars == final hash. - // That is, `next.prev_hash[i] + local.work_vars[i] == next.final_hash[i]` - // where addition is done modulo 2^32 - for i in 0..SHA256_HASH_WORDS { - let mut carry = AB::Expr::ZERO; - for j in 0..SHA256_WORD_U16S { - let work_var_limb = if i < SHA256_ROUNDS_PER_ROW { - compose::( - &local.work_vars.a[SHA256_ROUNDS_PER_ROW - 1 - i][j * 16..(j + 1) * 16], - 1, - ) - } else { - compose::( - &local.work_vars.e[SHA256_ROUNDS_PER_ROW + 3 - i][j * 16..(j + 1) * 16], - 1, - ) - }; - let final_hash_limb = - compose::(&next.final_hash[i][j * 2..(j + 1) * 2], 8); - - carry = AB::Expr::from(AB::F::from_canonical_u32(1 << 16).inverse()) - * (next.prev_hash[i][j] + work_var_limb + carry - final_hash_limb); - builder - .when(next.flags.is_digest_row) - .assert_bool(carry.clone()); - } - // constrain the final hash limbs two at a time since we can do two checks per - // interaction - for chunk in next.final_hash[i].chunks(2) { - self.bitwise_lookup_bus - .send_range(chunk[0], chunk[1]) - .eval(builder, next.flags.is_digest_row); - } - } - } - - fn eval_transitions(&self, builder: &mut AB, start_col: usize) { - let main = builder.main(); - let local = main.row_slice(0); - let next = main.row_slice(1); - - // Doesn't matter what column structs we use here - let local_cols: &Sha256RoundCols = - local[start_col..start_col + SHA256_ROUND_WIDTH].borrow(); - let next_cols: &Sha256RoundCols = - next[start_col..start_col + SHA256_ROUND_WIDTH].borrow(); - - let local_is_padding_row = local_cols.flags.is_padding_row(); - // Note that there will always be a padding row in the trace since the unpadded height is a - // multiple of 17. So the next row is padding iff the current block is the last - // block in the trace. - let next_is_padding_row = next_cols.flags.is_padding_row(); - - // We check that the very last block has `is_last_block` set to true, which guarantees that - // there is at least one complete message. If other digest rows have `is_last_block` set to - // true, then the trace will be interpreted as containing multiple messages. - builder - .when(next_is_padding_row.clone()) - .when(local_cols.flags.is_digest_row) - .assert_one(local_cols.flags.is_last_block); - // If we are in a round row, the next row cannot be a padding row - builder - .when(local_cols.flags.is_round_row) - .assert_zero(next_is_padding_row.clone()); - // The first row must be a round row - builder - .when_first_row() - .assert_one(local_cols.flags.is_round_row); - // If we are in a padding row, the next row must also be a padding row - builder - .when_transition() - .when(local_is_padding_row.clone()) - .assert_one(next_is_padding_row.clone()); - // If we are in a digest row, the next row cannot be a digest row - builder - .when(local_cols.flags.is_digest_row) - .assert_zero(next_cols.flags.is_digest_row); - // Constrain how much the row index changes by - // round->round: 1 - // round->digest: 1 - // digest->round: -16 - // digest->padding: 1 - // padding->padding: 0 - // Other transitions are not allowed by the above constraints - let delta = local_cols.flags.is_round_row * AB::Expr::ONE - + local_cols.flags.is_digest_row - * next_cols.flags.is_round_row - * AB::Expr::from_canonical_u32(16) - * AB::Expr::NEG_ONE - + local_cols.flags.is_digest_row * next_is_padding_row.clone() * AB::Expr::ONE; - - let local_row_idx = self.row_idx_encoder.flag_with_val::( - &local_cols.flags.row_idx, - &(0..18).map(|i| (i, i)).collect::>(), - ); - let next_row_idx = self.row_idx_encoder.flag_with_val::( - &next_cols.flags.row_idx, - &(0..18).map(|i| (i, i)).collect::>(), - ); - - builder - .when_transition() - .assert_eq(local_row_idx.clone() + delta, next_row_idx.clone()); - builder.when_first_row().assert_zero(local_row_idx); - - // Constrain the global block index - // We set the global block index to 0 for padding rows - // Starting with 1 so it is not the same as the padding rows - - // Global block index is 1 on first row - builder - .when_first_row() - .assert_one(local_cols.flags.global_block_idx); - - // Global block index is constant on all rows in a block - builder.when(local_cols.flags.is_round_row).assert_eq( - local_cols.flags.global_block_idx, - next_cols.flags.global_block_idx, - ); - // Global block index increases by 1 between blocks - builder - .when_transition() - .when(local_cols.flags.is_digest_row) - .when(next_cols.flags.is_round_row) - .assert_eq( - local_cols.flags.global_block_idx + AB::Expr::ONE, - next_cols.flags.global_block_idx, - ); - // Global block index is 0 on padding rows - builder - .when(local_is_padding_row.clone()) - .assert_zero(local_cols.flags.global_block_idx); - - // Constrain the local block index - // We set the local block index to 0 for padding rows - - // Local block index is constant on all rows in a block - // and its value on padding rows is equal to its value on the first block - builder.when(not(local_cols.flags.is_digest_row)).assert_eq( - local_cols.flags.local_block_idx, - next_cols.flags.local_block_idx, - ); - // Local block index increases by 1 between blocks in the same message - builder - .when(local_cols.flags.is_digest_row) - .when(not(local_cols.flags.is_last_block)) - .assert_eq( - local_cols.flags.local_block_idx + AB::Expr::ONE, - next_cols.flags.local_block_idx, - ); - // Local block index is 0 on padding rows - // Combined with the above, this means that the local block index is 0 in the first block - builder - .when(local_cols.flags.is_digest_row) - .when(local_cols.flags.is_last_block) - .assert_zero(next_cols.flags.local_block_idx); - - self.eval_message_schedule::(builder, local_cols, next_cols); - self.eval_work_vars::(builder, local_cols, next_cols); - let next_cols: &Sha256DigestCols = - next[start_col..start_col + SHA256_DIGEST_WIDTH].borrow(); - self.eval_digest_row(builder, local_cols, next_cols); - let local_cols: &Sha256DigestCols = - local[start_col..start_col + SHA256_DIGEST_WIDTH].borrow(); - self.eval_prev_hash::(builder, local_cols, next_is_padding_row); - } - - /// Constrains that the next block's `prev_hash` is equal to the current block's `hash` - /// Note: the constraining is done by interactions with the chip itself on every digest row - fn eval_prev_hash( - &self, - builder: &mut AB, - local: &Sha256DigestCols, - is_last_block_of_trace: AB::Expr, /* note this indicates the last block of the trace, - * not the last block of the message */ - ) { - // Constrain that next block's `prev_hash` is equal to the current block's `hash` - let composed_hash: [[::Expr; SHA256_WORD_U16S]; SHA256_HASH_WORDS] = - array::from_fn(|i| { - let hash_bits = if i < SHA256_ROUNDS_PER_ROW { - local.hash.a[SHA256_ROUNDS_PER_ROW - 1 - i].map(|x| x.into()) - } else { - local.hash.e[SHA256_ROUNDS_PER_ROW + 3 - i].map(|x| x.into()) - }; - array::from_fn(|j| compose::(&hash_bits[j * 16..(j + 1) * 16], 1)) - }); - // Need to handle the case if this is the very last block of the trace matrix - let next_global_block_idx = select( - is_last_block_of_trace, - AB::Expr::ONE, - local.flags.global_block_idx + AB::Expr::ONE, - ); - // The following interactions constrain certain values from block to block - self.bus.send( - builder, - composed_hash - .into_iter() - .flatten() - .chain(once(next_global_block_idx)), - local.flags.is_digest_row, - ); - - self.bus.receive( - builder, - local - .prev_hash - .into_iter() - .flatten() - .map(|x| x.into()) - .chain(once(local.flags.global_block_idx.into())), - local.flags.is_digest_row, - ); - } - - /// Constrain the message schedule additions for `next` row - /// Note: For every addition we need to constrain the following for each of [SHA256_WORD_U16S] - /// limbs sig_1(w_{t-2})[i] + w_{t-7}[i] + sig_0(w_{t-15})[i] + w_{t-16}[i] + - /// carry_w[t][i-1] - carry_w[t][i] * 2^16 - w_t[i] == 0 Refer to [https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.180-4.pdf] - fn eval_message_schedule( - &self, - builder: &mut AB, - local: &Sha256RoundCols, - next: &Sha256RoundCols, - ) { - // This `w` array contains 8 message schedule words - w_{idx}, ..., w_{idx+7} for some idx - let w = [local.message_schedule.w, next.message_schedule.w].concat(); - - // Constrain `w_3` for `next` row - for i in 0..SHA256_ROUNDS_PER_ROW - 1 { - // here we constrain the w_3 of the i_th word of the next row - // w_3 of next is w[i+4-3] = w[i+1] - let w_3 = w[i + 1].map(|x| x.into()); - let expected_w_3 = next.schedule_helper.w_3[i]; - for j in 0..SHA256_WORD_U16S { - let w_3_limb = compose::(&w_3[j * 16..(j + 1) * 16], 1); - builder - .when(local.flags.is_round_row) - .assert_eq(w_3_limb, expected_w_3[j].into()); - } - } - - // Constrain intermed for `next` row - // We will only constrain intermed_12 for rows [3, 14], and let it be unconstrained for - // other rows Other rows should put the needed value in intermed_12 to make the - // below summation constraint hold - let is_row_3_14 = self - .row_idx_encoder - .contains_flag_range::(&next.flags.row_idx, 3..=14); - // We will only constrain intermed_8 for rows [2, 13], and let it unconstrained for other - // rows - let is_row_2_13 = self - .row_idx_encoder - .contains_flag_range::(&next.flags.row_idx, 2..=13); - for i in 0..SHA256_ROUNDS_PER_ROW { - // w_idx - let w_idx = w[i].map(|x| x.into()); - // sig_0(w_{idx+1}) - let sig_w = small_sig0_field::(&w[i + 1]); - for j in 0..SHA256_WORD_U16S { - let w_idx_limb = compose::(&w_idx[j * 16..(j + 1) * 16], 1); - let sig_w_limb = compose::(&sig_w[j * 16..(j + 1) * 16], 1); - - // We would like to constrain this only on rows 0..16, but we can't do a conditional - // check because the degree is already 3. So we must fill in - // `intermed_4` with dummy values on rows 0 and 16 to ensure the constraint holds on - // these rows. - builder.when_transition().assert_eq( - next.schedule_helper.intermed_4[i][j], - w_idx_limb + sig_w_limb, - ); - - builder.when(is_row_2_13.clone()).assert_eq( - next.schedule_helper.intermed_8[i][j], - local.schedule_helper.intermed_4[i][j], - ); - - builder.when(is_row_3_14.clone()).assert_eq( - next.schedule_helper.intermed_12[i][j], - local.schedule_helper.intermed_8[i][j], - ); - } - } - - // Constrain the message schedule additions for `next` row - for i in 0..SHA256_ROUNDS_PER_ROW { - // Note, here by w_{t} we mean the i_th word of the `next` row - // w_{t-7} - let w_7 = if i < 3 { - local.schedule_helper.w_3[i].map(|x| x.into()) - } else { - let w_3 = w[i - 3].map(|x| x.into()); - array::from_fn(|j| compose::(&w_3[j * 16..(j + 1) * 16], 1)) - }; - // sig_0(w_{t-15}) + w_{t-16} - let intermed_16 = local.schedule_helper.intermed_12[i].map(|x| x.into()); - - let carries = array::from_fn(|j| { - next.message_schedule.carry_or_buffer[i][j * 2] - + AB::Expr::TWO * next.message_schedule.carry_or_buffer[i][j * 2 + 1] - }); - - // Constrain `W_{idx} = sig_1(W_{idx-2}) + W_{idx-7} + sig_0(W_{idx-15}) + W_{idx-16}` - // We would like to constrain this only on rows 4..16, but we can't do a conditional - // check because the degree of sum is already 3 So we must fill in - // `intermed_12` with dummy values on rows 0..3 and 15 and 16 to ensure the constraint - // holds on rows 0..4 and 16. Note that the dummy value goes in the previous - // row to make the current row's constraint hold. - constraint_word_addition( - // Note: here we can't do a conditional check because the degree of sum is already - // 3 - &mut builder.when_transition(), - &[&small_sig1_field::(&w[i + 2])], - &[&w_7, &intermed_16], - &w[i + 4], - &carries, - ); - - for j in 0..SHA256_WORD_U16S { - // When on rows 4..16 message schedule carries should be 0 or 1 - let is_row_4_15 = next.flags.is_round_row - next.flags.is_first_4_rows; - builder - .when(is_row_4_15.clone()) - .assert_bool(next.message_schedule.carry_or_buffer[i][j * 2]); - builder - .when(is_row_4_15) - .assert_bool(next.message_schedule.carry_or_buffer[i][j * 2 + 1]); - } - // Constrain w being composed of bits - for j in 0..SHA256_WORD_BITS { - builder - .when(next.flags.is_round_row) - .assert_bool(next.message_schedule.w[i][j]); - } - } - } - - /// Constrain the work vars on `next` row according to the sha256 documentation - /// Refer to [https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.180-4.pdf] - fn eval_work_vars( - &self, - builder: &mut AB, - local: &Sha256RoundCols, - next: &Sha256RoundCols, - ) { - let a = [local.work_vars.a, next.work_vars.a].concat(); - let e = [local.work_vars.e, next.work_vars.e].concat(); - for i in 0..SHA256_ROUNDS_PER_ROW { - for j in 0..SHA256_WORD_U16S { - // Although we need carry_a <= 6 and carry_e <= 5, constraining carry_a, carry_e in - // [0, 2^8) is enough to prevent overflow and ensure the soundness - // of the addition we want to check - self.bitwise_lookup_bus - .send_range(local.work_vars.carry_a[i][j], local.work_vars.carry_e[i][j]) - .eval(builder, local.flags.is_round_row); - } - - let w_limbs = array::from_fn(|j| { - compose::(&next.message_schedule.w[i][j * 16..(j + 1) * 16], 1) - * next.flags.is_round_row - }); - let k_limbs = array::from_fn(|j| { - self.row_idx_encoder.flag_with_val::( - &next.flags.row_idx, - &(0..16) - .map(|rw_idx| { - ( - rw_idx, - u32_into_u16s(SHA256_K[rw_idx * SHA256_ROUNDS_PER_ROW + i])[j] - as usize, - ) - }) - .collect::>(), - ) - }); - - // Constrain `a = h + sig_1(e) + ch(e, f, g) + K + W + sig_0(a) + Maj(a, b, c)` - // We have to enforce this constraint on all rows since the degree of the constraint is - // already 3. So, we must fill in `carry_a` with dummy values on digest rows - // to ensure the constraint holds. - constraint_word_addition( - builder, - &[ - &e[i].map(|x| x.into()), // previous `h` - &big_sig1_field::(&e[i + 3]), // sig_1 of previous `e` - &ch_field::(&e[i + 3], &e[i + 2], &e[i + 1]), /* Ch of previous - * `e`, `f`, `g` */ - &big_sig0_field::(&a[i + 3]), // sig_0 of previous `a` - &maj_field::(&a[i + 3], &a[i + 2], &a[i + 1]), /* Maj of previous - * a, b, c */ - ], - &[&w_limbs, &k_limbs], // K and W - &a[i + 4], // new `a` - &next.work_vars.carry_a[i], // carries of addition - ); - - // Constrain `e = d + h + sig_1(e) + ch(e, f, g) + K + W` - // We have to enforce this constraint on all rows since the degree of the constraint is - // already 3. So, we must fill in `carry_e` with dummy values on digest rows - // to ensure the constraint holds. - constraint_word_addition( - builder, - &[ - &a[i].map(|x| x.into()), // previous `d` - &e[i].map(|x| x.into()), // previous `h` - &big_sig1_field::(&e[i + 3]), /* sig_1 of previous - * `e` */ - &ch_field::(&e[i + 3], &e[i + 2], &e[i + 1]), /* Ch of previous - * `e`, `f`, `g` */ - ], - &[&w_limbs, &k_limbs], // K and W - &e[i + 4], // new `e` - &next.work_vars.carry_e[i], // carries of addition - ); - } - } -} diff --git a/crates/circuits/sha256-air/src/columns.rs b/crates/circuits/sha256-air/src/columns.rs deleted file mode 100644 index 1c735394c3..0000000000 --- a/crates/circuits/sha256-air/src/columns.rs +++ /dev/null @@ -1,140 +0,0 @@ -//! WARNING: the order of fields in the structs is important, do not change it - -use openvm_circuit_primitives::{utils::not, AlignedBorrow}; -use openvm_stark_backend::p3_field::FieldAlgebra; - -use super::{ - SHA256_HASH_WORDS, SHA256_ROUNDS_PER_ROW, SHA256_ROW_VAR_CNT, SHA256_WORD_BITS, - SHA256_WORD_U16S, SHA256_WORD_U8S, -}; - -/// In each SHA256 block: -/// - First 16 rows use Sha256RoundCols -/// - Final row uses Sha256DigestCols -/// -/// Note that for soundness, we require that there is always a padding row after the last digest row -/// in the trace. Right now, this is true because the unpadded height is a multiple of 17, and thus -/// not a power of 2. -/// -/// Sha256RoundCols and Sha256DigestCols share the same first 3 fields: -/// - flags -/// - work_vars/hash (same type, different name) -/// - schedule_helper -/// -/// This design allows for: -/// 1. Common constraints to work on either struct type by accessing these shared fields -/// 2. Specific constraints to use the appropriate struct, with flags helping to do conditional -/// constraints -/// -/// Note that the `Sha256WorkVarsCols` field it is used for different purposes in the two structs. -#[repr(C)] -#[derive(Clone, Copy, Debug, AlignedBorrow)] -pub struct Sha256RoundCols { - pub flags: Sha256FlagsCols, - /// Stores the current state of the working variables - pub work_vars: Sha256WorkVarsCols, - pub schedule_helper: Sha256MessageHelperCols, - pub message_schedule: Sha256MessageScheduleCols, -} - -#[repr(C)] -#[derive(Clone, Copy, Debug, AlignedBorrow)] -pub struct Sha256DigestCols { - pub flags: Sha256FlagsCols, - /// Will serve as previous hash values for the next block. - /// - on non-last blocks, this is the final hash of the current block - /// - on last blocks, this is the initial state constants, SHA256_H. - /// The work variables constraints are applied on all rows, so `carry_a` and `carry_e` - /// must be filled in with dummy values to ensure these constraints hold. - pub hash: Sha256WorkVarsCols, - pub schedule_helper: Sha256MessageHelperCols, - /// The actual final hash values of the given block - /// Note: the above `hash` will be equal to `final_hash` unless we are on the last block - pub final_hash: [[T; SHA256_WORD_U8S]; SHA256_HASH_WORDS], - /// The final hash of the previous block - /// Note: will be constrained using interactions with the chip itself - pub prev_hash: [[T; SHA256_WORD_U16S]; SHA256_HASH_WORDS], -} - -#[repr(C)] -#[derive(Clone, Copy, Debug, AlignedBorrow)] -pub struct Sha256MessageScheduleCols { - /// The message schedule words as 32-bit integers - /// The first 16 words will be the message data - pub w: [[T; SHA256_WORD_BITS]; SHA256_ROUNDS_PER_ROW], - /// Will be message schedule carries for rows 4..16 and a buffer for rows 0..4 to be used - /// freely by wrapper chips Note: carries are 2 bit numbers represented using 2 cells as - /// individual bits - pub carry_or_buffer: [[T; SHA256_WORD_U8S]; SHA256_ROUNDS_PER_ROW], -} - -#[repr(C)] -#[derive(Clone, Copy, Debug, AlignedBorrow)] -pub struct Sha256WorkVarsCols { - /// `a` and `e` after each iteration as 32-bits - pub a: [[T; SHA256_WORD_BITS]; SHA256_ROUNDS_PER_ROW], - pub e: [[T; SHA256_WORD_BITS]; SHA256_ROUNDS_PER_ROW], - /// The carry's used for addition during each iteration when computing `a` and `e` - pub carry_a: [[T; SHA256_WORD_U16S]; SHA256_ROUNDS_PER_ROW], - pub carry_e: [[T; SHA256_WORD_U16S]; SHA256_ROUNDS_PER_ROW], -} - -/// These are the columns that are used to help with the message schedule additions -/// Note: these need to be correctly assigned for every row even on padding rows -#[repr(C)] -#[derive(Clone, Copy, Debug, AlignedBorrow)] -pub struct Sha256MessageHelperCols { - /// The following are used to move data forward to constrain the message schedule additions - /// The value of `w` (message schedule word) from 3 rounds ago - /// In general, `w_i` means `w` from `i` rounds ago - pub w_3: [[T; SHA256_WORD_U16S]; SHA256_ROUNDS_PER_ROW - 1], - /// Here intermediate(i) = w_i + sig_0(w_{i+1}) - /// Intermed_t represents the intermediate t rounds ago - /// This is needed to constrain the message schedule, since we can only constrain on two rows - /// at a time - pub intermed_4: [[T; SHA256_WORD_U16S]; SHA256_ROUNDS_PER_ROW], - pub intermed_8: [[T; SHA256_WORD_U16S]; SHA256_ROUNDS_PER_ROW], - pub intermed_12: [[T; SHA256_WORD_U16S]; SHA256_ROUNDS_PER_ROW], -} - -#[repr(C)] -#[derive(Clone, Copy, Debug, AlignedBorrow)] -pub struct Sha256FlagsCols { - /// A flag that indicates if the current row is among the first 16 rows of a block. - pub is_round_row: T, - /// A flag that indicates if the current row is among the first 4 rows of a block. - pub is_first_4_rows: T, - /// A flag that indicates if the current row is the last (17th) row of a block. - pub is_digest_row: T, - // A flag that indicates if the current row is the last block of the message. - // This flag is only used in digest rows. - pub is_last_block: T, - /// We will encode the row index [0..17) using 5 cells - pub row_idx: [T; SHA256_ROW_VAR_CNT], - /// The index of the current block in the trace starting at 1. - /// Set to 0 on padding rows. - pub global_block_idx: T, - /// The index of the current block in the current message starting at 0. - /// Resets after every message. - /// Set to 0 on padding rows. - pub local_block_idx: T, -} - -impl> Sha256FlagsCols { - // This refers to the padding rows that are added to the air to make the trace length a power of - // 2. Not to be confused with the padding added to messages as part of the SHA hash - // function. - pub fn is_not_padding_row(&self) -> O { - self.is_round_row + self.is_digest_row - } - - // This refers to the padding rows that are added to the air to make the trace length a power of - // 2. Not to be confused with the padding added to messages as part of the SHA hash - // function. - pub fn is_padding_row(&self) -> O - where - O: FieldAlgebra, - { - not(self.is_not_padding_row()) - } -} diff --git a/crates/circuits/sha256-air/src/trace.rs b/crates/circuits/sha256-air/src/trace.rs deleted file mode 100644 index 656f8340c1..0000000000 --- a/crates/circuits/sha256-air/src/trace.rs +++ /dev/null @@ -1,558 +0,0 @@ -use std::{array, borrow::BorrowMut, ops::Range}; - -use openvm_circuit_primitives::{ - bitwise_op_lookup::BitwiseOperationLookupChip, encoder::Encoder, - utils::next_power_of_two_or_zero, -}; -use openvm_stark_backend::{ - p3_field::PrimeField32, p3_matrix::dense::RowMajorMatrix, p3_maybe_rayon::prelude::*, -}; -use sha2::{compress256, digest::generic_array::GenericArray}; - -use super::{ - big_sig0_field, big_sig1_field, ch_field, columns::Sha256RoundCols, compose, get_flag_pt_array, - maj_field, small_sig0_field, small_sig1_field, SHA256_BLOCK_WORDS, SHA256_DIGEST_WIDTH, - SHA256_HASH_WORDS, SHA256_ROUND_WIDTH, -}; -use crate::{ - big_sig0, big_sig1, ch, columns::Sha256DigestCols, limbs_into_u32, maj, small_sig0, small_sig1, - u32_into_bits_field, u32_into_u16s, SHA256_BLOCK_U8S, SHA256_H, SHA256_INVALID_CARRY_A, - SHA256_INVALID_CARRY_E, SHA256_K, SHA256_ROUNDS_PER_ROW, SHA256_ROWS_PER_BLOCK, - SHA256_WORD_U16S, SHA256_WORD_U8S, -}; - -/// A helper struct for the SHA256 trace generation. -/// Also, separates the inner AIR from the trace generation. -pub struct Sha256StepHelper { - pub row_idx_encoder: Encoder, -} - -impl Default for Sha256StepHelper { - fn default() -> Self { - Self::new() - } -} - -/// The trace generation of SHA256 should be done in two passes. -/// The first pass should do `get_block_trace` for every block and generate the invalid rows through -/// `get_default_row` The second pass should go through all the blocks and call -/// `generate_missing_cells` -impl Sha256StepHelper { - pub fn new() -> Self { - Self { - row_idx_encoder: Encoder::new(18, 2, false), - } - } - /// This function takes the input_message (padding not handled), the previous hash, - /// and returns the new hash after processing the block input - pub fn get_block_hash( - prev_hash: &[u32; SHA256_HASH_WORDS], - input: [u8; SHA256_BLOCK_U8S], - ) -> [u32; SHA256_HASH_WORDS] { - let mut new_hash = *prev_hash; - let input_array = [GenericArray::from(input)]; - compress256(&mut new_hash, &input_array); - new_hash - } - - /// This function takes a 512-bit chunk of the input message (padding not handled), the previous - /// hash, a flag indicating if it's the last block, the global block index, the local block - /// index, and the buffer values that will be put in rows 0..4. - /// Will populate the given `trace` with the trace of the block, where the width of the trace is - /// `trace_width` and the starting column for the `Sha256Air` is `trace_start_col`. - /// **Note**: this function only generates some of the required trace. Another pass is required, - /// refer to [`Self::generate_missing_cells`] for details. - #[allow(clippy::too_many_arguments)] - pub fn generate_block_trace( - &self, - trace: &mut [F], - trace_width: usize, - trace_start_col: usize, - input: &[u32; SHA256_BLOCK_WORDS], - bitwise_lookup_chip: &BitwiseOperationLookupChip<8>, - prev_hash: &[u32; SHA256_HASH_WORDS], - is_last_block: bool, - global_block_idx: u32, - local_block_idx: u32, - ) { - #[cfg(debug_assertions)] - { - assert!(trace.len() == trace_width * SHA256_ROWS_PER_BLOCK); - assert!(trace_start_col + super::SHA256_WIDTH <= trace_width); - if local_block_idx == 0 { - assert!(*prev_hash == SHA256_H); - } - } - let get_range = |start: usize, len: usize| -> Range { start..start + len }; - let mut message_schedule = [0u32; 64]; - message_schedule[..input.len()].copy_from_slice(input); - let mut work_vars = *prev_hash; - for (i, row) in trace.chunks_exact_mut(trace_width).enumerate() { - // doing the 64 rounds in 16 rows - if i < 16 { - let cols: &mut Sha256RoundCols = - row[get_range(trace_start_col, SHA256_ROUND_WIDTH)].borrow_mut(); - cols.flags.is_round_row = F::ONE; - cols.flags.is_first_4_rows = if i < 4 { F::ONE } else { F::ZERO }; - cols.flags.is_digest_row = F::ZERO; - cols.flags.is_last_block = F::from_bool(is_last_block); - cols.flags.row_idx = - get_flag_pt_array(&self.row_idx_encoder, i).map(F::from_canonical_u32); - cols.flags.global_block_idx = F::from_canonical_u32(global_block_idx); - cols.flags.local_block_idx = F::from_canonical_u32(local_block_idx); - - // W_idx = M_idx - if i < 4 { - for j in 0..SHA256_ROUNDS_PER_ROW { - cols.message_schedule.w[j] = - u32_into_bits_field::(input[i * SHA256_ROUNDS_PER_ROW + j]); - } - } - // W_idx = SIG1(W_{idx-2}) + W_{idx-7} + SIG0(W_{idx-15}) + W_{idx-16} - else { - for j in 0..SHA256_ROUNDS_PER_ROW { - let idx = i * SHA256_ROUNDS_PER_ROW + j; - let nums: [u32; 4] = [ - small_sig1(message_schedule[idx - 2]), - message_schedule[idx - 7], - small_sig0(message_schedule[idx - 15]), - message_schedule[idx - 16], - ]; - let w: u32 = nums.iter().fold(0, |acc, &num| acc.wrapping_add(num)); - cols.message_schedule.w[j] = u32_into_bits_field::(w); - - let nums_limbs = nums.map(u32_into_u16s); - let w_limbs = u32_into_u16s(w); - - // fill in the carrys - for k in 0..SHA256_WORD_U16S { - let mut sum = nums_limbs.iter().fold(0, |acc, num| acc + num[k]); - if k > 0 { - sum += (cols.message_schedule.carry_or_buffer[j][k * 2 - 2] - + F::TWO * cols.message_schedule.carry_or_buffer[j][k * 2 - 1]) - .as_canonical_u32(); - } - let carry = (sum - w_limbs[k]) >> 16; - cols.message_schedule.carry_or_buffer[j][k * 2] = - F::from_canonical_u32(carry & 1); - cols.message_schedule.carry_or_buffer[j][k * 2 + 1] = - F::from_canonical_u32(carry >> 1); - } - // update the message schedule - message_schedule[idx] = w; - } - } - // fill in the work variables - for j in 0..SHA256_ROUNDS_PER_ROW { - // t1 = h + SIG1(e) + ch(e, f, g) + K_idx + W_idx - let t1 = [ - work_vars[7], - big_sig1(work_vars[4]), - ch(work_vars[4], work_vars[5], work_vars[6]), - SHA256_K[i * SHA256_ROUNDS_PER_ROW + j], - limbs_into_u32(cols.message_schedule.w[j].map(|f| f.as_canonical_u32())), - ]; - let t1_sum: u32 = t1.iter().fold(0, |acc, &num| acc.wrapping_add(num)); - - // t2 = SIG0(a) + maj(a, b, c) - let t2 = [ - big_sig0(work_vars[0]), - maj(work_vars[0], work_vars[1], work_vars[2]), - ]; - - let t2_sum: u32 = t2.iter().fold(0, |acc, &num| acc.wrapping_add(num)); - - // e = d + t1 - let e = work_vars[3].wrapping_add(t1_sum); - cols.work_vars.e[j] = u32_into_bits_field::(e); - let e_limbs = u32_into_u16s(e); - // a = t1 + t2 - let a = t1_sum.wrapping_add(t2_sum); - cols.work_vars.a[j] = u32_into_bits_field::(a); - let a_limbs = u32_into_u16s(a); - // fill in the carrys - for k in 0..SHA256_WORD_U16S { - let t1_limb = t1.iter().fold(0, |acc, &num| acc + u32_into_u16s(num)[k]); - let t2_limb = t2.iter().fold(0, |acc, &num| acc + u32_into_u16s(num)[k]); - - let mut e_limb = t1_limb + u32_into_u16s(work_vars[3])[k]; - let mut a_limb = t1_limb + t2_limb; - if k > 0 { - a_limb += cols.work_vars.carry_a[j][k - 1].as_canonical_u32(); - e_limb += cols.work_vars.carry_e[j][k - 1].as_canonical_u32(); - } - let carry_a = (a_limb - a_limbs[k]) >> 16; - let carry_e = (e_limb - e_limbs[k]) >> 16; - cols.work_vars.carry_a[j][k] = F::from_canonical_u32(carry_a); - cols.work_vars.carry_e[j][k] = F::from_canonical_u32(carry_e); - bitwise_lookup_chip.request_range(carry_a, carry_e); - } - - // update working variables - work_vars[7] = work_vars[6]; - work_vars[6] = work_vars[5]; - work_vars[5] = work_vars[4]; - work_vars[4] = e; - work_vars[3] = work_vars[2]; - work_vars[2] = work_vars[1]; - work_vars[1] = work_vars[0]; - work_vars[0] = a; - } - - // filling w_3 and intermed_4 here and the rest later - if i > 0 { - for j in 0..SHA256_ROUNDS_PER_ROW { - let idx = i * SHA256_ROUNDS_PER_ROW + j; - let w_4 = u32_into_u16s(message_schedule[idx - 4]); - let sig_0_w_3 = u32_into_u16s(small_sig0(message_schedule[idx - 3])); - cols.schedule_helper.intermed_4[j] = - array::from_fn(|k| F::from_canonical_u32(w_4[k] + sig_0_w_3[k])); - if j < SHA256_ROUNDS_PER_ROW - 1 { - let w_3 = message_schedule[idx - 3]; - cols.schedule_helper.w_3[j] = - u32_into_u16s(w_3).map(F::from_canonical_u32); - } - } - } - } - // generate the digest row - else { - let cols: &mut Sha256DigestCols = - row[get_range(trace_start_col, SHA256_DIGEST_WIDTH)].borrow_mut(); - for j in 0..SHA256_ROUNDS_PER_ROW - 1 { - let w_3 = message_schedule[i * SHA256_ROUNDS_PER_ROW + j - 3]; - cols.schedule_helper.w_3[j] = u32_into_u16s(w_3).map(F::from_canonical_u32); - } - cols.flags.is_round_row = F::ZERO; - cols.flags.is_first_4_rows = F::ZERO; - cols.flags.is_digest_row = F::ONE; - cols.flags.is_last_block = F::from_bool(is_last_block); - cols.flags.row_idx = - get_flag_pt_array(&self.row_idx_encoder, 16).map(F::from_canonical_u32); - cols.flags.global_block_idx = F::from_canonical_u32(global_block_idx); - - cols.flags.local_block_idx = F::from_canonical_u32(local_block_idx); - let final_hash: [u32; SHA256_HASH_WORDS] = - array::from_fn(|i| work_vars[i].wrapping_add(prev_hash[i])); - let final_hash_limbs: [[u8; SHA256_WORD_U8S]; SHA256_HASH_WORDS] = - array::from_fn(|i| final_hash[i].to_le_bytes()); - // need to ensure final hash limbs are bytes, in order for - // prev_hash[i] + work_vars[i] == final_hash[i] - // to be constrained correctly - for word in final_hash_limbs.iter() { - for chunk in word.chunks(2) { - bitwise_lookup_chip.request_range(chunk[0] as u32, chunk[1] as u32); - } - } - cols.final_hash = array::from_fn(|i| { - array::from_fn(|j| F::from_canonical_u8(final_hash_limbs[i][j])) - }); - cols.prev_hash = prev_hash.map(|f| u32_into_u16s(f).map(F::from_canonical_u32)); - let hash = if is_last_block { - SHA256_H.map(u32_into_bits_field::) - } else { - cols.final_hash - .map(|f| u32::from_le_bytes(f.map(|x| x.as_canonical_u32() as u8))) - .map(u32_into_bits_field::) - }; - - for i in 0..SHA256_ROUNDS_PER_ROW { - cols.hash.a[i] = hash[SHA256_ROUNDS_PER_ROW - i - 1]; - cols.hash.e[i] = hash[SHA256_ROUNDS_PER_ROW - i + 3]; - } - } - } - - for i in 0..SHA256_ROWS_PER_BLOCK - 1 { - let rows = &mut trace[i * trace_width..(i + 2) * trace_width]; - let (local, next) = rows.split_at_mut(trace_width); - let local_cols: &mut Sha256RoundCols = - local[get_range(trace_start_col, SHA256_ROUND_WIDTH)].borrow_mut(); - let next_cols: &mut Sha256RoundCols = - next[get_range(trace_start_col, SHA256_ROUND_WIDTH)].borrow_mut(); - if i > 0 { - for j in 0..SHA256_ROUNDS_PER_ROW { - next_cols.schedule_helper.intermed_8[j] = - local_cols.schedule_helper.intermed_4[j]; - if (2..SHA256_ROWS_PER_BLOCK - 3).contains(&i) { - next_cols.schedule_helper.intermed_12[j] = - local_cols.schedule_helper.intermed_8[j]; - } - } - } - if i == SHA256_ROWS_PER_BLOCK - 2 { - // `next` is a digest row. - // Fill in `carry_a` and `carry_e` with dummy values so the constraints on `a` and - // `e` hold. - Self::generate_carry_ae(local_cols, next_cols); - // Fill in row 16's `intermed_4` with dummy values so the message schedule - // constraints holds on that row - Self::generate_intermed_4(local_cols, next_cols); - } - if i <= 2 { - // i is in 0..3. - // Fill in `local.intermed_12` with dummy values so the message schedule constraints - // hold on rows 1..4. - Self::generate_intermed_12(local_cols, next_cols); - } - } - } - - /// This function will fill in the cells that we couldn't do during the first pass. - /// This function should be called only after `generate_block_trace` was called for all blocks - /// And [`Self::generate_default_row`] is called for all invalid rows - /// Will populate the missing values of `trace`, where the width of the trace is `trace_width` - /// and the starting column for the `Sha256Air` is `trace_start_col`. - /// Note: `trace` needs to be the rows 1..17 of a block and the first row of the next block - pub fn generate_missing_cells( - &self, - trace: &mut [F], - trace_width: usize, - trace_start_col: usize, - ) { - // Here row_17 = next blocks row 0 - let rows_15_17 = &mut trace[14 * trace_width..17 * trace_width]; - let (row_15, row_16_17) = rows_15_17.split_at_mut(trace_width); - let (row_16, row_17) = row_16_17.split_at_mut(trace_width); - let cols_15: &mut Sha256RoundCols = - row_15[trace_start_col..trace_start_col + SHA256_ROUND_WIDTH].borrow_mut(); - let cols_16: &mut Sha256RoundCols = - row_16[trace_start_col..trace_start_col + SHA256_ROUND_WIDTH].borrow_mut(); - let cols_17: &mut Sha256RoundCols = - row_17[trace_start_col..trace_start_col + SHA256_ROUND_WIDTH].borrow_mut(); - // Fill in row 15's `intermed_12` with dummy values so the message schedule constraints - // holds on row 16 - Self::generate_intermed_12(cols_15, cols_16); - // Fill in row 16's `intermed_12` with dummy values so the message schedule constraints - // holds on the next block's row 0 - Self::generate_intermed_12(cols_16, cols_17); - // Fill in row 0's `intermed_4` with dummy values so the message schedule constraints holds - // on that row - Self::generate_intermed_4(cols_16, cols_17); - } - - /// Fills the `cols` as a padding row - /// Note: we still need to correctly fill in the hash values, carries and intermeds - pub fn generate_default_row( - self: &Sha256StepHelper, - cols: &mut Sha256RoundCols, - ) { - cols.flags.row_idx = - get_flag_pt_array(&self.row_idx_encoder, 17).map(F::from_canonical_u32); - - let hash = SHA256_H.map(u32_into_bits_field::); - - for i in 0..SHA256_ROUNDS_PER_ROW { - cols.work_vars.a[i] = hash[SHA256_ROUNDS_PER_ROW - i - 1]; - cols.work_vars.e[i] = hash[SHA256_ROUNDS_PER_ROW - i + 3]; - } - - cols.work_vars.carry_a = array::from_fn(|i| { - array::from_fn(|j| F::from_canonical_u32(SHA256_INVALID_CARRY_A[i][j])) - }); - cols.work_vars.carry_e = array::from_fn(|i| { - array::from_fn(|j| F::from_canonical_u32(SHA256_INVALID_CARRY_E[i][j])) - }); - } - - /// The following functions do the calculations in native field since they will be called on - /// padding rows which can overflow and we need to make sure it matches the AIR constraints - /// Puts the correct carrys in the `next_row`, the resulting carrys can be out of bound - fn generate_carry_ae( - local_cols: &Sha256RoundCols, - next_cols: &mut Sha256RoundCols, - ) { - let a = [local_cols.work_vars.a, next_cols.work_vars.a].concat(); - let e = [local_cols.work_vars.e, next_cols.work_vars.e].concat(); - for i in 0..SHA256_ROUNDS_PER_ROW { - let cur_a = a[i + 4]; - let sig_a = big_sig0_field::(&a[i + 3]); - let maj_abc = maj_field::(&a[i + 3], &a[i + 2], &a[i + 1]); - let d = a[i]; - let cur_e = e[i + 4]; - let sig_e = big_sig1_field::(&e[i + 3]); - let ch_efg = ch_field::(&e[i + 3], &e[i + 2], &e[i + 1]); - let h = e[i]; - - let t1 = [h, sig_e, ch_efg]; - let t2 = [sig_a, maj_abc]; - for j in 0..SHA256_WORD_U16S { - let t1_limb_sum = t1.iter().fold(F::ZERO, |acc, x| { - acc + compose::(&x[j * 16..(j + 1) * 16], 1) - }); - let t2_limb_sum = t2.iter().fold(F::ZERO, |acc, x| { - acc + compose::(&x[j * 16..(j + 1) * 16], 1) - }); - let d_limb = compose::(&d[j * 16..(j + 1) * 16], 1); - let cur_a_limb = compose::(&cur_a[j * 16..(j + 1) * 16], 1); - let cur_e_limb = compose::(&cur_e[j * 16..(j + 1) * 16], 1); - let sum = d_limb - + t1_limb_sum - + if j == 0 { - F::ZERO - } else { - next_cols.work_vars.carry_e[i][j - 1] - } - - cur_e_limb; - let carry_e = sum * (F::from_canonical_u32(1 << 16).inverse()); - - let sum = t1_limb_sum - + t2_limb_sum - + if j == 0 { - F::ZERO - } else { - next_cols.work_vars.carry_a[i][j - 1] - } - - cur_a_limb; - let carry_a = sum * (F::from_canonical_u32(1 << 16).inverse()); - next_cols.work_vars.carry_e[i][j] = carry_e; - next_cols.work_vars.carry_a[i][j] = carry_a; - } - } - } - - /// Puts the correct intermed_4 in the `next_row` - fn generate_intermed_4( - local_cols: &Sha256RoundCols, - next_cols: &mut Sha256RoundCols, - ) { - let w = [local_cols.message_schedule.w, next_cols.message_schedule.w].concat(); - let w_limbs: Vec<[F; SHA256_WORD_U16S]> = w - .iter() - .map(|x| array::from_fn(|i| compose::(&x[i * 16..(i + 1) * 16], 1))) - .collect(); - for i in 0..SHA256_ROUNDS_PER_ROW { - let sig_w = small_sig0_field::(&w[i + 1]); - let sig_w_limbs: [F; SHA256_WORD_U16S] = - array::from_fn(|j| compose::(&sig_w[j * 16..(j + 1) * 16], 1)); - for (j, sig_w_limb) in sig_w_limbs.iter().enumerate() { - next_cols.schedule_helper.intermed_4[i][j] = w_limbs[i][j] + *sig_w_limb; - } - } - } - - /// Puts the needed intermed_12 in the `local_row` - fn generate_intermed_12( - local_cols: &mut Sha256RoundCols, - next_cols: &Sha256RoundCols, - ) { - let w = [local_cols.message_schedule.w, next_cols.message_schedule.w].concat(); - let w_limbs: Vec<[F; SHA256_WORD_U16S]> = w - .iter() - .map(|x| array::from_fn(|i| compose::(&x[i * 16..(i + 1) * 16], 1))) - .collect(); - for i in 0..SHA256_ROUNDS_PER_ROW { - // sig_1(w_{t-2}) - let sig_w_2: [F; SHA256_WORD_U16S] = array::from_fn(|j| { - compose::(&small_sig1_field::(&w[i + 2])[j * 16..(j + 1) * 16], 1) - }); - // w_{t-7} - let w_7 = if i < 3 { - local_cols.schedule_helper.w_3[i] - } else { - w_limbs[i - 3] - }; - // w_t - let w_cur = w_limbs[i + 4]; - for j in 0..SHA256_WORD_U16S { - let carry = next_cols.message_schedule.carry_or_buffer[i][j * 2] - + F::TWO * next_cols.message_schedule.carry_or_buffer[i][j * 2 + 1]; - let sum = sig_w_2[j] + w_7[j] - carry * F::from_canonical_u32(1 << 16) - w_cur[j] - + if j > 0 { - next_cols.message_schedule.carry_or_buffer[i][j * 2 - 2] - + F::from_canonical_u32(2) - * next_cols.message_schedule.carry_or_buffer[i][j * 2 - 1] - } else { - F::ZERO - }; - local_cols.schedule_helper.intermed_12[i][j] = -sum; - } - } - } -} - -/// Generates a trace for a standalone SHA256 computation (currently only used for testing) -/// `records` consists of pairs of `(input_block, is_last_block)`. -pub fn generate_trace( - step: &Sha256StepHelper, - bitwise_lookup_chip: &BitwiseOperationLookupChip<8>, - width: usize, - records: Vec<([u8; SHA256_BLOCK_U8S], bool)>, -) -> RowMajorMatrix { - let non_padded_height = records.len() * SHA256_ROWS_PER_BLOCK; - let height = next_power_of_two_or_zero(non_padded_height); - let mut values = F::zero_vec(height * width); - - struct BlockContext { - prev_hash: [u32; 8], - local_block_idx: u32, - global_block_idx: u32, - input: [u8; SHA256_BLOCK_U8S], - is_last_block: bool, - } - let mut block_ctx: Vec = Vec::with_capacity(records.len()); - let mut prev_hash = SHA256_H; - let mut local_block_idx = 0; - let mut global_block_idx = 1; - for (input, is_last_block) in records { - block_ctx.push(BlockContext { - prev_hash, - local_block_idx, - global_block_idx, - input, - is_last_block, - }); - global_block_idx += 1; - if is_last_block { - local_block_idx = 0; - prev_hash = SHA256_H; - } else { - local_block_idx += 1; - prev_hash = Sha256StepHelper::get_block_hash(&prev_hash, input); - } - } - // first pass - values - .par_chunks_exact_mut(width * SHA256_ROWS_PER_BLOCK) - .zip(block_ctx) - .for_each(|(block, ctx)| { - let BlockContext { - prev_hash, - local_block_idx, - global_block_idx, - input, - is_last_block, - } = ctx; - let input_words = array::from_fn(|i| { - limbs_into_u32::(array::from_fn(|j| { - input[(i + 1) * SHA256_WORD_U8S - j - 1] as u32 - })) - }); - step.generate_block_trace( - block, - width, - 0, - &input_words, - bitwise_lookup_chip, - &prev_hash, - is_last_block, - global_block_idx, - local_block_idx, - ); - }); - // second pass: padding rows - values[width * non_padded_height..] - .par_chunks_mut(width) - .for_each(|row| { - let cols: &mut Sha256RoundCols = row.borrow_mut(); - step.generate_default_row(cols); - }); - // second pass: non-padding rows - values[width..] - .par_chunks_mut(width * SHA256_ROWS_PER_BLOCK) - .take(non_padded_height / SHA256_ROWS_PER_BLOCK) - .for_each(|chunk| { - step.generate_missing_cells(chunk, width, 0); - }); - RowMajorMatrix::new(values, width) -} diff --git a/crates/circuits/sha256-air/src/utils.rs b/crates/circuits/sha256-air/src/utils.rs deleted file mode 100644 index ba598f2604..0000000000 --- a/crates/circuits/sha256-air/src/utils.rs +++ /dev/null @@ -1,271 +0,0 @@ -use std::array; - -pub use openvm_circuit_primitives::utils::compose; -use openvm_circuit_primitives::{ - encoder::Encoder, - utils::{not, select}, -}; -use openvm_stark_backend::{p3_air::AirBuilder, p3_field::FieldAlgebra}; - -use super::{Sha256DigestCols, Sha256RoundCols}; - -// ==== Do not change these constants! ==== -/// Number of bits in a SHA256 word -pub const SHA256_WORD_BITS: usize = 32; -/// Number of 16-bit limbs in a SHA256 word -pub const SHA256_WORD_U16S: usize = SHA256_WORD_BITS / 16; -/// Number of 8-bit limbs in a SHA256 word -pub const SHA256_WORD_U8S: usize = SHA256_WORD_BITS / 8; -/// Number of words in a SHA256 block -pub const SHA256_BLOCK_WORDS: usize = 16; -/// Number of cells in a SHA256 block -pub const SHA256_BLOCK_U8S: usize = SHA256_BLOCK_WORDS * SHA256_WORD_U8S; -/// Number of bits in a SHA256 block -pub const SHA256_BLOCK_BITS: usize = SHA256_BLOCK_WORDS * SHA256_WORD_BITS; -/// Number of rows per block -pub const SHA256_ROWS_PER_BLOCK: usize = 17; -/// Number of rounds per row -pub const SHA256_ROUNDS_PER_ROW: usize = 4; -/// Number of words in a SHA256 hash -pub const SHA256_HASH_WORDS: usize = 8; -/// Number of vars needed to encode the row index with [Encoder] -pub const SHA256_ROW_VAR_CNT: usize = 5; -/// Width of the Sha256RoundCols -pub const SHA256_ROUND_WIDTH: usize = Sha256RoundCols::::width(); -/// Width of the Sha256DigestCols -pub const SHA256_DIGEST_WIDTH: usize = Sha256DigestCols::::width(); -/// Size of the buffer of the first 4 rows of a block (each row's size) -pub const SHA256_BUFFER_SIZE: usize = SHA256_ROUNDS_PER_ROW * SHA256_WORD_U16S * 2; -/// Width of the Sha256Cols -pub const SHA256_WIDTH: usize = if SHA256_ROUND_WIDTH > SHA256_DIGEST_WIDTH { - SHA256_ROUND_WIDTH -} else { - SHA256_DIGEST_WIDTH -}; -/// We can notice that `carry_a`'s and `carry_e`'s are always the same on invalid rows -/// To optimize the trace generation of invalid rows, we have those values precomputed here -pub(crate) const SHA256_INVALID_CARRY_A: [[u32; SHA256_WORD_U16S]; SHA256_ROUNDS_PER_ROW] = [ - [1230919683, 1162494304], - [266373122, 1282901987], - [1519718403, 1008990871], - [923381762, 330807052], -]; -pub(crate) const SHA256_INVALID_CARRY_E: [[u32; SHA256_WORD_U16S]; SHA256_ROUNDS_PER_ROW] = [ - [204933122, 1994683449], - [443873282, 1544639095], - [719953922, 1888246508], - [194580482, 1075725211], -]; -/// SHA256 constant K's -pub const SHA256_K: [u32; 64] = [ - 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5, - 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174, - 0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da, - 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967, - 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, - 0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070, - 0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3, - 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2, -]; - -/// SHA256 initial hash values -pub const SHA256_H: [u32; 8] = [ - 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19, -]; - -/// Returns the number of blocks required to hash a message of length `len` -pub fn get_sha256_num_blocks(len: u32) -> u32 { - // need to pad with one 1 bit, 64 bits for the message length and then pad until the length - // is divisible by [SHA256_BLOCK_BITS] - ((len << 3) as usize + 1 + 64).div_ceil(SHA256_BLOCK_BITS) as u32 -} - -/// Convert a u32 into a list of bits in little endian then convert each bit into a field element -pub fn u32_into_bits_field(num: u32) -> [F; SHA256_WORD_BITS] { - array::from_fn(|i| F::from_bool((num >> i) & 1 == 1)) -} - -/// Convert a u32 into a an array of 2 16-bit limbs in little endian -pub fn u32_into_u16s(num: u32) -> [u32; 2] { - [num & 0xffff, num >> 16] -} - -/// Convert a list of limbs in little endian into a u32 -pub fn limbs_into_u32(limbs: [u32; NUM_LIMBS]) -> u32 { - let limb_bits = 32 / NUM_LIMBS; - limbs - .iter() - .rev() - .fold(0, |acc, &limb| (acc << limb_bits) | limb) -} - -/// Rotates `bits` right by `n` bits, assumes `bits` is in little-endian -#[inline] -pub(crate) fn rotr( - bits: &[impl Into + Clone; SHA256_WORD_BITS], - n: usize, -) -> [F; SHA256_WORD_BITS] { - array::from_fn(|i| bits[(i + n) % SHA256_WORD_BITS].clone().into()) -} - -/// Shifts `bits` right by `n` bits, assumes `bits` is in little-endian -#[inline] -pub(crate) fn shr( - bits: &[impl Into + Clone; SHA256_WORD_BITS], - n: usize, -) -> [F; SHA256_WORD_BITS] { - array::from_fn(|i| { - if i + n < SHA256_WORD_BITS { - bits[i + n].clone().into() - } else { - F::ZERO - } - }) -} - -/// Computes x ^ y ^ z, where x, y, z are assumed to be boolean -#[inline] -pub(crate) fn xor_bit( - x: impl Into, - y: impl Into, - z: impl Into, -) -> F { - let (x, y, z) = (x.into(), y.into(), z.into()); - (x.clone() * y.clone() * z.clone()) - + (x.clone() * not::(y.clone()) * not::(z.clone())) - + (not::(x.clone()) * y.clone() * not::(z.clone())) - + (not::(x) * not::(y) * z) -} - -/// Computes x ^ y ^ z, where x, y, z are [SHA256_WORD_BITS] bit numbers -#[inline] -pub(crate) fn xor( - x: &[impl Into + Clone; SHA256_WORD_BITS], - y: &[impl Into + Clone; SHA256_WORD_BITS], - z: &[impl Into + Clone; SHA256_WORD_BITS], -) -> [F; SHA256_WORD_BITS] { - array::from_fn(|i| xor_bit(x[i].clone(), y[i].clone(), z[i].clone())) -} - -/// Choose function from SHA256 -#[inline] -pub fn ch(x: u32, y: u32, z: u32) -> u32 { - (x & y) ^ ((!x) & z) -} - -/// Computes Ch(x,y,z), where x, y, z are [SHA256_WORD_BITS] bit numbers -#[inline] -pub(crate) fn ch_field( - x: &[impl Into + Clone; SHA256_WORD_BITS], - y: &[impl Into + Clone; SHA256_WORD_BITS], - z: &[impl Into + Clone; SHA256_WORD_BITS], -) -> [F; SHA256_WORD_BITS] { - array::from_fn(|i| select(x[i].clone(), y[i].clone(), z[i].clone())) -} - -/// Majority function from SHA256 -pub fn maj(x: u32, y: u32, z: u32) -> u32 { - (x & y) ^ (x & z) ^ (y & z) -} - -/// Computes Maj(x,y,z), where x, y, z are [SHA256_WORD_BITS] bit numbers -#[inline] -pub(crate) fn maj_field( - x: &[impl Into + Clone; SHA256_WORD_BITS], - y: &[impl Into + Clone; SHA256_WORD_BITS], - z: &[impl Into + Clone; SHA256_WORD_BITS], -) -> [F; SHA256_WORD_BITS] { - array::from_fn(|i| { - let (x, y, z) = ( - x[i].clone().into(), - y[i].clone().into(), - z[i].clone().into(), - ); - x.clone() * y.clone() + x.clone() * z.clone() + y.clone() * z.clone() - F::TWO * x * y * z - }) -} - -/// Big sigma_0 function from SHA256 -pub fn big_sig0(x: u32) -> u32 { - x.rotate_right(2) ^ x.rotate_right(13) ^ x.rotate_right(22) -} - -/// Computes BigSigma0(x), where x is a [SHA256_WORD_BITS] bit number in little-endian -#[inline] -pub(crate) fn big_sig0_field( - x: &[impl Into + Clone; SHA256_WORD_BITS], -) -> [F; SHA256_WORD_BITS] { - xor(&rotr::(x, 2), &rotr::(x, 13), &rotr::(x, 22)) -} - -/// Big sigma_1 function from SHA256 -pub fn big_sig1(x: u32) -> u32 { - x.rotate_right(6) ^ x.rotate_right(11) ^ x.rotate_right(25) -} - -/// Computes BigSigma1(x), where x is a [SHA256_WORD_BITS] bit number in little-endian -#[inline] -pub(crate) fn big_sig1_field( - x: &[impl Into + Clone; SHA256_WORD_BITS], -) -> [F; SHA256_WORD_BITS] { - xor(&rotr::(x, 6), &rotr::(x, 11), &rotr::(x, 25)) -} - -/// Small sigma_0 function from SHA256 -pub fn small_sig0(x: u32) -> u32 { - x.rotate_right(7) ^ x.rotate_right(18) ^ (x >> 3) -} - -/// Computes SmallSigma0(x), where x is a [SHA256_WORD_BITS] bit number in little-endian -#[inline] -pub(crate) fn small_sig0_field( - x: &[impl Into + Clone; SHA256_WORD_BITS], -) -> [F; SHA256_WORD_BITS] { - xor(&rotr::(x, 7), &rotr::(x, 18), &shr::(x, 3)) -} - -/// Small sigma_1 function from SHA256 -pub fn small_sig1(x: u32) -> u32 { - x.rotate_right(17) ^ x.rotate_right(19) ^ (x >> 10) -} - -/// Computes SmallSigma1(x), where x is a [SHA256_WORD_BITS] bit number in little-endian -#[inline] -pub(crate) fn small_sig1_field( - x: &[impl Into + Clone; SHA256_WORD_BITS], -) -> [F; SHA256_WORD_BITS] { - xor(&rotr::(x, 17), &rotr::(x, 19), &shr::(x, 10)) -} - -/// Wrapper of `get_flag_pt` to get the flag pointer as an array -pub fn get_flag_pt_array(encoder: &Encoder, flag_idx: usize) -> [u32; N] { - encoder.get_flag_pt(flag_idx).try_into().unwrap() -} - -/// Constrain the addition of [SHA256_WORD_BITS] bit words in 16-bit limbs -/// It takes in the terms some in bits some in 16-bit limbs, -/// the expected sum in bits and the carries -pub fn constraint_word_addition( - builder: &mut AB, - terms_bits: &[&[impl Into + Clone; SHA256_WORD_BITS]], - terms_limb: &[&[impl Into + Clone; SHA256_WORD_U16S]], - expected_sum: &[impl Into + Clone; SHA256_WORD_BITS], - carries: &[impl Into + Clone; SHA256_WORD_U16S], -) { - for i in 0..SHA256_WORD_U16S { - let mut limb_sum = if i == 0 { - AB::Expr::ZERO - } else { - carries[i - 1].clone().into() - }; - for term in terms_bits { - limb_sum += compose::(&term[i * 16..(i + 1) * 16], 1); - } - for term in terms_limb { - limb_sum += term[i].clone().into(); - } - let expected_sum_limb = compose::(&expected_sum[i * 16..(i + 1) * 16], 1) - + carries[i].clone().into() * AB::Expr::from_canonical_u32(1 << 16); - builder.assert_eq(limb_sum, expected_sum_limb); - } -} diff --git a/crates/sdk/Cargo.toml b/crates/sdk/Cargo.toml index f42223dd4e..cf5b46d930 100644 --- a/crates/sdk/Cargo.toml +++ b/crates/sdk/Cargo.toml @@ -18,8 +18,8 @@ openvm-ecc-circuit = { workspace = true } openvm-ecc-transpiler = { workspace = true } openvm-keccak256-circuit = { workspace = true } openvm-keccak256-transpiler = { workspace = true } -openvm-sha256-circuit = { workspace = true } -openvm-sha256-transpiler = { workspace = true } +openvm-sha2-circuit = { workspace = true } +openvm-sha2-transpiler = { workspace = true } openvm-pairing-circuit = { workspace = true } openvm-pairing-transpiler = { workspace = true } openvm-native-circuit = { workspace = true } diff --git a/crates/sdk/src/config/global.rs b/crates/sdk/src/config/global.rs index 0da07c63c4..fd6655f2bb 100644 --- a/crates/sdk/src/config/global.rs +++ b/crates/sdk/src/config/global.rs @@ -15,9 +15,7 @@ use openvm_circuit::{ circuit_derive::{Chip, ChipUsageGetter}, derive::{AnyEnum, InsExecutorE1, InsExecutorE2, InstructionExecutor}, }; -use openvm_ecc_circuit::{ - WeierstrassExtension, WeierstrassExtensionExecutor, WeierstrassExtensionPeriphery, -}; +use openvm_ecc_circuit::{EccExtension, EccExtensionExecutor, EccExtensionPeriphery}; use openvm_ecc_transpiler::EccTranspilerExtension; use openvm_keccak256_circuit::{Keccak256, Keccak256Executor, Keccak256Periphery}; use openvm_keccak256_transpiler::Keccak256TranspilerExtension; @@ -37,8 +35,8 @@ use openvm_rv32im_circuit::{ use openvm_rv32im_transpiler::{ Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, }; -use openvm_sha256_circuit::{Sha256, Sha256Executor, Sha256Periphery}; -use openvm_sha256_transpiler::Sha256TranspilerExtension; +use openvm_sha2_circuit::{Sha2, Sha2Executor, Sha2Periphery}; +use openvm_sha2_transpiler::Sha2TranspilerExtension; use openvm_stark_backend::p3_field::PrimeField32; use openvm_transpiler::transpiler::Transpiler; use serde::{Deserialize, Serialize}; @@ -53,7 +51,7 @@ pub struct SdkVmConfig { pub rv32i: Option, pub io: Option, pub keccak: Option, - pub sha256: Option, + pub sha2: Option, pub native: Option, pub castf: Option, @@ -62,7 +60,7 @@ pub struct SdkVmConfig { pub modular: Option, pub fp2: Option, pub pairing: Option, - pub ecc: Option, + pub ecc: Option, } #[derive( @@ -78,7 +76,7 @@ pub enum SdkVmConfigExecutor { #[any_enum] Keccak(Keccak256Executor), #[any_enum] - Sha256(Sha256Executor), + Sha2(Sha2Executor), #[any_enum] Native(NativeExecutor), #[any_enum] @@ -92,7 +90,7 @@ pub enum SdkVmConfigExecutor { #[any_enum] Pairing(PairingExtensionExecutor), #[any_enum] - Ecc(WeierstrassExtensionExecutor), + Ecc(EccExtensionExecutor), #[any_enum] CastF(CastFExtensionExecutor), } @@ -108,7 +106,7 @@ pub enum SdkVmConfigPeriphery { #[any_enum] Keccak(Keccak256Periphery), #[any_enum] - Sha256(Sha256Periphery), + Sha2(Sha2Periphery), #[any_enum] Native(NativePeriphery), #[any_enum] @@ -122,7 +120,7 @@ pub enum SdkVmConfigPeriphery { #[any_enum] Pairing(PairingExtensionPeriphery), #[any_enum] - Ecc(WeierstrassExtensionPeriphery), + Ecc(EccExtensionPeriphery), #[any_enum] CastF(CastFExtensionPeriphery), } @@ -139,8 +137,8 @@ impl SdkVmConfig { if self.keccak.is_some() { transpiler = transpiler.with_extension(Keccak256TranspilerExtension); } - if self.sha256.is_some() { - transpiler = transpiler.with_extension(Sha256TranspilerExtension); + if self.sha2.is_some() { + transpiler = transpiler.with_extension(Sha2TranspilerExtension); } if self.native.is_some() { transpiler = transpiler.with_extension(LongFormTranspilerExtension); @@ -193,8 +191,8 @@ impl VmConfig for SdkVmConfig { if self.keccak.is_some() { complex = complex.extend(&Keccak256)?; } - if self.sha256.is_some() { - complex = complex.extend(&Sha256)?; + if self.sha2.is_some() { + complex = complex.extend(&Sha2)?; } if self.native.is_some() { complex = complex.extend(&Native)?; @@ -264,7 +262,7 @@ impl InitFileGenerator for SdkVmConfig { } if let Some(ecc_config) = &self.ecc { - contents.push_str(&ecc_config.generate_sw_init()); + contents.push_str(&ecc_config.generate_ecc_init()); contents.push('\n'); } @@ -320,8 +318,8 @@ impl From for UnitStruct { } } -impl From for UnitStruct { - fn from(_: Sha256) -> Self { +impl From for UnitStruct { + fn from(_: Sha2) -> Self { UnitStruct {} } } @@ -346,7 +344,7 @@ struct SdkVmConfigWithDefaultDeser { pub rv32i: Option, pub io: Option, pub keccak: Option, - pub sha256: Option, + pub sha2: Option, pub native: Option, pub castf: Option, @@ -355,7 +353,7 @@ struct SdkVmConfigWithDefaultDeser { pub modular: Option, pub fp2: Option, pub pairing: Option, - pub ecc: Option, + pub ecc: Option, } impl From for SdkVmConfig { @@ -371,7 +369,7 @@ impl From for SdkVmConfig { rv32i: config.rv32i, io: config.io, keccak: config.keccak, - sha256: config.sha256, + sha2: config.sha2, native: config.native, castf: config.castf, rv32m: config.rv32m, diff --git a/crates/toolchain/tests/tests/transpiler_tests.rs b/crates/toolchain/tests/tests/transpiler_tests.rs index 42b62be437..8118414809 100644 --- a/crates/toolchain/tests/tests/transpiler_tests.rs +++ b/crates/toolchain/tests/tests/transpiler_tests.rs @@ -19,7 +19,7 @@ use openvm_circuit::{ derive::VmConfig, utils::air_test, }; -use openvm_ecc_circuit::{SECP256K1_MODULUS, SECP256K1_ORDER}; +use openvm_ecc_circuit::SECP256K1_CONFIG; use openvm_instructions::exe::VmExe; use openvm_platform::memory::MEM_SIZE; use openvm_rv32im_circuit::{ @@ -133,8 +133,14 @@ impl InitFileGenerator for Rv32ModularFp2Int256Config { #[test_case("tests/data/rv32im-intrin-from-as")] fn test_intrinsic_runtime(elf_path: &str) -> Result<()> { let config = Rv32ModularFp2Int256Config::new( - vec![SECP256K1_MODULUS.clone(), SECP256K1_ORDER.clone()], - vec![("Secp256k1Coord".to_string(), SECP256K1_MODULUS.clone())], + vec![ + SECP256K1_CONFIG.modulus.clone(), + SECP256K1_CONFIG.scalar.clone(), + ], + vec![( + SECP256K1_CONFIG.struct_name.clone(), + SECP256K1_CONFIG.modulus.clone(), + )], ); let elf = get_elf(elf_path)?; let openvm_exe = VmExe::from_elf( diff --git a/docs/specs/ISA.md b/docs/specs/ISA.md index 4b412425ba..24fd26f89e 100644 --- a/docs/specs/ISA.md +++ b/docs/specs/ISA.md @@ -6,7 +6,7 @@ This specification describes the overall architecture and default VM extensions - [RV32IM](#rv32im-extension): An extension supporting the 32-bit RISC-V ISA with multiplication. - [Native](#native-extension): An extension supporting native field arithmetic for proof recursion and aggregation. - [Keccak-256](#keccak-extension): An extension implementing the Keccak-256 hash function compatibly with RISC-V memory. -- [SHA2-256](#sha2-256-extension): An extension implementing the SHA2-256 hash function compatibly with RISC-V memory. +- [SHA2](#sha-2-extension): An extension implementing the SHA-256, SHA-512, and SHA-384 hash functions compatibly with RISC-V memory. - [BigInt](#bigint-extension): An extension supporting 256-bit signed and unsigned integer arithmetic, including multiplication. This extension respects the RISC-V memory format. - [Algebra](#algebra-extension): An extension supporting modular arithmetic over arbitrary fields and their complex @@ -541,14 +541,16 @@ all memory cells are constrained to be bytes. | -------------- | ----------- | ----------------------------------------------------------------------------------------------------------------- | | KECCAK256_RV32 | `a,b,c,1,2` | `[r32{0}(a):32]_2 = keccak256([r32{0}(b)..r32{0}(b)+r32{0}(c)]_2)`. Performs memory accesses with block size `4`. | -### SHA2-256 Extension +### SHA-2 Extension -The SHA2-256 extension supports the SHA2-256 hash function. The extension operates on address spaces `1` and `2`, +The SHA-2 extension supports the SHA-256 and SHA-512 hash functions. The extension operates on address spaces `1` and `2`, meaning all memory cells are constrained to be bytes. | Name | Operands | Description | | ----------- | ----------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | | SHA256_RV32 | `a,b,c,1,2` | `[r32{0}(a):32]_2 = sha256([r32{0}(b)..r32{0}(b)+r32{0}(c)]_2)`. Does the necessary padding. Performs memory reads with block size `16` and writes with block size `32`. | +| SHA512_RV32 | `a,b,c,1,2` | `[r32{0}(a):64]_2 = sha512([r32{0}(b)..r32{0}(b)+r32{0}(c)]_2)`. Does the necessary padding. Performs memory reads with block size `32` and writes with block size `32`. | +| SHA384_RV32 | `a,b,c,1,2` | `[r32{0}(a):64]_2 = sha384([r32{0}(b)..r32{0}(b)+r32{0}(c)]_2)`. Does the necessary padding. Performs memory reads with block size `32` and writes with block size `32`. Writes 64 bytes to memory: the first 48 are the SHA-384 digest and the last 16 are zeros. | ### BigInt Extension @@ -677,12 +679,13 @@ r32_fp2(a) -> Fp2 { ### Elliptic Curve Extension -The elliptic curve extension supports arithmetic over elliptic curves `C` in Weierstrass form given by -equation `C: y^2 = x^3 + C::A * x + C::B` where `C::A` and `C::B` are constants in the coordinate field. We note that -the definitions of the -curve arithmetic operations do not depend on `C::B`. The VM configuration will specify a list of supported curves. For -each Weierstrass curve `C` there will be associated configuration parameters `C::COORD_SIZE` and `C::BLOCK_SIZE` ( -defined below). The extension operates on address spaces `1` and `2`, meaning all memory cells are constrained to be +The elliptic curve extension supports arithmetic over elliptic curves `C` in the following forms: +- in short Weierstrass form given by equation `C: y^2 = x^3 + C::A * x + C::B` where `C::A` and `C::B` are constants in the coordinate field +- in twisted Edwards form given by equation `C: C::A * x^2 + y^2 = 1 + C::D * x^2 * y^2` where `C::A` and `C::D` are constants in the coordinate field + +We note that +the definitions of the curve arithmetic operations for short Weierstrass curves do not depend on `C::B`. The VM configuration will specify a list of supported curves. For +each curve `C` (of either form) there will be associated configuration parameters `C::COORD_SIZE` and `C::BLOCK_SIZE` (defined below). The extension operates on address spaces `1` and `2`, meaning all memory cells are constrained to be bytes. An affine curve point `EcPoint(x, y)` is a pair of `x,y` where each element is an array of `C::COORD_SIZE` elements each @@ -700,12 +703,16 @@ r32_ec_point(a) -> EcPoint { } ``` +The instructions that have prefix `SW_` perform short Weierstrass curve operations, and those with prefix `TE_` perform twisted Edwards curve operations. + | Name | Operands | Description | | -------------------- | ----------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| EC_ADD_NE\ | `a,b,c,1,2` | Set `r32_ec_point(a) = r32_ec_point(b) + r32_ec_point(c)` (curve addition). Assumes that `r32_ec_point(b), r32_ec_point(c)` both lie on the curve and are not the identity point. Further assumes that `r32_ec_point(b).x, r32_ec_point(c).x` are not equal in the coordinate field. | -| SETUP_EC_ADD_NE\ | `a,b,c,1,2` | `assert(r32_ec_point(b).x == C::MODULUS)` in the chip for EC ADD. For the sake of implementation convenience it also writes something (can be anything) into `[r32{0}(a): 2*C::COORD_SIZE]_2`. It is required for proper functionality that `assert(r32_ec_point(b).x != r32_ec_point(c).x)` | -| EC_DOUBLE\ | `a,b,_,1,2` | Set `r32_ec_point(a) = 2 * r32_ec_point(b)`. This doubles the input point. Assumes that `r32_ec_point(b)` lies on the curve and is not the identity point. | -| SETUP_EC_DOUBLE\ | `a,b,_,1,2` | `assert(r32_ec_point(b).x == C::MODULUS)` in the chip for EC DOUBLE. For the sake of implementation convenience it also writes something (can be anything) into `[r32{0}(a): 2*C::COORD_SIZE]_2`. It is required for proper functionality that `assert(r32_ec_point(b).y != 0 mod C::MODULUS)` | +| SW_ADD_NE\ | `a,b,c,1,2` | Set `r32_ec_point(a) = r32_ec_point(b) + r32_ec_point(c)` (curve addition). Assumes that `r32_ec_point(b), r32_ec_point(c)` both lie on the curve and are not the identity point. Further assumes that `r32_ec_point(b).x, r32_ec_point(c).x` are not equal in the coordinate field. | +| SETUP_SW_ADD_NE\ | `a,b,c,1,2` | `assert(r32_ec_point(b).x == C::MODULUS && r32_ec_point(b).y == C::A)` in the chip for SW ADD. For the sake of implementation convenience it also writes something (can be anything) into `[r32{0}(a): 2*C::COORD_SIZE]_2`. It is required for proper functionality that `assert(r32_ec_point(b).x != r32_ec_point(c).x)` | +| SW_DOUBLE\ | `a,b,_,1,2` | Set `r32_ec_point(a) = 2 * r32_ec_point(b)`. This doubles the input point. Assumes that `r32_ec_point(b)` lies on the curve and is not the identity point. | +| SETUP_SW_DOUBLE\ | `a,b,_,1,2` | `assert(r32_ec_point(b).x == C::MODULUS)` in the chip for SW DOUBLE. For the sake of implementation convenience it also writes something (can be anything) into `[r32{0}(a): 2*C::COORD_SIZE]_2`. It is required for proper functionality that `assert(r32_ec_point(b).y != 0 mod C::MODULUS)` | +| TE_ADD\ | `a,b,c,1,2` | Set `r32_ec_point(a) = r32_ec_point(b) + r32_ec_point(c)` (curve addition). Assumes that `r32_ec_point(b), r32_ec_point(c)` both lie on the curve. | +| SETUP_TE_ADD\ | `a,b,c,1,2` | `assert(r32_ec_point(b).x == C::MODULUS && r32_ec_point(b).y == C::A && r32_ec_point(c).x == C::D)` in the chip for TE ADD. For the sake of implementation convenience it also writes something (can be anything) into `[r32{0}(a): 2*C::COORD_SIZE]_2`. | ### Pairing Extension diff --git a/docs/specs/RISCV.md b/docs/specs/RISCV.md index 32d0cc63fa..d60cc04d9c 100644 --- a/docs/specs/RISCV.md +++ b/docs/specs/RISCV.md @@ -5,7 +5,7 @@ The default VM extensions that support transpilation are: - [RV32IM](#rv32im-extension): An extension supporting the 32-bit RISC-V ISA with multiplication. - [Keccak-256](#keccak-extension): An extension implementing the Keccak-256 hash function compatibly with RISC-V memory. -- [SHA2-256](#sha2-256-extension): An extension implementing the SHA2-256 hash function compatibly with RISC-V memory. +- [SHA2](#sha-2-extension): An extension implementing the SHA-256, SHA-512, and SHA-384 hash functions compatibly with RISC-V memory. - [BigInt](#bigint-extension): An extension supporting 256-bit signed and unsigned integer arithmetic, including multiplication. This extension respects the RISC-V memory format. - [Algebra](#algebra-extension): An extension supporting modular arithmetic over arbitrary fields and their complex field extensions. This extension respects the RISC-V memory format. - [Elliptic curve](#elliptic-curve-extension): An extension for elliptic curve operations over Weierstrass curves, including addition and doubling. This can be used to implement multi-scalar multiplication and ECDSA scalar multiplication. This extension respects the RISC-V memory format. @@ -85,11 +85,13 @@ implementation is here. But we use `funct3 = 111` because the native extension h | ----------- | --- | ----------- | ------ | ------ | ------------------------------------------- | | keccak256 | R | 0001011 | 100 | 0x0 | `[rd:32]_2 = keccak256([rs1..rs1 + rs2]_2)` | -## SHA2-256 Extension +## SHA-2 Extension | RISC-V Inst | FMT | opcode[6:0] | funct3 | funct7 | RISC-V description and notes | | ----------- | --- | ----------- | ------ | ------ | ---------------------------------------- | | sha256 | R | 0001011 | 100 | 0x1 | `[rd:32]_2 = sha256([rs1..rs1 + rs2]_2)` | +| sha512 | R | 0001011 | 100 | 0x2 | `[rd:64]_2 = sha512([rs1..rs1 + rs2]_2)` | +| sha384 | R | 0001011 | 100 | 0x3 | `[rd:64]_2 = sha384([rs1..rs1 + rs2]_2)`. Last 16 bytes will be set to zeros. | ## BigInt Extension @@ -176,13 +178,16 @@ Complex extension field arithmetic over `Fp2` depends on `Fp` where `-1` is not ## Elliptic Curve Extension -The elliptic curve extension supports arithmetic over short Weierstrass curves, which requires specification of the elliptic curve `C`. The extension must be configured to support a fixed ordered list of supported curves. We use `config.curve_idx(C)` to denote the index of `C` in this list. In the list below, `idx` denotes `config.curve_idx(C)`. +The elliptic curve extension supports arithmetic over short Weierstrass curves and twisted Edwards curves, which requires specification of the elliptic curve `C`. The extension must be configured to support two fixed ordered lists of supported curves: one list of short Weierstrass curves and one list of twisted Edwards curves. Instructions prefixed with `sw_` are for short Weierstrass curves and instructions prefixed with `te_` are for twisted Edwards curves. We use `config.curve_idx(C)` to denote the index of `C` in the appropriate list. In the list below, `idx` denotes `config.curve_idx(C)`. | RISC-V Inst | FMT | opcode[6:0] | funct3 | funct7 | RISC-V description and notes | | --------------- | --- | ----------- | ------ | --------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | sw_add_ne\ | R | 0101011 | 001 | `idx*8` | `EcPoint([rd:2*C::COORD_SIZE]_2) = EcPoint([rs1:2*C::COORD_SIZE]_2) + EcPoint([rs2:2*C::COORD_SIZE]_2)`. Assumes that input affine points are not identity and do not have same x-coordinate. | | sw_double\ | R | 0101011 | 001 | `idx*8+1` | `EcPoint([rd:2*C::COORD_SIZE]_2) = 2 * EcPoint([rs1:2*C::COORD_SIZE]_2)`. Assumes that input affine point is not identity. `rs2` is unused and must be set to `x0`. | -| setup\ | R | 0101011 | 001 | `idx*8+2` | `assert([rs1: C::COORD_SIZE]_2 == C::MODULUS)` in the chip defined by the register index of `rs2`. For the sake of implementation convenience it also writes an unconstrained value into `[rd: 2*C::COORD_SIZE]_2`. If `ind(rs2) != 0`, then this instruction is setup for `sw_add_ne`. Otherwise it is setup for `sw_double`. When `ind(rs2) != 0` (add_ne), it is required for proper functionality that `[rs2: C::COORD_SIZE]_2 != [rs1: C::COORD_SIZE]_2`; otherwise (double), it is required that `[rs1 + C::COORD_SIZE: C::COORD_SIZE]_2 != C::Fp::ZERO` | +| sw_setup\ | R | 0101011 | 001 | `idx*8+2` | `assert([rs1: 2*C::COORD_SIZE]_2 == [C::MODULUS, CURVE_A])` in the chip defined by the register index of `rs2`. For the sake of implementation convenience it also writes an unconstrained value into `[rd: 2*C::COORD_SIZE]_2`. If `ind(rs2) != 0`, then this instruction is setup for `sw_add_ne`. Otherwise it is setup for `sw_double`. When `ind(rs2) != 0` (add_ne), it is required for proper functionality that `[rs2: C::COORD_SIZE]_2 != [rs1: C::COORD_SIZE]_2`; otherwise (double), it is required that `[rs1 + C::COORD_SIZE: C::COORD_SIZE]_2 != C::Fp::ZERO` | +| te_add\ | R | 0101011 | 100 | `idx*8` | `EcPoint([rd:2*C::COORD_SIZE]_2) = EcPoint([rs1:2*C::COORD_SIZE]_2) + EcPoint([rs2:2*C::COORD_SIZE]_2)`. | +| te_setup\ | R | 0101011 | 100 | `idx*8+1` | `assert([rs1: 2*C::COORD_SIZE]_2 == [C::MODULUS, C::CURVE_A] && [rs2: C::COORD_SIZE]_2 == C::CURVE_D])`. For the sake of implementation convenience it also writes an unconstrained value into `[rd: 2*C::COORD_SIZE]_2`. | + Since `funct7` is 7-bits, up to 16 curves can be supported simultaneously. We use `idx*8` to leave some room for future expansion. diff --git a/docs/specs/circuit.md b/docs/specs/circuit.md index 4238c7c27b..bd34344674 100644 --- a/docs/specs/circuit.md +++ b/docs/specs/circuit.md @@ -104,7 +104,7 @@ The chips that fall into these categories are: | FriReducedOpeningChip | – | – | Case 1. | | NativePoseidon2Chip | – | – | Case 1. | | Rv32HintStoreChip | – | – | Case 1. | -| Sha256VmChip | – | – | Case 1. | +| Sha2VmChip | – | – | Case 1. | The PhantomChip satisfies the condition because `1 < 3`. diff --git a/docs/specs/isa-table.md b/docs/specs/isa-table.md index 7b7f374065..fc76462a00 100644 --- a/docs/specs/isa-table.md +++ b/docs/specs/isa-table.md @@ -130,13 +130,15 @@ In the tables below, we provide the mapping between the `LocalOpcode` and `Phant | ------------- | ---------- | ------------- | | Keccak | `Rv32KeccakOpcode::KECCAK256` | KECCAK256_RV32 | -## SHA2-256 Extension +## SHA-2 Extension #### Instructions | VM Extension | `LocalOpcode` | ISA Instruction | | ------------- | ---------- | ------------- | -| SHA2-256 | `Rv32Sha256Opcode::SHA256` | SHA256_RV32 | +| SHA-2 | `Rv32Sha2Opcode::SHA256` | SHA256_RV32 | +| SHA-2 | `Rv32Sha2Opcode::SHA512` | SHA512_RV32 | +| SHA-2 | `Rv32Sha2Opcode::SHA384` | SHA384_RV32 | ## BigInt Extension diff --git a/docs/specs/transpiler.md b/docs/specs/transpiler.md index fded65b6d8..ae89dc418f 100644 --- a/docs/specs/transpiler.md +++ b/docs/specs/transpiler.md @@ -151,11 +151,13 @@ Each VM extension's behavior is specified below. | ----------- | -------------------------------------------------- | | keccak256 | KECCAK256_RV32 `ind(rd), ind(rs1), ind(rs2), 1, 2` | -### SHA2-256 Extension +### SHA-2 Extension | RISC-V Inst | OpenVM Instruction | | ----------- | ----------------------------------------------- | | sha256 | SHA256_RV32 `ind(rd), ind(rs1), ind(rs2), 1, 2` | +| sha512 | SHA512_RV32 `ind(rd), ind(rs1), ind(rs2), 1, 2` | +| sha384 | SHA384_RV32 `ind(rd), ind(rs1), ind(rs2), 1, 2` | ### BigInt Extension @@ -205,7 +207,9 @@ Each VM extension's behavior is specified below. | --------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------- | | sw_add_ne\ | EC_ADD_NE_RV32\ `ind(rd), ind(rs1), ind(rs2), 1, 2` | | sw_double\ | EC_DOUBLE_RV32\ `ind(rd), ind(rs1), 0, 1, 2` | -| setup\ | SETUP_EC_ADD_NE_RV32\ `ind(rd), ind(rs1), ind(rs2), 1, 2` if `ind(rs2) != 0`, SETUP_EC_DOUBLE_RV32\ `ind(rd), ind(rs1), ind(rs2), 1, 2` if `ind(rs2) = 0` | +| sw_setup\ | SETUP_EC_ADD_NE_RV32\ `ind(rd), ind(rs1), ind(rs2), 1, 2` if `ind(rs2) != 0`, SETUP_EC_DOUBLE_RV32\ `ind(rd), ind(rs1), ind(rs2), 1, 2` if `ind(rs2) = 0` | +| te_add\ | TE_ADD_RV32\ `ind(rd), ind(rs1), ind(rs2), 1, 2` | +| te_setup\ | SETUP_TE_ADD_RV32\ `ind(rd), ind(rs1), ind(rs2), 1, 2` | ### Pairing Extension diff --git a/examples/algebra/openvm/app.vmexe b/examples/algebra/openvm/app.vmexe new file mode 100644 index 0000000000..801ce82638 Binary files /dev/null and b/examples/algebra/openvm/app.vmexe differ diff --git a/examples/ecc/Cargo.toml b/examples/ecc/Cargo.toml index 3e0dcdbcfc..0206cba6a0 100644 --- a/examples/ecc/Cargo.toml +++ b/examples/ecc/Cargo.toml @@ -11,9 +11,11 @@ openvm = { git = "https://github.com/openvm-org/openvm.git", features = [ "std", ] } openvm-algebra-guest = { git = "https://github.com/openvm-org/openvm.git" } -openvm-ecc-guest = { git = "https://github.com/openvm-org/openvm.git" } +openvm-ecc-guest = { git = "https://github.com/openvm-org/openvm.git", features = ["ed25519"]} openvm-k256 = { git = "https://github.com/openvm-org/openvm.git", package = "k256" } hex-literal = { version = "0.4.1", default-features = false } +serde = { version = "1.0", default-features = false, features = [ "derive" ] } +num-bigint = { version = "0.4.6", default-features = false } [features] default = [] diff --git a/examples/ecc/openvm.toml b/examples/ecc/openvm.toml index 1dc6cf25f2..db8e420efc 100644 --- a/examples/ecc/openvm.toml +++ b/examples/ecc/openvm.toml @@ -2,11 +2,22 @@ [app_vm_config.rv32m] [app_vm_config.io] [app_vm_config.modular] -supported_moduli = ["115792089237316195423570985008687907853269984665640564039457584007908834671663", "115792089237316195423570985008687907852837564279074904382605163141518161494337"] +supported_moduli = ["115792089237316195423570985008687907853269984665640564039457584007908834671663", "115792089237316195423570985008687907852837564279074904382605163141518161494337", "57896044618658097711785492504343953926634992332820282019728792003956564819949"] -[[app_vm_config.ecc.supported_curves]] +[[app_vm_config.ecc.supported_sw_curves]] struct_name = "Secp256k1Point" modulus = "115792089237316195423570985008687907853269984665640564039457584007908834671663" scalar = "115792089237316195423570985008687907852837564279074904382605163141518161494337" + +[app_vm_config.ecc.supported_sw_curves.coeffs] a = "0" -b = "7" \ No newline at end of file +b = "7" + +[[app_vm_config.ecc.supported_te_curves]] +struct_name = "Ed25519Point" +modulus = "57896044618658097711785492504343953926634992332820282019728792003956564819949" +scalar = "7237005577332262213973186563042994240857116359379907606001950938285454250989" + +[app_vm_config.ecc.supported_te_curves.coeffs] +a = "57896044618658097711785492504343953926634992332820282019728792003956564819948" +d = "37095705934669439343138083508754565189542113879843219016388785533085940283555" diff --git a/examples/ecc/openvm/app.vmexe b/examples/ecc/openvm/app.vmexe new file mode 100644 index 0000000000..910f3a4efd Binary files /dev/null and b/examples/ecc/openvm/app.vmexe differ diff --git a/examples/ecc/openvm_init.rs b/examples/ecc/openvm_init.rs index bec9f527e9..eb3bca4373 100644 --- a/examples/ecc/openvm_init.rs +++ b/examples/ecc/openvm_init.rs @@ -1,3 +1,4 @@ // This file is automatically generated by cargo openvm. Do not rename or edit. -openvm_algebra_guest::moduli_macros::moduli_init! { "115792089237316195423570985008687907853269984665640564039457584007908834671663", "115792089237316195423570985008687907852837564279074904382605163141518161494337" } +openvm_algebra_guest::moduli_macros::moduli_init! { "115792089237316195423570985008687907853269984665640564039457584007908834671663", "115792089237316195423570985008687907852837564279074904382605163141518161494337", "57896044618658097711785492504343953926634992332820282019728792003956564819949" } openvm_ecc_guest::sw_macros::sw_init! { Secp256k1Point } +openvm_ecc_guest::te_macros::te_init! { Ed25519Point } diff --git a/examples/ecc/src/main.rs b/examples/ecc/src/main.rs index f95b6272ad..7b903397db 100644 --- a/examples/ecc/src/main.rs +++ b/examples/ecc/src/main.rs @@ -1,7 +1,11 @@ // ANCHOR: imports use hex_literal::hex; -use openvm_algebra_guest::IntMod; -use openvm_ecc_guest::weierstrass::WeierstrassPoint; +use openvm_ecc_guest::{ + algebra::IntMod, + ed25519::{Ed25519Coord, Ed25519Point}, + edwards::TwistedEdwardsPoint, + weierstrass::WeierstrassPoint, +}; use openvm_k256::{Secp256k1Coord, Secp256k1Point}; // ANCHOR_END: imports @@ -9,13 +13,11 @@ use openvm_k256::{Secp256k1Coord, Secp256k1Point}; openvm::init!(); /* The init! macro will expand to the following openvm_algebra_guest::moduli_macros::moduli_init! { - "0xFFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFE FFFFFC2F", - "0xFFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFE BAAEDCE6 AF48A03B BFD25E8C D0364141" -} - -openvm_ecc_guest::sw_macros::sw_init! { - Secp256k1Point, +"115792089237316195423570985008687907853269984665640564039457584007908834671663", +"115792089237316195423570985008687907852837564279074904382605163141518161494337" } +openvm_ecc_guest::sw_macros::sw_init! { Secp256k1Point } +openvm_ecc_guest::te_macros::te_init! { Ed25519Point } */ // ANCHOR_END: init @@ -35,5 +37,22 @@ pub fn main() { #[allow(clippy::op_ref)] let _p3 = &p1 + &p2; + + let x1 = Ed25519Coord::from_be_bytes_unchecked(&hex!( + "216936D3CD6E53FEC0A4E231FDD6DC5C692CC7609525A7B2C9562D608F25D51A" + )); + let y1 = Ed25519Coord::from_be_bytes_unchecked(&hex!( + "6666666666666666666666666666666666666666666666666666666666666658" + )); + let p1 = Ed25519Point::from_xy(x1, y1).unwrap(); + + let x2 = Ed25519Coord::from_u32(2); + let y2 = Ed25519Coord::from_be_bytes_unchecked(&hex!( + "1A43BF127BDDC4D71FF910403C11DDB5BA2BCDD2815393924657EF111E712631" + )); + let p2 = Ed25519Point::from_xy(x2, y2).unwrap(); + + #[allow(clippy::op_ref)] + let _p3 = &p1 + &p2; } // ANCHOR_END: main diff --git a/examples/i256/openvm/app.vmexe b/examples/i256/openvm/app.vmexe new file mode 100644 index 0000000000..e45a699ef3 Binary files /dev/null and b/examples/i256/openvm/app.vmexe differ diff --git a/examples/keccak/Cargo.toml b/examples/keccak/Cargo.toml index 74f15e6234..3c5cd8a26e 100644 --- a/examples/keccak/Cargo.toml +++ b/examples/keccak/Cargo.toml @@ -7,7 +7,7 @@ edition = "2021" members = [] [dependencies] -openvm = { git = "https://github.com/openvm-org/openvm.git", features = [ +openvm = { git = "https://github.com/openvm-org/openvm.git", branch = "develop", features = [ "std", ] } openvm-keccak256 = { git = "https://github.com/openvm-org/openvm.git" } diff --git a/examples/sha256/Cargo.toml b/examples/sha2/Cargo.toml similarity index 89% rename from examples/sha256/Cargo.toml rename to examples/sha2/Cargo.toml index 0b5a44bc3e..adfc269750 100644 --- a/examples/sha256/Cargo.toml +++ b/examples/sha2/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "sha256-example" +name = "sha2-example" version = "0.0.0" edition = "2021" @@ -7,6 +7,7 @@ edition = "2021" members = [] [dependencies] +# TODO: update rev after PR is merged openvm = { git = "https://github.com/openvm-org/openvm.git", features = [ "std", ] } diff --git a/examples/sha256/openvm.toml b/examples/sha2/openvm.toml similarity index 73% rename from examples/sha256/openvm.toml rename to examples/sha2/openvm.toml index 656bf52414..35f92b7195 100644 --- a/examples/sha256/openvm.toml +++ b/examples/sha2/openvm.toml @@ -1,4 +1,4 @@ [app_vm_config.rv32i] [app_vm_config.rv32m] [app_vm_config.io] -[app_vm_config.sha256] +[app_vm_config.sha2] diff --git a/examples/sha2/src/main.rs b/examples/sha2/src/main.rs new file mode 100644 index 0000000000..4fa1539ab6 --- /dev/null +++ b/examples/sha2/src/main.rs @@ -0,0 +1,39 @@ +// ANCHOR: imports +use core::hint::black_box; + +use hex::FromHex; +use openvm_sha2::{sha256, sha384, sha512}; +// ANCHOR_END: imports + +// ANCHOR: main +openvm::entry!(main); + +pub fn main() { + let test_vectors = [( + "", + "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", + "cf83e1357eefb8bdf1542850d66d8007d620e4050b5715dc83f4a921d36ce9ce47d0d13c5d85f2b0ff8318d2877eec2f63b931bd47417a81a538327af927da3e", + "38b060a751ac96384cd9327eb1b1e36a21fdb71114be07434c0cc7bf63f6e1da274edebfe76f65fbd51ad2f14898b95b", + )]; + for (input, expected_output_sha256, expected_output_sha512, expected_output_sha384) in + test_vectors.iter() + { + let input = Vec::from_hex(input).unwrap(); + let expected_output_sha256 = Vec::from_hex(expected_output_sha256).unwrap(); + let expected_output_sha512 = Vec::from_hex(expected_output_sha512).unwrap(); + let expected_output_sha384 = Vec::from_hex(expected_output_sha384).unwrap(); + let output = sha256(black_box(&input)); + if output != *expected_output_sha256 { + panic!(); + } + let output = sha512(black_box(&input)); + if output != *expected_output_sha512 { + panic!(); + } + let output = sha384(black_box(&input)); + if output != *expected_output_sha384 { + panic!(); + } + } +} +// ANCHOR_END: main diff --git a/examples/sha256/src/main.rs b/examples/sha256/src/main.rs deleted file mode 100644 index 6389aaa1dc..0000000000 --- a/examples/sha256/src/main.rs +++ /dev/null @@ -1,27 +0,0 @@ -openvm::entry!(main); - -// ANCHOR: imports -use core::hint::black_box; - -use hex::FromHex; -use openvm_sha2::sha256; -// ANCHOR_END: imports - -// ANCHOR: main -openvm::entry!(main); - -pub fn main() { - let test_vectors = [( - "", - "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", - )]; - for (input, expected_output) in test_vectors.iter() { - let input = Vec::from_hex(input).unwrap(); - let expected_output = Vec::from_hex(expected_output).unwrap(); - let output = sha256(&black_box(input)); - if output != *expected_output { - panic!(); - } - } -} -// ANCHOR_END: main diff --git a/extensions/ecc/circuit/Cargo.toml b/extensions/ecc/circuit/Cargo.toml index 81798e207a..a169102156 100644 --- a/extensions/ecc/circuit/Cargo.toml +++ b/extensions/ecc/circuit/Cargo.toml @@ -19,8 +19,11 @@ openvm-rv32im-circuit = { workspace = true } openvm-algebra-circuit = { workspace = true } openvm-rv32-adapters = { workspace = true } openvm-ecc-transpiler = { workspace = true } +openvm-ecc-guest = { workspace = true, features = ["ed25519"] } +openvm-sha2-circuit = { workspace = true } num-bigint = { workspace = true } +num-integer = { workspace = true } num-traits = { workspace = true } strum = { workspace = true } derive_more = { workspace = true } diff --git a/extensions/ecc/circuit/src/config.rs b/extensions/ecc/circuit/src/config.rs index a959938be9..c251e29b12 100644 --- a/extensions/ecc/circuit/src/config.rs +++ b/extensions/ecc/circuit/src/config.rs @@ -2,13 +2,14 @@ use openvm_algebra_circuit::*; use openvm_circuit::arch::{InitFileGenerator, SystemConfig}; use openvm_circuit_derive::VmConfig; use openvm_rv32im_circuit::*; +use openvm_sha2_circuit::{Sha2, Sha2Executor, Sha2Periphery}; use openvm_stark_backend::p3_field::PrimeField32; use serde::{Deserialize, Serialize}; use super::*; #[derive(Clone, Debug, VmConfig, Serialize, Deserialize)] -pub struct Rv32WeierstrassConfig { +pub struct Rv32EccConfig { #[system] pub system: SystemConfig, #[extension] @@ -20,32 +21,43 @@ pub struct Rv32WeierstrassConfig { #[extension] pub modular: ModularExtension, #[extension] - pub weierstrass: WeierstrassExtension, + pub ecc: EccExtension, + #[extension] + pub sha2: Sha2, } -impl Rv32WeierstrassConfig { - pub fn new(curves: Vec) -> Self { - let primes: Vec<_> = curves +impl Rv32EccConfig { + pub fn new( + sw_curves: Vec>, + te_curves: Vec>, + ) -> Self { + let sw_primes: Vec<_> = sw_curves + .iter() + .flat_map(|c| [c.modulus.clone(), c.scalar.clone()]) + .collect(); + let te_primes: Vec<_> = te_curves .iter() .flat_map(|c| [c.modulus.clone(), c.scalar.clone()]) .collect(); + let primes = sw_primes.into_iter().chain(te_primes).collect(); Self { system: SystemConfig::default().with_continuations(), base: Default::default(), mul: Default::default(), io: Default::default(), modular: ModularExtension::new(primes), - weierstrass: WeierstrassExtension::new(curves), + ecc: EccExtension::new(sw_curves, te_curves), + sha2: Default::default(), } } } -impl InitFileGenerator for Rv32WeierstrassConfig { +impl InitFileGenerator for Rv32EccConfig { fn generate_init_file_contents(&self) -> Option { Some(format!( "// This file is automatically generated by cargo openvm. Do not rename or edit.\n{}\n{}\n", self.modular.generate_moduli_init(), - self.weierstrass.generate_sw_init() + self.ecc.generate_ecc_init() )) } } diff --git a/extensions/ecc/circuit/src/ecc_extension.rs b/extensions/ecc/circuit/src/ecc_extension.rs new file mode 100644 index 0000000000..648a6ea00b --- /dev/null +++ b/extensions/ecc/circuit/src/ecc_extension.rs @@ -0,0 +1,348 @@ +use derive_more::derive::From; +use hex_literal::hex; +use num_bigint::BigUint; +use num_traits::{FromPrimitive, Zero}; +use once_cell::sync::Lazy; +use openvm_circuit::{ + arch::{ + ExecutionBridge, SystemPort, VmExtension, VmInventory, VmInventoryBuilder, VmInventoryError, + }, + system::phantom::PhantomChip, +}; +use openvm_circuit_derive::{AnyEnum, InsExecutorE1, InsExecutorE2, InstructionExecutor}; +use openvm_circuit_primitives::bitwise_op_lookup::{ + BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, +}; +use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; +use openvm_ecc_guest::{ + algebra::IntMod, + ed25519::{CURVE_A as ED25519_A, CURVE_D as ED25519_D, ED25519_MODULUS, ED25519_ORDER}, +}; +use openvm_ecc_transpiler::{Rv32EdwardsOpcode, Rv32WeierstrassOpcode}; +use openvm_instructions::{LocalOpcode, VmOpcode}; +use openvm_mod_circuit_builder::ExprBuilderConfig; +use openvm_stark_backend::p3_field::PrimeField32; +use serde::{Deserialize, Serialize}; +use serde_with::{serde_as, DisplayFromStr}; +use strum::EnumCount; + +use super::{SwAddNeChip, SwDoubleChip, TeAddChip}; + +#[serde_as] +#[derive(Clone, Debug, derive_new::new, Serialize, Deserialize)] +pub struct CurveConfig { + /// The name of the curve struct as defined by moduli_declare. + pub struct_name: String, + /// The coordinate modulus of the curve. + #[serde_as(as = "DisplayFromStr")] + pub modulus: BigUint, + /// The scalar field modulus of the curve. + #[serde_as(as = "DisplayFromStr")] + pub scalar: BigUint, + // curve-specific coefficients + pub coeffs: T, +} + +#[serde_as] +#[derive(Clone, Debug, derive_new::new, Serialize, Deserialize)] +pub struct SwCurveCoeffs { + /// The coefficient a of y^2 = x^3 + ax + b. + #[serde_as(as = "DisplayFromStr")] + pub a: BigUint, + /// The coefficient b of y^2 = x^3 + ax + b. + #[serde_as(as = "DisplayFromStr")] + pub b: BigUint, +} + +#[serde_as] +#[derive(Clone, Debug, derive_new::new, Serialize, Deserialize)] +pub struct TeCurveCoeffs { + /// The coefficient a of ax^2 + y^2 = 1 + dx^2y^2 + #[serde_as(as = "DisplayFromStr")] + pub a: BigUint, + /// The coefficient d of ax^2 + y^2 = 1 + dx^2y^2 + #[serde_as(as = "DisplayFromStr")] + pub d: BigUint, +} + +pub static SECP256K1_CONFIG: Lazy> = Lazy::new(|| CurveConfig { + struct_name: "Secp256k1Point".to_string(), + modulus: BigUint::from_bytes_be(&hex!( + "FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFE FFFFFC2F" + )), + scalar: BigUint::from_bytes_be(&hex!( + "FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFE BAAEDCE6 AF48A03B BFD25E8C D0364141" + )), + coeffs: SwCurveCoeffs { + a: BigUint::zero(), + b: BigUint::from_u8(7u8).unwrap(), + }, +}); + +pub static P256_CONFIG: Lazy> = Lazy::new(|| CurveConfig { + struct_name: "P256Point".to_string(), + modulus: BigUint::from_bytes_be(&hex!( + "ffffffff00000001000000000000000000000000ffffffffffffffffffffffff" + )), + scalar: BigUint::from_bytes_be(&hex!( + "ffffffff00000000ffffffffffffffffbce6faada7179e84f3b9cac2fc632551" + )), + coeffs: SwCurveCoeffs { + a: BigUint::from_bytes_le(&hex!( + "fcffffffffffffffffffffff00000000000000000000000001000000ffffffff" + )), + b: BigUint::from_bytes_le(&hex!( + "4b60d2273e3cce3bf6b053ccb0061d65bc86987655bdebb3e7933aaad835c65a" + )), + }, +}); + +pub static ED25519_CONFIG: Lazy> = Lazy::new(|| CurveConfig { + struct_name: "Ed25519Point".to_string(), + modulus: ED25519_MODULUS.clone(), + scalar: ED25519_ORDER.clone(), + coeffs: TeCurveCoeffs { + a: BigUint::from_bytes_le(ED25519_A.as_le_bytes()), + d: BigUint::from_bytes_le(ED25519_D.as_le_bytes()), + }, +}); + +#[derive(Clone, Debug, derive_new::new, Serialize, Deserialize)] +pub struct EccExtension { + #[serde(default)] + pub supported_sw_curves: Vec>, + #[serde(default)] + pub supported_te_curves: Vec>, +} + +impl EccExtension { + pub fn generate_ecc_init(&self) -> String { + let supported_sw_curves = self + .supported_sw_curves + .iter() + .map(|curve_config| curve_config.struct_name.to_string()) + .collect::>() + .join(", "); + + let supported_te_curves = self + .supported_te_curves + .iter() + .map(|curve_config| curve_config.struct_name.to_string()) + .collect::>() + .join(", "); + + format!( + "openvm_ecc_guest::sw_macros::sw_init! {{ {supported_sw_curves} }}\nopenvm_ecc_guest::te_macros::te_init! {{ {supported_te_curves} }}" + ) + } +} + +#[derive(Chip, ChipUsageGetter, InstructionExecutor, AnyEnum, InsExecutorE1, InsExecutorE2)] +pub enum EccExtensionExecutor { + // 32 limbs prime + SwEcAddNeRv32_32(SwAddNeChip), + SwEcDoubleRv32_32(SwDoubleChip), + // 48 limbs prime + SwEcAddNeRv32_48(SwAddNeChip), + SwEcDoubleRv32_48(SwDoubleChip), + // 32 limbs prime + TeEcAddRv32_32(TeAddChip), + // 48 limbs prime + TeEcAddRv32_48(TeAddChip), +} + +#[derive(ChipUsageGetter, Chip, AnyEnum, From)] +pub enum EccExtensionPeriphery { + BitwiseOperationLookup(SharedBitwiseOperationLookupChip<8>), + // We put this only to get the generic to work + Phantom(PhantomChip), +} + +impl VmExtension for EccExtension { + type Executor = EccExtensionExecutor; + type Periphery = EccExtensionPeriphery; + + fn build( + &self, + builder: &mut VmInventoryBuilder, + ) -> Result, VmInventoryError> { + let mut inventory = VmInventory::new(); + let SystemPort { + execution_bus, + program_bus, + memory_bridge, + } = builder.system_port(); + + let execution_bridge = ExecutionBridge::new(execution_bus, program_bus); + let range_checker = builder.system_base().range_checker_chip.clone(); + let pointer_max_bits = builder.system_config().memory_config.pointer_max_bits; + + let bitwise_lu_chip = if let Some(&chip) = builder + .find_chip::>() + .first() + { + chip.clone() + } else { + let bitwise_lu_bus = BitwiseOperationLookupBus::new(builder.new_bus_idx()); + let chip = SharedBitwiseOperationLookupChip::new(bitwise_lu_bus); + inventory.add_periphery_chip(chip.clone()); + chip + }; + + let sw_add_ne_opcodes = (Rv32WeierstrassOpcode::SW_ADD_NE as usize) + ..=(Rv32WeierstrassOpcode::SETUP_SW_ADD_NE as usize); + let sw_double_opcodes = (Rv32WeierstrassOpcode::SW_DOUBLE as usize) + ..=(Rv32WeierstrassOpcode::SETUP_SW_DOUBLE as usize); + + let te_add_opcodes = + (Rv32EdwardsOpcode::TE_ADD as usize)..=(Rv32EdwardsOpcode::SETUP_TE_ADD as usize); + + for (sw_idx, curve) in self.supported_sw_curves.iter().enumerate() { + // TODO: Better support for different limb sizes. Currently only 32 or 48 limbs are + // supported. + let sw_start_offset = + Rv32WeierstrassOpcode::CLASS_OFFSET + sw_idx * Rv32WeierstrassOpcode::COUNT; + let bytes = curve.modulus.bits().div_ceil(8); + let config32 = ExprBuilderConfig { + modulus: curve.modulus.clone(), + num_limbs: 32, + limb_bits: 8, + }; + let config48 = ExprBuilderConfig { + modulus: curve.modulus.clone(), + num_limbs: 48, + limb_bits: 8, + }; + if bytes <= 32 { + let sw_add_ne_chip = SwAddNeChip::new( + execution_bridge, + memory_bridge, + builder.system_base().memory_controller.helper(), + pointer_max_bits, + config32.clone(), + sw_start_offset, + bitwise_lu_chip.clone(), + range_checker.clone(), + ); + inventory.add_executor( + EccExtensionExecutor::SwEcAddNeRv32_32(sw_add_ne_chip), + sw_add_ne_opcodes + .clone() + .map(|x| VmOpcode::from_usize(x + sw_start_offset)), + )?; + let sw_double_chip = SwDoubleChip::new( + execution_bridge, + memory_bridge, + builder.system_base().memory_controller.helper(), + pointer_max_bits, + config32.clone(), + sw_start_offset, + bitwise_lu_chip.clone(), + range_checker.clone(), + curve.coeffs.a.clone(), + ); + inventory.add_executor( + EccExtensionExecutor::SwEcDoubleRv32_32(sw_double_chip), + sw_double_opcodes + .clone() + .map(|x| VmOpcode::from_usize(x + sw_start_offset)), + )?; + } else if bytes <= 48 { + let sw_add_ne_chip = SwAddNeChip::new( + execution_bridge, + memory_bridge, + builder.system_base().memory_controller.helper(), + pointer_max_bits, + config48.clone(), + sw_start_offset, + bitwise_lu_chip.clone(), + range_checker.clone(), + ); + inventory.add_executor( + EccExtensionExecutor::SwEcAddNeRv32_48(sw_add_ne_chip), + sw_add_ne_opcodes + .clone() + .map(|x| VmOpcode::from_usize(x + sw_start_offset)), + )?; + let sw_double_chip = SwDoubleChip::new( + execution_bridge, + memory_bridge, + builder.system_base().memory_controller.helper(), + pointer_max_bits, + config48.clone(), + sw_start_offset, + bitwise_lu_chip.clone(), + range_checker.clone(), + curve.coeffs.a.clone(), + ); + inventory.add_executor( + EccExtensionExecutor::SwEcDoubleRv32_48(sw_double_chip), + sw_double_opcodes + .clone() + .map(|x| VmOpcode::from_usize(x + sw_start_offset)), + )?; + } else { + panic!("Modulus too large"); + } + } + + for (te_idx, curve) in self.supported_te_curves.iter().enumerate() { + let bytes = curve.modulus.bits().div_ceil(8); + let config32 = ExprBuilderConfig { + modulus: curve.modulus.clone(), + num_limbs: 32, + limb_bits: 8, + }; + let config48 = ExprBuilderConfig { + modulus: curve.modulus.clone(), + num_limbs: 48, + limb_bits: 8, + }; + let te_start_offset = + Rv32EdwardsOpcode::CLASS_OFFSET + te_idx * Rv32EdwardsOpcode::COUNT; + if bytes <= 32 { + let te_add_chip = TeAddChip::new( + execution_bridge, + memory_bridge, + builder.system_base().memory_controller.helper(), + pointer_max_bits, + config32.clone(), + te_start_offset, + bitwise_lu_chip.clone(), + range_checker.clone(), + curve.coeffs.a.clone(), + curve.coeffs.d.clone(), + ); + inventory.add_executor( + EccExtensionExecutor::TeEcAddRv32_32(te_add_chip), + te_add_opcodes + .clone() + .map(|x| VmOpcode::from_usize(x + te_start_offset)), + )?; + } else if bytes <= 48 { + let te_add_chip = TeAddChip::new( + execution_bridge, + memory_bridge, + builder.system_base().memory_controller.helper(), + pointer_max_bits, + config48.clone(), + te_start_offset, + bitwise_lu_chip.clone(), + range_checker.clone(), + curve.coeffs.a.clone(), + curve.coeffs.d.clone(), + ); + inventory.add_executor( + EccExtensionExecutor::TeEcAddRv32_48(te_add_chip), + te_add_opcodes + .clone() + .map(|x| VmOpcode::from_usize(x + te_start_offset)), + )?; + } else { + panic!("Modulus too large"); + } + } + + Ok(inventory) + } +} diff --git a/extensions/ecc/circuit/src/edwards_chip/README.md b/extensions/ecc/circuit/src/edwards_chip/README.md new file mode 100644 index 0000000000..24167e062a --- /dev/null +++ b/extensions/ecc/circuit/src/edwards_chip/README.md @@ -0,0 +1,17 @@ +# Twisted Edwards (TE) curve operations + +The `te_add` instruction is implemented in the `edwards_chip` module. + +### 1. `te_add` + +**Assumptions:** + +- Both points `(x1, y1)` and `(x2, y2)` lie on the curve. + +**Circuit statements:** + +- The chip takes two inputs: `(x1, y1)` and `(x2, y2)`, and returns `(x3, y3)` where: + - `x3 = (x1 * y2 + x2 * y1) / (1 + d * x1 * x2 * y1 * y2)` + - `y3 = (y1 * y2 - a * x1 * x2) / (1 - d * x1 * x2 * y1 * y2)` + +- The `TeAddChip` constrains that these field expressions are computed correctly over the field `C::Fp`. The coefficients `a` and `d` are taken from the `CurveConfig`. diff --git a/extensions/ecc/circuit/src/edwards_chip/add.rs b/extensions/ecc/circuit/src/edwards_chip/add.rs new file mode 100644 index 0000000000..87b84cffb9 --- /dev/null +++ b/extensions/ecc/circuit/src/edwards_chip/add.rs @@ -0,0 +1,114 @@ +use std::{cell::RefCell, rc::Rc}; + +use num_bigint::BigUint; +use num_traits::One; +use openvm_circuit::{ + arch::ExecutionBridge, + system::memory::{offline_checker::MemoryBridge, SharedMemoryHelper}, +}; +use openvm_circuit_derive::{InsExecutorE1, InsExecutorE2, InstructionExecutor}; +use openvm_circuit_primitives::{ + bitwise_op_lookup::SharedBitwiseOperationLookupChip, + var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus}, + Chip, ChipUsageGetter, +}; +use openvm_ecc_transpiler::Rv32EdwardsOpcode; +use openvm_mod_circuit_builder::{ + ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreAir, +}; +use openvm_rv32_adapters::{Rv32VecHeapAdapterAir, Rv32VecHeapAdapterStep}; +use openvm_stark_backend::p3_field::PrimeField32; + +use super::{utils::jacobi, EdwardsAir, EdwardsChip, EdwardsStep}; + +pub fn te_add_expr( + config: ExprBuilderConfig, // The coordinate field. + range_bus: VariableRangeCheckerBus, + a_biguint: BigUint, + d_biguint: BigUint, +) -> FieldExpr { + config.check_valid(); + let builder = ExprBuilder::new(config, range_bus.range_max_bits); + let builder = Rc::new(RefCell::new(builder)); + + let x1 = ExprBuilder::new_input(builder.clone()); + let y1 = ExprBuilder::new_input(builder.clone()); + let x2 = ExprBuilder::new_input(builder.clone()); + let y2 = ExprBuilder::new_input(builder.clone()); + let a = ExprBuilder::new_const(builder.clone(), a_biguint.clone()); + let d = ExprBuilder::new_const(builder.clone(), d_biguint.clone()); + let one = ExprBuilder::new_const(builder.clone(), BigUint::one()); + + let x1y2 = x1.clone() * y2.clone(); + let x2y1 = x2.clone() * y1.clone(); + let y1y2 = y1 * y2; + let x1x2 = x1 * x2; + let dx1x2y1y2 = d * x1x2.clone() * y1y2.clone(); + + let mut x3 = (x1y2 + x2y1) / (one.clone() + dx1x2y1y2.clone()); + let mut y3 = (y1y2 - a * x1x2) / (one - dx1x2y1y2); + + x3.save_output(); + y3.save_output(); + + let builder = builder.borrow().clone(); + + FieldExpr::new_with_setup_values(builder, range_bus, true, vec![a_biguint, d_biguint]) +} + +#[derive(Chip, ChipUsageGetter, InstructionExecutor, InsExecutorE1, InsExecutorE2)] +pub struct TeAddChip( + pub EdwardsChip, +); + +#[allow(clippy::too_many_arguments)] +impl + TeAddChip +{ + pub fn new( + execution_bridge: ExecutionBridge, + memory_bridge: MemoryBridge, + mem_helper: SharedMemoryHelper, + pointer_max_bits: usize, + config: ExprBuilderConfig, + offset: usize, + bitwise_lookup_chip: SharedBitwiseOperationLookupChip<8>, + range_checker: SharedVariableRangeCheckerChip, + a: BigUint, + d: BigUint, + ) -> Self { + // Ensure that the addition operation is complete + assert!(jacobi(&a.clone().into(), &config.modulus.clone().into()) == 1); + assert!(jacobi(&d.clone().into(), &config.modulus.clone().into()) == -1); + + let expr = te_add_expr(config, range_checker.bus(), a, d); + + let local_opcode_idx = vec![ + Rv32EdwardsOpcode::TE_ADD as usize, + Rv32EdwardsOpcode::SETUP_TE_ADD as usize, + ]; + + let air = EdwardsAir::new( + Rv32VecHeapAdapterAir::new( + execution_bridge, + memory_bridge, + bitwise_lookup_chip.bus(), + pointer_max_bits, + ), + FieldExpressionCoreAir::new(expr.clone(), offset, local_opcode_idx.clone(), vec![]), + ); + + let step = EdwardsStep::new( + Rv32VecHeapAdapterStep::new(pointer_max_bits, bitwise_lookup_chip), + expr, + offset, + local_opcode_idx, + vec![], + range_checker, + "TeEcAdd", + true, + ); + + Self(EdwardsChip::new(air, step, mem_helper)) + } +} diff --git a/extensions/ecc/circuit/src/edwards_chip/mod.rs b/extensions/ecc/circuit/src/edwards_chip/mod.rs new file mode 100644 index 0000000000..66ab7c8639 --- /dev/null +++ b/extensions/ecc/circuit/src/edwards_chip/mod.rs @@ -0,0 +1,33 @@ +mod add; +pub use add::*; + +mod utils; + +#[cfg(test)] +mod tests; + +use openvm_algebra_circuit::FieldExprVecHeapStep; +use openvm_circuit::arch::{MatrixRecordArena, NewVmChipWrapper, VmAirWrapper}; +use openvm_mod_circuit_builder::FieldExpressionCoreAir; +use openvm_rv32_adapters::Rv32VecHeapAdapterAir; + +pub(crate) type EdwardsAir = + VmAirWrapper< + Rv32VecHeapAdapterAir, + FieldExpressionCoreAir, + >; + +pub(crate) type EdwardsStep = + FieldExprVecHeapStep; + +pub(crate) type EdwardsChip< + F, + const NUM_READS: usize, + const BLOCKS: usize, + const BLOCK_SIZE: usize, +> = NewVmChipWrapper< + F, + EdwardsAir, + EdwardsStep, + MatrixRecordArena, +>; diff --git a/extensions/ecc/circuit/src/edwards_chip/tests.rs b/extensions/ecc/circuit/src/edwards_chip/tests.rs new file mode 100644 index 0000000000..cefa42d275 --- /dev/null +++ b/extensions/ecc/circuit/src/edwards_chip/tests.rs @@ -0,0 +1,186 @@ +use std::str::FromStr; + +use num_bigint::BigUint; +use num_traits::FromPrimitive; +use openvm_circuit::arch::testing::{VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}; +use openvm_circuit_primitives::{ + bigint::utils::big_uint_to_limbs, + bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, +}; +use openvm_ecc_transpiler::Rv32EdwardsOpcode; +use openvm_instructions::{riscv::RV32_CELL_BITS, LocalOpcode}; +use openvm_mod_circuit_builder::{test_utils::biguint_to_limbs, ExprBuilderConfig, FieldExpr}; +use openvm_rv32_adapters::rv32_write_heap_default; +use openvm_stark_backend::p3_field::FieldAlgebra; +use openvm_stark_sdk::p3_baby_bear::BabyBear; + +use super::TeAddChip; + +const NUM_LIMBS: usize = 32; +const LIMB_BITS: usize = 8; +const BLOCK_SIZE: usize = 32; +const MAX_INS_CAPACITY: usize = 128; +type F = BabyBear; + +lazy_static::lazy_static! { + pub static ref SampleEcPoints: Vec<(BigUint, BigUint)> = { + // Base point of edwards25519 + let x1 = BigUint::from_str( + "15112221349535400772501151409588531511454012693041857206046113283949847762202", + ) + .unwrap(); + let y1 = BigUint::from_str( + "46316835694926478169428394003475163141307993866256225615783033603165251855960", + ) + .unwrap(); + + // random point on edwards25519 + let x2 = BigUint::from_u32(2).unwrap(); + let y2 = BigUint::from_str( + "11879831548380997166425477238087913000047176376829905612296558668626594440753", + ) + .unwrap(); + + // This is the sum of (x1, y1) and (x2, y2). + let x3 = BigUint::from_str( + "44969869612046584870714054830543834361257841801051546235130567688769346152934", + ) + .unwrap(); + let y3 = BigUint::from_str( + "50796027728050908782231253190819121962159170739537197094456293084373503699602", + ) + .unwrap(); + + // This is 2 * (x1, y1) + let x4 = BigUint::from_str( + "39226743113244985161159605482495583316761443760287217110659799046557361995496", + ) + .unwrap(); + let y4 = BigUint::from_str( + "12570354238812836652656274015246690354874018829607973815551555426027032771563", + ) + .unwrap(); + + vec![(x1, y1), (x2, y2), (x3, y3), (x4, y4)] + }; + + pub static ref Edwards25519_Prime: BigUint = BigUint::from_str( + "57896044618658097711785492504343953926634992332820282019728792003956564819949", + ) + .unwrap(); + + pub static ref Edwards25519_A: BigUint = BigUint::from_str( + "57896044618658097711785492504343953926634992332820282019728792003956564819948", + ) + .unwrap(); + + pub static ref Edwards25519_D: BigUint = BigUint::from_str( + "37095705934669439343138083508754565189542113879843219016388785533085940283555", + ) + .unwrap(); + + pub static ref Edwards25519_A_LIMBS: [BabyBear; NUM_LIMBS] = + big_uint_to_limbs(&Edwards25519_A, LIMB_BITS) + .into_iter() + .map(BabyBear::from_canonical_usize) + .collect::>() + .try_into() + .unwrap(); + pub static ref Edwards25519_D_LIMBS: [BabyBear; NUM_LIMBS] = + big_uint_to_limbs(&Edwards25519_D, LIMB_BITS) + .into_iter() + .map(BabyBear::from_canonical_usize) + .collect::>() + .try_into() + .unwrap(); +} + +fn prime_limbs(expr: &FieldExpr) -> Vec { + expr.prime_limbs + .iter() + .map(|n| BabyBear::from_canonical_usize(*n)) + .collect::>() +} + +#[test] +fn test_add() { + let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); + let config = ExprBuilderConfig { + modulus: Edwards25519_Prime.clone(), + num_limbs: NUM_LIMBS, + limb_bits: LIMB_BITS, + }; + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); + + let mut chip = TeAddChip::::new( + tester.execution_bridge(), + tester.memory_bridge(), + tester.memory_helper(), + tester.address_bits(), + config, + Rv32EdwardsOpcode::CLASS_OFFSET, + bitwise_chip.clone(), + tester.range_checker(), + Edwards25519_A.clone(), + Edwards25519_D.clone(), + ); + chip.0.set_trace_buffer_height(MAX_INS_CAPACITY); + + assert_eq!(chip.0.step.0.expr.builder.num_variables, 12); + + let (p1_x, p1_y) = SampleEcPoints[0].clone(); + let (p2_x, p2_y) = SampleEcPoints[1].clone(); + + let p1_x_limbs = + biguint_to_limbs::(p1_x.clone(), LIMB_BITS).map(BabyBear::from_canonical_u32); + let p1_y_limbs = + biguint_to_limbs::(p1_y.clone(), LIMB_BITS).map(BabyBear::from_canonical_u32); + let p2_x_limbs = + biguint_to_limbs::(p2_x.clone(), LIMB_BITS).map(BabyBear::from_canonical_u32); + let p2_y_limbs = + biguint_to_limbs::(p2_y.clone(), LIMB_BITS).map(BabyBear::from_canonical_u32); + + let r = chip + .0 + .step + .0 + .expr + .execute(vec![p1_x, p1_y, p2_x, p2_y], vec![true]); + assert_eq!(r.len(), 12); + + let outputs = chip + .0 + .step + .0 + .output_indices() + .iter() + .map(|i| &r[*i]) + .collect::>(); + assert_eq!(outputs[0], &SampleEcPoints[2].0); + assert_eq!(outputs[1], &SampleEcPoints[2].1); + + let prime_limbs: [BabyBear; NUM_LIMBS] = prime_limbs(&chip.0.step.0.expr).try_into().unwrap(); + let mut one_limbs = [BabyBear::ONE; NUM_LIMBS]; + one_limbs[0] = BabyBear::ONE; + let setup_instruction = rv32_write_heap_default( + &mut tester, + vec![prime_limbs, *Edwards25519_A_LIMBS], + vec![*Edwards25519_D_LIMBS], + chip.0.step.0.offset + Rv32EdwardsOpcode::SETUP_TE_ADD as usize, + ); + tester.execute(&mut chip, &setup_instruction); + + let instruction = rv32_write_heap_default( + &mut tester, + vec![p1_x_limbs, p1_y_limbs], + vec![p2_x_limbs, p2_y_limbs], + chip.0.step.0.offset + Rv32EdwardsOpcode::TE_ADD as usize, + ); + + tester.execute(&mut chip, &instruction); + + let tester = tester.build().load(chip).load(bitwise_chip).finalize(); + + tester.simple_test().expect("Verification failed"); +} diff --git a/extensions/ecc/circuit/src/edwards_chip/utils.rs b/extensions/ecc/circuit/src/edwards_chip/utils.rs new file mode 100644 index 0000000000..ce7711519f --- /dev/null +++ b/extensions/ecc/circuit/src/edwards_chip/utils.rs @@ -0,0 +1,101 @@ +use num_bigint::BigInt; +use num_integer::Integer; +use num_traits::{sign::Signed, One, Zero}; + +/// Jacobi returns the Jacobi symbol (x/y), either +1, -1, or 0. +/// The y argument must be an odd integer. +pub fn jacobi(x: &BigInt, y: &BigInt) -> isize { + if !y.is_odd() { + panic!( + "invalid arguments, y must be an odd integer,but got {:?}", + y + ); + } + + let mut a = x.clone(); + let mut b = y.clone(); + let mut j = 1; + + if b.is_negative() { + if a.is_negative() { + j = -1; + } + b = -b; + } + + loop { + if b.is_one() { + return j; + } + if a.is_zero() { + return 0; + } + + a = a.mod_floor(&b); + if a.is_zero() { + return 0; + } + + // a > 0 + + // handle factors of 2 in a + let s = a.trailing_zeros().unwrap(); + if s & 1 != 0 { + //let bmod8 = b.get_limb(0) & 7; + let bmod8 = mod_2_to_the_k(&b, 3); + if bmod8 == BigInt::from(3) || bmod8 == BigInt::from(5) { + j = -j; + } + } + + let c = &a >> s; // a = 2^s*c + + // swap numerator and denominator + if mod_2_to_the_k(&b, 2) == BigInt::from(3) && mod_2_to_the_k(&c, 2) == BigInt::from(3) { + j = -j + } + + a = b; + b = c; + } +} + +fn mod_2_to_the_k(x: &BigInt, k: u32) -> BigInt { + x & BigInt::from(2u32.pow(k) - 1) +} +#[cfg(test)] +mod tests { + use num_traits::FromPrimitive; + + use super::*; + + #[test] + fn test_jacobi() { + let cases = [ + [0, 1, 1], + [0, -1, 1], + [1, 1, 1], + [1, -1, 1], + [0, 5, 0], + [1, 5, 1], + [2, 5, -1], + [-2, 5, -1], + [2, -5, -1], + [-2, -5, 1], + [3, 5, -1], + [5, 5, 0], + [-5, 5, 0], + [6, 5, 1], + [6, -5, 1], + [-6, 5, 1], + [-6, -5, -1], + ]; + + for case in cases.iter() { + let x = BigInt::from_i64(case[0]).unwrap(); + let y = BigInt::from_i64(case[1]).unwrap(); + + assert_eq!(case[2] as isize, jacobi(&x, &y), "jacobi({}, {})", x, y); + } + } +} diff --git a/extensions/ecc/circuit/src/lib.rs b/extensions/ecc/circuit/src/lib.rs index c1ec864636..050bf57a3b 100644 --- a/extensions/ecc/circuit/src/lib.rs +++ b/extensions/ecc/circuit/src/lib.rs @@ -1,8 +1,11 @@ mod weierstrass_chip; pub use weierstrass_chip::*; -mod weierstrass_extension; -pub use weierstrass_extension::*; +mod ecc_extension; +pub use ecc_extension::*; + +mod edwards_chip; +pub use edwards_chip::*; mod config; pub use config::*; diff --git a/extensions/ecc/circuit/src/weierstrass_chip/README.md b/extensions/ecc/circuit/src/weierstrass_chip/README.md index 94d8df6847..ba7119b0fc 100644 --- a/extensions/ecc/circuit/src/weierstrass_chip/README.md +++ b/extensions/ecc/circuit/src/weierstrass_chip/README.md @@ -1,8 +1,8 @@ # Short Weierstrass (SW) Curve Operations -The `ec_add_ne` and `ec_double` instructions are implemented in the `weierstrass_chip` module. +The `sw_add_ne` and `sw_double` instructions are implemented in the `weierstrass_chip` module. -### 1. `ec_add_ne` +### 1. `sw_add_ne` **Assumptions:** @@ -16,9 +16,9 @@ The `ec_add_ne` and `ec_double` instructions are implemented in the `weierstrass - `x3 = lambda^2 - x1 - x2` - `y3 = lambda * (x1 - x3) - y1` -- The `EcAddNeChip` constrains that these field expressions are computed correctly over the field `C::Fp`. +- The `SwAddNeChip` constrains that these field expressions are computed correctly over the field `C::Fp`. -### 2. `ec_double` +### 2. `sw_double` **Assumptions:** @@ -31,4 +31,4 @@ The `ec_add_ne` and `ec_double` instructions are implemented in the `weierstrass - `x3 = lambda^2 - 2 * x1` - `y3 = lambda * (x1 - x3) - y1` -- The `EcDoubleChip` constrains that these expressions are computed correctly over the field `C::Fp`. The coefficient `a` is taken from the `CurveConfig`. +- The `SwDoubleChip` constrains that these expressions are computed correctly over the field `C::Fp`. The coefficient `a` is taken from the `CurveConfig`. diff --git a/extensions/ecc/circuit/src/weierstrass_chip/add_ne.rs b/extensions/ecc/circuit/src/weierstrass_chip/add_ne.rs index f45766c6a4..5a85510603 100644 --- a/extensions/ecc/circuit/src/weierstrass_chip/add_ne.rs +++ b/extensions/ecc/circuit/src/weierstrass_chip/add_ne.rs @@ -22,7 +22,7 @@ use super::{WeierstrassAir, WeierstrassChip, WeierstrassStep}; // Assumes that (x1, y1), (x2, y2) both lie on the curve and are not the identity point. // Further assumes that x1, x2 are not equal in the coordinate field. -pub fn ec_add_ne_expr( +pub fn sw_add_ne_expr( config: ExprBuilderConfig, // The coordinate field. range_bus: VariableRangeCheckerBus, ) -> FieldExpr { @@ -50,12 +50,12 @@ pub fn ec_add_ne_expr( /// input AffinePoint, BLOCKS = 6. For secp256k1, BLOCK_SIZE = 32, BLOCKS = 2. #[derive(Chip, ChipUsageGetter, InstructionExecutor, InsExecutorE1, InsExecutorE2)] -pub struct EcAddNeChip( +pub struct SwAddNeChip( pub WeierstrassChip, ); impl - EcAddNeChip + SwAddNeChip { #[allow(clippy::too_many_arguments)] pub fn new( @@ -68,11 +68,11 @@ impl bitwise_lookup_chip: SharedBitwiseOperationLookupChip, range_checker: SharedVariableRangeCheckerChip, ) -> Self { - let expr = ec_add_ne_expr(config, range_checker.bus()); + let expr = sw_add_ne_expr(config, range_checker.bus()); let local_opcode_idx = vec![ - Rv32WeierstrassOpcode::EC_ADD_NE as usize, - Rv32WeierstrassOpcode::SETUP_EC_ADD_NE as usize, + Rv32WeierstrassOpcode::SW_ADD_NE as usize, + Rv32WeierstrassOpcode::SETUP_SW_ADD_NE as usize, ]; let air = WeierstrassAir::new( diff --git a/extensions/ecc/circuit/src/weierstrass_chip/double.rs b/extensions/ecc/circuit/src/weierstrass_chip/double.rs index b804ba8931..eeef7fc2ea 100644 --- a/extensions/ecc/circuit/src/weierstrass_chip/double.rs +++ b/extensions/ecc/circuit/src/weierstrass_chip/double.rs @@ -22,7 +22,7 @@ use openvm_stark_backend::p3_field::PrimeField32; use super::{WeierstrassAir, WeierstrassChip, WeierstrassStep}; -pub fn ec_double_ne_expr( +pub fn sw_double_ne_expr( config: ExprBuilderConfig, // The coordinate field. range_bus: VariableRangeCheckerBus, a_biguint: BigUint, @@ -58,12 +58,12 @@ pub fn ec_double_ne_expr( /// input AffinePoint, BLOCKS = 6. For secp256k1, BLOCK_SIZE = 32, BLOCKS = 2. #[derive(Chip, ChipUsageGetter, InstructionExecutor, InsExecutorE1, InsExecutorE2)] -pub struct EcDoubleChip( +pub struct SwDoubleChip( pub WeierstrassChip, ); impl - EcDoubleChip + SwDoubleChip { #[allow(clippy::too_many_arguments)] pub fn new( @@ -77,11 +77,11 @@ impl range_checker: SharedVariableRangeCheckerChip, a_biguint: BigUint, ) -> Self { - let expr = ec_double_ne_expr(config, range_checker.bus(), a_biguint); + let expr = sw_double_ne_expr(config, range_checker.bus(), a_biguint); let local_opcode_idx = vec![ - Rv32WeierstrassOpcode::EC_DOUBLE as usize, - Rv32WeierstrassOpcode::SETUP_EC_DOUBLE as usize, + Rv32WeierstrassOpcode::SW_DOUBLE as usize, + Rv32WeierstrassOpcode::SETUP_SW_DOUBLE as usize, ]; let air = WeierstrassAir::new( diff --git a/extensions/ecc/circuit/src/weierstrass_chip/tests.rs b/extensions/ecc/circuit/src/weierstrass_chip/tests.rs index 99051550c8..e16abe9106 100644 --- a/extensions/ecc/circuit/src/weierstrass_chip/tests.rs +++ b/extensions/ecc/circuit/src/weierstrass_chip/tests.rs @@ -14,7 +14,7 @@ use openvm_rv32_adapters::rv32_write_heap_default; use openvm_stark_backend::p3_field::FieldAlgebra; use openvm_stark_sdk::p3_baby_bear::BabyBear; -use super::{EcAddNeChip, EcDoubleChip}; +use super::{SwAddNeChip, SwDoubleChip}; const NUM_LIMBS: usize = 32; const LIMB_BITS: usize = 8; @@ -89,7 +89,7 @@ fn test_add_ne() { let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let mut chip = EcAddNeChip::::new( + let mut chip = SwAddNeChip::::new( tester.execution_bridge(), tester.memory_bridge(), tester.memory_helper(), @@ -132,7 +132,7 @@ fn test_add_ne() { &mut tester, vec![prime_limbs, one_limbs], // inputs[0] = prime, others doesn't matter vec![one_limbs, one_limbs], - chip.0.step.0.offset + Rv32WeierstrassOpcode::SETUP_EC_ADD_NE as usize, + chip.0.step.0.offset + Rv32WeierstrassOpcode::SETUP_SW_ADD_NE as usize, ); tester.execute(&mut chip, &setup_instruction); @@ -140,7 +140,7 @@ fn test_add_ne() { &mut tester, vec![p1_x_limbs, p1_y_limbs], vec![p2_x_limbs, p2_y_limbs], - chip.0.step.0.offset + Rv32WeierstrassOpcode::EC_ADD_NE as usize, + chip.0.step.0.offset + Rv32WeierstrassOpcode::SW_ADD_NE as usize, ); tester.execute(&mut chip, &instruction); @@ -160,7 +160,7 @@ fn test_double() { }; let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let mut chip = EcDoubleChip::::new( + let mut chip = SwDoubleChip::::new( tester.execution_bridge(), tester.memory_bridge(), tester.memory_helper(), @@ -193,7 +193,7 @@ fn test_double() { vec![prime_limbs, a_limbs], /* inputs[0] = prime, inputs[1] = a coeff of weierstrass * equation */ vec![], - chip.0.step.0.offset + Rv32WeierstrassOpcode::SETUP_EC_DOUBLE as usize, + chip.0.step.0.offset + Rv32WeierstrassOpcode::SETUP_SW_DOUBLE as usize, ); tester.execute(&mut chip, &setup_instruction); @@ -201,7 +201,7 @@ fn test_double() { &mut tester, vec![p1_x_limbs, p1_y_limbs], vec![], - chip.0.step.0.offset + Rv32WeierstrassOpcode::EC_DOUBLE as usize, + chip.0.step.0.offset + Rv32WeierstrassOpcode::SW_DOUBLE as usize, ); tester.execute(&mut chip, &instruction); @@ -226,7 +226,7 @@ fn test_p256_double() { let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let mut chip = EcDoubleChip::::new( + let mut chip = SwDoubleChip::::new( tester.execution_bridge(), tester.memory_bridge(), tester.memory_helper(), @@ -280,7 +280,7 @@ fn test_p256_double() { vec![prime_limbs, a_limbs], /* inputs[0] = prime, inputs[1] = a coeff of weierstrass * equation */ vec![], - chip.0.step.0.offset + Rv32WeierstrassOpcode::SETUP_EC_DOUBLE as usize, + chip.0.step.0.offset + Rv32WeierstrassOpcode::SETUP_SW_DOUBLE as usize, ); tester.execute(&mut chip, &setup_instruction); @@ -288,7 +288,7 @@ fn test_p256_double() { &mut tester, vec![p1_x_limbs, p1_y_limbs], vec![], - chip.0.step.0.offset + Rv32WeierstrassOpcode::EC_DOUBLE as usize, + chip.0.step.0.offset + Rv32WeierstrassOpcode::SW_DOUBLE as usize, ); tester.execute(&mut chip, &instruction); diff --git a/extensions/ecc/circuit/src/weierstrass_extension.rs b/extensions/ecc/circuit/src/weierstrass_extension.rs deleted file mode 100644 index a0a41fbe18..0000000000 --- a/extensions/ecc/circuit/src/weierstrass_extension.rs +++ /dev/null @@ -1,254 +0,0 @@ -use derive_more::derive::From; -use hex_literal::hex; -use lazy_static::lazy_static; -use num_bigint::BigUint; -use num_traits::{FromPrimitive, Zero}; -use once_cell::sync::Lazy; -use openvm_circuit::{ - arch::{ - ExecutionBridge, SystemPort, VmExtension, VmInventory, VmInventoryBuilder, VmInventoryError, - }, - system::phantom::PhantomChip, -}; -use openvm_circuit_derive::{AnyEnum, InsExecutorE1, InsExecutorE2, InstructionExecutor}; -use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, -}; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; -use openvm_ecc_transpiler::Rv32WeierstrassOpcode; -use openvm_instructions::{LocalOpcode, VmOpcode}; -use openvm_mod_circuit_builder::ExprBuilderConfig; -use openvm_stark_backend::p3_field::PrimeField32; -use serde::{Deserialize, Serialize}; -use serde_with::{serde_as, DisplayFromStr}; -use strum::EnumCount; - -use super::{EcAddNeChip, EcDoubleChip}; - -// TODO: this should be decided after e2 execution - -#[serde_as] -#[derive(Clone, Debug, derive_new::new, Serialize, Deserialize)] -pub struct CurveConfig { - /// The name of the curve struct as defined by moduli_declare. - pub struct_name: String, - /// The coordinate modulus of the curve. - #[serde_as(as = "DisplayFromStr")] - pub modulus: BigUint, - /// The scalar field modulus of the curve. - #[serde_as(as = "DisplayFromStr")] - pub scalar: BigUint, - /// The coefficient a of y^2 = x^3 + ax + b. - #[serde_as(as = "DisplayFromStr")] - pub a: BigUint, - /// The coefficient b of y^2 = x^3 + ax + b. - #[serde_as(as = "DisplayFromStr")] - pub b: BigUint, -} - -pub static SECP256K1_CONFIG: Lazy = Lazy::new(|| CurveConfig { - struct_name: SECP256K1_ECC_STRUCT_NAME.to_string(), - modulus: SECP256K1_MODULUS.clone(), - scalar: SECP256K1_ORDER.clone(), - a: BigUint::zero(), - b: BigUint::from_u8(7u8).unwrap(), -}); - -pub static P256_CONFIG: Lazy = Lazy::new(|| CurveConfig { - struct_name: P256_ECC_STRUCT_NAME.to_string(), - modulus: P256_MODULUS.clone(), - scalar: P256_ORDER.clone(), - a: BigUint::from_bytes_le(&P256_A), - b: BigUint::from_bytes_le(&P256_B), -}); - -#[derive(Clone, Debug, derive_new::new, Serialize, Deserialize)] -pub struct WeierstrassExtension { - pub supported_curves: Vec, -} - -impl WeierstrassExtension { - pub fn generate_sw_init(&self) -> String { - let supported_curves = self - .supported_curves - .iter() - .map(|curve_config| curve_config.struct_name.to_string()) - .collect::>() - .join(", "); - - format!("openvm_ecc_guest::sw_macros::sw_init! {{ {supported_curves} }}") - } -} - -#[derive(Chip, ChipUsageGetter, InstructionExecutor, AnyEnum, InsExecutorE1, InsExecutorE2)] -pub enum WeierstrassExtensionExecutor { - // 32 limbs prime - EcAddNeRv32_32(EcAddNeChip), - EcDoubleRv32_32(EcDoubleChip), - // 48 limbs prime - EcAddNeRv32_48(EcAddNeChip), - EcDoubleRv32_48(EcDoubleChip), -} - -#[derive(ChipUsageGetter, Chip, AnyEnum, From)] -pub enum WeierstrassExtensionPeriphery { - BitwiseOperationLookup(SharedBitwiseOperationLookupChip<8>), - Phantom(PhantomChip), -} - -impl VmExtension for WeierstrassExtension { - type Executor = WeierstrassExtensionExecutor; - type Periphery = WeierstrassExtensionPeriphery; - - fn build( - &self, - builder: &mut VmInventoryBuilder, - ) -> Result, VmInventoryError> { - let mut inventory = VmInventory::new(); - let SystemPort { - execution_bus, - program_bus, - memory_bridge, - } = builder.system_port(); - - let execution_bridge = ExecutionBridge::new(execution_bus, program_bus); - let range_checker = builder.system_base().range_checker_chip.clone(); - let pointer_max_bits = builder.system_config().memory_config.pointer_max_bits; - - let bitwise_lu_chip = if let Some(&chip) = builder - .find_chip::>() - .first() - { - chip.clone() - } else { - let bitwise_lu_bus = BitwiseOperationLookupBus::new(builder.new_bus_idx()); - let chip = SharedBitwiseOperationLookupChip::new(bitwise_lu_bus); - inventory.add_periphery_chip(chip.clone()); - chip - }; - - let ec_add_ne_opcodes = (Rv32WeierstrassOpcode::EC_ADD_NE as usize) - ..=(Rv32WeierstrassOpcode::SETUP_EC_ADD_NE as usize); - let ec_double_opcodes = (Rv32WeierstrassOpcode::EC_DOUBLE as usize) - ..=(Rv32WeierstrassOpcode::SETUP_EC_DOUBLE as usize); - - for (i, curve) in self.supported_curves.iter().enumerate() { - let start_offset = - Rv32WeierstrassOpcode::CLASS_OFFSET + i * Rv32WeierstrassOpcode::COUNT; - let bytes = curve.modulus.bits().div_ceil(8); - let config32 = ExprBuilderConfig { - modulus: curve.modulus.clone(), - num_limbs: 32, - limb_bits: 8, - }; - let config48 = ExprBuilderConfig { - modulus: curve.modulus.clone(), - num_limbs: 48, - limb_bits: 8, - }; - if bytes <= 32 { - let add_ne_chip = EcAddNeChip::new( - execution_bridge, - memory_bridge, - builder.system_base().memory_controller.helper(), - pointer_max_bits, - config32.clone(), - start_offset, - bitwise_lu_chip.clone(), - range_checker.clone(), - ); - - inventory.add_executor( - WeierstrassExtensionExecutor::EcAddNeRv32_32(add_ne_chip), - ec_add_ne_opcodes - .clone() - .map(|x| VmOpcode::from_usize(x + start_offset)), - )?; - let double_chip = EcDoubleChip::new( - execution_bridge, - memory_bridge, - builder.system_base().memory_controller.helper(), - pointer_max_bits, - config32.clone(), - start_offset, - bitwise_lu_chip.clone(), - range_checker.clone(), - curve.a.clone(), - ); - inventory.add_executor( - WeierstrassExtensionExecutor::EcDoubleRv32_32(double_chip), - ec_double_opcodes - .clone() - .map(|x| VmOpcode::from_usize(x + start_offset)), - )?; - } else if bytes <= 48 { - let add_ne_chip = EcAddNeChip::new( - execution_bridge, - memory_bridge, - builder.system_base().memory_controller.helper(), - pointer_max_bits, - config48.clone(), - start_offset, - bitwise_lu_chip.clone(), - range_checker.clone(), - ); - - inventory.add_executor( - WeierstrassExtensionExecutor::EcAddNeRv32_48(add_ne_chip), - ec_add_ne_opcodes - .clone() - .map(|x| VmOpcode::from_usize(x + start_offset)), - )?; - let double_chip = EcDoubleChip::new( - execution_bridge, - memory_bridge, - builder.system_base().memory_controller.helper(), - pointer_max_bits, - config48.clone(), - start_offset, - bitwise_lu_chip.clone(), - range_checker.clone(), - curve.a.clone(), - ); - inventory.add_executor( - WeierstrassExtensionExecutor::EcDoubleRv32_48(double_chip), - ec_double_opcodes - .clone() - .map(|x| VmOpcode::from_usize(x + start_offset)), - )?; - } else { - panic!("Modulus too large"); - } - } - - Ok(inventory) - } -} - -// Convenience constants for constructors -lazy_static! { - // The constants are taken from: https://en.bitcoin.it/wiki/Secp256k1 - pub static ref SECP256K1_MODULUS: BigUint = BigUint::from_bytes_be(&hex!( - "FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFE FFFFFC2F" - )); - pub static ref SECP256K1_ORDER: BigUint = BigUint::from_bytes_be(&hex!( - "FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFE BAAEDCE6 AF48A03B BFD25E8C D0364141" - )); -} - -lazy_static! { - // The constants are taken from: https://neuromancer.sk/std/secg/secp256r1 - pub static ref P256_MODULUS: BigUint = BigUint::from_bytes_be(&hex!( - "ffffffff00000001000000000000000000000000ffffffffffffffffffffffff" - )); - pub static ref P256_ORDER: BigUint = BigUint::from_bytes_be(&hex!( - "ffffffff00000000ffffffffffffffffbce6faada7179e84f3b9cac2fc632551" - )); -} -// little-endian -const P256_A: [u8; 32] = hex!("fcffffffffffffffffffffff00000000000000000000000001000000ffffffff"); -// little-endian -const P256_B: [u8; 32] = hex!("4b60d2273e3cce3bf6b053ccb0061d65bc86987655bdebb3e7933aaad835c65a"); - -pub const SECP256K1_ECC_STRUCT_NAME: &str = "Secp256k1Point"; -pub const P256_ECC_STRUCT_NAME: &str = "P256Point"; diff --git a/extensions/ecc/guest/Cargo.toml b/extensions/ecc/guest/Cargo.toml index e5251eb366..aa3bd000e6 100644 --- a/extensions/ecc/guest/Cargo.toml +++ b/extensions/ecc/guest/Cargo.toml @@ -16,15 +16,24 @@ elliptic-curve = { workspace = true, features = ["arithmetic", "sec1"] } openvm-custom-insn = { workspace = true } openvm-rv32im-guest = { workspace = true } openvm-algebra-guest = { workspace = true } +openvm-algebra-moduli-macros = { workspace = true } openvm-ecc-sw-macros = { workspace = true } +openvm-ecc-te-macros = { workspace = true } once_cell = { workspace = true, features = ["race", "alloc"] } +num-bigint = { workspace = true } +hex-literal = { workspace = true } +openvm-sha2 = { workspace = true } # Used for `halo2curves` feature halo2curves-axiom = { workspace = true, optional = true } group = "0.13.0" +[target.'cfg(not(target_os = "zkvm"))'.dependencies] +lazy_static = { workspace = true } + [features] default = [] +ed25519 = [] halo2curves = ["dep:halo2curves-axiom", "openvm-algebra-guest/halo2curves"] std = ["alloc"] alloc = [] diff --git a/extensions/ecc/guest/src/ecdsa.rs b/extensions/ecc/guest/src/ecdsa.rs index 07fc6d44fc..7c60e575d3 100644 --- a/extensions/ecc/guest/src/ecdsa.rs +++ b/extensions/ecc/guest/src/ecdsa.rs @@ -20,10 +20,7 @@ use elliptic_curve::{ }; use openvm_algebra_guest::{DivUnsafe, IntMod, Reduce}; -use crate::{ - weierstrass::{FromCompressed, IntrinsicCurve, WeierstrassPoint}, - CyclicGroup, Group, -}; +use crate::{weierstrass::WeierstrassPoint, CyclicGroup, FromCompressed, Group, IntrinsicCurve}; type Coordinate = <::Point as WeierstrassPoint>::Coordinate; type Scalar = ::Scalar; diff --git a/extensions/ecc/guest/src/ed25519.rs b/extensions/ecc/guest/src/ed25519.rs new file mode 100644 index 0000000000..32f48cba58 --- /dev/null +++ b/extensions/ecc/guest/src/ed25519.rs @@ -0,0 +1,85 @@ +use core::ops::Add; + +use hex_literal::hex; +#[cfg(not(target_os = "zkvm"))] +use lazy_static::lazy_static; +#[cfg(not(target_os = "zkvm"))] +use num_bigint::BigUint; +use openvm_algebra_guest::IntMod; + +use super::group::{CyclicGroup, Group}; +use crate::IntrinsicCurve; + +#[cfg(not(target_os = "zkvm"))] +lazy_static! { + pub static ref ED25519_MODULUS: BigUint = BigUint::from_bytes_be(&hex!( + "7FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFED" + )); + pub static ref ED25519_ORDER: BigUint = BigUint::from_bytes_be(&hex!( + "1000000000000000000000000000000014DEF9DEA2F79CD65812631A5CF5D3ED" + )); + pub static ref ED25519_A: BigUint = BigUint::from_bytes_be(&hex!( + "7FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEC" + )); + pub static ref ED25519_D: BigUint = BigUint::from_bytes_be(&hex!( + "52036CEE2B6FFE738CC740797779E89800700A4D4141D8AB75EB4DCA135978A3" + )); +} + +openvm_algebra_moduli_macros::moduli_declare! { + Ed25519Coord { modulus = "0x7FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFED" }, + Ed25519Scalar { modulus = "0x1000000000000000000000000000000014DEF9DEA2F79CD65812631A5CF5D3ED" }, +} + +pub const ED25519_NUM_LIMBS: usize = 32; +pub const ED25519_LIMB_BITS: usize = 8; +pub const ED25519_BLOCK_SIZE: usize = 32; +// from_const_bytes is little endian +pub const CURVE_A: Ed25519Coord = Ed25519Coord::from_const_bytes(hex!( + "ECFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF7F" +)); +pub const CURVE_D: Ed25519Coord = Ed25519Coord::from_const_bytes(hex!( + "A3785913CA4DEB75ABD841414D0A700098E879777940C78C73FE6F2BEE6C0352" +)); + +openvm_ecc_te_macros::te_declare! { + Ed25519Point { mod_type = Ed25519Coord, a = CURVE_A, d = CURVE_D }, +} + +impl CyclicGroup for Ed25519Point { + // from_const_bytes is little endian + const GENERATOR: Self = Ed25519Point { + x: Ed25519Coord::from_const_bytes(hex!( + "1AD5258F602D56C9B2A7259560C72C695CDCD6FD31E2A4C0FE536ECDD3366921" + )), + y: Ed25519Coord::from_const_bytes(hex!( + "5866666666666666666666666666666666666666666666666666666666666666" + )), + }; + const NEG_GENERATOR: Self = Ed25519Point { + x: Ed25519Coord::from_const_bytes([ + 211, 42, 218, 112, 159, 210, 169, 54, 77, 88, 218, 106, 159, 56, 211, 150, 163, 35, 41, + 2, 206, 29, 91, 63, 1, 172, 145, 50, 44, 201, 150, 94, + ]), + y: Ed25519Coord::from_const_bytes(hex!( + "5866666666666666666666666666666666666666666666666666666666666666" + )), + }; +} + +impl IntrinsicCurve for Ed25519Point { + type Scalar = Ed25519Scalar; + type Point = Ed25519Point; + + fn msm(coeffs: &[Self::Scalar], bases: &[Self::Point]) -> Self::Point + where + for<'a> &'a Self::Point: Add<&'a Self::Point, Output = Self::Point>, + { + if coeffs.len() < 25 { + let table = crate::edwards::CachedMulTable::::new(bases, 4); + table.windowed_mul(coeffs) + } else { + crate::msm(coeffs, bases) + } + } +} diff --git a/extensions/ecc/guest/src/eddsa.rs b/extensions/ecc/guest/src/eddsa.rs new file mode 100644 index 0000000000..17929d81d1 --- /dev/null +++ b/extensions/ecc/guest/src/eddsa.rs @@ -0,0 +1,204 @@ +// Implementation of the EdDSA signature verification algorithm. +// The code is generic over the twisted Edwards curve, but currently only instantiated with Ed25519. +// The implementation is based on the RFC: https://datatracker.ietf.org/doc/html/rfc8032 +// We support both the prehash variant (Ed25519ph) and the non-prehash variant (Ed25519). +// Note: our implementation is not intended to be safe against timing attacks. + +extern crate alloc; +use alloc::vec::Vec; + +use openvm_sha2::sha512; + +use crate::{ + algebra::{IntMod, Reduce}, + edwards::TwistedEdwardsPoint, + CyclicGroup, FromCompressed, IntrinsicCurve, +}; + +type Coordinate = <::Point as TwistedEdwardsPoint>::Coordinate; +type Scalar = ::Scalar; +type Point = ::Point; + +#[repr(C)] +#[derive(Clone)] +pub struct VerifyingKey { + /// Affine point + point: Point, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum VerificationError { + InvalidSignature, + InvalidContext, + FailedToVerify, +} + +impl VerifyingKey +where + Point: TwistedEdwardsPoint + FromCompressed> + CyclicGroup, + Coordinate: IntMod, + C::Scalar: IntMod + Reduce, +{ + /// Assumes the point is encoded as in https://datatracker.ietf.org/doc/html/rfc8032#section-5.1.2 + pub fn from_bytes(bytes: &[u8]) -> Option { + if bytes.len() != Coordinate::::NUM_LIMBS { + return None; + } + Some(Self { + point: decode_point::(bytes)?, + }) + } + + pub fn verify(&self, message: &[u8], sig: &[u8]) -> Result<(), VerificationError> { + self.verify_prehashed(message, sig, &[]) + } + + /// The verify function for the prehash variant of Ed25519. + /// message should be the message to be verified, before the prehash is applied. + /// context is the optional context bytes that are shared between a signer and verifier, as per + /// the Ed25519ph specification. If no context is provided, the empty slice will be used. + /// The context can be up to 255 bytes. + pub fn verify_ph( + &self, + message: &[u8], + context: Option<&[u8]>, + sig: &[u8], + ) -> Result<(), VerificationError> { + let prehash = sha512(message); + + // dom2(F, C) domain separator + // RFC reference: https://datatracker.ietf.org/doc/html/rfc8032#section-2 + // See definition of dom2 in the RFC. Note that the RFC refers to the prehash + // version of Ed25519 as Ed25519ph, and the non-prehash version as Ed25519. + let mut dom2 = Vec::new(); + dom2.extend_from_slice(b"SigEd25519 no Ed25519 collisions"); + dom2.push(1); // phflag = 1 + + // The RFC specifies optional "context" bytes that are shared between a signer and verifier. + // See: https://datatracker.ietf.org/doc/html/rfc8032#section-5.1 + if let Some(context) = context { + if context.len() > 255 { + return Err(VerificationError::InvalidContext); + } + dom2.push(context.len() as u8); + dom2.extend_from_slice(context); + } else { + dom2.push(0); // context len = 0 + } + + self.verify_prehashed(&prehash, sig, &dom2) + } + + // Shared verify function for both the prehash and non-prehash variants of Ed25519. + // prehash is either SHA512(message) or message, for Ed25519ph and Ed25519 respectively. + // dom2 is the domain separator for the Ed25519ph and Ed25519 variants. It should be empty for + // Ed25519. + // See RFC reference: https://datatracker.ietf.org/doc/html/rfc8032#section-2 + fn verify_prehashed( + &self, + prehash: &[u8], + sig: &[u8], + dom2: &[u8], + ) -> Result<(), VerificationError> { + let Some(sig) = Signature::::from_bytes(sig) else { + return Err(VerificationError::InvalidSignature); + }; + + // h = SHA512(dom2(F, C) || R || A || PH(M)) + // RFC reference: https://datatracker.ietf.org/doc/html/rfc8032#section-5.1.7 + let mut sha_input = Vec::new(); + + sha_input.extend_from_slice(dom2); + + sha_input.extend_from_slice(&encode_point::(&sig.r)); + sha_input.extend_from_slice(&encode_point::(&self.point)); + sha_input.extend_from_slice(prehash); + + let h = sha512(&sha_input); + + let h = C::Scalar::reduce_le_bytes(&h); + + // assert s * B = R + h * A + // <=> R + h * A - s * B = 0 + // <=> [1, h, s] * [R, A, -B] = 0 + let res = C::msm( + &[C::Scalar::ONE, h, sig.s], + &[ + sig.r, + self.point.clone(), + as CyclicGroup>::NEG_GENERATOR, + ], + ); + if res == as TwistedEdwardsPoint>::IDENTITY { + Ok(()) + } else { + Err(VerificationError::FailedToVerify) + } + } +} + +// Internal struct used for decoding the signature from bytes +struct Signature { + r: C::Point, + s: C::Scalar, +} + +impl Signature +where + C::Point: TwistedEdwardsPoint + FromCompressed>, + Coordinate: IntMod, + C::Scalar: IntMod, +{ + pub fn from_bytes(bytes: &[u8]) -> Option { + if bytes.len() != Coordinate::::NUM_LIMBS + Scalar::::NUM_LIMBS { + return None; + } + // from_le_bytes checks that s is reduced + let s = Scalar::::from_le_bytes(&bytes[Coordinate::::NUM_LIMBS..])?; + Some(Self { + r: decode_point::(&bytes[..Coordinate::::NUM_LIMBS])?, + s, + }) + } +} + +/// RFC reference: https://datatracker.ietf.org/doc/html/rfc8032#section-5.1.3 +/// We require that the most significant bit in the little-endian encoding of +/// elements of the coordinate field is always 0, because we pack the parity +/// of the x-coordinate there. +fn decode_point(bytes: &[u8]) -> Option> +where + Point: TwistedEdwardsPoint + FromCompressed>, + Coordinate: IntMod, +{ + if bytes.len() != Coordinate::::NUM_LIMBS { + return None; + } + let mut y_bytes = bytes.to_vec(); + // most significant bit stores the parity of the x-coordinate + let rec_id = (y_bytes[Coordinate::::NUM_LIMBS - 1] & 0b10000000) >> 7; + y_bytes[Coordinate::::NUM_LIMBS - 1] &= 0b01111111; + // from_le_bytes checks that y is reduced + let y = Coordinate::::from_le_bytes(&y_bytes)?; + Point::::decompress(y, &rec_id) +} + +/// RFC reference: https://datatracker.ietf.org/doc/html/rfc8032#section-5.1.2 +/// We require that the most significant bit in the little-endian encoding of +/// elements of the coordinate field is always 0, because we pack the parity +/// of the x-coordinate there. +fn encode_point(p: &Point) -> Vec +where + Point: TwistedEdwardsPoint, + Coordinate: IntMod, +{ + let mut y_bytes = p.y().as_le_bytes().to_vec(); + if p.x().as_le_bytes()[0] & 1u8 == 1 { + // We pack the parity of the x-coordinate in the most significant bit of the last byte, as + // per the Ed25519 spec, so the Coordinate type must have enough limbs so that the most + // significant bit of the last byte is always 0. + debug_assert!(y_bytes[Coordinate::::NUM_LIMBS - 1] & 0b10000000 == 0); + y_bytes[Coordinate::::NUM_LIMBS - 1] |= 0b10000000; + } + y_bytes +} diff --git a/extensions/ecc/guest/src/edwards.rs b/extensions/ecc/guest/src/edwards.rs new file mode 100644 index 0000000000..e9089e49d8 --- /dev/null +++ b/extensions/ecc/guest/src/edwards.rs @@ -0,0 +1,396 @@ +use alloc::vec::Vec; +use core::ops::{AddAssign, Mul}; + +use openvm_algebra_guest::{Field, IntMod}; + +use crate::{Group, IntrinsicCurve}; + +pub trait TwistedEdwardsPoint: Sized { + /// The `a` coefficient in the twisted Edwards curve equation `ax^2 + y^2 = 1 + d x^2 y^2`. + const CURVE_A: Self::Coordinate; + /// The `d` coefficient in the twisted Edwards curve equation `ax^2 + y^2 = 1 + d x^2 y^2`. + const CURVE_D: Self::Coordinate; + const IDENTITY: Self; + + type Coordinate: Field; + + /// The concatenated `x, y` coordinates of the affine point, where + /// coordinates are in little endian. + /// + /// **Warning**: The memory layout of `Self` is expected to pack + /// `x` and `y` contigously with no unallocated space in between. + fn as_le_bytes(&self) -> &[u8]; + + /// Raw constructor without asserting point is on the curve. + fn from_xy_unchecked(x: Self::Coordinate, y: Self::Coordinate) -> Self; + fn into_coords(self) -> (Self::Coordinate, Self::Coordinate); + fn x(&self) -> &Self::Coordinate; + fn y(&self) -> &Self::Coordinate; + fn x_mut(&mut self) -> &mut Self::Coordinate; + fn y_mut(&mut self) -> &mut Self::Coordinate; + + fn add_impl(&self, p2: &Self) -> Self; + + #[inline(always)] + fn from_xy(x: Self::Coordinate, y: Self::Coordinate) -> Option + where + for<'a> &'a Self::Coordinate: Mul<&'a Self::Coordinate, Output = Self::Coordinate>, + { + let lhs = Self::CURVE_A * &x * &x + &y * &y; + let rhs = Self::CURVE_D * &x * &x * &y * &y + &Self::Coordinate::ONE; + if lhs != rhs { + return None; + } + Some(Self::from_xy_unchecked(x, y)) + } +} + +/// Macro to generate a newtype wrapper for [AffinePoint](crate::AffinePoint) +/// that implements elliptic curve operations by using the underlying field operations according to +/// the [formulas](https://en.wikipedia.org/wiki/Twisted_Edwards_curve) for twisted Edwards curves. +/// +/// The following imports are required: +/// ```rust +/// use core::ops::AddAssign; +/// +/// use openvm_algebra_guest::{DivUnsafe, Field}; +/// use openvm_ecc_guest::{edwards::TwistedEdwardsPoint, AffinePoint, Group}; +/// ``` +#[macro_export] +macro_rules! impl_te_affine { + ($struct_name:ident, $field:ty, $a:expr, $d:expr) => { + /// A newtype wrapper for [AffinePoint] that implements elliptic curve operations + /// by using the underlying field operations according to the [formulas](https://en.wikipedia.org/wiki/Twisted_Edwards_curve) for twisted Edwards curves. + #[derive(Debug, Clone, serde::Serialize, serde::Deserialize, PartialEq, Eq)] + #[repr(transparent)] + pub struct $struct_name(AffinePoint<$field>); + + impl TwistedEdwardsPoint for $struct_name { + const CURVE_A: $field = $a; + const CURVE_D: $field = $d; + const IDENTITY: Self = Self(AffinePoint::new(<$field>::ZERO, <$field>::ONE)); + + type Coordinate = $field; + + /// SAFETY: assumes that [$field] has internal representation in little-endian. + fn as_le_bytes(&self) -> &[u8] { + unsafe { + &*core::ptr::slice_from_raw_parts( + self as *const Self as *const u8, + core::mem::size_of::(), + ) + } + } + fn from_xy_unchecked(x: Self::Coordinate, y: Self::Coordinate) -> Self { + Self(AffinePoint::new(x, y)) + } + fn into_coords(self) -> (Self::Coordinate, Self::Coordinate) { + (self.0.x, self.0.y) + } + fn x(&self) -> &Self::Coordinate { + &self.0.x + } + fn y(&self) -> &Self::Coordinate { + &self.0.y + } + fn x_mut(&mut self) -> &mut Self::Coordinate { + &mut self.0.x + } + fn y_mut(&mut self) -> &mut Self::Coordinate { + &mut self.0.y + } + + fn add_impl(&self, p2: &Self) -> Self { + use ::openvm_algebra_guest::DivUnsafe; + // For twisted Edwards curves: + // x3 = (x1*y2 + y1*x2)/(1 + d*x1*x2*y1*y2) + // y3 = (y1*y2 - a*x1*x2)/(1 - d*x1*x2*y1*y2) + let x1y2 = self.x() * p2.y(); + let y1x2 = self.y() * p2.x(); + let x1x2 = self.x() * p2.x(); + let y1y2 = self.y() * p2.y(); + let dx1x2y1y2 = Self::CURVE_D * x1x2 * y1y2; + + let x3 = (x1y2 + y1x2).div_unsafe(&(Self::Coordinate::ONE + dx1x2y1y2)); + let y3 = (y1y2 - Self::CURVE_A * x1x2).div_unsafe(&(Self::Coordinate::ONE - dx1x2y1y2)); + + Self(AffinePoint::new(x3, y3)) + } + + impl core::ops::Neg for $struct_name { + type Output = Self; + + fn neg(mut self) -> Self::Output { + self.0.x.neg_assign(); + self + } + } + + impl core::ops::Neg for &$struct_name { + type Output = $struct_name; + + fn neg(self) -> Self::Output { + self.clone().neg() + } + } + + impl From<$struct_name> for AffinePoint<$field> { + fn from(value: $struct_name) -> Self { + value.0 + } + } + + impl From> for $struct_name { + fn from(value: AffinePoint<$field>) -> Self { + Self(value) + } + } + } + } +} + +/// Implements `Group` on `$struct_name` assuming that `$struct_name` implements +/// `TwistedEdwardsPoint`. Assumes that `Neg` is implemented for `&$struct_name`. +#[macro_export] +macro_rules! impl_te_group_ops { + ($struct_name:ident, $field:ty) => { + impl Group for $struct_name { + type SelfRef<'a> = &'a Self; + + const IDENTITY: Self = ::IDENTITY; + + fn double(&self) -> Self { + if self.is_identity() { + self.clone() + } else { + self.add_impl(self) + } + } + + fn double_assign(&mut self) { + if !self.is_identity() { + *self = self.add_impl(self) + } + } + + // Note: It was found that implementing `is_identity` in group.rs as a default + // implementation increases the cycle count by 50% on the ecrecover benchmark. For + // this reason, we implement it here instead. We hypothesize that this is due to + // compiler optimizations that are not possible when the `is_identity` function is + // defined in a different source file. + #[inline(always)] + fn is_identity(&self) -> bool { + self == &::IDENTITY + } + } + + impl core::ops::Add<&$struct_name> for $struct_name { + type Output = Self; + + fn add(mut self, p2: &$struct_name) -> Self::Output { + use core::ops::AddAssign; + self.add_assign(p2); + self + } + } + + impl core::ops::Add for $struct_name { + type Output = Self; + + fn add(self, rhs: Self) -> Self::Output { + self.add(&rhs) + } + } + + impl core::ops::Add<&$struct_name> for &$struct_name { + type Output = $struct_name; + + fn add(self, p2: &$struct_name) -> Self::Output { + if self.is_identity() { + p2.clone() + } else if p2.is_identity() { + self.clone() + } else if self.x() + p2.x() == <$field as openvm_algebra_guest::Field>::ZERO + && self.y() == p2.y() + { + <$struct_name as TwistedEdwardsPoint>::IDENTITY + } else { + self.add_impl(p2) + } + } + } + + impl core::ops::AddAssign<&$struct_name> for $struct_name { + fn add_assign(&mut self, p2: &$struct_name) { + if self.is_identity() { + *self = p2.clone(); + } else if p2.is_identity() { + // do nothing + } else if self.x() + p2.x() == <$field as openvm_algebra_guest::Field>::ZERO + && self.y() == p2.y() + { + *self = <$struct_name as TwistedEdwardsPoint>::IDENTITY; + } else { + *self = self.add_impl(p2); + } + } + } + + impl core::ops::AddAssign for $struct_name { + fn add_assign(&mut self, rhs: Self) { + self.add_assign(&rhs); + } + } + + impl core::ops::Sub<&$struct_name> for $struct_name { + type Output = Self; + + fn sub(self, rhs: &$struct_name) -> Self::Output { + core::ops::Sub::sub(&self, rhs) + } + } + + impl core::ops::Sub for $struct_name { + type Output = $struct_name; + + fn sub(self, rhs: Self) -> Self::Output { + self.sub(&rhs) + } + } + + impl core::ops::Sub<&$struct_name> for &$struct_name { + type Output = $struct_name; + + fn sub(self, p2: &$struct_name) -> Self::Output { + use core::ops::Add; + self.add(&-p2) + } + } + + impl core::ops::SubAssign<&$struct_name> for $struct_name { + fn sub_assign(&mut self, p2: &$struct_name) { + use core::ops::AddAssign; + self.add_assign(-p2); + } + } + + impl core::ops::SubAssign for $struct_name { + fn sub_assign(&mut self, rhs: Self) { + self.sub_assign(&rhs); + } + } + }; +} + +// This is the same as the Weierstrass version, but for Edwards curves we use +// TwistedEdwardsPoint::add_impl instead of WeierstrassPoint::add_ne_nonidentity, etc. +// Unlike the Weierstrass version, we do not require the bases to have prime order, since our +// addition formulas are complete. + +// MSM using preprocessed table (windowed method) +// Reference: modified from https://github.com/arkworks-rs/algebra/blob/master/ec/src/scalar_mul/mod.rs + +/// Cached precomputations of scalar multiples of several base points. +/// - `window_bits` is the window size used for the precomputation +/// - `max_scalar_bits` is the maximum size of the scalars that will be multiplied +/// - `table` is the precomputed table +pub struct CachedMulTable<'a, C: IntrinsicCurve> { + /// Window bits. Must be > 0. + /// For alignment, we currently require this to divide 8 (bits in a byte). + pub window_bits: usize, + pub bases: &'a [C::Point], + /// `table[i][j] = (j + 2) * bases[i]` for `j + 2 < 2 ** window_bits` + table: Vec>, + /// Needed to return reference to the identity point. + identity: C::Point, +} + +impl<'a, C: IntrinsicCurve> CachedMulTable<'a, C> +where + C::Point: TwistedEdwardsPoint + Group, + C::Scalar: IntMod, +{ + pub fn new(bases: &'a [C::Point], window_bits: usize) -> Self { + assert!(window_bits > 0); + let window_size = 1 << window_bits; + let table = bases + .iter() + .map(|base| { + if base.is_identity() { + vec![::IDENTITY; window_size - 2] + } else { + let mut multiples = Vec::with_capacity(window_size - 2); + for _ in 0..window_size - 2 { + let multiple = multiples + .last() + .map(|last| TwistedEdwardsPoint::add_impl(last, base)) + .unwrap_or_else(|| base.double()); + multiples.push(multiple); + } + multiples + } + }) + .collect(); + + Self { + window_bits, + bases, + table, + identity: ::IDENTITY, + } + } + + fn get_multiple(&self, base_idx: usize, scalar: usize) -> &C::Point { + if scalar == 0 { + &self.identity + } else if scalar == 1 { + unsafe { self.bases.get_unchecked(base_idx) } + } else { + unsafe { self.table.get_unchecked(base_idx).get_unchecked(scalar - 2) } + } + } + + /// Computes `sum scalars[i] * bases[i]`. + /// + /// For implementation simplicity, currently only implemented when + /// `window_bits` divides 8 (number of bits in a byte). + pub fn windowed_mul(&self, scalars: &[C::Scalar]) -> C::Point { + assert_eq!(8 % self.window_bits, 0); + assert_eq!(scalars.len(), self.bases.len()); + let windows_per_byte = 8 / self.window_bits; + + let num_windows = C::Scalar::NUM_LIMBS * windows_per_byte; + let mask = (1u8 << self.window_bits) - 1; + + // The current byte index (little endian) at the current step of the + // windowed method, across all scalars. + let mut limb_idx = C::Scalar::NUM_LIMBS; + // The current bit (little endian) within the current byte of the windowed + // method. The window will look at bits `bit_idx..bit_idx + window_bits`. + // bit_idx will always be in range [0, 8) + let mut bit_idx = 0; + + let mut res = ::IDENTITY; + for outer in 0..num_windows { + if bit_idx == 0 { + limb_idx -= 1; + bit_idx = 8 - self.window_bits; + } else { + bit_idx -= self.window_bits; + } + + if outer != 0 { + for _ in 0..self.window_bits { + res.double_assign(); + } + } + for (base_idx, scalar) in scalars.iter().enumerate() { + let scalar = (scalar.as_le_bytes()[limb_idx] >> bit_idx) & mask; + let summand = self.get_multiple(base_idx, scalar as usize); + // handles identity + res.add_assign(summand); + } + } + res + } +} diff --git a/extensions/ecc/guest/src/lib.rs b/extensions/ecc/guest/src/lib.rs index c7a9851cfd..a98431b001 100644 --- a/extensions/ecc/guest/src/lib.rs +++ b/extensions/ecc/guest/src/lib.rs @@ -6,6 +6,7 @@ extern crate alloc; pub use once_cell; pub use openvm_algebra_guest as algebra; pub use openvm_ecc_sw_macros as sw_macros; +pub use openvm_ecc_te_macros as te_macros; use strum_macros::FromRepr; mod affine_point; @@ -17,11 +18,18 @@ pub use msm::*; /// Optimized ECDSA implementation with the same functional interface as the `ecdsa` crate pub mod ecdsa; +/// Optimized EDDSA implementation +pub mod eddsa; +/// Edwards curve traits +pub mod edwards; /// Weierstrass curve traits pub mod weierstrass; +#[cfg(feature = "ed25519")] +pub mod ed25519; + /// This is custom-1 defined in RISC-V spec document -pub const OPCODE: u8 = 0x2b; +pub const SW_OPCODE: u8 = 0x2b; pub const SW_FUNCT3: u8 = 0b001; /// Short Weierstrass curves are configurable. @@ -37,3 +45,46 @@ pub enum SwBaseFunct7 { impl SwBaseFunct7 { pub const SHORT_WEIERSTRASS_MAX_KINDS: u8 = 8; } + +/// This is custom-1 defined in RISC-V spec document +pub const TE_OPCODE: u8 = 0x2b; +pub const TE_FUNCT3: u8 = 0b100; + +#[derive(Debug, Copy, Clone, PartialEq, Eq, FromRepr)] +#[repr(u8)] +pub enum TeBaseFunct7 { + TeAdd = 0, + TeSetup, + TeHintDecompress, + TeHintNonQr, +} + +impl TeBaseFunct7 { + pub const TWISTED_EDWARDS_MAX_KINDS: u8 = 8; +} + +/// A trait for elliptic curves that bridges the openvm types and external types with +/// CurveArithmetic etc. Implement this for external curves with corresponding openvm point and +/// scalar types. +pub trait IntrinsicCurve { + type Scalar: Clone; + type Point: Clone; + + /// Multi-scalar multiplication. + /// The implementation may be specialized to use properties of the curve + /// (e.g., if the curve order is prime). + fn msm(coeffs: &[Self::Scalar], bases: &[Self::Point]) -> Self::Point; +} + +pub trait FromCompressed { + /// Given `x`-coordinate, + /// + /// Decompresses a point from its x-coordinate and a recovery identifier which indicates + /// the parity of the y-coordinate. Given the x-coordinate, this function attempts to find the + /// corresponding y-coordinate that satisfies the elliptic curve equation. If successful, it + /// returns the point as an instance of Self. If the point cannot be decompressed, it returns + /// None. + fn decompress(x: Coordinate, rec_id: &u8) -> Option + where + Self: core::marker::Sized; +} diff --git a/extensions/ecc/guest/src/weierstrass.rs b/extensions/ecc/guest/src/weierstrass.rs index 82d5468b04..3c39cbcfe6 100644 --- a/extensions/ecc/guest/src/weierstrass.rs +++ b/extensions/ecc/guest/src/weierstrass.rs @@ -4,6 +4,7 @@ use core::ops::Mul; use openvm_algebra_guest::{Field, IntMod}; use super::group::Group; +use crate::IntrinsicCurve; /// Short Weierstrass curve affine point. pub trait WeierstrassPoint: Clone + Sized { @@ -113,32 +114,6 @@ pub trait WeierstrassPoint: Clone + Sized { } } -pub trait FromCompressed { - /// Given `x`-coordinate, - /// - /// Decompresses a point from its x-coordinate and a recovery identifier which indicates - /// the parity of the y-coordinate. Given the x-coordinate, this function attempts to find the - /// corresponding y-coordinate that satisfies the elliptic curve equation. If successful, it - /// returns the point as an instance of Self. If the point cannot be decompressed, it returns - /// None. - fn decompress(x: Coordinate, rec_id: &u8) -> Option - where - Self: core::marker::Sized; -} - -/// A trait for elliptic curves that bridges the openvm types and external types with -/// CurveArithmetic etc. Implement this for external curves with corresponding openvm point and -/// scalar types. -pub trait IntrinsicCurve { - type Scalar: Clone; - type Point: Clone; - - /// Multi-scalar multiplication. - /// The implementation may be specialized to use properties of the curve - /// (e.g., if the curve order is prime). - fn msm(coeffs: &[Self::Scalar], bases: &[Self::Point]) -> Self::Point; -} - // MSM using preprocessed table (windowed method) // Reference: modified from https://github.com/arkworks-rs/algebra/blob/master/ec/src/scalar_mul/mod.rs // @@ -476,11 +451,11 @@ macro_rules! impl_sw_group_ops { self.double_assign_impl::(); } - // This implementation is the same as the default implementation in the `Group` trait, - // but it was found that overriding the default implementation reduced the cycle count - // by 50% on the ecrecover benchmark. - // We hypothesize that this is due to compiler optimizations that are not possible when - // the `is_identity` function is defined in a different source file. + // Note: It was found that implementing `is_identity` in group.rs as a default + // implementation increases the cycle count by 50% on the ecrecover benchmark. For + // this reason, we implement it here instead. We hypothesize that this is due to + // compiler optimizations that are not possible when the `is_identity` function is + // defined in a different source file. #[inline(always)] fn is_identity(&self) -> bool { self == &::IDENTITY diff --git a/extensions/ecc/sw-macros/README.md b/extensions/ecc/sw-macros/README.md index 71f8d553f4..fbb8989cd7 100644 --- a/extensions/ecc/sw-macros/README.md +++ b/extensions/ecc/sw-macros/README.md @@ -93,7 +93,7 @@ mod openvm_intrinsics_ffi_2 { 3. Again, if using the Rust bindings, then the `sw_setup_extern_func_*` function for every curve is automatically called on first use of any of the curve's intrinsics. -4. The order of the items in `sw_init!` **must match** the order of the moduli in the chip configuration -- more specifically, in the modular extension parameters (the order of `CurveConfig`s in `WeierstrassExtension::supported_curves`, which is usually defined with the whole `app_vm_config` in the `openvm.toml` file). +4. The order of the items in `sw_init!` **must match** the order of the moduli in the chip configuration -- more specifically, in the modular extension parameters (the order of `CurveConfig`s in `EccExtension::supported_sw_curves`, which is usually defined with the whole `app_vm_config` in the `openvm.toml` file). 5. Note that, due to the nature of function names, the name of the struct used in `sw_init!` must be the same as in `sw_declare!`. To illustrate, the following code will **fail** to compile: diff --git a/extensions/ecc/sw-macros/src/lib.rs b/extensions/ecc/sw-macros/src/lib.rs index 7af9e77daf..8163598558 100644 --- a/extensions/ecc/sw-macros/src/lib.rs +++ b/extensions/ecc/sw-macros/src/lib.rs @@ -214,7 +214,7 @@ pub fn sw_declare(input: TokenStream) -> TokenStream { use openvm_algebra_guest::IntMod; // Safety: Self::set_up_once() ensures IntMod::set_up_once() has been called. unsafe { - self.x.eq_impl::(&#intmod_type::ZERO) && self.y.eq_impl::(&#intmod_type::ZERO) + self.x.eq_impl::(&<#intmod_type as IntMod>::ZERO) && self.y.eq_impl::(&<#intmod_type as IntMod>::ZERO) } } } @@ -373,7 +373,7 @@ pub fn sw_declare(input: TokenStream) -> TokenStream { } mod #group_ops_mod_name { - use ::openvm_ecc_guest::{weierstrass::{WeierstrassPoint, FromCompressed}, impl_sw_group_ops, algebra::IntMod}; + use ::openvm_ecc_guest::{weierstrass::{WeierstrassPoint}, FromCompressed, impl_sw_group_ops, algebra::IntMod}; use super::*; impl_sw_group_ops!(#struct_name, #intmod_type); @@ -457,7 +457,7 @@ pub fn sw_init(input: TokenStream) -> TokenStream { #[no_mangle] extern "C" fn #add_ne_extern_func(rd: usize, rs1: usize, rs2: usize) { openvm::platform::custom_insn_r!( - opcode = OPCODE, + opcode = SW_OPCODE, funct3 = SW_FUNCT3 as usize, funct7 = SwBaseFunct7::SwAddNe as usize + #ec_idx * (SwBaseFunct7::SHORT_WEIERSTRASS_MAX_KINDS as usize), @@ -470,7 +470,7 @@ pub fn sw_init(input: TokenStream) -> TokenStream { #[no_mangle] extern "C" fn #double_extern_func(rd: usize, rs1: usize) { openvm::platform::custom_insn_r!( - opcode = OPCODE, + opcode = SW_OPCODE, funct3 = SW_FUNCT3 as usize, funct7 = SwBaseFunct7::SwDouble as usize + #ec_idx * (SwBaseFunct7::SHORT_WEIERSTRASS_MAX_KINDS as usize), @@ -497,7 +497,7 @@ pub fn sw_init(input: TokenStream) -> TokenStream { let p2 = [one.as_ref(), one.as_ref()].concat(); let mut uninit: core::mem::MaybeUninit<[#item; 2]> = core::mem::MaybeUninit::uninit(); openvm::platform::custom_insn_r!( - opcode = ::openvm_ecc_guest::OPCODE, + opcode = ::openvm_ecc_guest::SW_OPCODE, funct3 = ::openvm_ecc_guest::SW_FUNCT3 as usize, funct7 = ::openvm_ecc_guest::SwBaseFunct7::SwSetup as usize + #ec_idx @@ -507,7 +507,7 @@ pub fn sw_init(input: TokenStream) -> TokenStream { rs2 = In p2.as_ptr() ); openvm::platform::custom_insn_r!( - opcode = ::openvm_ecc_guest::OPCODE, + opcode = ::openvm_ecc_guest::SW_OPCODE, funct3 = ::openvm_ecc_guest::SW_FUNCT3 as usize, funct7 = ::openvm_ecc_guest::SwBaseFunct7::SwSetup as usize + #ec_idx @@ -524,8 +524,8 @@ pub fn sw_init(input: TokenStream) -> TokenStream { TokenStream::from(quote::quote_spanned! { span.into() => #[allow(non_snake_case)] #[cfg(target_os = "zkvm")] - mod openvm_intrinsics_ffi_2 { - use ::openvm_ecc_guest::{OPCODE, SW_FUNCT3, SwBaseFunct7}; + mod openvm_intrinsics_ffi_2_sw { + use ::openvm_ecc_guest::{SW_OPCODE, SW_FUNCT3, SwBaseFunct7}; #(#externs)* } diff --git a/extensions/ecc/te-macros/Cargo.toml b/extensions/ecc/te-macros/Cargo.toml new file mode 100644 index 0000000000..de3544ff87 --- /dev/null +++ b/extensions/ecc/te-macros/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "openvm-ecc-te-macros" +version.workspace = true +authors.workspace = true +edition.workspace = true +homepage.workspace = true +repository.workspace = true + +[dependencies] +syn = { version = "2.0", features = ["full"] } +quote = "1.0" +openvm-macros-common = { workspace = true, default-features = false } + +[lib] +proc-macro = true diff --git a/extensions/ecc/te-macros/README.md b/extensions/ecc/te-macros/README.md new file mode 100644 index 0000000000..6de5c50110 --- /dev/null +++ b/extensions/ecc/te-macros/README.md @@ -0,0 +1,125 @@ +# `openvm-ecc-te-macros` + +Procedural macros for use in guest program to generate short twisted Edwards elliptic curve struct with custom intrinsics for compile-time modulus. + +The workflow of this macro is very similar to the [`openvm-algebra-moduli-macros`](../moduli-macros/README.md) crate. We recommend reading it first. + +## Example + +```rust +// ... + +moduli_declare! { + Ed25519Coord { modulus = "0x7FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFED" }, + Ed25519Scalar { modulus = "0x1000000000000000000000000000000014DEF9DEA2F79CD65812631A5CF5D3ED" }, +} + +// Note that from_const_bytes is little endian +pub const CURVE_A: Ed25519Coord = Ed25519Coord::from_const_bytes(hex!( + "ECFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF7F" +)); +pub const CURVE_D: Ed25519Coord = Ed25519Coord::from_const_bytes(hex!( + "A3785913CA4DEB75ABD841414D0A700098E879777940C78C73FE6F2BEE6C0352" +)); + +sw_declare! { + Ed25519Point { mod_type = Ed25519Coord, a = CURVE_A, d = CURVE_D }, +} + +openvm_algebra_guest::moduli_macros::moduli_init! { + "0x7FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFED", + "0x1000000000000000000000000000000014DEF9DEA2F79CD65812631A5CF5D3ED", +} + +openvm_ecc_guest::te_macros::te_init! { + Ed25519Point, +} + +pub fn main() { + setup_all_moduli(); + setup_all_te_curves(); + // ... +} +``` + +## Full story + +Again, the principle is the same as in the [`openvm-algebra-moduli-macros`](../moduli-macros/README.md) crate. Here we emphasize the core differences. + +The crate provides two macros: `te_declare!` and `te_init!`. The signatures are: + +- `te_declare!` receives comma-separated list of moduli classes descriptions. Each description looks like `TeStruct { mod_type = ModulusName, a = a_expr, d = d_expr }`. Here `ModulusName` is the name of any struct that implements `trait IntMod` -- in particular, the ones created by `moduli_declare!` do. Parameters `a` and `d` correspond to the coefficients of the equation defining the curve. They **must be compile-time constants**. Both the parameters `a` and `d` are required. + +- `te_init!` receives comma-separated list of struct names. The struct name must exactly match the name in `te_declare!` -- type defs are not allowed (see point 5 below). + +What happens under the hood: + +1. `te_declare!` macro creates a struct with two field `x` and `y` of type `mod_type`. This struct denotes a point on the corresponding elliptic curve. In the example it would be + +```rust +struct Ed25519Point { + x: Ed25519Coord, + y: Ed25519Coord, +} +``` + +Similar to `moduli_declare!`, this macro also creates extern functions for arithmetic operations -- but in this case they are named after the te type, not after any hexadecimal (since the macro has no way to obtain it from the name of the modulus type anyway): + +```rust +extern "C" { + fn te_add_extern_func_Ed25519Point(rd: usize, rs1: usize, rs2: usize); + fn hint_decompress_extern_func_Ed25519Point(rs1: usize, rs2: usize); +} +``` + +2. Again, `te_init!` macro implements these extern functions and defines the setup functions for the te struct. + +```rust +#[cfg(target_os = "zkvm")] +mod openvm_intrinsics_ffi_2 { + use :openvm_ecc_guest::{OPCODE, TE_FUNCT3, TeBaseFunct7}; + + #[no_mangle] + extern "C" fn te_add_extern_func_Ed25519Point(rd: usize, rs1: usize, rs2: usize) { + // ... + } + // other externs +} +#[allow(non_snake_case)] +pub fn setup_te_Ed25519Point() { + #[cfg(target_os = "zkvm")] + { + // ... + } +} +pub fn setup_all_te_curves() { + setup_te_Ed25519Point(); + // other setups +} +``` + +3. Again, if using the Rust bindings, then the `te_setup_extern_func_*` function for every curve is automatically called on first use of any of the curve's intrinsics. + +4. The order of the items in `te_init!` **must match** the order of the moduli in the chip configuration -- more specifically, in the modular extension parameters (the order of `CurveConfig`s in `EccExtension::supported_te_curves`, which is usually defined with the whole `app_vm_config` in the `openvm.toml` file). + +5. Note that, due to the nature of function names, the name of the struct used in `te_init!` must be the same as in `te_declare!`. To illustrate, the following code will **fail** to compile: + +```rust +// ... + +te_declare! { + Ed25519Point { mod_type = Ed25519Coord, a = CURVE_A, d = CURVE_D }, +} + +pub type Te = Ed25519Point; + +te_init! { + Te, +} +``` + +The reason is that, for example, the function `sw_add_extern_func_Secp256k1Point` remains unimplemented, but we implement `sw_add_extern_func_Sw`. + +6. `cargo openvm build` will automatically generate a call to `te_init!` based on `openvm.toml`. +Note that `openvm.toml` must contain the name of each struct created by `te_declare!` as a string (in the example at the top of this document, its `"Ed25519Point"`). +The SDK also supports this feature. diff --git a/extensions/ecc/te-macros/src/lib.rs b/extensions/ecc/te-macros/src/lib.rs new file mode 100644 index 0000000000..5b0b2fd106 --- /dev/null +++ b/extensions/ecc/te-macros/src/lib.rs @@ -0,0 +1,359 @@ +extern crate proc_macro; + +use openvm_macros_common::MacroArgs; +use proc_macro::TokenStream; +use quote::format_ident; +use syn::{ + parse::{Parse, ParseStream}, + parse_macro_input, Expr, ExprPath, Path, Token, +}; + +/// This macro generates the code to setup a Twisted Edwards elliptic curve for a given modular +/// type. Also it places the curve parameters into a special static variable to be later extracted +/// from the ELF and used by the VM. Usage: +/// ``` +/// te_declare! { +/// [TODO] +/// } +/// ``` +/// +/// For this macro to work, you must import the `elliptic_curve` crate and the `openvm_ecc_guest` +/// crate.. +#[proc_macro] +pub fn te_declare(input: TokenStream) -> TokenStream { + let MacroArgs { items } = parse_macro_input!(input as MacroArgs); + + let mut output = Vec::new(); + + let span = proc_macro::Span::call_site(); + + for item in items.into_iter() { + let struct_name = item.name.to_string(); + let struct_name = syn::Ident::new(&struct_name, span.into()); + let struct_path: syn::Path = syn::parse_quote!(#struct_name); + let mut intmod_type: Option = None; + let mut const_a: Option = None; + let mut const_d: Option = None; + for param in item.params { + match param.name.to_string().as_str() { + "mod_type" => { + if let syn::Expr::Path(ExprPath { path, .. }) = param.value { + intmod_type = Some(path) + } else { + return syn::Error::new_spanned(param.value, "Expected a type") + .to_compile_error() + .into(); + } + } + "a" => { + const_a = Some(param.value); + } + "d" => { + const_d = Some(param.value); + } + _ => { + panic!("Unknown parameter {}", param.name); + } + } + } + + let intmod_type = intmod_type.expect("mod_type parameter is required"); + let const_a = const_a.expect("constant a coefficient is required"); + let const_d = const_d.expect("constant d coefficient is required"); + + macro_rules! create_extern_func { + ($name:ident) => { + let $name = syn::Ident::new( + &format!( + "{}_{}", + stringify!($name), + struct_path + .segments + .iter() + .map(|x| x.ident.to_string()) + .collect::>() + .join("_") + ), + span.into(), + ); + }; + } + create_extern_func!(te_add_extern_func); + create_extern_func!(te_setup_extern_func); + + let group_ops_mod_name = format_ident!("{}_ops", struct_name.to_string().to_lowercase()); + + let result = TokenStream::from(quote::quote_spanned! { span.into() => + extern "C" { + fn #te_add_extern_func(rd: usize, rs1: usize, rs2: usize); + fn #te_setup_extern_func(); + } + + #[derive(Eq, PartialEq, Clone, Debug, serde::Serialize, serde::Deserialize)] + #[repr(C)] + pub struct #struct_name { + x: #intmod_type, + y: #intmod_type, + } + + impl #struct_name { + const fn identity() -> Self { + Self { + x: <#intmod_type as openvm_algebra_guest::IntMod>::ZERO, + y: <#intmod_type as openvm_algebra_guest::IntMod>::ONE, + } + } + // Below are wrapper functions for the intrinsic instructions. + // Should not be called directly. + #[inline(always)] + fn add_chip(p1: &#struct_name, p2: &#struct_name) -> #struct_name { + #[cfg(not(target_os = "zkvm"))] + { + use openvm_algebra_guest::DivUnsafe; + + let x1y2 = p1.x.clone() * p2.y.clone(); + let y1x2 = p1.y.clone() * p2.x.clone(); + let x1x2 = p1.x.clone() * p2.x.clone(); + let y1y2 = p1.y.clone() * p2.y.clone(); + let dx1x2y1y2 = ::CURVE_D * &x1x2 * &y1y2; + + let x3 = (x1y2 + y1x2).div_unsafe(&<#intmod_type as openvm_algebra_guest::IntMod>::ONE + &dx1x2y1y2); + let y3 = (y1y2 - ::CURVE_A * x1x2).div_unsafe(&<#intmod_type as openvm_algebra_guest::IntMod>::ONE - &dx1x2y1y2); + + #struct_name { x: x3, y: y3 } + } + #[cfg(target_os = "zkvm")] + { + Self::set_up_once(); + let mut uninit: core::mem::MaybeUninit<#struct_name> = core::mem::MaybeUninit::uninit(); + unsafe { + #te_add_extern_func( + uninit.as_mut_ptr() as usize, + p1 as *const #struct_name as usize, + p2 as *const #struct_name as usize + ) + }; + unsafe { uninit.assume_init() } + } + } + + // Helper function to call the setup instruction on first use + #[cfg(target_os = "zkvm")] + #[inline(always)] + fn set_up_once() { + static is_setup: ::openvm_ecc_guest::once_cell::race::OnceBool = ::openvm_ecc_guest::once_cell::race::OnceBool::new(); + is_setup.get_or_init(|| { + unsafe { #te_setup_extern_func(); } + <#intmod_type as openvm_algebra_guest::IntMod>::set_up_once(); + true + }); + } + + #[cfg(not(target_os = "zkvm"))] + #[inline(always)] + fn set_up_once() { + // No-op for non-ZKVM targets + } + } + + impl ::openvm_ecc_guest::edwards::TwistedEdwardsPoint for #struct_name { + const CURVE_A: Self::Coordinate = #const_a; + const CURVE_D: Self::Coordinate = #const_d; + + const IDENTITY: Self = Self::identity(); + type Coordinate = #intmod_type; + + /// SAFETY: assumes that #intmod_type has a memory representation + /// such that with repr(C), two coordinates are packed contiguously. + #[inline(always)] + fn as_le_bytes(&self) -> &[u8] { + unsafe { &*core::ptr::slice_from_raw_parts(self as *const Self as *const u8, <#intmod_type as openvm_algebra_guest::IntMod>::NUM_LIMBS * 2) } + } + + #[inline(always)] + fn from_xy_unchecked(x: Self::Coordinate, y: Self::Coordinate) -> Self { + Self { x, y } + } + + #[inline(always)] + fn x(&self) -> &Self::Coordinate { + &self.x + } + + #[inline(always)] + fn y(&self) -> &Self::Coordinate { + &self.y + } + + #[inline(always)] + fn x_mut(&mut self) -> &mut Self::Coordinate { + &mut self.x + } + + #[inline(always)] + fn y_mut(&mut self) -> &mut Self::Coordinate { + &mut self.y + } + + #[inline(always)] + fn into_coords(self) -> (Self::Coordinate, Self::Coordinate) { + (self.x, self.y) + } + + #[inline(always)] + fn add_impl(&self, p2: &Self) -> Self { + Self::add_chip(self, p2) + } + } + + impl core::ops::Neg for #struct_name { + type Output = Self; + + fn neg(self) -> Self::Output { + #struct_name { + x: core::ops::Neg::neg(&self.x), + y: self.y, + } + } + } + + impl core::ops::Neg for &#struct_name { + type Output = #struct_name; + + fn neg(self) -> #struct_name { + #struct_name { + x: core::ops::Neg::neg(&self.x), + y: self.y.clone(), + } + } + } + + mod #group_ops_mod_name { + use ::openvm_ecc_guest::{edwards::TwistedEdwardsPoint, FromCompressed, impl_te_group_ops, algebra::{IntMod, DivUnsafe, DivAssignUnsafe, ExpBytes}}; + use super::*; + + impl_te_group_ops!(#struct_name, #intmod_type); + + impl FromCompressed<#intmod_type> for #struct_name { + fn decompress(y: #intmod_type, rec_id: &u8) -> Option { + use openvm_algebra_guest::{Sqrt, DivUnsafe}; + let x_squared = (<#intmod_type as openvm_algebra_guest::IntMod>::ONE - &y * &y).div_unsafe(<#struct_name as ::openvm_ecc_guest::edwards::TwistedEdwardsPoint>::CURVE_A - &<#struct_name as ::openvm_ecc_guest::edwards::TwistedEdwardsPoint>::CURVE_D * &y * &y); + let x = x_squared.sqrt(); + match x { + None => None, + Some(x) => { + let correct_x = if x.as_le_bytes()[0] & 1 == *rec_id & 1 { + x + } else { + -x + }; + // handle the case where x = 0 + if correct_x.as_le_bytes()[0] & 1 != *rec_id & 1 { + return None; + } + // In order for sqrt() to return Some, we are guaranteed that x * x == x_squared, which already proves (correct_x, y) is on the curve + Some(<#struct_name as ::openvm_ecc_guest::edwards::TwistedEdwardsPoint>::from_xy_unchecked(correct_x, y)) + } + } + } + } + } + }); + output.push(result); + } + + TokenStream::from_iter(output) +} + +struct TeDefine { + items: Vec, +} + +impl Parse for TeDefine { + fn parse(input: ParseStream) -> syn::Result { + let items = input.parse_terminated(::parse, Token![,])?; + Ok(Self { + items: items + .into_iter() + .map(|e| { + if let Expr::Path(p) = e { + p.path + } else { + panic!("expected path"); + } + }) + .collect(), + }) + } +} + +#[proc_macro] +pub fn te_init(input: TokenStream) -> TokenStream { + let TeDefine { items } = parse_macro_input!(input as TeDefine); + + let mut externs = Vec::new(); + + let span = proc_macro::Span::call_site(); + + for (ec_idx, item) in items.into_iter().enumerate() { + let str_path = item + .segments + .iter() + .map(|x| x.ident.to_string()) + .collect::>() + .join("_"); + let add_extern_func = + syn::Ident::new(&format!("te_add_extern_func_{}", str_path), span.into()); + let setup_extern_func = + syn::Ident::new(&format!("te_setup_extern_func_{}", str_path), span.into()); + externs.push(quote::quote_spanned! { span.into() => + #[no_mangle] + extern "C" fn #add_extern_func(rd: usize, rs1: usize, rs2: usize) { + openvm::platform::custom_insn_r!( + opcode = TE_OPCODE, + funct3 = TE_FUNCT3 as usize, + funct7 = TeBaseFunct7::TeAdd as usize + #ec_idx + * (TeBaseFunct7::TWISTED_EDWARDS_MAX_KINDS as usize), + rd = In rd, + rs1 = In rs1, + rs2 = In rs2 + ); + } + + #[no_mangle] + extern "C" fn #setup_extern_func() { + #[cfg(target_os = "zkvm")] + { + use super::#item; + let modulus_bytes = <<#item as openvm_ecc_guest::edwards::TwistedEdwardsPoint>::Coordinate as openvm_algebra_guest::IntMod>::MODULUS; + let mut zero = [0u8; <<#item as openvm_ecc_guest::edwards::TwistedEdwardsPoint>::Coordinate as openvm_algebra_guest::IntMod>::NUM_LIMBS]; + let curve_a_bytes = openvm_algebra_guest::IntMod::as_le_bytes(&<#item as openvm_ecc_guest::edwards::TwistedEdwardsPoint>::CURVE_A); + let curve_d_bytes = openvm_algebra_guest::IntMod::as_le_bytes(&<#item as openvm_ecc_guest::edwards::TwistedEdwardsPoint>::CURVE_D); + let p1 = [modulus_bytes.as_ref(), curve_a_bytes.as_ref()].concat(); + let p2 = [curve_d_bytes.as_ref(), zero.as_ref()].concat(); + let mut uninit: core::mem::MaybeUninit<[#item; 2]> = core::mem::MaybeUninit::uninit(); + openvm::platform::custom_insn_r!( + opcode = ::openvm_ecc_guest::TE_OPCODE, + funct3 = ::openvm_ecc_guest::TE_FUNCT3 as usize, + funct7 = ::openvm_ecc_guest::TeBaseFunct7::TeSetup as usize + + #ec_idx + * (::openvm_ecc_guest::TeBaseFunct7::TWISTED_EDWARDS_MAX_KINDS as usize), + rd = In uninit.as_mut_ptr(), + rs1 = In p1.as_ptr(), + rs2 = In p2.as_ptr(), + ); + } + } + }); + } + + TokenStream::from(quote::quote_spanned! { span.into() => + #[allow(non_snake_case)] + #[cfg(target_os = "zkvm")] + mod openvm_intrinsics_ffi_2_te { + use ::openvm_ecc_guest::{TE_OPCODE, TE_FUNCT3, TeBaseFunct7}; + + #(#externs)* + } + }) +} diff --git a/extensions/ecc/tests/Cargo.toml b/extensions/ecc/tests/Cargo.toml index 5f90e77fa4..587f3636d4 100644 --- a/extensions/ecc/tests/Cargo.toml +++ b/extensions/ecc/tests/Cargo.toml @@ -12,6 +12,7 @@ openvm-stark-sdk.workspace = true openvm-circuit = { workspace = true, features = ["test-utils"] } openvm-transpiler.workspace = true openvm-algebra-transpiler.workspace = true +openvm-sha2-transpiler.workspace = true openvm-ecc-transpiler.workspace = true openvm-ecc-circuit.workspace = true openvm-rv32im-transpiler.workspace = true @@ -21,8 +22,8 @@ serde.workspace = true serde_with.workspace = true toml.workspace = true eyre.workspace = true -hex-literal.workspace = true num-bigint.workspace = true +hex-literal.workspace = true halo2curves-axiom = { workspace = true } [features] diff --git a/extensions/ecc/tests/programs/Cargo.toml b/extensions/ecc/tests/programs/Cargo.toml index 55fcedd3ee..68ed6601b4 100644 --- a/extensions/ecc/tests/programs/Cargo.toml +++ b/extensions/ecc/tests/programs/Cargo.toml @@ -9,8 +9,9 @@ openvm = { path = "../../../../crates/toolchain/openvm" } openvm-platform = { path = "../../../../crates/toolchain/platform" } openvm-custom-insn = { path = "../../../../crates/toolchain/custom_insn", default-features = false } -openvm-ecc-guest = { path = "../../guest", default-features = false } +openvm-ecc-guest = { path = "../../guest", default-features = false, features = ["ed25519"] } openvm-ecc-sw-macros = { path = "../../../../extensions/ecc/sw-macros", default-features = false } +openvm-ecc-te-macros = { path = "../../../../extensions/ecc/te-macros", default-features = false } openvm-algebra-guest = { path = "../../../algebra/guest", default-features = false } openvm-algebra-moduli-macros = { path = "../../../algebra/moduli-macros", default-features = false } openvm-rv32im-guest = { path = "../../../../extensions/rv32im/guest", default-features = false } @@ -43,6 +44,7 @@ default = [] std = ["serde/std", "openvm/std"] k256 = ["dep:openvm-k256"] p256 = ["dep:openvm-p256"] +ed25519 = ["openvm-ecc-guest/ed25519"] [profile.release] panic = "abort" @@ -63,7 +65,7 @@ required-features = ["k256", "p256"] [[example]] name = "decompress" -required-features = ["k256"] +required-features = ["k256", "ed25519"] [[example]] name = "ecdsa" @@ -81,6 +83,10 @@ required-features = ["k256"] name = "sec1_decode" required-features = ["k256"] +[[example]] +name = "edwards_ec" +required-features = ["ed25519"] + [[example]] name = "invalid_setup" required-features = ["k256", "p256"] diff --git a/extensions/ecc/tests/programs/examples/decompress.rs b/extensions/ecc/tests/programs/examples/decompress.rs index 0148d5d057..f6e9870a3e 100644 --- a/extensions/ecc/tests/programs/examples/decompress.rs +++ b/extensions/ecc/tests/programs/examples/decompress.rs @@ -7,9 +7,11 @@ extern crate alloc; use hex_literal::hex; use openvm::io::read_vec; use openvm_ecc_guest::{ - algebra::IntMod, - weierstrass::{FromCompressed, WeierstrassPoint}, - Group, + algebra::{Field, IntMod}, + ed25519::{Ed25519Coord, Ed25519Point}, + edwards::TwistedEdwardsPoint, + weierstrass::WeierstrassPoint, + FromCompressed, Group, }; use openvm_k256::{Secp256k1Coord, Secp256k1Point}; @@ -22,7 +24,6 @@ openvm_algebra_moduli_macros::moduli_declare! { Fp1mod4 { modulus = "0xffffffffffffffffffffffffffffffff000000000000000000000001" }, } -// const CURVE_B_5MOD8: Fp5mod8 = Fp5mod8::from_const_u8(3); const CURVE_B_5MOD8: Fp5mod8 = Fp5mod8::from_const_u8(6); const CURVE_A_1MOD4: Fp1mod4 = Fp1mod4::from_const_bytes(hex!( @@ -44,7 +45,7 @@ openvm_ecc_sw_macros::sw_declare! { }, } -openvm::init!("openvm_init_decompress_k256.rs"); +openvm::init!("openvm_init_decompress_k256_ed25519.rs"); // test decompression under an honest host pub fn main() { @@ -53,35 +54,43 @@ pub fn main() { let y = Secp256k1Coord::from_le_bytes_unchecked(&bytes[32..64]); let rec_id = y.as_le_bytes()[0] & 1; - test_possible_decompression::(&x, &y, rec_id); + test_possible_sw_decompression::(&x, &y, rec_id); // x = 5 is not on the x-coordinate of any point on the Secp256k1 curve - test_impossible_decompression::(&Secp256k1Coord::from_u8(5), rec_id); + test_impossible_sw_decompression::(&Secp256k1Coord::from_u8(5), rec_id); let x = Fp5mod8::from_le_bytes_unchecked(&bytes[64..96]); let y = Fp5mod8::from_le_bytes_unchecked(&bytes[96..128]); let rec_id = y.as_le_bytes()[0] & 1; - test_possible_decompression::(&x, &y, rec_id); + test_possible_sw_decompression::(&x, &y, rec_id); // x = 0 is not on the x-coordinate of any point on the CurvePoint5mod8 curve - test_impossible_decompression::(&Fp5mod8::ZERO, rec_id); + test_impossible_sw_decompression::(&::ZERO, rec_id); // this x is such that y^2 = x^3 + 6 = 0 // we want to test the case where y^2 = 0 and rec_id = 1 let x = Fp5mod8::from_le_bytes_unchecked(&hex!( "d634a701c3b9b8cbf7797988be3953b442863b74d2d5c4d5f1a9de3c0c256d90" )); - test_possible_decompression::(&x, &Fp5mod8::ZERO, 0); - test_impossible_decompression::(&x, 1); + test_possible_sw_decompression::(&x, &::ZERO, 0); + test_impossible_sw_decompression::(&x, 1); let x = Fp1mod4::from_le_bytes_unchecked(&bytes[128..160]); let y = Fp1mod4::from_le_bytes_unchecked(&bytes[160..192]); let rec_id = y.as_le_bytes()[0] & 1; - test_possible_decompression::(&x, &y, rec_id); + test_possible_sw_decompression::(&x, &y, rec_id); // x = 1 is not on the x-coordinate of any point on the CurvePoint1mod4 curve - test_impossible_decompression::(&Fp1mod4::from_u8(1), rec_id); + test_impossible_sw_decompression::(&Fp1mod4::from_u8(1), rec_id); + + // ed25519 + let x = Ed25519Coord::from_le_bytes_unchecked(&bytes[192..224]); + let y = Ed25519Coord::from_le_bytes_unchecked(&bytes[224..256]); + let rec_id = x.as_le_bytes()[0] & 1; + test_possible_te_decompression::(&x, &y, rec_id); + // y = 2 is not on the y-coordinate of any point on the Ed25519 curve + test_impossible_te_decompression::(&Ed25519Coord::from_u8(2), rec_id); } -fn test_possible_decompression>( +fn test_possible_sw_decompression>( x: &P::Coordinate, y: &P::Coordinate, rec_id: u8, @@ -91,7 +100,25 @@ fn test_possible_decompression>( +fn test_possible_te_decompression>( + x: &P::Coordinate, + y: &P::Coordinate, + rec_id: u8, +) { + let p = P::decompress(y.clone(), &rec_id).unwrap(); + assert_eq!(p.x(), x); + assert_eq!(p.y(), y); +} + +fn test_impossible_sw_decompression>( + x: &P::Coordinate, + rec_id: u8, +) { + let p = P::decompress(x.clone(), &rec_id); + assert!(p.is_none()); +} + +fn test_impossible_te_decompression>( x: &P::Coordinate, rec_id: u8, ) { diff --git a/extensions/ecc/tests/programs/examples/ec.rs b/extensions/ecc/tests/programs/examples/ec.rs index 1b63057c30..71c1194463 100644 --- a/extensions/ecc/tests/programs/examples/ec.rs +++ b/extensions/ecc/tests/programs/examples/ec.rs @@ -2,8 +2,7 @@ #![cfg_attr(not(feature = "std"), no_std)] use hex_literal::hex; -use openvm_algebra_guest::IntMod; -use openvm_ecc_guest::{msm, weierstrass::WeierstrassPoint, Group}; +use openvm_ecc_guest::{algebra::IntMod, msm, weierstrass::WeierstrassPoint, Group}; use openvm_k256::{Secp256k1Coord, Secp256k1Point, Secp256k1Scalar}; openvm::init!("openvm_init_ec_k256.rs"); diff --git a/extensions/ecc/tests/programs/examples/ec_nonzero_a.rs b/extensions/ecc/tests/programs/examples/ec_nonzero_a.rs index 41db1ececc..854641c4bf 100644 --- a/extensions/ecc/tests/programs/examples/ec_nonzero_a.rs +++ b/extensions/ecc/tests/programs/examples/ec_nonzero_a.rs @@ -2,8 +2,7 @@ #![cfg_attr(not(feature = "std"), no_std)] use hex_literal::hex; -use openvm_algebra_guest::IntMod; -use openvm_ecc_guest::{weierstrass::WeierstrassPoint, CyclicGroup, Group}; +use openvm_ecc_guest::{algebra::IntMod, weierstrass::WeierstrassPoint, CyclicGroup, Group}; use openvm_p256::{P256Coord, P256Point}; openvm::entry!(main); diff --git a/extensions/ecc/tests/programs/examples/ec_two_curves.rs b/extensions/ecc/tests/programs/examples/ec_two_curves.rs index 6412e3184f..681f1c9fe4 100644 --- a/extensions/ecc/tests/programs/examples/ec_two_curves.rs +++ b/extensions/ecc/tests/programs/examples/ec_two_curves.rs @@ -2,8 +2,7 @@ #![cfg_attr(not(feature = "std"), no_std)] use hex_literal::hex; -use openvm_algebra_guest::IntMod; -use openvm_ecc_guest::{msm, weierstrass::WeierstrassPoint, Group}; +use openvm_ecc_guest::{algebra::IntMod, msm, weierstrass::WeierstrassPoint, Group}; use openvm_k256::{Secp256k1Coord, Secp256k1Point, Secp256k1Scalar}; use openvm_p256::{P256Coord, P256Point}; diff --git a/extensions/ecc/tests/programs/examples/ed25519.rs b/extensions/ecc/tests/programs/examples/ed25519.rs new file mode 100644 index 0000000000..4266fd3f35 --- /dev/null +++ b/extensions/ecc/tests/programs/examples/ed25519.rs @@ -0,0 +1,146 @@ +#![cfg_attr(not(feature = "std"), no_main)] +#![cfg_attr(not(feature = "std"), no_std)] + +use hex_literal::hex; +use openvm_ecc_guest::{ed25519::Ed25519Point, eddsa::VerifyingKey}; + +openvm::entry!(main); + +openvm::init!("openvm_init_ed25519_ed25519.rs"); + +pub struct Ed25519TestData { + pub msg: &'static [u8], + pub signature: [u8; 64], + pub vk: [u8; 32], +} + +// Test data for the non-prehash variant of Ed25519. +// The first five tests were taken from https://datatracker.ietf.org/doc/html/rfc8032#section-7.1 +// The rest were randomly generated. +const ED25519_TEST_DATA: [Ed25519TestData; 10] = [ + Ed25519TestData { + msg: b"", + signature: hex!("e5564300c360ac729086e2cc806e828a84877f1eb8e5d974d873e065224901555fb8821590a33bacc61e39701cf9b46bd25bf5f0595bbe24655141438e7a100b"), + vk: hex!("d75a980182b10ab7d54bfed3c964073a0ee172f3daa62325af021a68f707511a"), + }, + Ed25519TestData { + msg: &hex!("72"), + signature: hex!("92a009a9f0d4cab8720e820b5f642540a2b27b5416503f8fb3762223ebdb69da085ac1e43e15996e458f3613d0f11d8c387b2eaeb4302aeeb00d291612bb0c00"), + vk: hex!("3d4017c3e843895a92b70aa74d1b7ebc9c982ccf2ec4968cc0cd55f12af4660c"), + }, + Ed25519TestData { + msg: &hex!("af82"), + signature: hex!("6291d657deec24024827e69c3abe01a30ce548a284743a445e3680d7db5ac3ac18ff9b538d16f290ae67f760984dc6594a7c15e9716ed28dc027beceea1ec40a"), + vk: hex!("fc51cd8e6218a1a38da47ed00230f0580816ed13ba3303ac5deb911548908025"), + }, + Ed25519TestData { + msg: &hex!("08b8b2b733424243760fe426a4b54908632110a66c2f6591eabd3345e3e4eb98fa6e264bf09efe12ee50f8f54e9f77b1e355f6c50544e23fb1433ddf73be84d879de7c0046dc4996d9e773f4bc9efe5738829adb26c81b37c93a1b270b20329d658675fc6ea534e0810a4432826bf58c941efb65d57a338bbd2e26640f89ffbc1a858efcb8550ee3a5e1998bd177e93a7363c344fe6b199ee5d02e82d522c4feba15452f80288a821a579116ec6dad2b3b310da903401aa62100ab5d1a36553e06203b33890cc9b832f79ef80560ccb9a39ce767967ed628c6ad573cb116dbefefd75499da96bd68a8a97b928a8bbc103b6621fcde2beca1231d206be6cd9ec7aff6f6c94fcd7204ed3455c68c83f4a41da4af2b74ef5c53f1d8ac70bdcb7ed185ce81bd84359d44254d95629e9855a94a7c1958d1f8ada5d0532ed8a5aa3fb2d17ba70eb6248e594e1a2297acbbb39d502f1a8c6eb6f1ce22b3de1a1f40cc24554119a831a9aad6079cad88425de6bde1a9187ebb6092cf67bf2b13fd65f27088d78b7e883c8759d2c4f5c65adb7553878ad575f9fad878e80a0c9ba63bcbcc2732e69485bbc9c90bfbd62481d9089beccf80cfe2df16a2cf65bd92dd597b0707e0917af48bbb75fed413d238f5555a7a569d80c3414a8d0859dc65a46128bab27af87a71314f318c782b23ebfe808b82b0ce26401d2e22f04d83d1255dc51addd3b75a2b1ae0784504df543af8969be3ea7082ff7fc9888c144da2af58429ec96031dbcad3dad9af0dcbaaaf268cb8fcffead94f3c7ca495e056a9b47acdb751fb73e666c6c655ade8297297d07ad1ba5e43f1bca32301651339e22904cc8c42f58c30c04aafdb038dda0847dd988dcda6f3bfd15c4b4c4525004aa06eeff8ca61783aacec57fb3d1f92b0fe2fd1a85f6724517b65e614ad6808d6f6ee34dff7310fdc82aebfd904b01e1dc54b2927094b2db68d6f903b68401adebf5a7e08d78ff4ef5d63653a65040cf9bfd4aca7984a74d37145986780fc0b16ac451649de6188a7dbdf191f64b5fc5e2ab47b57f7f7276cd419c17a3ca8e1b939ae49e488acba6b965610b5480109c8b17b80e1b7b750dfc7598d5d5011fd2dcc5600a32ef5b52a1ecc820e308aa342721aac0943bf6686b64b2579376504ccc493d97e6aed3fb0f9cd71a43dd497f01f17c0e2cb3797aa2a2f256656168e6c496afc5fb93246f6b1116398a346f1a641f3b041e989f7914f90cc2c7fff357876e506b50d334ba77c225bc307ba537152f3f1610e4eafe595f6d9d90d11faa933a15ef1369546868a7f3a45a96768d40fd9d03412c091c6315cf4fde7cb68606937380db2eaaa707b4c4185c32eddcdd306705e4dc1ffc872eeee475a64dfac86aba41c0618983f8741c5ef68d3a101e8a3b8cac60c905c15fc910840b94c00a0b9d0"), + signature: hex!("0aab4c900501b3e24d7cdf4663326a3a87df5e4843b2cbdb67cbf6e460fec350aa5371b1508f9f4528ecea23c436d94b5e8fcd4f681e30a6ac00a9704a188a03"), + vk: hex!("278117fc144c72340f67d0f2316e8386ceffbf2b2428c9c51fef7c597f1d426e"), + }, + Ed25519TestData { + msg: &hex!("ddaf35a193617abacc417349ae20413112e6fa4e89a97ea20a9eeee64b55d39a2192992a274fc1a836ba3c23a3feebbd454d4423643ce80e2a9ac94fa54ca49f"), + signature: hex!("dc2a4459e7369633a52b1bf277839a00201009a3efbf3ecb69bea2186c26b58909351fc9ac90b3ecfdfbc7c66431e0303dca179c138ac17ad9bef1177331a704"), + vk: hex!("ec172b93ad5e563bf4932c70e1245034c35467ef2efd4d64ebf819683467e2bf"), + }, + Ed25519TestData { + msg: &hex!("470d6c430959"), + signature: hex!("17ba04a7351648d316c9567cee48bfb568499ee0fea83fd246c44202e9ad9e920d983306ed7a3ac8ea51ebb5a1e57a0b270ca962c812aa8a89e60ce787ac8205"), + vk: hex!("758751992ea75a6736661ac6f6ed4de7e5ed9dfbe33eaa9325780923614341d3"), + }, + Ed25519TestData { + msg: &hex!("a7bf65eac2bbbcb776761f247c3ccd6971396a88b3eb0bbebea592fd68b20a4d9e7cf474bea1eff3a332c9cdecd8fff2fe1e3cc6a3844318c2bc6f78a04a853ed1c535fe5824"), + signature: hex!("a67dd824237e219fb224da3d16bcf5142b5d642e5a62198f3d2ed901eae5bb96e14975fcaeb714516fd0ded27a9fab1bada235d44d65457a96085f3d4eb0230c"), + vk: hex!("7bfb9c93c50bb636b9d916fe3aec6a5ac6e19c47278ea404f7ea1721e3c46ced"), + }, + Ed25519TestData { + msg: &hex!("c781d2fd5640dd66f50c57cb7015feecf44a297c93ceafb611acffb39cb7c69f277491cc39eaa836008194e77860a7716799eca708859188b46d3e44dd3f57f3553244b1a8e5092fe1bdd6e016b67fd94e88187d03efe25d4178266dcac56aa1"), + signature: hex!("63afcc5c9b282e2e7e8871b411cd69e1cad83f057cb764862453af88ed5bb255ebf96dab5ea1b1041bdc6d515e79f4c774e9c87d7b7a681cf399cab3005a580e"), + vk: hex!("7e2467b9b1ae68d0b79e9c8214592022be6c369b2ba771cd7100d4be0db554b3"), + }, + Ed25519TestData { + msg: &hex!("7d89777ab2b5ff2ae46da312ed32a48b22977eb52a11fc3ab355f3ad7ad40641218681eef5add98f01"), + signature: hex!("cf76c790c77166e2db3a28eca5de7f42ffc9dac85895de1f929c72714d23a9fdb92017432ef7424ff14acb815c76881d55dcc80cca1ed8630473baa9b9b9d005"), + vk: hex!("4ff316d580e2330d99f92fdc149d4c88f36981be132a4f3fa065e649cf3571a8"), + }, + Ed25519TestData { + msg: &hex!("3d98781c525e466626b418"), + signature: hex!("185a5fd9b6e82e07c72bc81296cb1e4f7a5bfd4f5226961f52e24c0f20cf12310b740d38146dfdba662be9b6b2926712a648e73fe22239486149a404864df50f"), + vk: hex!("99cd13488f1b48f6d57d1f77ce2006487e65f8e7d6f1936cf6f36adf0b602d55"), + }, +]; + +// Test data for the prehash variant of Ed25519. +// First test was taken from the RFC: https://datatracker.ietf.org/doc/html/rfc8032#section-7.3 +// The rest were randomly generated. +const ED25519PH_TEST_DATA: [Ed25519TestData; 10] = [ + // This test is taken from the RFC: https://datatracker.ietf.org/doc/html/rfc8032#section-7.3 + Ed25519TestData { + msg: b"abc", + signature: hex!( + "98a70222f0b8121aa9d30f813d683f809e462b469c7ff87639499bb94e6dae4131f85042463c2a355a2003d062adf5aaa10b8c61e636062aaad11c2a26083406" + ), + vk: hex!( + "ec172b93ad5e563bf4932c70e1245034c35467ef2efd4d64ebf819683467e2bf" + ), + }, + Ed25519TestData { + msg: &hex!("4023e5edbfde97998cea65ee971c8cb24526596044f1216fa2c0d8c8ec8df95ef237bdd314022a2780dd09b9dcb8ba1df76d6ac7f0d7bf4374ef6405979dad73490d2b363545a2c0f5eddb965705a565f44a371d5cf58004d6834e0271c5e674"), + signature: hex!("5ac8ece6e00341bb1bc10403837f2f59fc0a3cbcd352e9101dccb5af2ad41da9199758ad606679bd2dc4af5a1d89c73c36365ae5455b725c6a2cea8d06399501"), + vk: hex!("6c6193110409068064bc7986acbc3c96449dfe32891e6c2fb3fc33ad7655ba0b") + }, + Ed25519TestData { + msg: &hex!("65b5f8e0fb2f47a16bc0e6777b5a4abed8a8dbcbd0b685257e47ede83a433cc3c5d8755959cdf8caa6990eae48f3759b03593b9bc0d8fc7383a5d8ea7b02de9dd761b410a4"), + signature: hex!("93406bdd8d40925e1a6e316654874444baca364f6700c074e2975e50cd2e708e17d10084574701fe0c91eda4d2d796e26b2aba67c48b3f94fac151e699cfe809"), + vk: hex!("9b606b002b395f9efe41a3f388d37bb81ed52486f8ddc19996176462da5d7b29") + }, + Ed25519TestData { + msg: &hex!("62dba7573793c7fa9908c4feb690c0b61b136d5c744b69b343a61bdc"), + signature: hex!("c4bd72c7935dc4d3784abfcb7f20124429b73925d01ec48673bd37ab26c2cc159089ade51df1eb7ce175abf43afd6c23c7bb39ad2d6476acb3a04ce7339f2e0f"), + vk: hex!("a166560aa0d208bf06c93a2a7f64748e503def9407ca8ee81687d7e6ee21efa7") + }, + Ed25519TestData { + msg: &hex!("40033e866996ddca4ff7a48d557bc7a4ffc4d97274bfae4691976cf0d587a9d5823a38ed5e7314b67b61b5d7536d3d581bd0ce77ca27ebd2ce26ce2e"), + signature: hex!("24ea129ba9abc5b11dc5a690bccbaebe315b8882b029b8c81bd8cc6a5b65e79aa82298b64fc61e03d4081642c90a60ad3955ab484304194d95ecf1ba8dfe760a"), + vk: hex!("bed62265ab0d6ba1be8dfe009fbb9514c6774ead6c34492adaee4893c32d39fc") + }, + Ed25519TestData { + msg: &hex!("0e9e0304a804628305083e4bde6de5d82fe5f5dbb47e232f2cf14439fd36dd59f26b87574614d8f4af6019c5d4d7ec77fee102445faf0c75f635a31234e2135199df5cdf013ff3472346e6f69e8a"), + signature: hex!("dcf2a479abd7ea5211d853279f128e402fb64fc3148780694422e8a572e29ed1557fd89f172c0a3c1b2ab6944297deace095583bff09b302936f64198385490e"), + vk: hex!("47547e699eaf210dcce343b4a2d176607a8a1ceb2a3e912360f40dd1fa3ab216") + }, + Ed25519TestData { + msg: &hex!("563df2cb74d0cfd961ed010958845e6983b1ca7a55761dba35ccebcf17dfd972bfcb908c116a4eacb84235ed"), + signature: hex!("9960cb5936c23969707bf92ab0d51ae941dc2ce2534d818ed1c829dbbe916a93657bc6c7d38ad5f1d07df513f35409d9581cb3c122f0742f41295dc8e0396d0c"), + vk: hex!("4c6a352b76f20b7757ee96083f2cd8b759556551d9e5937a7ec4345d4bf974ec") + }, + Ed25519TestData { + msg: &hex!("de08bd2b3a014ea5a85de35f718fea45e6c23c06f2d23e5bbc22bb998de9d7cb9d21fbb9a55aff4d2a867daaf4a897281c889c9536fb2259014030d8b24d9a04dfc0b74e62847638f1740767e62ccfbfab174749f657a75a3845924ca6a6a6d539"), + signature: hex!("64bbb532b1ebe08554d6569be1133ed9219e0450fd1bf12fd30ce6aa3150aea3054ed3ff35ee7c458eeeee6141a9e5f3bb4d14c345a4435ae331ed3ab0fa350e"), + vk: hex!("e22e10095f6fb59e5714093b78f9889ac3aa394c2a87b858ca6afb8558241a2e") + }, + Ed25519TestData { + msg: &hex!("96d3dd0d9df5160827ad62c96dc04ff1621e50b3c7a2bcba53e6e2c83b6939"), + signature: hex!("5f6d536c7cb27391f7a274fd7f72e258b49f511476e5a0cd562c3f52a050694734c7251434bde23fe14977094a482fe5ff14a237fc24804cab5e595916240401"), + vk: hex!("83a53511df806437a5c3e72c1fd43d448e217da4cc789d817fa2b12d9628128b") + }, + Ed25519TestData { + msg: &hex!("c366ea5e7b41a0382eaa22a5d6101d9fef9a7b234cc54f218108f2695f288c9369723aa7958d0605166b52be9a895133ae77eba9122d8780dd897d3f7c871cc48f57c11fcfa7a5a056c1e1"), + signature: hex!("535e3a9fc9a3948998398f8d79af40ff7cb320e62498df3db3ad7cfbbd5c02eedee95a11f745f5334f6532e36ec5cfca1c2e14e416481e3cc3600ef58dab4901"), + vk: hex!("4028f8f8d4f461fcc41d2c2878eddb7acb9fbf3254e7de511c3f513673ff05d0") + }, +]; + +pub fn main() { + for test_data in ED25519_TEST_DATA { + let vk = VerifyingKey::::from_bytes(&test_data.vk).unwrap(); + assert!(vk.verify(test_data.msg, &test_data.signature).is_ok()); + } + + for test_data in ED25519PH_TEST_DATA { + let vk = VerifyingKey::::from_bytes(&test_data.vk).unwrap(); + assert!(vk + .verify_ph(test_data.msg, None, &test_data.signature) + .is_ok()); + } +} diff --git a/extensions/ecc/tests/programs/examples/edwards_ec.rs b/extensions/ecc/tests/programs/examples/edwards_ec.rs new file mode 100644 index 0000000000..53214fb14a --- /dev/null +++ b/extensions/ecc/tests/programs/examples/edwards_ec.rs @@ -0,0 +1,70 @@ +#![cfg_attr(not(feature = "std"), no_main)] +#![cfg_attr(not(feature = "std"), no_std)] + +use hex_literal::hex; +use openvm_algebra_guest::moduli_macros::moduli_init; +use openvm_ecc_guest::{ + algebra::IntMod, + ed25519::{Ed25519Coord, Ed25519Point}, + edwards::TwistedEdwardsPoint, + te_macros::te_init, + CyclicGroup, Group, +}; + +moduli_init! { + "57896044618658097711785492504343953926634992332820282019728792003956564819949", +} + +te_init! { + Ed25519Point, +} + +openvm::entry!(main); + +pub fn main() { + // Base point of edwards25519 + let mut p1 = Ed25519Point::GENERATOR; + + // random point on edwards25519 + let x2 = Ed25519Coord::from_u32(2); + let y2 = Ed25519Coord::from_be_bytes_unchecked(&hex!( + "1A43BF127BDDC4D71FF910403C11DDB5BA2BCDD2815393924657EF111E712631" + )); + let mut p2 = Ed25519Point::from_xy(x2, y2).unwrap(); + + // This is the sum of (x1, y1) and (x2, y2). + let x3 = Ed25519Coord::from_be_bytes_unchecked(&hex!( + "636C0B519B2C5B1E0D3BFD213F45AFD5DAEE3CECC9B68CF88615101BC78329E6" + )); + let y3 = Ed25519Coord::from_be_bytes_unchecked(&hex!( + "704D8868CB335A7B609D04B9CD619511675691A78861F1DFF7A5EBC389C7EA92" + )); + + // This is 2 * (x1, y1) + let x4 = Ed25519Coord::from_be_bytes_unchecked(&hex!( + "56B98CC045559AD2BBC45CAB58D842ECEE264DB9395F6014B772501B62BB7EE8" + )); + let y4 = Ed25519Coord::from_be_bytes_unchecked(&hex!( + "1BCA918096D89C83A15105DF343DC9F7510494407750226DAC0A7620ACE77BEB" + )); + + // Generic add can handle equal or unequal points. + let p3 = &p1 + &p2; + if p3.x() != &x3 || p3.y() != &y3 { + panic!(); + } + let p4 = &p2 + &p2; + if p4.x() != &x4 || p4.y() != &y4 { + panic!(); + } + + // Add assign and double assign + p1 += &p2; + if p1.x() != &x3 || p1.y() != &y3 { + panic!(); + } + p2.double_assign(); + if p2.x() != &x4 || p2.y() != &y4 { + panic!(); + } +} diff --git a/extensions/ecc/tests/programs/openvm_init_decompress_k256.rs b/extensions/ecc/tests/programs/openvm_init_decompress_k256_ed25519.rs similarity index 73% rename from extensions/ecc/tests/programs/openvm_init_decompress_k256.rs rename to extensions/ecc/tests/programs/openvm_init_decompress_k256_ed25519.rs index b6137ae9ee..9c00595f31 100644 --- a/extensions/ecc/tests/programs/openvm_init_decompress_k256.rs +++ b/extensions/ecc/tests/programs/openvm_init_decompress_k256_ed25519.rs @@ -1,3 +1,4 @@ // This file is automatically generated by cargo openvm. Do not rename or edit. -openvm_algebra_guest::moduli_macros::moduli_init! { "115792089237316195423570985008687907853269984665640564039457584007908834671663", "115792089237316195423570985008687907852837564279074904382605163141518161494337", "115792089237316195423570985008687907853269984665640564039457584007913129639501", "1000000007", "26959946667150639794667015087019630673557916260026308143510066298881", "26959946667150639794667015087019625940457807714424391721682722368061" } +openvm_algebra_guest::moduli_macros::moduli_init! { "115792089237316195423570985008687907853269984665640564039457584007908834671663", "115792089237316195423570985008687907852837564279074904382605163141518161494337", "115792089237316195423570985008687907853269984665640564039457584007913129639501", "1000000007", "26959946667150639794667015087019630673557916260026308143510066298881", "26959946667150639794667015087019625940457807714424391721682722368061", "57896044618658097711785492504343953926634992332820282019728792003956564819949", "7237005577332262213973186563042994240857116359379907606001950938285454250989" } openvm_ecc_guest::sw_macros::sw_init! { Secp256k1Point, CurvePoint5mod8, CurvePoint1mod4 } +openvm_ecc_guest::te_macros::te_init! { Ed25519Point } diff --git a/extensions/ecc/tests/programs/openvm_init_ec_k256.rs b/extensions/ecc/tests/programs/openvm_init_ec_k256.rs index bec9f527e9..f0855c9497 100644 --- a/extensions/ecc/tests/programs/openvm_init_ec_k256.rs +++ b/extensions/ecc/tests/programs/openvm_init_ec_k256.rs @@ -1,3 +1,4 @@ // This file is automatically generated by cargo openvm. Do not rename or edit. openvm_algebra_guest::moduli_macros::moduli_init! { "115792089237316195423570985008687907853269984665640564039457584007908834671663", "115792089237316195423570985008687907852837564279074904382605163141518161494337" } openvm_ecc_guest::sw_macros::sw_init! { Secp256k1Point } +openvm_ecc_guest::te_macros::te_init! { } diff --git a/extensions/ecc/tests/programs/openvm_init_ec_nonzero_a_p256.rs b/extensions/ecc/tests/programs/openvm_init_ec_nonzero_a_p256.rs index 02f8b5c05d..cd95ac085f 100644 --- a/extensions/ecc/tests/programs/openvm_init_ec_nonzero_a_p256.rs +++ b/extensions/ecc/tests/programs/openvm_init_ec_nonzero_a_p256.rs @@ -1,3 +1,4 @@ // This file is automatically generated by cargo openvm. Do not rename or edit. openvm_algebra_guest::moduli_macros::moduli_init! { "115792089210356248762697446949407573530086143415290314195533631308867097853951", "115792089210356248762697446949407573529996955224135760342422259061068512044369" } openvm_ecc_guest::sw_macros::sw_init! { P256Point } +openvm_ecc_guest::te_macros::te_init! { } diff --git a/extensions/ecc/tests/programs/openvm_init_ec_two_curves_k256_p256.rs b/extensions/ecc/tests/programs/openvm_init_ec_two_curves_k256_p256.rs index 8689190544..624788b82a 100644 --- a/extensions/ecc/tests/programs/openvm_init_ec_two_curves_k256_p256.rs +++ b/extensions/ecc/tests/programs/openvm_init_ec_two_curves_k256_p256.rs @@ -1,3 +1,4 @@ // This file is automatically generated by cargo openvm. Do not rename or edit. openvm_algebra_guest::moduli_macros::moduli_init! { "115792089237316195423570985008687907853269984665640564039457584007908834671663", "115792089237316195423570985008687907852837564279074904382605163141518161494337", "115792089210356248762697446949407573530086143415290314195533631308867097853951", "115792089210356248762697446949407573529996955224135760342422259061068512044369" } openvm_ecc_guest::sw_macros::sw_init! { Secp256k1Point, P256Point } +openvm_ecc_guest::te_macros::te_init! { } diff --git a/extensions/ecc/tests/programs/openvm_init_ecdsa_k256.rs b/extensions/ecc/tests/programs/openvm_init_ecdsa_k256.rs index bec9f527e9..f0855c9497 100644 --- a/extensions/ecc/tests/programs/openvm_init_ecdsa_k256.rs +++ b/extensions/ecc/tests/programs/openvm_init_ecdsa_k256.rs @@ -1,3 +1,4 @@ // This file is automatically generated by cargo openvm. Do not rename or edit. openvm_algebra_guest::moduli_macros::moduli_init! { "115792089237316195423570985008687907853269984665640564039457584007908834671663", "115792089237316195423570985008687907852837564279074904382605163141518161494337" } openvm_ecc_guest::sw_macros::sw_init! { Secp256k1Point } +openvm_ecc_guest::te_macros::te_init! { } diff --git a/extensions/ecc/tests/programs/openvm_init_ed25519_ed25519.rs b/extensions/ecc/tests/programs/openvm_init_ed25519_ed25519.rs new file mode 100644 index 0000000000..03ed9602f6 --- /dev/null +++ b/extensions/ecc/tests/programs/openvm_init_ed25519_ed25519.rs @@ -0,0 +1,4 @@ +// This file is automatically generated by cargo openvm. Do not rename or edit. +openvm_algebra_guest::moduli_macros::moduli_init! { "57896044618658097711785492504343953926634992332820282019728792003956564819949", "7237005577332262213973186563042994240857116359379907606001950938285454250989" } +openvm_ecc_guest::sw_macros::sw_init! { } +openvm_ecc_guest::te_macros::te_init! { Ed25519Point } diff --git a/extensions/ecc/tests/programs/openvm_init_edwards_ec_ed25519.rs b/extensions/ecc/tests/programs/openvm_init_edwards_ec_ed25519.rs new file mode 100644 index 0000000000..03ed9602f6 --- /dev/null +++ b/extensions/ecc/tests/programs/openvm_init_edwards_ec_ed25519.rs @@ -0,0 +1,4 @@ +// This file is automatically generated by cargo openvm. Do not rename or edit. +openvm_algebra_guest::moduli_macros::moduli_init! { "57896044618658097711785492504343953926634992332820282019728792003956564819949", "7237005577332262213973186563042994240857116359379907606001950938285454250989" } +openvm_ecc_guest::sw_macros::sw_init! { } +openvm_ecc_guest::te_macros::te_init! { Ed25519Point } diff --git a/extensions/ecc/tests/programs/openvm_k256.toml b/extensions/ecc/tests/programs/openvm_k256.toml index 571fdb895c..2fa80a5af3 100644 --- a/extensions/ecc/tests/programs/openvm_k256.toml +++ b/extensions/ecc/tests/programs/openvm_k256.toml @@ -8,9 +8,11 @@ supported_moduli = [ "115792089237316195423570985008687907852837564279074904382605163141518161494337", ] -[[app_vm_config.ecc.supported_curves]] +[[app_vm_config.ecc.supported_sw_curves]] struct_name = "Secp256k1Point" modulus = "115792089237316195423570985008687907853269984665640564039457584007908834671663" scalar = "115792089237316195423570985008687907852837564279074904382605163141518161494337" + +[app_vm_config.ecc.supported_sw_curves.coeffs] a = "0" b = "7" diff --git a/extensions/ecc/tests/programs/openvm_k256_keccak.toml b/extensions/ecc/tests/programs/openvm_k256_keccak.toml index c1261ee458..4dc77ccd80 100644 --- a/extensions/ecc/tests/programs/openvm_k256_keccak.toml +++ b/extensions/ecc/tests/programs/openvm_k256_keccak.toml @@ -9,9 +9,11 @@ supported_moduli = [ "115792089237316195423570985008687907852837564279074904382605163141518161494337", ] -[[app_vm_config.ecc.supported_curves]] +[[app_vm_config.ecc.supported_sw_curves]] struct_name = "Secp256k1Point" modulus = "115792089237316195423570985008687907853269984665640564039457584007908834671663" scalar = "115792089237316195423570985008687907852837564279074904382605163141518161494337" + +[app_vm_config.ecc.supported_sw_curves.coeffs] a = "0" b = "7" diff --git a/extensions/ecc/tests/programs/openvm_p256.toml b/extensions/ecc/tests/programs/openvm_p256.toml index 0035cd83da..2cc5bd92c3 100644 --- a/extensions/ecc/tests/programs/openvm_p256.toml +++ b/extensions/ecc/tests/programs/openvm_p256.toml @@ -7,9 +7,11 @@ supported_moduli = [ "115792089210356248762697446949407573529996955224135760342422259061068512044369", ] -[[app_vm_config.ecc.supported_curves]] +[[app_vm_config.ecc.supported_sw_curves]] struct_name = "P256Point" modulus = "115792089210356248762697446949407573530086143415290314195533631308867097853951" scalar = "115792089210356248762697446949407573529996955224135760342422259061068512044369" + +[app_vm_config.ecc.supported_sw_curves.coeffs] a = "115792089210356248762697446949407573530086143415290314195533631308867097853948" b = "41058363725152142129326129780047268409114441015993725554835256314039467401291" diff --git a/extensions/ecc/tests/src/lib.rs b/extensions/ecc/tests/src/lib.rs index b9ae366d82..7fed6b2610 100644 --- a/extensions/ecc/tests/src/lib.rs +++ b/extensions/ecc/tests/src/lib.rs @@ -13,7 +13,11 @@ mod tests { arch::instructions::exe::VmExe, utils::{air_test, air_test_with_min_segments, test_system_config_with_continuations}, }; - use openvm_ecc_circuit::{CurveConfig, Rv32WeierstrassConfig, P256_CONFIG, SECP256K1_CONFIG}; + #[cfg(test)] + use openvm_ecc_circuit::TeCurveCoeffs; + use openvm_ecc_circuit::{ + CurveConfig, Rv32EccConfig, SwCurveCoeffs, ED25519_CONFIG, P256_CONFIG, SECP256K1_CONFIG, + }; use openvm_ecc_transpiler::EccTranspilerExtension; use openvm_rv32im_transpiler::{ Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, @@ -22,6 +26,7 @@ mod tests { config::{AppConfig, SdkVmConfig}, StdIn, }; + use openvm_sha2_transpiler::Sha2TranspilerExtension; use openvm_stark_backend::p3_field::FieldAlgebra; use openvm_stark_sdk::{openvm_stark_backend, p3_baby_bear::BabyBear}; use openvm_toolchain_tests::{ @@ -36,15 +41,18 @@ mod tests { type F = BabyBear; #[cfg(test)] - fn test_rv32weierstrass_config(curves: Vec) -> Rv32WeierstrassConfig { - let mut config = Rv32WeierstrassConfig::new(curves); + fn test_rv32ecc_config( + sw_curves: Vec>, + te_curves: Vec>, + ) -> Rv32EccConfig { + let mut config = Rv32EccConfig::new(sw_curves, te_curves); config.system = test_system_config_with_continuations(); config } #[test] fn test_ec() -> Result<()> { - let config = test_rv32weierstrass_config(vec![SECP256K1_CONFIG.clone()]); + let config = test_rv32ecc_config(vec![SECP256K1_CONFIG.clone()], vec![]); let elf = build_example_program_at_path_with_features( get_programs_dir!(), "ec", @@ -66,7 +74,7 @@ mod tests { #[test] fn test_ec_nonzero_a() -> Result<()> { - let config = test_rv32weierstrass_config(vec![P256_CONFIG.clone()]); + let config = test_rv32ecc_config(vec![P256_CONFIG.clone()], vec![]); let elf = build_example_program_at_path_with_features( get_programs_dir!(), "ec_nonzero_a", @@ -89,7 +97,7 @@ mod tests { #[test] fn test_ec_two_curves() -> Result<()> { let config = - test_rv32weierstrass_config(vec![SECP256K1_CONFIG.clone(), P256_CONFIG.clone()]); + test_rv32ecc_config(vec![SECP256K1_CONFIG.clone(), P256_CONFIG.clone()], vec![]); let elf = build_example_program_at_path_with_features( get_programs_dir!(), "ec_two_curves", @@ -111,9 +119,9 @@ mod tests { #[test] fn test_decompress() -> Result<()> { - use halo2curves_axiom::{group::Curve, secp256k1::Secp256k1Affine}; + use halo2curves_axiom::{ed25519::Ed25519Affine, group::Curve, secp256k1::Secp256k1Affine}; - let config = test_rv32weierstrass_config(vec![SECP256K1_CONFIG.clone(), + let config = test_rv32ecc_config(vec![SECP256K1_CONFIG.clone(), CurveConfig { struct_name: "CurvePoint5mod8".to_string(), modulus: BigUint::from_str("115792089237316195423570985008687907853269984665640564039457584007913129639501") @@ -121,8 +129,10 @@ mod tests { // unused, set to 10e9 + 7 scalar: BigUint::from_str("1000000007") .unwrap(), - a: BigUint::ZERO, - b: BigUint::from_str("6").unwrap(), + coeffs: SwCurveCoeffs { + a: BigUint::ZERO, + b: BigUint::from_str("6").unwrap(), + }, }, CurveConfig { struct_name: "CurvePoint1mod4".to_string(), @@ -130,19 +140,24 @@ mod tests { .unwrap(), scalar: BigUint::from_radix_be(&hex!("ffffffffffffffffffffffffffff16a2e0b8f03e13dd29455c5c2a3d"), 256) .unwrap(), - a: BigUint::from_radix_be(&hex!("fffffffffffffffffffffffffffffffefffffffffffffffffffffffe"), 256) - .unwrap(), - b: BigUint::from_radix_be(&hex!("b4050a850c04b3abf54132565044b0b7d7bfd8ba270b39432355ffb4"), 256) - .unwrap(), + coeffs: SwCurveCoeffs { + a: BigUint::from_radix_be(&hex!("fffffffffffffffffffffffffffffffefffffffffffffffffffffffe"), 256) + .unwrap(), + b: BigUint::from_radix_be(&hex!("b4050a850c04b3abf54132565044b0b7d7bfd8ba270b39432355ffb4"), 256) + .unwrap(), + }, }, - ]); + ], + vec![ED25519_CONFIG.clone()], + ); let elf = build_example_program_at_path_with_features( get_programs_dir!(), "decompress", - ["k256"], + ["k256", "ed25519"], &config, )?; + let openvm_exe = VmExe::from_elf( elf, Transpiler::::default() @@ -155,8 +170,7 @@ mod tests { let p = Secp256k1Affine::generator(); let p = (p + p + p).to_affine(); - println!("decompressed: {:?}", p); - + println!("secp256k1 decompressed: {:?}", p); let q_x: [u8; 32] = hex!("0100000000000000000000000000000000000000000000000000000000000000"); let q_y: [u8; 32] = @@ -165,12 +179,24 @@ mod tests { hex!("211D5C11D68032342211C256D3C1034AB99013327FBFB46BBD0C0EB700000000"); let r_y: [u8; 32] = hex!("347E00859981D5446447075AA07543CDE6DF224CFB23F7B5886337BD00000000"); + let s = Ed25519Affine::generator(); + let s = (s + s + s).to_affine(); + + let coords = [ + p.x.to_bytes(), + p.y.to_bytes(), + q_x, + q_y, + r_x, + r_y, + s.x.to_bytes(), + s.y.to_bytes(), + ] + .concat() + .into_iter() + .map(FieldAlgebra::from_canonical_u8) + .collect(); - let coords = [p.x.to_bytes(), p.y.to_bytes(), q_x, q_y, r_x, r_y] - .concat() - .into_iter() - .map(FieldAlgebra::from_canonical_u8) - .collect(); air_test_with_min_segments(config, openvm_exe, vec![coords], 1); Ok(()) } @@ -246,6 +272,28 @@ mod tests { Ok(()) } + #[test] + fn test_edwards_ec() -> Result<()> { + let config = Rv32EccConfig::new(vec![], vec![ED25519_CONFIG.clone()]); + let elf = build_example_program_at_path_with_features::<&str>( + get_programs_dir!(), + "edwards_ec", + ["ed25519"], + &config, + )?; + let openvm_exe = VmExe::from_elf( + elf, + Transpiler::::default() + .with_extension(Rv32ITranspilerExtension) + .with_extension(Rv32MTranspilerExtension) + .with_extension(Rv32IoTranspilerExtension) + .with_extension(EccTranspilerExtension) + .with_extension(ModularTranspilerExtension), + )?; + air_test(config, openvm_exe); + Ok(()) + } + #[test] #[should_panic] fn test_invalid_setup() { @@ -267,7 +315,30 @@ mod tests { ) .unwrap(); let config = - test_rv32weierstrass_config(vec![SECP256K1_CONFIG.clone(), P256_CONFIG.clone()]); + test_rv32ecc_config(vec![SECP256K1_CONFIG.clone(), P256_CONFIG.clone()], vec![]); air_test(config, openvm_exe); } + + #[test] + fn test_ed25519() -> Result<()> { + let config = Rv32EccConfig::new(vec![], vec![ED25519_CONFIG.clone()]); + let elf = build_example_program_at_path_with_features( + get_programs_dir!(), + "ed25519", + ["ed25519"], + &config, + )?; + let openvm_exe = VmExe::from_elf( + elf, + Transpiler::::default() + .with_extension(Rv32ITranspilerExtension) + .with_extension(Rv32MTranspilerExtension) + .with_extension(Rv32IoTranspilerExtension) + .with_extension(EccTranspilerExtension) + .with_extension(ModularTranspilerExtension) + .with_extension(Sha2TranspilerExtension), + )?; + air_test(config, openvm_exe); + Ok(()) + } } diff --git a/extensions/ecc/transpiler/src/lib.rs b/extensions/ecc/transpiler/src/lib.rs index 462e95dbdd..469868d3ae 100644 --- a/extensions/ecc/transpiler/src/lib.rs +++ b/extensions/ecc/transpiler/src/lib.rs @@ -1,4 +1,4 @@ -use openvm_ecc_guest::{SwBaseFunct7, OPCODE, SW_FUNCT3}; +use openvm_ecc_guest::{SwBaseFunct7, TeBaseFunct7, SW_FUNCT3, SW_OPCODE, TE_FUNCT3, TE_OPCODE}; use openvm_instructions::{ instruction::Instruction, riscv::RV32_REGISTER_NUM_LIMBS, LocalOpcode, VmOpcode, }; @@ -15,10 +15,21 @@ use strum::{EnumCount, EnumIter, FromRepr}; #[allow(non_camel_case_types)] #[repr(usize)] pub enum Rv32WeierstrassOpcode { - EC_ADD_NE, - SETUP_EC_ADD_NE, - EC_DOUBLE, - SETUP_EC_DOUBLE, + SW_ADD_NE, + SETUP_SW_ADD_NE, + SW_DOUBLE, + SETUP_SW_DOUBLE, +} + +#[derive( + Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, EnumCount, EnumIter, FromRepr, LocalOpcode, +)] +#[opcode_offset = 0x680] +#[allow(non_camel_case_types)] +#[repr(usize)] +pub enum Rv32EdwardsOpcode { + TE_ADD, + SETUP_TE_ADD, } #[derive(Default)] @@ -26,6 +37,67 @@ pub struct EccTranspilerExtension; impl TranspilerExtension for EccTranspilerExtension { fn process_custom(&self, instruction_stream: &[u32]) -> Option> { + self.process_weierstrass_instruction(instruction_stream) + .or(self.process_edwards_instruction(instruction_stream)) + } +} + +impl EccTranspilerExtension { + fn process_edwards_instruction( + &self, + instruction_stream: &[u32], + ) -> Option> { + if instruction_stream.is_empty() { + return None; + } + let instruction_u32 = instruction_stream[0]; + let opcode = (instruction_u32 & 0x7f) as u8; + let funct3 = ((instruction_u32 >> 12) & 0b111) as u8; + + if opcode != TE_OPCODE { + return None; + } + if funct3 != TE_FUNCT3 { + return None; + } + + let instruction = { + // twisted edwards ec + assert!(Rv32EdwardsOpcode::COUNT <= TeBaseFunct7::TWISTED_EDWARDS_MAX_KINDS as usize); + let dec_insn = RType::new(instruction_u32); + let base_funct7 = (dec_insn.funct7 as u8) % TeBaseFunct7::TWISTED_EDWARDS_MAX_KINDS; + let curve_idx = + ((dec_insn.funct7 as u8) / TeBaseFunct7::TWISTED_EDWARDS_MAX_KINDS) as usize; + let curve_idx_shift = curve_idx * Rv32EdwardsOpcode::COUNT; + + if base_funct7 == TeBaseFunct7::TeSetup as u8 { + let local_opcode = Rv32EdwardsOpcode::SETUP_TE_ADD; + Some(Instruction::new( + VmOpcode::from_usize(local_opcode.global_opcode().as_usize() + curve_idx_shift), + F::from_canonical_usize(RV32_REGISTER_NUM_LIMBS * dec_insn.rd), + F::from_canonical_usize(RV32_REGISTER_NUM_LIMBS * dec_insn.rs1), + F::from_canonical_usize(RV32_REGISTER_NUM_LIMBS * dec_insn.rs2), + F::ONE, // d_as = 1 + F::TWO, // e_as = 2 + F::ZERO, + F::ZERO, + )) + } else { + let global_opcode = match TeBaseFunct7::from_repr(base_funct7) { + Some(TeBaseFunct7::TeAdd) => Rv32EdwardsOpcode::TE_ADD.global_opcode(), + _ => unimplemented!(), + }; + let global_opcode = global_opcode.as_usize() + curve_idx_shift; + Some(from_r_type(global_opcode, 2, &dec_insn, true)) + } + }; + instruction.map(TranspilerOutput::one_to_one) + } + + fn process_weierstrass_instruction( + &self, + instruction_stream: &[u32], + ) -> Option> { if instruction_stream.is_empty() { return None; } @@ -33,7 +105,7 @@ impl TranspilerExtension for EccTranspilerExtension { let opcode = (instruction_u32 & 0x7f) as u8; let funct3 = ((instruction_u32 >> 12) & 0b111) as u8; - if opcode != OPCODE { + if opcode != SW_OPCODE { return None; } if funct3 != SW_FUNCT3 { @@ -52,8 +124,8 @@ impl TranspilerExtension for EccTranspilerExtension { let curve_idx_shift = curve_idx * Rv32WeierstrassOpcode::COUNT; if base_funct7 == SwBaseFunct7::SwSetup as u8 { let local_opcode = match dec_insn.rs2 { - 0 => Rv32WeierstrassOpcode::SETUP_EC_DOUBLE, - _ => Rv32WeierstrassOpcode::SETUP_EC_ADD_NE, + 0 => Rv32WeierstrassOpcode::SETUP_SW_DOUBLE, + _ => Rv32WeierstrassOpcode::SETUP_SW_ADD_NE, }; Some(Instruction::new( VmOpcode::from_usize(local_opcode.global_opcode().as_usize() + curve_idx_shift), @@ -67,18 +139,14 @@ impl TranspilerExtension for EccTranspilerExtension { )) } else { let global_opcode = match SwBaseFunct7::from_repr(base_funct7) { - Some(SwBaseFunct7::SwAddNe) => { - Rv32WeierstrassOpcode::EC_ADD_NE as usize - + Rv32WeierstrassOpcode::CLASS_OFFSET - } + Some(SwBaseFunct7::SwAddNe) => Rv32WeierstrassOpcode::SW_ADD_NE.global_opcode(), Some(SwBaseFunct7::SwDouble) => { assert!(dec_insn.rs2 == 0); - Rv32WeierstrassOpcode::EC_DOUBLE as usize - + Rv32WeierstrassOpcode::CLASS_OFFSET + Rv32WeierstrassOpcode::SW_DOUBLE.global_opcode() } _ => unimplemented!(), }; - let global_opcode = global_opcode + curve_idx_shift; + let global_opcode = global_opcode.as_usize() + curve_idx_shift; Some(from_r_type(global_opcode, 2, &dec_insn, true)) } }; diff --git a/extensions/pairing/circuit/src/config.rs b/extensions/pairing/circuit/src/config.rs index d63bac664e..3958eeddb2 100644 --- a/extensions/pairing/circuit/src/config.rs +++ b/extensions/pairing/circuit/src/config.rs @@ -23,7 +23,7 @@ pub struct Rv32PairingConfig { #[extension] pub fp2: Fp2Extension, #[extension] - pub weierstrass: WeierstrassExtension, + pub ecc: EccExtension, #[extension] pub pairing: PairingExtension, } @@ -48,9 +48,7 @@ impl Rv32PairingConfig { .zip(modulus_primes) .collect(), ), - weierstrass: WeierstrassExtension::new( - curves.iter().map(|c| c.curve_config()).collect(), - ), + ecc: EccExtension::new(curves.iter().map(|c| c.curve_config()).collect(), vec![]), pairing: PairingExtension::new(curves), } } @@ -62,7 +60,7 @@ impl InitFileGenerator for Rv32PairingConfig { "// This file is automatically generated by cargo openvm. Do not rename or edit.\n{}\n{}\n{}\n", self.modular.generate_moduli_init(), self.fp2.generate_complex_init(&self.modular), - self.weierstrass.generate_sw_init() + self.ecc.generate_ecc_init() )) } } diff --git a/extensions/pairing/circuit/src/pairing_extension.rs b/extensions/pairing/circuit/src/pairing_extension.rs index f700ca4dc5..b4a8fa62c4 100644 --- a/extensions/pairing/circuit/src/pairing_extension.rs +++ b/extensions/pairing/circuit/src/pairing_extension.rs @@ -8,7 +8,7 @@ use openvm_circuit::{ use openvm_circuit_derive::{AnyEnum, InsExecutorE1, InsExecutorE2, InstructionExecutor}; use openvm_circuit_primitives::bitwise_op_lookup::SharedBitwiseOperationLookupChip; use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; -use openvm_ecc_circuit::CurveConfig; +use openvm_ecc_circuit::{CurveConfig, SwCurveCoeffs}; use openvm_instructions::PhantomDiscriminant; use openvm_pairing_guest::{ bls12_381::{ @@ -30,21 +30,25 @@ pub enum PairingCurve { } impl PairingCurve { - pub fn curve_config(&self) -> CurveConfig { + pub fn curve_config(&self) -> CurveConfig { match self { PairingCurve::Bn254 => CurveConfig::new( BN254_ECC_STRUCT_NAME.to_string(), BN254_MODULUS.clone(), BN254_ORDER.clone(), - BigUint::zero(), - BigUint::from_u8(3).unwrap(), + SwCurveCoeffs { + a: BigUint::zero(), + b: BigUint::from_u8(3).unwrap(), + }, ), PairingCurve::Bls12_381 => CurveConfig::new( BLS12_381_ECC_STRUCT_NAME.to_string(), BLS12_381_MODULUS.clone(), BLS12_381_ORDER.clone(), - BigUint::zero(), - BigUint::from_u8(4).unwrap(), + SwCurveCoeffs { + a: BigUint::zero(), + b: BigUint::from_u8(4).unwrap(), + }, ), } } diff --git a/extensions/sha256/circuit/Cargo.toml b/extensions/sha2/circuit/Cargo.toml similarity index 80% rename from extensions/sha256/circuit/Cargo.toml rename to extensions/sha2/circuit/Cargo.toml index 95c87b0871..213965c0cb 100644 --- a/extensions/sha256/circuit/Cargo.toml +++ b/extensions/sha2/circuit/Cargo.toml @@ -1,9 +1,9 @@ [package] -name = "openvm-sha256-circuit" +name = "openvm-sha2-circuit" version.workspace = true authors.workspace = true edition.workspace = true -description = "OpenVM circuit extension for sha256" +description = "OpenVM circuit extension for SHA-2" [dependencies] openvm-stark-backend = { workspace = true } @@ -13,16 +13,16 @@ openvm-circuit-primitives-derive = { workspace = true } openvm-circuit-derive = { workspace = true } openvm-circuit = { workspace = true } openvm-instructions = { workspace = true } -openvm-sha256-transpiler = { workspace = true } +openvm-sha2-transpiler = { workspace = true } openvm-rv32im-circuit = { workspace = true } -openvm-sha256-air = { workspace = true } +openvm-sha2-air = { workspace = true } derive-new.workspace = true derive_more = { workspace = true, features = ["from"] } rand.workspace = true serde.workspace = true sha2 = { version = "0.10", default-features = false } -strum = { workspace = true } +ndarray = { workspace = true, default-features = false } [dev-dependencies] openvm-stark-sdk = { workspace = true } @@ -37,3 +37,6 @@ mimalloc = ["openvm-circuit/mimalloc"] jemalloc = ["openvm-circuit/jemalloc"] jemalloc-prof = ["openvm-circuit/jemalloc-prof"] nightly-features = ["openvm-circuit/nightly-features"] + +[package.metadata.cargo-shear] +ignored = ["ndarray"] \ No newline at end of file diff --git a/extensions/sha256/circuit/README.md b/extensions/sha2/circuit/README.md similarity index 56% rename from extensions/sha256/circuit/README.md rename to extensions/sha2/circuit/README.md index 1e794cd35c..de2100b261 100644 --- a/extensions/sha256/circuit/README.md +++ b/extensions/sha2/circuit/README.md @@ -1,28 +1,43 @@ -# SHA256 VM Extension +# SHA-2 VM Extension -This crate contains the circuit for the SHA256 VM extension. +This crate contains circuits for the SHA-2 family of hash functions. +We support SHA-256, SHA-512, and SHA-384. -## SHA-256 Algorithm Summary +## SHA-2 Algorithms Summary -See the [FIPS standard](https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.180-4.pdf), in particular, section 6.2 for reference. +The SHA-256, SHA-512, and SHA-384 algorithms are similar in structure. +We will first describe the SHA-256 algorithm, and then describe the differences between the three algorithms. + +See the [FIPS standard](https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.180-4.pdf) for reference. In particular, sections 6.2, 6.4, and 6.5. In short the SHA-256 algorithm works as follows. 1. Pad the message to 512 bits and split it into 512-bit 'blocks'. -2. Initialize a hash state consisting of eight 32-bit words. +2. Initialize a hash state consisting of eight 32-bit words to a specific constant value. 3. For each block, - 1. split the message into 16 32-bit words and produce 48 more 'message schedule' words based on them. - 2. apply 64 'rounds' to update the hash state based on the message schedule. - 3. add the previous block's final hash state to the current hash state (modulo `2^32`). + 1. split the message into 16 32-bit words and produce 48 more words based on them. The 16 message words together with the 48 additional words are called the 'message schedule'. + 2. apply a scrambling function 64 times to the hash state to update it based on the message schedule. We call each update a 'round'. + 3. add the previous block's final hash state to the current hash state (modulo $2^{32}$). 4. The output is the final hash state +The differences with the SHA-512 algorithm are that: +- SHA-512 uses 64-bit words, 1024-bit blocks, performs 80 rounds, and produces a 512-bit output. +- all the arithmetic is done modulo $2^{64}$. +- the initial hash state is different. + +The SHA-384 algorithm is a truncation of the SHA-512 output to 384 bits, and the only difference is that the initial hash state is different. + ## Design Overview -This chip produces an AIR that consists of 17 rows for each block (512 bits) in the message, and no more rows. -The first 16 rows of each block are called 'round rows', and each of them represents four rounds of the SHA-256 algorithm. -Each row constrains updates to the working variables on each round, and it also constrains the message schedule words based on previous rounds. -The final row is called a 'digest row' and it produces a final hash for the block, computed as the sum of the working variables and the previous block's final hash. +We reuse the same AIR code to produce circuits for all three algorithms. +To achieve this, we parameterize the AIR by constants (such as the word size, number of rounds, and block size) that are specific to each algorithm. + +This chip produces an AIR that consists of $R+1$ rows for each block of the message, and no more rows +(for SHA-256, $R = 16$ and for SHA-512 and SHA-384, $R = 20$). +The first $R$ rows of each block are called 'round rows', and each of them constrains four rounds of the hash algorithm. +Each row constrains updates to the working variables on each round, and also constrains the message schedule words based on previous rounds. +The final row of each block is called a 'digest row' and it produces a final hash for the block, computed as the sum of the working variables and the previous block's final hash. -Note that this chip only supports messages of length less than `2^29` bytes. +Note that this chip only supports messages of length less than $2^{29}$ bytes. ### Storing working variables @@ -50,7 +65,7 @@ Since we can reliably constrain values from four rounds ago, we can build up `in The last block of every message should have the `is_last_block` flag set to `1`. Note that `is_last_block` is not constrained to be true for the last block of every message, instead it *defines* what the last block of a message is. -For instance, if we produce an air with 10 blocks and only the last block has `is_last_block = 1` then the constraints will interpret it as a single message of length 10 blocks. +For instance, if we produce a trace with 10 blocks and only the last block has `is_last_block = 1` then the constraints will interpret it as a single message of length 10 blocks. If, however, we set `is_last_block` to true for the 6th block, the trace will be interpreted as hashing two messages, each of length 5 blocks. Note that we do constrain, however, that the very last block of the trace has `is_last_block = 1`. @@ -63,11 +78,11 @@ We use this trick in several places in this chip. ### Block index counter variables -There are two "block index" counter variables in each row of the air named `global_block_idx` and `local_block_idx`. -Both of these variables take on the same value on all 17 rows in a block. +There are two "block index" counter variables in each row named `global_block_idx` and `local_block_idx`. +Both of these variables take on the same value on all $R+1$ rows in a block. The `global_block_idx` is the index of the block in the entire trace. -The very first 17 rows in the trace will have `global_block_idx = 1` and the counter will increment by 1 between blocks. +The very first block in the trace will have `global_block_idx = 1` on each row and the counter will increment by 1 between blocks. The padding rows will all have `global_block_idx = 0`. The `global_block_idx` is used in interaction constraints to constrain the value of `hash` between blocks. @@ -79,15 +94,16 @@ The `local_block_idx` is used to calculate the length of the message processed s ### VM air vs SubAir -The SHA-256 VM extension chip uses the `Sha256Air` SubAir to help constrain the SHA-256 hash. -The VM extension air constrains the correctness of the SHA message padding, while the SubAir adds all other constraints related to the hash algorithm. -The VM extension air also constrains memory reads and writes. +The SHA-2 VM extension chip uses the `Sha2Air` SubAir to help constrain the appropriate SHA-2 hash algorithm. +The SubAir is also parameterized by the specific SHA-2 variant's constants. +The VM extension AIR constrains the correctness of the message padding, while the SubAir adds all other constraints related to the hash algorithm. +The VM extension AIR also constrains memory reads and writes. ### A gotcha about padding rows There are two senses of the word padding used in the context of this chip and this can be confusing. -First, we use padding to refer to the extra bits added to the message that is input to the SHA-256 algorithm in order to make the input's length a multiple of 512 bits. -So, we may use the term 'padding rows' to refer to round rows that correspond to the padded bits of a message (as in `Sha256VmAir::eval_padding_row`). +First, we use padding to refer to the extra bits added to the message that is input to the hash algorithm in order to make the input's length a multiple of the block size. +So, we may use the term 'padding rows' to refer to round rows that correspond to the padded bits of a message (as in `Sha2VmAir::eval_padding_row`). Second, the dummy rows that are added to the trace to make the trace height a power of 2 are also called padding rows (see the `is_padding_row` flag). In the SubAir, padding row probably means dummy row. -In the VM air, it probably refers to SHA-256 padding. \ No newline at end of file +In the VM air, it probably refers to the message padding. \ No newline at end of file diff --git a/extensions/sha256/circuit/src/extension.rs b/extensions/sha2/circuit/src/extension.rs similarity index 57% rename from extensions/sha256/circuit/src/extension.rs rename to extensions/sha2/circuit/src/extension.rs index 77373cbb48..b05e4412e0 100644 --- a/extensions/sha256/circuit/src/extension.rs +++ b/extensions/sha2/circuit/src/extension.rs @@ -16,17 +16,17 @@ use openvm_rv32im_circuit::{ Rv32I, Rv32IExecutor, Rv32IPeriphery, Rv32Io, Rv32IoExecutor, Rv32IoPeriphery, Rv32M, Rv32MExecutor, Rv32MPeriphery, }; -use openvm_sha256_transpiler::Rv32Sha256Opcode; +use openvm_sha2_air::{Sha256Config, Sha384Config, Sha512Config}; +use openvm_sha2_transpiler::Rv32Sha2Opcode; use openvm_stark_backend::p3_field::PrimeField32; use serde::{Deserialize, Serialize}; -use strum::IntoEnumIterator; use crate::*; // TODO: this should be decided after e2 execution #[derive(Clone, Debug, VmConfig, derive_new::new, Serialize, Deserialize)] -pub struct Sha256Rv32Config { +pub struct Sha2Rv32Config { #[system] pub system: SystemConfig, #[extension] @@ -36,43 +36,45 @@ pub struct Sha256Rv32Config { #[extension] pub io: Rv32Io, #[extension] - pub sha256: Sha256, + pub sha2: Sha2, } -impl Default for Sha256Rv32Config { +impl Default for Sha2Rv32Config { fn default() -> Self { Self { system: SystemConfig::default().with_continuations(), rv32i: Rv32I, rv32m: Rv32M::default(), io: Rv32Io, - sha256: Sha256, + sha2: Sha2, } } } // Default implementation uses no init file -impl InitFileGenerator for Sha256Rv32Config {} +impl InitFileGenerator for Sha2Rv32Config {} #[derive(Clone, Copy, Debug, Default, Serialize, Deserialize)] -pub struct Sha256; +pub struct Sha2; #[derive( ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum, InsExecutorE1, InsExecutorE2, )] -pub enum Sha256Executor { - Sha256(Sha256VmChip), +pub enum Sha2Executor { + Sha256(Sha2VmChip), + Sha512(Sha2VmChip), + Sha384(Sha2VmChip), } #[derive(From, ChipUsageGetter, Chip, AnyEnum)] -pub enum Sha256Periphery { +pub enum Sha2Periphery { BitwiseOperationLookup(SharedBitwiseOperationLookupChip<8>), Phantom(PhantomChip), } -impl VmExtension for Sha256 { - type Executor = Sha256Executor; - type Periphery = Sha256Periphery; +impl VmExtension for Sha2 { + type Executor = Sha2Executor; + type Periphery = Sha2Periphery; fn build( &self, @@ -93,24 +95,53 @@ impl VmExtension for Sha256 { chip }; - let sha256_chip = Sha256VmChip::new( - Sha256VmAir::new( + let sha256_chip = Sha2VmChip::::new( + Sha2VmAir::new( builder.system_port(), bitwise_lu_chip.bus(), pointer_max_bits, builder.new_bus_idx(), ), - Sha256VmStep::new( + Sha2VmStep::new( bitwise_lu_chip.clone(), - Rv32Sha256Opcode::CLASS_OFFSET, + Rv32Sha2Opcode::CLASS_OFFSET, pointer_max_bits, ), builder.system_base().memory_controller.helper(), ); - inventory.add_executor( - sha256_chip, - Rv32Sha256Opcode::iter().map(|x| x.global_opcode()), - )?; + inventory.add_executor(sha256_chip, vec![Rv32Sha2Opcode::SHA256.global_opcode()])?; + + let sha512_chip = Sha2VmChip::::new( + Sha2VmAir::new( + builder.system_port(), + bitwise_lu_chip.bus(), + pointer_max_bits, + builder.new_bus_idx(), + ), + Sha2VmStep::new( + bitwise_lu_chip.clone(), + Rv32Sha2Opcode::CLASS_OFFSET, + pointer_max_bits, + ), + builder.system_base().memory_controller.helper(), + ); + inventory.add_executor(sha512_chip, vec![Rv32Sha2Opcode::SHA512.global_opcode()])?; + + let sha384_chip = Sha2VmChip::::new( + Sha2VmAir::new( + builder.system_port(), + bitwise_lu_chip.bus(), + pointer_max_bits, + builder.new_bus_idx(), + ), + Sha2VmStep::new( + bitwise_lu_chip.clone(), + Rv32Sha2Opcode::CLASS_OFFSET, + pointer_max_bits, + ), + builder.system_base().memory_controller.helper(), + ); + inventory.add_executor(sha384_chip, vec![Rv32Sha2Opcode::SHA384.global_opcode()])?; Ok(inventory) } diff --git a/extensions/sha2/circuit/src/lib.rs b/extensions/sha2/circuit/src/lib.rs new file mode 100644 index 0000000000..cc51aaaf20 --- /dev/null +++ b/extensions/sha2/circuit/src/lib.rs @@ -0,0 +1,5 @@ +mod sha2_chip; +pub use sha2_chip::*; + +mod extension; +pub use extension::*; diff --git a/extensions/sha2/circuit/src/sha2_chip/air.rs b/extensions/sha2/circuit/src/sha2_chip/air.rs new file mode 100644 index 0000000000..600d483e63 --- /dev/null +++ b/extensions/sha2/circuit/src/sha2_chip/air.rs @@ -0,0 +1,777 @@ +use std::{cmp::min, convert::TryInto}; + +use openvm_circuit::{ + arch::{ExecutionBridge, SystemPort}, + system::memory::{ + offline_checker::{MemoryBridge, MemoryWriteAuxCols}, + MemoryAddress, + }, +}; +use openvm_circuit_primitives::{ + bitwise_op_lookup::BitwiseOperationLookupBus, encoder::Encoder, utils::not, SubAir, +}; +use openvm_instructions::{ + riscv::{RV32_CELL_BITS, RV32_MEMORY_AS, RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS}, + LocalOpcode, +}; +use openvm_sha2_air::{compose, Sha256Config, Sha2Air, Sha2Variant, Sha512Config}; +use openvm_stark_backend::{ + interaction::{BusIndex, InteractionBuilder}, + p3_air::{Air, AirBuilder, BaseAir}, + p3_field::{Field, FieldAlgebra}, + p3_matrix::Matrix, + rap::{BaseAirWithPublicValues, PartitionedBaseAir}, +}; + +use super::{Sha2ChipConfig, Sha2VmDigestColsRef, Sha2VmRoundColsRef}; + +/// Sha2VmAir does all constraints related to message padding and +/// the Sha2Air subair constrains the actual hash +#[derive(Clone, Debug)] +pub struct Sha2VmAir { + pub execution_bridge: ExecutionBridge, + pub memory_bridge: MemoryBridge, + /// Bus to send byte checks to + pub bitwise_lookup_bus: BitwiseOperationLookupBus, + /// Maximum number of bits allowed for an address pointer + /// Must be at least 24 + pub ptr_max_bits: usize, + pub(super) sha_subair: Sha2Air, + pub(super) padding_encoder: Encoder, +} + +impl Sha2VmAir { + pub fn new( + SystemPort { + execution_bus, + program_bus, + memory_bridge, + }: SystemPort, + bitwise_lookup_bus: BitwiseOperationLookupBus, + ptr_max_bits: usize, + self_bus_idx: BusIndex, + ) -> Self { + Self { + execution_bridge: ExecutionBridge::new(execution_bus, program_bus), + memory_bridge, + bitwise_lookup_bus, + ptr_max_bits, + sha_subair: Sha2Air::::new(bitwise_lookup_bus, self_bus_idx), + // optimization opportunity: we use fewer encoder cells for sha256 than sha512 or sha384 + padding_encoder: Encoder::new(PaddingFlags::COUNT, 2, false), + } + } +} + +impl BaseAirWithPublicValues for Sha2VmAir {} +impl PartitionedBaseAir for Sha2VmAir {} +impl BaseAir for Sha2VmAir { + fn width(&self) -> usize { + C::VM_WIDTH + } +} + +impl Air for Sha2VmAir { + fn eval(&self, builder: &mut AB) { + self.eval_padding(builder); + self.eval_transitions(builder); + self.eval_reads(builder); + self.eval_last_row(builder); + + self.sha_subair.eval(builder, C::VM_CONTROL_WIDTH); + } +} + +#[allow(dead_code, non_camel_case_types)] +pub(super) enum PaddingFlags { + /// Not considered for padding - W's are not constrained + NotConsidered, + /// Not padding - W's should be equal to the message + NotPadding, + /// FIRST_PADDING_i: it is the first row with padding and there are i cells of non-padding + FirstPadding0, + FirstPadding1, + FirstPadding2, + FirstPadding3, + FirstPadding4, + FirstPadding5, + FirstPadding6, + FirstPadding7, + FirstPadding8, + FirstPadding9, + FirstPadding10, + FirstPadding11, + FirstPadding12, + FirstPadding13, + FirstPadding14, + FirstPadding15, + FirstPadding16, + FirstPadding17, + FirstPadding18, + FirstPadding19, + FirstPadding20, + FirstPadding21, + FirstPadding22, + FirstPadding23, + FirstPadding24, + FirstPadding25, + FirstPadding26, + FirstPadding27, + FirstPadding28, + FirstPadding29, + FirstPadding30, + FirstPadding31, + /// FIRST_PADDING_i_LastRow: it is the first row with padding and there are i cells of + /// non-padding AND it is the last reading row of the message + /// NOTE: if the Last row has padding it has to be at least: + /// - 9 cells since the last 8 cells are padded with the message length (for SHA-256) + /// - 17 cells since the last 16 cells are padded with the message length (for SHA-512) + FirstPadding0_LastRow, + FirstPadding1_LastRow, + FirstPadding2_LastRow, + FirstPadding3_LastRow, + FirstPadding4_LastRow, + FirstPadding5_LastRow, + FirstPadding6_LastRow, + FirstPadding7_LastRow, + FirstPadding8_LastRow, + FirstPadding9_LastRow, + FirstPadding10_LastRow, + FirstPadding11_LastRow, + FirstPadding12_LastRow, + FirstPadding13_LastRow, + FirstPadding14_LastRow, + FirstPadding15_LastRow, + + /// The entire row is padding AND it is not the first row with padding + /// AND it is the 4th row of the last block of the message + EntirePaddingLastRow, + /// The entire row is padding AND it is not the first row with padding + EntirePadding, +} + +impl PaddingFlags { + /// The number of padding flags (including NotConsidered) + pub const COUNT: usize = EntirePadding as usize + 1; +} + +use PaddingFlags::*; +impl Sha2VmAir { + /// Implement all necessary constraints for the padding + fn eval_padding(&self, builder: &mut AB) { + let main = builder.main(); + let (local, next) = (main.row_slice(0), main.row_slice(1)); + let local_cols = Sha2VmRoundColsRef::::from::(&local[..C::VM_ROUND_WIDTH]); + let next_cols = Sha2VmRoundColsRef::::from::(&next[..C::VM_ROUND_WIDTH]); + + // Constrain the sanity of the padding flags + self.padding_encoder + .eval(builder, local_cols.control.pad_flags.as_slice().unwrap()); + + builder.assert_one(self.padding_encoder.contains_flag_range::( + local_cols.control.pad_flags.as_slice().unwrap(), + NotConsidered as usize..=EntirePadding as usize, + )); + + Self::eval_padding_transitions(self, builder, &local_cols, &next_cols); + Self::eval_padding_row(self, builder, &local_cols); + } + + fn eval_padding_transitions( + &self, + builder: &mut AB, + local: &Sha2VmRoundColsRef, + next: &Sha2VmRoundColsRef, + ) { + let next_is_last_row = *next.inner.flags.is_digest_row * *next.inner.flags.is_last_block; + + // Constrain that `padding_occured` is 1 on a suffix of rows in each message, excluding the + // last digest row, and 0 everywhere else. Furthermore, the suffix starts in the + // first 4 rows of some block. + + builder.assert_bool(*local.control.padding_occurred); + // Last round row in the last block has padding_occurred = 1 + // This is the end of the suffix + builder + .when(next_is_last_row.clone()) + .assert_one(*local.control.padding_occurred); + + // Digest row in the last block has padding_occurred = 0 + builder + .when(next_is_last_row.clone()) + .assert_zero(*next.control.padding_occurred); + + // If padding_occurred = 1 in the current row, then padding_occurred = 1 in the next row, + // unless next is the last digest row + builder + .when(*local.control.padding_occurred - next_is_last_row.clone()) + .assert_one(*next.control.padding_occurred); + + // If next row is not first 4 rows of a block, then next.padding_occurred = + // local.padding_occurred. So padding_occurred only changes in the first 4 rows of a + // block. + builder + .when_transition() + .when(not(*next.inner.flags.is_first_4_rows) - next_is_last_row) + .assert_eq( + *next.control.padding_occurred, + *local.control.padding_occurred, + ); + + // Constrain the that the start of the padding is correct + let next_is_first_padding_row = + *next.control.padding_occurred - *local.control.padding_occurred; + // Row index if its between 0..4, else 0 + let next_row_idx = self.sha_subair.row_idx_encoder.flag_with_val::( + next.inner.flags.row_idx.as_slice().unwrap(), + &(0..C::MESSAGE_ROWS).map(|x| (x, x)).collect::>(), + ); + // How many non-padding cells there are in the next row. + // Will be 0 on non-padding rows. + let next_padding_offset = self.padding_encoder.flag_with_val::( + next.control.pad_flags.as_slice().unwrap(), + &(0..C::MAX_FIRST_PADDING + 1) + .map(|i| (FirstPadding0 as usize + i, i)) + .collect::>(), + ) + self.padding_encoder.flag_with_val::( + next.control.pad_flags.as_slice().unwrap(), + &(0..C::MAX_FIRST_PADDING_LAST_ROW + 1) + .map(|i| (FirstPadding0_LastRow as usize + i, i)) + .collect::>(), + ); + + // Will be 0 on last digest row since: + // - padding_occurred = 0 is constrained above + // - next_row_idx = 0 since row_idx is not in 0..4 + // - and next_padding_offset = 0 since `pad_flags = NotConsidered` + let expected_len = *next.inner.flags.local_block_idx + * *next.control.padding_occurred + * AB::Expr::from_canonical_usize(C::BLOCK_U8S) + + next_row_idx * AB::Expr::from_canonical_usize(C::READ_SIZE) + + next_padding_offset; + + // Note: `next_is_first_padding_row` is either -1,0,1 + // If 1, then this constrains the length of message + // If -1, then `next` must be the last digest row and so this constraint will be 0 == 0 + builder.when(next_is_first_padding_row).assert_eq( + expected_len, + *next.control.len * *next.control.padding_occurred, + ); + + // Constrain the padding flags are of correct type (eg is not padding or first padding) + let is_next_first_padding = self.padding_encoder.contains_flag_range::( + next.control.pad_flags.as_slice().unwrap(), + FirstPadding0 as usize..=(FirstPadding15_LastRow as usize), + ); + + let is_next_last_padding = self.padding_encoder.contains_flag_range::( + next.control.pad_flags.as_slice().unwrap(), + FirstPadding0_LastRow as usize..=EntirePaddingLastRow as usize, + ); + + let is_next_entire_padding = self.padding_encoder.contains_flag_range::( + next.control.pad_flags.as_slice().unwrap(), + EntirePaddingLastRow as usize..=EntirePadding as usize, + ); + + let is_next_not_considered = self.padding_encoder.contains_flag::( + next.control.pad_flags.as_slice().unwrap(), + &[NotConsidered as usize], + ); + + let is_next_not_padding = self.padding_encoder.contains_flag::( + next.control.pad_flags.as_slice().unwrap(), + &[NotPadding as usize], + ); + + let is_next_4th_row = self + .sha_subair + .row_idx_encoder + .contains_flag::(next.inner.flags.row_idx.as_slice().unwrap(), &[3]); + + // `pad_flags` is `NotConsidered` on all rows except the first 4 rows of a block + builder.assert_eq( + not(*next.inner.flags.is_first_4_rows), + is_next_not_considered, + ); + + // `pad_flags` is `EntirePadding` if the previous row is padding + builder.when(*next.inner.flags.is_first_4_rows).assert_eq( + *local.control.padding_occurred * *next.control.padding_occurred, + is_next_entire_padding, + ); + + // `pad_flags` is `FirstPadding*` if current row is padding and the previous row is not + // padding + builder.when(*next.inner.flags.is_first_4_rows).assert_eq( + not(*local.control.padding_occurred) * *next.control.padding_occurred, + is_next_first_padding, + ); + + // `pad_flags` is `NotPadding` if current row is not padding + builder + .when(*next.inner.flags.is_first_4_rows) + .assert_eq(not(*next.control.padding_occurred), is_next_not_padding); + + // `pad_flags` is `*LastRow` on the row that contains the last four words of the message + builder + .when(*next.inner.flags.is_last_block) + .assert_eq(is_next_4th_row, is_next_last_padding); + } + + fn eval_padding_row( + &self, + builder: &mut AB, + local: &Sha2VmRoundColsRef, + ) { + let message = (0..C::READ_SIZE) + .map(|i| { + local.inner.message_schedule.carry_or_buffer[[i / (C::WORD_U8S), i % (C::WORD_U8S)]] + }) + .collect::>(); + + let get_ith_byte = |i: usize| { + let word_idx = i / C::WORD_U8S; + let word = local + .inner + .message_schedule + .w + .row(word_idx) + .mapv(|x| x.into()); + // Need to reverse the byte order to match the endianness of the memory + let byte_idx = C::WORD_U8S - i % C::WORD_U8S - 1; + compose::( + &word.as_slice().unwrap()[byte_idx * 8..(byte_idx + 1) * 8], + 1, + ) + }; + + let is_not_padding = self.padding_encoder.contains_flag::( + local.control.pad_flags.as_slice().unwrap(), + &[NotPadding as usize], + ); + + // Check the `w`s on case by case basis + for (i, message_byte) in message.iter().enumerate() { + let w = get_ith_byte(i); + let should_be_message = is_not_padding.clone() + + if i < C::MAX_FIRST_PADDING { + self.padding_encoder.contains_flag_range::( + local.control.pad_flags.as_slice().unwrap(), + FirstPadding0 as usize + i + 1 + ..=FirstPadding0 as usize + C::MAX_FIRST_PADDING, + ) + } else { + AB::Expr::ZERO + } + + if i < C::MAX_FIRST_PADDING_LAST_ROW { + self.padding_encoder.contains_flag_range::( + local.control.pad_flags.as_slice().unwrap(), + FirstPadding0_LastRow as usize + i + 1 + ..=FirstPadding0_LastRow as usize + C::MAX_FIRST_PADDING_LAST_ROW, + ) + } else { + AB::Expr::ZERO + }; + + builder + .when(should_be_message) + .assert_eq(w.clone(), *message_byte); + + let should_be_zero = self.padding_encoder.contains_flag::( + local.control.pad_flags.as_slice().unwrap(), + &[EntirePadding as usize], + ) + + // - 4 because the last 4 bytes are the padded length + if i < C::CELLS_PER_ROW - 4 { + self.padding_encoder.contains_flag::( + local.control.pad_flags.as_slice().unwrap(), + &[EntirePaddingLastRow as usize], + ) + if i > 0 { + self.padding_encoder.contains_flag_range::( + local.control.pad_flags.as_slice().unwrap(), + FirstPadding0_LastRow as usize + ..=min( + FirstPadding0_LastRow as usize + i - 1, + FirstPadding0_LastRow as usize + C::MAX_FIRST_PADDING_LAST_ROW, + ), + ) + } else { + AB::Expr::ZERO + } + } else { + AB::Expr::ZERO + } + if i > 0 { + self.padding_encoder.contains_flag_range::( + local.control.pad_flags.as_slice().unwrap(), + FirstPadding0 as usize..=FirstPadding0 as usize + i - 1, + ) + } else { + AB::Expr::ZERO + }; + builder.when(should_be_zero).assert_zero(w.clone()); + + // Assumes bit-length of message is a multiple of 8 (message is bytes) + // This is true because the message is given as &[u8] + let should_be_128 = self.padding_encoder.contains_flag::( + local.control.pad_flags.as_slice().unwrap(), + &[FirstPadding0 as usize + i], + ) + if i < 8 { + self.padding_encoder.contains_flag::( + local.control.pad_flags.as_slice().unwrap(), + &[FirstPadding0_LastRow as usize + i], + ) + } else { + AB::Expr::ZERO + }; + + builder + .when(should_be_128) + .assert_eq(AB::Expr::from_canonical_u32(1 << 7), w); + + // should be len is handled outside of the loop + } + let appended_len = compose::( + &[ + get_ith_byte(C::CELLS_PER_ROW - 1), + get_ith_byte(C::CELLS_PER_ROW - 2), + get_ith_byte(C::CELLS_PER_ROW - 3), + get_ith_byte(C::CELLS_PER_ROW - 4), + ], + RV32_CELL_BITS, + ); + + let actual_len = *local.control.len; + + let is_last_padding_row = self.padding_encoder.contains_flag_range::( + local.control.pad_flags.as_slice().unwrap(), + FirstPadding0_LastRow as usize..=EntirePaddingLastRow as usize, + ); + + builder.when(is_last_padding_row.clone()).assert_eq( + appended_len * AB::F::from_canonical_usize(RV32_CELL_BITS).inverse(), // bit to byte conversion + actual_len, + ); + + // We constrain that the appended length is in bytes + builder.when(is_last_padding_row.clone()).assert_zero( + local.inner.message_schedule.w[[3, 0]] + + local.inner.message_schedule.w[[3, 1]] + + local.inner.message_schedule.w[[3, 2]], + ); + + // We can't support messages longer than 2^29 bytes because the length has to fit in a + // field element. So, constrain that the first few bytes of the length are 0 (so only the + // last 4 bytes of the length can be nonzero). Thus, the bit-length is < 2^32 so the message + // is < 2^29 bytes. + // For SHA-256, assert bytes 8..12 are 0, because the message length is 8 bytes, and each + // row has 16 bytes. + // For SHA-512 and SHA-384, assert bytes 16..28 are 0, because the + // message length is 16 bytes and each row has 32 bytes. + for i in C::CELLS_PER_ROW - C::MESSAGE_LENGTH_BITS / 8..C::CELLS_PER_ROW - 4 { + builder + .when(is_last_padding_row.clone()) + .assert_zero(get_ith_byte(i)); + } + } + /// Implement constraints on `len`, `read_ptr` and `cur_timestamp` + fn eval_transitions(&self, builder: &mut AB) { + let main = builder.main(); + let (local, next) = (main.row_slice(0), main.row_slice(1)); + let local_cols = Sha2VmRoundColsRef::::from::(&local[..C::VM_ROUND_WIDTH]); + let next_cols = Sha2VmRoundColsRef::::from::(&next[..C::VM_ROUND_WIDTH]); + + let is_last_row = + *local_cols.inner.flags.is_last_block * *local_cols.inner.flags.is_digest_row; + // Len should be the same for the entire message + builder + .when_transition() + .when(not::(is_last_row.clone())) + .assert_eq(*next_cols.control.len, *local_cols.control.len); + + // Read ptr should increment by [C::READ_SIZE] for the first 4 rows and stay the same + // otherwise + let read_ptr_delta = + *local_cols.inner.flags.is_first_4_rows * AB::Expr::from_canonical_usize(C::READ_SIZE); + builder + .when_transition() + .when(not::(is_last_row.clone())) + .assert_eq( + *next_cols.control.read_ptr, + *local_cols.control.read_ptr + read_ptr_delta, + ); + + // Timestamp should increment by 1 for the first 4 rows and stay the same otherwise + let timestamp_delta = *local_cols.inner.flags.is_first_4_rows * AB::Expr::ONE; + builder + .when_transition() + .when(not::(is_last_row.clone())) + .assert_eq( + *next_cols.control.cur_timestamp, + *local_cols.control.cur_timestamp + timestamp_delta, + ); + } + + /// Implement the reads for the first 4 rows of a block + fn eval_reads(&self, builder: &mut AB) { + let main = builder.main(); + let local = main.row_slice(0); + let local_cols = Sha2VmRoundColsRef::::from::(&local[..C::VM_ROUND_WIDTH]); + + let message: Vec = (0..C::READ_SIZE) + .map(|i| { + local_cols.inner.message_schedule.carry_or_buffer + [[i / (C::WORD_U16S * 2), i % (C::WORD_U16S * 2)]] + }) + .collect(); + + match C::VARIANT { + Sha2Variant::Sha256 => { + let message: [AB::Var; Sha256Config::READ_SIZE] = + message.try_into().unwrap_or_else(|_| { + panic!("message is not the correct size"); + }); + self.memory_bridge + .read( + MemoryAddress::new( + AB::Expr::from_canonical_u32(RV32_MEMORY_AS), + *local_cols.control.read_ptr, + ), + message, + *local_cols.control.cur_timestamp, + local_cols.read_aux, + ) + .eval(builder, *local_cols.inner.flags.is_first_4_rows); + } + // Sha512 and Sha384 have the same read size so we put them together + Sha2Variant::Sha512 | Sha2Variant::Sha384 => { + let message: [AB::Var; Sha512Config::READ_SIZE] = + message.try_into().unwrap_or_else(|_| { + panic!("message is not the correct size"); + }); + self.memory_bridge + .read( + MemoryAddress::new( + AB::Expr::from_canonical_u32(RV32_MEMORY_AS), + *local_cols.control.read_ptr, + ), + message, + *local_cols.control.cur_timestamp, + local_cols.read_aux, + ) + .eval(builder, *local_cols.inner.flags.is_first_4_rows); + } + } + } + /// Implement the constraints for the last row of a message + fn eval_last_row(&self, builder: &mut AB) { + let main = builder.main(); + let local = main.row_slice(0); + let local_cols = Sha2VmDigestColsRef::::from::(&local[..C::VM_DIGEST_WIDTH]); + + let timestamp: AB::Var = local_cols.from_state.timestamp; + let mut timestamp_delta: usize = 0; + let mut timestamp_pp = || { + timestamp_delta += 1; + timestamp + AB::Expr::from_canonical_usize(timestamp_delta - 1) + }; + + let is_last_row = + *local_cols.inner.flags.is_last_block * *local_cols.inner.flags.is_digest_row; + + let dst_ptr: [AB::Var; RV32_REGISTER_NUM_LIMBS] = + local_cols.dst_ptr.to_vec().try_into().unwrap_or_else(|_| { + panic!("dst_ptr is not the correct size"); + }); + self.memory_bridge + .read( + MemoryAddress::new( + AB::Expr::from_canonical_u32(RV32_REGISTER_AS), + *local_cols.rd_ptr, + ), + dst_ptr, + timestamp_pp(), + &local_cols.register_reads_aux[0], + ) + .eval(builder, is_last_row.clone()); + + let src_ptr: [AB::Var; RV32_REGISTER_NUM_LIMBS] = + local_cols.src_ptr.to_vec().try_into().unwrap_or_else(|_| { + panic!("src_ptr is not the correct size"); + }); + self.memory_bridge + .read( + MemoryAddress::new( + AB::Expr::from_canonical_u32(RV32_REGISTER_AS), + *local_cols.rs1_ptr, + ), + src_ptr, + timestamp_pp(), + &local_cols.register_reads_aux[1], + ) + .eval(builder, is_last_row.clone()); + + let len_data: [AB::Var; RV32_REGISTER_NUM_LIMBS] = + local_cols.len_data.to_vec().try_into().unwrap_or_else(|_| { + panic!("len_data is not the correct size"); + }); + self.memory_bridge + .read( + MemoryAddress::new( + AB::Expr::from_canonical_u32(RV32_REGISTER_AS), + *local_cols.rs2_ptr, + ), + len_data, + timestamp_pp(), + &local_cols.register_reads_aux[2], + ) + .eval(builder, is_last_row.clone()); + // range check that the memory pointers don't overflow + // Note: no need to range check the length since we read from memory step by step and + // the memory bus will catch any memory accesses beyond ptr_max_bits + let shift = AB::Expr::from_canonical_usize( + 1 << (RV32_REGISTER_NUM_LIMBS * RV32_CELL_BITS - self.ptr_max_bits), + ); + // This only works if self.ptr_max_bits >= 24 which is typically the case + self.bitwise_lookup_bus + .send_range( + // It is fine to shift like this since we already know that dst_ptr and src_ptr + // have [RV32_CELL_BITS] bits + local_cols.dst_ptr[RV32_REGISTER_NUM_LIMBS - 1] * shift.clone(), + local_cols.src_ptr[RV32_REGISTER_NUM_LIMBS - 1] * shift.clone(), + ) + .eval(builder, is_last_row.clone()); + + // the number of reads that happened to read the entire message: we do 4 reads per block + let time_delta = (*local_cols.inner.flags.local_block_idx + AB::Expr::ONE) + * AB::Expr::from_canonical_usize(4); + // Every time we read the message we increment the read pointer by C::READ_SIZE + let read_ptr_delta = time_delta.clone() * AB::Expr::from_canonical_usize(C::READ_SIZE); + + let result: Vec = (0..C::HASH_SIZE) + .map(|i| { + // The limbs are written in big endian order to the memory so need to be reversed + local_cols.inner.final_hash[[i / C::WORD_U8S, C::WORD_U8S - i % C::WORD_U8S - 1]] + }) + .collect(); + + let dst_ptr_val = compose::( + local_cols.dst_ptr.mapv(|x| x.into()).as_slice().unwrap(), + RV32_CELL_BITS, + ); + + match C::VARIANT { + Sha2Variant::Sha256 => { + debug_assert_eq!(C::NUM_WRITES, 1); + debug_assert_eq!(local_cols.writes_aux_base.len(), 1); + debug_assert_eq!(local_cols.writes_aux_prev_data.nrows(), 1); + let prev_data: [AB::Var; Sha256Config::HASH_SIZE] = local_cols + .writes_aux_prev_data + .row(0) + .to_vec() + .try_into() + .unwrap_or_else(|_| { + panic!("writes_aux_prev_data is not the correct size"); + }); + // Note: revisit in the future to do 2 block writes of 16 cells instead of 1 block + // write of 32 cells. This could be beneficial as the output is often an input for + // another hash + self.memory_bridge + .write( + MemoryAddress::new( + AB::Expr::from_canonical_u32(RV32_MEMORY_AS), + dst_ptr_val, + ), + result.try_into().unwrap_or_else(|_| { + panic!("result is not the correct size"); + }), + timestamp_pp() + time_delta.clone(), + &MemoryWriteAuxCols::from_base(local_cols.writes_aux_base[0], prev_data), + ) + .eval(builder, is_last_row.clone()); + } + Sha2Variant::Sha512 | Sha2Variant::Sha384 => { + debug_assert_eq!(C::NUM_WRITES, 2); + debug_assert_eq!(local_cols.writes_aux_base.len(), 2); + debug_assert_eq!(local_cols.writes_aux_prev_data.nrows(), 2); + + // For Sha384, set the last 16 cells to 0 + let mut truncated_result: Vec = + result.iter().map(|x| (*x).into()).collect(); + for x in truncated_result.iter_mut().skip(C::DIGEST_SIZE) { + *x = AB::Expr::ZERO; + } + + // write the digest in two halves because we only support writes up to 32 bytes + for i in 0..Sha512Config::NUM_WRITES { + let prev_data: [AB::Var; Sha512Config::WRITE_SIZE] = local_cols + .writes_aux_prev_data + .row(i) + .to_vec() + .try_into() + .unwrap_or_else(|_| { + panic!("writes_aux_prev_data is not the correct size"); + }); + + self.memory_bridge + .write( + MemoryAddress::new( + AB::Expr::from_canonical_u32(RV32_MEMORY_AS), + dst_ptr_val.clone() + + AB::Expr::from_canonical_usize(i * Sha512Config::WRITE_SIZE), + ), + truncated_result + [i * Sha512Config::WRITE_SIZE..(i + 1) * Sha512Config::WRITE_SIZE] + .to_vec() + .try_into() + .unwrap_or_else(|_| { + panic!("result is not the correct size"); + }), + timestamp_pp() + time_delta.clone(), + &MemoryWriteAuxCols::from_base( + local_cols.writes_aux_base[i], + prev_data, + ), + ) + .eval(builder, is_last_row.clone()); + } + } + } + self.execution_bridge + .execute_and_increment_pc( + AB::Expr::from_canonical_usize(C::OPCODE.global_opcode().as_usize()), + [ + >::into(*local_cols.rd_ptr), + >::into(*local_cols.rs1_ptr), + >::into(*local_cols.rs2_ptr), + AB::Expr::from_canonical_u32(RV32_REGISTER_AS), + AB::Expr::from_canonical_u32(RV32_MEMORY_AS), + ], + *local_cols.from_state, + AB::Expr::from_canonical_usize(timestamp_delta) + time_delta.clone(), + ) + .eval(builder, is_last_row.clone()); + + // Assert that we read the correct length of the message + let len_val = compose::( + local_cols.len_data.mapv(|x| x.into()).as_slice().unwrap(), + RV32_CELL_BITS, + ); + builder + .when(is_last_row.clone()) + .assert_eq(*local_cols.control.len, len_val); + // Assert that we started reading from the correct pointer initially + let src_val = compose::( + local_cols.src_ptr.mapv(|x| x.into()).as_slice().unwrap(), + RV32_CELL_BITS, + ); + builder + .when(is_last_row.clone()) + .assert_eq(*local_cols.control.read_ptr, src_val + read_ptr_delta); + // Assert that we started reading from the correct timestamp + builder.when(is_last_row.clone()).assert_eq( + *local_cols.control.cur_timestamp, + local_cols.from_state.timestamp + AB::Expr::from_canonical_u32(3) + time_delta, + ); + } +} diff --git a/extensions/sha2/circuit/src/sha2_chip/columns.rs b/extensions/sha2/circuit/src/sha2_chip/columns.rs new file mode 100644 index 0000000000..20a2080860 --- /dev/null +++ b/extensions/sha2/circuit/src/sha2_chip/columns.rs @@ -0,0 +1,106 @@ +//! WARNING: the order of fields in the structs is important, do not change it + +use openvm_circuit::{ + arch::ExecutionState, + system::memory::offline_checker::{MemoryBaseAuxCols, MemoryReadAuxCols}, +}; +use openvm_circuit_primitives_derive::ColsRef; +use openvm_instructions::riscv::RV32_REGISTER_NUM_LIMBS; +use openvm_sha2_air::{ + ShaDigestCols, ShaDigestColsRef, ShaDigestColsRefMut, ShaRoundCols, ShaRoundColsRef, + ShaRoundColsRefMut, +}; + +use super::SHA_REGISTER_READS; +use crate::Sha2ChipConfig; + +/// the first C::ROUND_ROWS rows of every SHA block will be of type ShaVmRoundCols and the last row +/// will be of type ShaVmDigestCols +#[repr(C)] +#[derive(Clone, Copy, Debug, ColsRef)] +#[config(Sha2ChipConfig)] +pub struct Sha2VmRoundCols< + T, + const WORD_BITS: usize, + const WORD_U8S: usize, + const WORD_U16S: usize, + const ROUNDS_PER_ROW: usize, + const ROUNDS_PER_ROW_MINUS_ONE: usize, + const ROW_VAR_CNT: usize, +> { + pub control: Sha2VmControlCols, + pub inner: ShaRoundCols< + T, + WORD_BITS, + WORD_U8S, + WORD_U16S, + ROUNDS_PER_ROW, + ROUNDS_PER_ROW_MINUS_ONE, + ROW_VAR_CNT, + >, + #[aligned_borrow] + pub read_aux: MemoryReadAuxCols, +} + +#[repr(C)] +#[derive(Clone, Copy, Debug, ColsRef)] +#[config(Sha2ChipConfig)] +pub struct Sha2VmDigestCols< + T, + const WORD_BITS: usize, + const WORD_U8S: usize, + const WORD_U16S: usize, + const HASH_WORDS: usize, + const ROUNDS_PER_ROW: usize, + const ROUNDS_PER_ROW_MINUS_ONE: usize, + const ROW_VAR_CNT: usize, + const NUM_WRITES: usize, + const WRITE_SIZE: usize, +> { + pub control: Sha2VmControlCols, + pub inner: ShaDigestCols< + T, + WORD_BITS, + WORD_U8S, + WORD_U16S, + HASH_WORDS, + ROUNDS_PER_ROW, + ROUNDS_PER_ROW_MINUS_ONE, + ROW_VAR_CNT, + >, + #[aligned_borrow] + pub from_state: ExecutionState, + /// It is counter intuitive, but we will constrain the register reads on the very last row of + /// every message + pub rd_ptr: T, + pub rs1_ptr: T, + pub rs2_ptr: T, + pub dst_ptr: [T; RV32_REGISTER_NUM_LIMBS], + pub src_ptr: [T; RV32_REGISTER_NUM_LIMBS], + pub len_data: [T; RV32_REGISTER_NUM_LIMBS], + #[aligned_borrow] + pub register_reads_aux: [MemoryReadAuxCols; SHA_REGISTER_READS], + // We store the fields of MemoryWriteAuxCols here because the length of prev_data depends on + // the sha variant + #[aligned_borrow] + pub writes_aux_base: [MemoryBaseAuxCols; NUM_WRITES], + pub writes_aux_prev_data: [[T; WRITE_SIZE]; NUM_WRITES], +} + +/// These are the columns that are used on both round and digest rows +#[repr(C)] +#[derive(Clone, Copy, Debug, ColsRef)] +#[config(Sha2ChipConfig)] +pub struct Sha2VmControlCols { + /// Note: We will use the buffer in `inner.message_schedule` as the message data + /// This is the length of the entire message in bytes + pub len: T, + /// Need to keep timestamp and read_ptr since block reads don't have the necessary information + pub cur_timestamp: T, + pub read_ptr: T, + /// Padding flags which will be used to encode the the number of non-padding cells in the + /// current row + pub pad_flags: [T; 9], + /// A boolean flag that indicates whether a padding already occurred + pub padding_occurred: T, +} diff --git a/extensions/sha2/circuit/src/sha2_chip/config.rs b/extensions/sha2/circuit/src/sha2_chip/config.rs new file mode 100644 index 0000000000..7dfed9610a --- /dev/null +++ b/extensions/sha2/circuit/src/sha2_chip/config.rs @@ -0,0 +1,99 @@ +use openvm_instructions::riscv::RV32_CELL_BITS; +use openvm_sha2_air::{Sha256Config, Sha2Config, Sha384Config, Sha512Config}; +use openvm_sha2_transpiler::Rv32Sha2Opcode; + +use super::{Sha2VmControlColsRef, Sha2VmDigestColsRef, Sha2VmRoundColsRef}; + +pub trait Sha2ChipConfig: Sha2Config { + // Name of the opcode + const OPCODE_NAME: &'static str; + /// Width of the ShaVmControlCols + const VM_CONTROL_WIDTH: usize = Sha2VmControlColsRef::::width::(); + /// Width of the ShaVmRoundCols + const VM_ROUND_WIDTH: usize = Sha2VmRoundColsRef::::width::(); + /// Width of the ShaVmDigestCols + const VM_DIGEST_WIDTH: usize = Sha2VmDigestColsRef::::width::(); + /// Width of the ShaVmCols + const VM_WIDTH: usize = if Self::VM_ROUND_WIDTH > Self::VM_DIGEST_WIDTH { + Self::VM_ROUND_WIDTH + } else { + Self::VM_DIGEST_WIDTH + }; + /// Number of bits to use when padding the message length. Given by the SHA-2 spec. + const MESSAGE_LENGTH_BITS: usize; + /// Maximum i such that `FirstPadding_i` is a valid padding flag + const MAX_FIRST_PADDING: usize = Self::CELLS_PER_ROW - 1; + /// Maximum i such that `FirstPadding_i_LastRow` is a valid padding flag + const MAX_FIRST_PADDING_LAST_ROW: usize = + Self::CELLS_PER_ROW - Self::MESSAGE_LENGTH_BITS / 8 - 1; + /// OpenVM Opcode for the instruction + const OPCODE: Rv32Sha2Opcode; + + // ==== Constants for register/memory adapter ==== + /// Number of rv32 cells read in a block + const BLOCK_CELLS: usize = Self::BLOCK_BITS / RV32_CELL_BITS; + /// Number of rows we will do a read on for each block + const NUM_READ_ROWS: usize = Self::MESSAGE_ROWS; + + /// Number of cells to read in a single memory access + const READ_SIZE: usize = Self::WORD_U8S * Self::ROUNDS_PER_ROW; + /// Number of cells in the digest before truncation (Sha384 truncates the digest) + const HASH_SIZE: usize = Self::WORD_U8S * Self::HASH_WORDS; + /// Number of cells in the digest after truncation + const DIGEST_SIZE: usize; + + /// Number of parts to write the hash in + const NUM_WRITES: usize = Self::HASH_SIZE / Self::WRITE_SIZE; + /// Size of each write. Must divide Self::HASH_SIZE + const WRITE_SIZE: usize; +} + +/// Register reads to get dst, src, len +pub const SHA_REGISTER_READS: usize = 3; + +impl Sha2ChipConfig for Sha256Config { + const OPCODE_NAME: &'static str = "SHA256"; + const MESSAGE_LENGTH_BITS: usize = 64; + const WRITE_SIZE: usize = SHA_WRITE_SIZE; + const OPCODE: Rv32Sha2Opcode = Rv32Sha2Opcode::SHA256; + // no truncation + const DIGEST_SIZE: usize = Self::HASH_SIZE; +} + +impl Sha2ChipConfig for Sha512Config { + const OPCODE_NAME: &'static str = "SHA512"; + const MESSAGE_LENGTH_BITS: usize = 128; + const WRITE_SIZE: usize = SHA_WRITE_SIZE; + const OPCODE: Rv32Sha2Opcode = Rv32Sha2Opcode::SHA512; + // no truncation + const DIGEST_SIZE: usize = Self::HASH_SIZE; +} + +impl Sha2ChipConfig for Sha384Config { + const OPCODE_NAME: &'static str = "SHA384"; + const MESSAGE_LENGTH_BITS: usize = 128; + const WRITE_SIZE: usize = SHA_WRITE_SIZE; + const OPCODE: Rv32Sha2Opcode = Rv32Sha2Opcode::SHA384; + // Sha384 truncates the output to 48 cells + const DIGEST_SIZE: usize = 48; +} + +// We use the same write size for all variants to simplify tracegen record storage. +// In particular, each memory write aux record will have the same size, which is useful for +// defining Sha2VmRecordHeader in a repr(C) way. +pub const SHA_WRITE_SIZE: usize = 32; + +pub const MAX_SHA_NUM_WRITES: usize = if Sha256Config::NUM_WRITES > Sha512Config::NUM_WRITES { + if Sha256Config::NUM_WRITES > Sha384Config::NUM_WRITES { + Sha256Config::NUM_WRITES + } else { + Sha384Config::NUM_WRITES + } +} else if Sha512Config::NUM_WRITES > Sha384Config::NUM_WRITES { + Sha512Config::NUM_WRITES +} else { + Sha384Config::NUM_WRITES +}; + +/// Maximum message length that this chip supports in bytes +pub const SHA_MAX_MESSAGE_LEN: usize = 1 << 29; diff --git a/extensions/sha2/circuit/src/sha2_chip/mod.rs b/extensions/sha2/circuit/src/sha2_chip/mod.rs new file mode 100644 index 0000000000..7525ec8435 --- /dev/null +++ b/extensions/sha2/circuit/src/sha2_chip/mod.rs @@ -0,0 +1,284 @@ +//! Sha256 hasher. Handles full sha256 hashing with padding. +//! variable length inputs read from VM memory. +use std::{ + borrow::{Borrow, BorrowMut}, + iter, +}; + +use openvm_circuit::arch::{ + execution_mode::{E1ExecutionCtx, E2ExecutionCtx}, + E2PreCompute, ExecuteFunc, + ExecutionError::InvalidInstruction, + MatrixRecordArena, NewVmChipWrapper, Result, StepExecutorE1, StepExecutorE2, VmSegmentState, +}; +use openvm_circuit_primitives::{ + bitwise_op_lookup::SharedBitwiseOperationLookupChip, encoder::Encoder, +}; +use openvm_circuit_primitives_derive::AlignedBytesBorrow; +use openvm_instructions::{ + instruction::Instruction, + program::DEFAULT_PC_STEP, + riscv::{RV32_CELL_BITS, RV32_MEMORY_AS, RV32_REGISTER_AS}, + LocalOpcode, +}; +use openvm_sha2_air::{Sha256Config, Sha2StepHelper, Sha2Variant, Sha384Config, Sha512Config}; +use openvm_stark_backend::p3_field::PrimeField32; +use sha2::{Digest, Sha256, Sha384, Sha512}; + +mod air; +mod columns; +mod config; +mod trace; +mod utils; + +pub use air::*; +pub use columns::*; +pub use config::*; +pub use utils::get_sha2_num_blocks; + +#[cfg(test)] +mod tests; + +pub type Sha2VmChip = NewVmChipWrapper, Sha2VmStep, MatrixRecordArena>; + +pub struct Sha2VmStep { + pub inner: Sha2StepHelper, + pub padding_encoder: Encoder, + pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, + pub offset: usize, + pub pointer_max_bits: usize, +} + +impl Sha2VmStep { + pub fn new( + bitwise_lookup_chip: SharedBitwiseOperationLookupChip, + offset: usize, + pointer_max_bits: usize, + ) -> Self { + Self { + inner: Sha2StepHelper::::new(), + padding_encoder: Encoder::new(PaddingFlags::COUNT, 2, false), + bitwise_lookup_chip, + offset, + pointer_max_bits, + } + } +} + +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct Sha2PreCompute { + a: u8, + b: u8, + c: u8, +} + +impl StepExecutorE1 for Sha2VmStep { + fn pre_compute_size(&self) -> usize { + size_of::() + } + + fn pre_compute_e1( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result> + where + Ctx: E1ExecutionCtx, + { + let data: &mut Sha2PreCompute = data.borrow_mut(); + self.pre_compute_impl(pc, inst, data)?; + Ok(execute_e1_impl::<_, _, C>) + } +} +impl StepExecutorE2 for Sha2VmStep { + fn e2_pre_compute_size(&self) -> usize { + size_of::>() + } + + fn pre_compute_e2( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result> + where + Ctx: E2ExecutionCtx, + { + let data: &mut E2PreCompute = data.borrow_mut(); + data.chip_idx = chip_idx as u32; + self.pre_compute_impl(pc, inst, &mut data.data)?; + Ok(execute_e2_impl::<_, _, C>) + } +} + +unsafe fn execute_e12_impl< + F: PrimeField32, + CTX: E1ExecutionCtx, + C: Sha2ChipConfig, + const IS_E1: bool, +>( + pre_compute: &Sha2PreCompute, + vm_state: &mut VmSegmentState, +) -> u32 { + let dst = vm_state.vm_read(RV32_REGISTER_AS, pre_compute.a as u32); + let src = vm_state.vm_read(RV32_REGISTER_AS, pre_compute.b as u32); + let len = vm_state.vm_read(RV32_REGISTER_AS, pre_compute.c as u32); + let dst_u32 = u32::from_le_bytes(dst); + let src_u32 = u32::from_le_bytes(src); + let len_u32 = u32::from_le_bytes(len); + + let (output, height) = if IS_E1 { + // SAFETY: RV32_MEMORY_AS is memory address space of type u8 + let message = vm_state.vm_read_slice(RV32_MEMORY_AS, src_u32, len_u32 as usize); + let output = sha2_solve::(message); + (output, 0) + } else { + let num_blocks = get_sha2_num_blocks::(len_u32); + let mut message = Vec::with_capacity(len_u32 as usize); + for block_idx in 0..num_blocks as usize { + // Reads happen on the first 4 rows of each block + for row in 0..C::NUM_READ_ROWS { + let read_idx = block_idx * C::NUM_READ_ROWS + row; + match C::VARIANT { + Sha2Variant::Sha256 => { + let row_input: [u8; Sha256Config::READ_SIZE] = vm_state + .vm_read(RV32_MEMORY_AS, src_u32 + (read_idx * C::READ_SIZE) as u32); + message.extend_from_slice(&row_input); + } + Sha2Variant::Sha512 => { + let row_input: [u8; Sha512Config::READ_SIZE] = vm_state + .vm_read(RV32_MEMORY_AS, src_u32 + (read_idx * C::READ_SIZE) as u32); + message.extend_from_slice(&row_input); + } + Sha2Variant::Sha384 => { + let row_input: [u8; Sha384Config::READ_SIZE] = vm_state + .vm_read(RV32_MEMORY_AS, src_u32 + (read_idx * C::READ_SIZE) as u32); + message.extend_from_slice(&row_input); + } + } + } + } + let output = sha2_solve::(&message[..len_u32 as usize]); + let height = num_blocks * C::ROWS_PER_BLOCK as u32; + (output, height) + }; + match C::VARIANT { + Sha2Variant::Sha256 => { + let output: [u8; Sha256Config::WRITE_SIZE] = output.try_into().unwrap(); + vm_state.vm_write(RV32_MEMORY_AS, dst_u32, &output); + } + Sha2Variant::Sha512 => { + for i in 0..C::NUM_WRITES { + let output: [u8; Sha512Config::WRITE_SIZE] = output + [i * Sha512Config::WRITE_SIZE..(i + 1) * Sha512Config::WRITE_SIZE] + .try_into() + .unwrap(); + vm_state.vm_write( + RV32_MEMORY_AS, + dst_u32 + (i * Sha512Config::WRITE_SIZE) as u32, + &output, + ); + } + } + Sha2Variant::Sha384 => { + // Pad the output with zeros to 64 bytes + let output = output + .into_iter() + .chain(iter::repeat(0).take(16)) + .collect::>(); + for i in 0..C::NUM_WRITES { + let output: [u8; Sha384Config::WRITE_SIZE] = output + [i * Sha384Config::WRITE_SIZE..(i + 1) * Sha384Config::WRITE_SIZE] + .try_into() + .unwrap(); + vm_state.vm_write( + RV32_MEMORY_AS, + dst_u32 + (i * Sha384Config::WRITE_SIZE) as u32, + &output, + ); + } + } + } + + vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP); + vm_state.instret += 1; + + height +} + +unsafe fn execute_e1_impl( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &Sha2PreCompute = pre_compute.borrow(); + execute_e12_impl::(pre_compute, vm_state); +} +unsafe fn execute_e2_impl( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &E2PreCompute = pre_compute.borrow(); + let height = execute_e12_impl::(&pre_compute.data, vm_state); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, height); +} + +impl Sha2VmStep { + fn pre_compute_impl( + &self, + pc: u32, + inst: &Instruction, + data: &mut Sha2PreCompute, + ) -> Result<()> { + let Instruction { + opcode, + a, + b, + c, + d, + e, + .. + } = inst; + let e_u32 = e.as_canonical_u32(); + if d.as_canonical_u32() != RV32_REGISTER_AS || e_u32 != RV32_MEMORY_AS { + return Err(InvalidInstruction(pc)); + } + *data = Sha2PreCompute { + a: a.as_canonical_u32() as u8, + b: b.as_canonical_u32() as u8, + c: c.as_canonical_u32() as u8, + }; + assert_eq!(&C::OPCODE.global_opcode(), opcode); + Ok(()) + } +} + +pub fn sha2_solve(input_message: &[u8]) -> Vec { + match C::VARIANT { + Sha2Variant::Sha256 => { + let mut hasher = Sha256::new(); + hasher.update(input_message); + let mut output = vec![0u8; C::DIGEST_SIZE]; + output.copy_from_slice(hasher.finalize().as_ref()); + output + } + Sha2Variant::Sha512 => { + let mut hasher = Sha512::new(); + hasher.update(input_message); + let mut output = vec![0u8; C::DIGEST_SIZE]; + output.copy_from_slice(hasher.finalize().as_ref()); + output + } + Sha2Variant::Sha384 => { + let mut hasher = Sha384::new(); + hasher.update(input_message); + let mut output = vec![0u8; C::DIGEST_SIZE]; + output.copy_from_slice(hasher.finalize().as_ref()); + output + } + } +} diff --git a/extensions/sha256/circuit/src/sha256_chip/tests.rs b/extensions/sha2/circuit/src/sha2_chip/tests.rs similarity index 51% rename from extensions/sha256/circuit/src/sha256_chip/tests.rs rename to extensions/sha2/circuit/src/sha2_chip/tests.rs index 9ddf6e6298..9cba30b34b 100644 --- a/extensions/sha256/circuit/src/sha256_chip/tests.rs +++ b/extensions/sha2/circuit/src/sha2_chip/tests.rs @@ -11,38 +11,40 @@ use openvm_circuit_primitives::bitwise_op_lookup::{ BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, }; use openvm_instructions::{instruction::Instruction, riscv::RV32_CELL_BITS, LocalOpcode}; -use openvm_sha256_transpiler::Rv32Sha256Opcode::{self, *}; +use openvm_sha2_air::{Sha256Config, Sha2Variant, Sha384Config, Sha512Config}; +use openvm_sha2_transpiler::Rv32Sha2Opcode; use openvm_stark_backend::{interaction::BusIndex, p3_field::FieldAlgebra}; use openvm_stark_sdk::{config::setup_tracing, p3_baby_bear::BabyBear, utils::create_seeded_rng}; use rand::{rngs::StdRng, Rng}; -use super::{Sha256VmAir, Sha256VmChip, Sha256VmStep}; +use super::{Sha2ChipConfig, Sha2VmAir, Sha2VmChip, Sha2VmStep}; use crate::{ - sha256_chip::trace::Sha256VmRecordLayout, sha256_solve, Sha256VmDigestCols, Sha256VmRoundCols, + sha2_chip::trace::Sha2VmRecordLayout, sha2_solve, Sha2VmDigestColsRef, Sha2VmRoundColsRef, }; type F = BabyBear; const SELF_BUS_IDX: BusIndex = 28; -const MAX_INS_CAPACITY: usize = 4096; +const MAX_INS_CAPACITY: usize = 8192; +type Sha2VmChipDense = NewVmChipWrapper, Sha2VmStep, DenseRecordArena>; -fn create_test_chips( +fn create_test_chips( tester: &mut VmChipTestBuilder, ) -> ( - Sha256VmChip, + Sha2VmChip, SharedBitwiseOperationLookupChip, ) { let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let mut chip = Sha256VmChip::new( - Sha256VmAir::new( + let mut chip = Sha2VmChip::::new( + Sha2VmAir::new( tester.system_port(), bitwise_bus, tester.address_bits(), SELF_BUS_IDX, ), - Sha256VmStep::new( + Sha2VmStep::new( bitwise_chip.clone(), - Rv32Sha256Opcode::CLASS_OFFSET, + Rv32Sha2Opcode::CLASS_OFFSET, tester.address_bits(), ), tester.memory_helper(), @@ -52,11 +54,11 @@ fn create_test_chips( (chip, bitwise_chip) } -fn set_and_execute>( +fn set_and_execute, C: Sha2ChipConfig>( tester: &mut VmChipTestBuilder, chip: &mut E, rng: &mut StdRng, - opcode: Rv32Sha256Opcode, + opcode: Rv32Sha2Opcode, message: Option<&[u8]>, len: Option, ) { @@ -70,7 +72,7 @@ fn set_and_execute>( let rs2 = gen_pointer(rng, 4); let max_mem_ptr: u32 = 1 << tester.address_bits(); - let dst_ptr = rng.gen_range(0..max_mem_ptr); + let dst_ptr = rng.gen_range(0..max_mem_ptr - C::DIGEST_SIZE as u32); let dst_ptr = dst_ptr ^ (dst_ptr & 3); tester.write(1, rd, dst_ptr.to_le_bytes().map(F::from_canonical_u8)); let src_ptr = rng.gen_range(0..(max_mem_ptr - len as u32)); @@ -92,11 +94,35 @@ fn set_and_execute>( &Instruction::from_usize(opcode.global_opcode(), [rd, rs1, rs2, 1, 2]), ); - let output = sha256_solve(message); - assert_eq!( - output.map(F::from_canonical_u8), - tester.read::<32>(2, dst_ptr as usize) - ); + let output = sha2_solve::(message); + match C::VARIANT { + Sha2Variant::Sha256 => { + assert_eq!( + output + .into_iter() + .map(F::from_canonical_u8) + .collect::>(), + tester.read::<{ Sha256Config::DIGEST_SIZE }>(2, dst_ptr as usize) + ); + } + Sha2Variant::Sha512 | Sha2Variant::Sha384 => { + let mut output = output; + output.extend(std::iter::repeat(0u8).take(C::HASH_SIZE)); + let output = output + .into_iter() + .map(F::from_canonical_u8) + .collect::>(); + for i in 0..C::NUM_WRITES { + assert_eq!( + output[i * C::WRITE_SIZE..(i + 1) * C::WRITE_SIZE], + tester.read::<{ Sha512Config::WRITE_SIZE }>( + 2, + dst_ptr as usize + i * C::WRITE_SIZE + ) + ); + } + } + } } /////////////////////////////////////////////////////////////////////////////////////// @@ -105,51 +131,80 @@ fn set_and_execute>( /// Randomly generate computations and execute, ensuring that the generated trace /// passes all constraints. /////////////////////////////////////////////////////////////////////////////////////// -#[test] -fn rand_sha256_test() { +fn rand_sha_test() { setup_tracing(); let mut rng = create_seeded_rng(); let mut tester = VmChipTestBuilder::default(); - let (mut chip, bitwise_chip) = create_test_chips(&mut tester); + let (mut chip, bitwise_chip) = create_test_chips::(&mut tester); let num_ops: usize = 10; for _ in 0..num_ops { - set_and_execute(&mut tester, &mut chip, &mut rng, SHA256, None, None); + set_and_execute::<_, C>(&mut tester, &mut chip, &mut rng, C::OPCODE, None, None); } let tester = tester.build().load(chip).load(bitwise_chip).finalize(); tester.simple_test().expect("Verification failed"); } +#[test] +fn rand_sha256_test() { + rand_sha_test::(); +} + +#[test] +fn rand_sha512_test() { + rand_sha_test::(); +} + +#[test] +fn rand_sha384_test() { + rand_sha_test::(); +} + /////////////////////////////////////////////////////////////////////////////////////// /// SANITY TESTS /// /// Ensure that solve functions produce the correct results. /////////////////////////////////////////////////////////////////////////////////////// -#[test] -fn execute_roundtrip_sanity_test() { +fn execute_roundtrip_sanity_test() { let mut rng = create_seeded_rng(); let mut tester = VmChipTestBuilder::default(); - let (mut chip, _) = create_test_chips(&mut tester); + let (mut chip, _) = create_test_chips::(&mut tester); println!( - "Sha256VmDigestCols::width(): {}", - Sha256VmDigestCols::::width() + "Sha2VmDigestColsRef::::width::(): {}", + Sha2VmDigestColsRef::::width::() ); println!( - "Sha256VmRoundCols::width(): {}", - Sha256VmRoundCols::::width() + "Sha2VmRoundColsRef::::width::(): {}", + Sha2VmRoundColsRef::::width::() ); + let num_tests: usize = 1; for _ in 0..num_tests { - set_and_execute(&mut tester, &mut chip, &mut rng, SHA256, None, None); + set_and_execute::<_, C>(&mut tester, &mut chip, &mut rng, C::OPCODE, None, None); } } +#[test] +fn sha256_roundtrip_sanity_test() { + execute_roundtrip_sanity_test::(); +} + +#[test] +fn sha512_roundtrip_sanity_test() { + execute_roundtrip_sanity_test::(); +} + +#[test] +fn sha384_roundtrip_sanity_test() { + execute_roundtrip_sanity_test::(); +} + #[test] fn sha256_solve_sanity_check() { let input = b"Axiom is the best! Axiom is the best! Axiom is the best! Axiom is the best!"; - let output = sha256_solve(input); + let output = sha2_solve::(input); let expected: [u8; 32] = [ 99, 196, 61, 185, 226, 212, 131, 80, 154, 248, 97, 108, 157, 55, 200, 226, 160, 73, 207, 46, 245, 169, 94, 255, 42, 136, 193, 15, 40, 133, 173, 22, @@ -157,6 +212,32 @@ fn sha256_solve_sanity_check() { assert_eq!(output, expected); } +#[test] +fn sha512_solve_sanity_check() { + let input = b"Axiom is the best! Axiom is the best! Axiom is the best! Axiom is the best!"; + let output = sha2_solve::(input); + // verified manually against the sha512 command line tool + let expected: [u8; 64] = [ + 0, 8, 195, 142, 70, 71, 97, 208, 132, 132, 243, 53, 179, 186, 8, 162, 71, 75, 126, 21, 130, + 203, 245, 126, 207, 65, 119, 60, 64, 79, 200, 2, 194, 17, 189, 137, 164, 213, 107, 197, + 152, 11, 242, 165, 146, 80, 96, 105, 249, 27, 139, 14, 244, 21, 118, 31, 94, 87, 32, 145, + 149, 98, 235, 75, + ]; + assert_eq!(output, expected); +} + +#[test] +fn sha384_solve_sanity_check() { + let input = b"Axiom is the best! Axiom is the best! Axiom is the best! Axiom is the best!"; + let output = sha2_solve::(input); + let expected: [u8; 48] = [ + 134, 227, 167, 229, 35, 110, 115, 174, 10, 27, 197, 116, 56, 144, 150, 36, 152, 190, 212, + 120, 26, 243, 125, 4, 2, 60, 164, 195, 218, 219, 255, 143, 240, 75, 158, 126, 102, 105, 8, + 202, 142, 240, 230, 161, 162, 152, 111, 71, + ]; + assert_eq!(output, expected); +} + /////////////////////////////////////////////////////////////////////////////////////// /// DENSE TESTS /// @@ -165,22 +246,22 @@ fn sha256_solve_sanity_check() { /// to a [MatrixRecordArena]. After transferring we generate the trace and make sure that /// all the constraints pass. /////////////////////////////////////////////////////////////////////////////////////// -type Sha256VmChipDense = NewVmChipWrapper; - -fn create_test_chip_dense(tester: &mut VmChipTestBuilder) -> Sha256VmChipDense { +fn create_test_chip_dense( + tester: &mut VmChipTestBuilder, +) -> Sha2VmChipDense { let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let mut chip = Sha256VmChipDense::new( - Sha256VmAir::new( + let mut chip = Sha2VmChipDense::::new( + Sha2VmAir::::new( tester.system_port(), bitwise_chip.bus(), tester.address_bits(), SELF_BUS_IDX, ), - Sha256VmStep::new( + Sha2VmStep::::new( bitwise_chip.clone(), - Rv32Sha256Opcode::CLASS_OFFSET, + Rv32Sha2Opcode::CLASS_OFFSET, tester.address_bits(), ), tester.memory_helper(), @@ -190,23 +271,29 @@ fn create_test_chip_dense(tester: &mut VmChipTestBuilder) -> Sha256VmChipDens chip } -#[test] -fn dense_record_arena_test() { +fn dense_record_arena_test() { let mut rng = create_seeded_rng(); let mut tester = VmChipTestBuilder::default(); - let (mut sparse_chip, bitwise_chip) = create_test_chips(&mut tester); + let (mut sparse_chip, bitwise_chip) = create_test_chips::(&mut tester); { - let mut dense_chip = create_test_chip_dense(&mut tester); + let mut dense_chip = create_test_chip_dense::(&mut tester); let num_ops: usize = 10; for _ in 0..num_ops { - set_and_execute(&mut tester, &mut dense_chip, &mut rng, SHA256, None, None); + set_and_execute::<_, C>( + &mut tester, + &mut dense_chip, + &mut rng, + C::OPCODE, + None, + None, + ); } let mut record_interpreter = dense_chip .arena - .get_record_seeker::<_, Sha256VmRecordLayout>(); + .get_record_seeker::<_, Sha2VmRecordLayout>(); record_interpreter.transfer_to_matrix_arena(&mut sparse_chip.arena); } @@ -217,3 +304,18 @@ fn dense_record_arena_test() { .finalize(); tester.simple_test().expect("Verification failed"); } + +#[test] +fn sha256_dense_record_arena_test() { + dense_record_arena_test::(); +} + +#[test] +fn sha512_dense_record_arena_test() { + dense_record_arena_test::(); +} + +#[test] +fn sha384_dense_record_arena_test() { + dense_record_arena_test::(); +} diff --git a/extensions/sha2/circuit/src/sha2_chip/trace.rs b/extensions/sha2/circuit/src/sha2_chip/trace.rs new file mode 100644 index 0000000000..b8a59413e5 --- /dev/null +++ b/extensions/sha2/circuit/src/sha2_chip/trace.rs @@ -0,0 +1,719 @@ +use std::{ + borrow::{Borrow, BorrowMut}, + cmp::min, + iter, + marker::PhantomData, +}; + +use openvm_circuit::{ + arch::{ + get_record_from_slice, CustomBorrow, MultiRowLayout, MultiRowMetadata, RecordArena, Result, + SizedRecord, TraceFiller, TraceStep, VmStateMut, + }, + system::memory::{ + offline_checker::{MemoryReadAuxRecord, MemoryWriteBytesAuxRecord}, + online::TracingMemory, + MemoryAuxColsFactory, + }, +}; +use openvm_circuit_primitives::AlignedBytesBorrow; +use openvm_instructions::{ + instruction::Instruction, + program::DEFAULT_PC_STEP, + riscv::{RV32_CELL_BITS, RV32_MEMORY_AS, RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS}, + LocalOpcode, +}; +use openvm_rv32im_circuit::adapters::{read_rv32_register, tracing_read, tracing_write}; +use openvm_sha2_air::{ + be_limbs_into_word, get_flag_pt_array, Sha256Config, Sha2StepHelper, Sha384Config, Sha512Config, +}; +use openvm_stark_backend::{ + p3_field::PrimeField32, + p3_matrix::{dense::RowMajorMatrix, Matrix}, + p3_maybe_rayon::prelude::*, +}; + +use super::{ + Sha2ChipConfig, Sha2Variant, Sha2VmDigestColsRefMut, Sha2VmRoundColsRefMut, Sha2VmStep, +}; +use crate::{ + get_sha2_num_blocks, sha2_chip::PaddingFlags, sha2_solve, Sha2VmControlColsRefMut, + MAX_SHA_NUM_WRITES, SHA_MAX_MESSAGE_LEN, SHA_REGISTER_READS, SHA_WRITE_SIZE, +}; + +#[derive(Clone, Copy)] +pub struct Sha2VmMetadata { + pub num_blocks: u32, + _phantom: PhantomData, +} + +impl MultiRowMetadata for Sha2VmMetadata { + #[inline(always)] + fn get_num_rows(&self) -> usize { + self.num_blocks as usize * C::ROWS_PER_BLOCK + } +} + +pub(crate) type Sha2VmRecordLayout = MultiRowLayout>; + +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug, Clone)] +pub struct Sha2VmRecordHeader { + pub from_pc: u32, + pub timestamp: u32, + pub rd_ptr: u32, + pub rs1_ptr: u32, + pub rs2_ptr: u32, + pub dst_ptr: u32, + pub src_ptr: u32, + pub len: u32, + + pub register_reads_aux: [MemoryReadAuxRecord; SHA_REGISTER_READS], + // Note: MAX_SHA_NUM_WRITES = 2 because SHA-256 uses 1 write, while SHA-512 and SHA-384 use 2 + // writes. We just use the same array for all variants to simplify record storage. + pub writes_aux: [MemoryWriteBytesAuxRecord; MAX_SHA_NUM_WRITES], +} + +pub struct Sha2VmRecordMut<'a> { + pub inner: &'a mut Sha2VmRecordHeader, + // Having a continuous slice of the input is useful for fast hashing in `execute` + pub input: &'a mut [u8], + pub read_aux: &'a mut [MemoryReadAuxRecord], +} + +/// Custom borrowing that splits the buffer into a fixed `Sha2VmRecord` header +/// followed by a slice of `u8`'s of length `C::BLOCK_CELLS * num_blocks` where `num_blocks` is +/// provided at runtime, followed by a slice of `MemoryReadAuxRecord`'s of length +/// `C::NUM_READ_ROWS * num_blocks`. Uses `align_to_mut()` to make sure the slice is properly +/// aligned to `MemoryReadAuxRecord`. Has debug assertions that check the size and alignment of the +/// slices. +impl<'a, C: Sha2ChipConfig> CustomBorrow<'a, Sha2VmRecordMut<'a>, Sha2VmRecordLayout> + for [u8] +{ + fn custom_borrow(&'a mut self, layout: Sha2VmRecordLayout) -> Sha2VmRecordMut<'a> { + let (header_buf, rest) = + unsafe { self.split_at_mut_unchecked(size_of::()) }; + let header: &mut Sha2VmRecordHeader = header_buf.borrow_mut(); + + // Using `split_at_mut_unchecked` for perf reasons + // input is a slice of `u8`'s of length `C::BLOCK_CELLS * num_blocks`, so the alignment + // is always satisfied + let (input, rest) = unsafe { + rest.split_at_mut_unchecked((layout.metadata.num_blocks as usize) * C::BLOCK_CELLS) + }; + + // Using `align_to_mut` to make sure the returned slice is properly aligned to + // `MemoryReadAuxRecord` Additionally, Rust's subslice operation (a few lines below) + // will verify that the buffer has enough capacity + let (_, read_aux_buf, _) = unsafe { rest.align_to_mut::() }; + Sha2VmRecordMut { + inner: header, + input, + read_aux: &mut read_aux_buf[..(layout.metadata.num_blocks as usize) * C::NUM_READ_ROWS], + } + } + + unsafe fn extract_layout(&self) -> Sha2VmRecordLayout { + let header: &Sha2VmRecordHeader = self.borrow(); + + Sha2VmRecordLayout { + metadata: Sha2VmMetadata { + num_blocks: get_sha2_num_blocks::(header.len), + _phantom: PhantomData::, + }, + } + } +} + +impl SizedRecord> for Sha2VmRecordMut<'_> { + fn size(layout: &Sha2VmRecordLayout) -> usize { + let mut total_len = size_of::(); + total_len += layout.metadata.num_blocks as usize * C::BLOCK_CELLS; + // Align the pointer to the alignment of `MemoryReadAuxRecord` + total_len = total_len.next_multiple_of(align_of::()); + total_len += layout.metadata.num_blocks as usize + * C::NUM_READ_ROWS + * size_of::(); + total_len + } + + fn alignment(_layout: &Sha2VmRecordLayout) -> usize { + align_of::() + } +} + +impl TraceStep for Sha2VmStep { + type RecordLayout = Sha2VmRecordLayout; + type RecordMut<'a> = Sha2VmRecordMut<'a>; + + fn get_opcode_name(&self, _: usize) -> String { + format!("{:?}", C::OPCODE) + } + + fn execute<'buf, RA>( + &mut self, + state: VmStateMut, CTX>, + instruction: &Instruction, + arena: &'buf mut RA, + ) -> Result<()> + where + RA: RecordArena<'buf, Self::RecordLayout, Self::RecordMut<'buf>>, + { + let Instruction { + opcode, + a, + b, + c, + d, + e, + .. + } = instruction; + debug_assert_eq!(*opcode, C::OPCODE.global_opcode()); + debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); + debug_assert_eq!(e.as_canonical_u32(), RV32_MEMORY_AS); + + // Reading the length first to allocate a record of correct size + let len = read_rv32_register(state.memory.data(), c.as_canonical_u32()); + + let num_blocks = get_sha2_num_blocks::(len); + let record = arena.alloc(MultiRowLayout { + metadata: Sha2VmMetadata { + num_blocks, + _phantom: PhantomData::, + }, + }); + + record.inner.from_pc = *state.pc; + record.inner.timestamp = state.memory.timestamp(); + record.inner.rd_ptr = a.as_canonical_u32(); + record.inner.rs1_ptr = b.as_canonical_u32(); + record.inner.rs2_ptr = c.as_canonical_u32(); + + record.inner.dst_ptr = u32::from_le_bytes(tracing_read( + state.memory, + RV32_REGISTER_AS, + record.inner.rd_ptr, + &mut record.inner.register_reads_aux[0].prev_timestamp, + )); + record.inner.src_ptr = u32::from_le_bytes(tracing_read( + state.memory, + RV32_REGISTER_AS, + record.inner.rs1_ptr, + &mut record.inner.register_reads_aux[1].prev_timestamp, + )); + record.inner.len = u32::from_le_bytes(tracing_read( + state.memory, + RV32_REGISTER_AS, + record.inner.rs2_ptr, + &mut record.inner.register_reads_aux[2].prev_timestamp, + )); + + // we will read [num_blocks] * [SHA256_BLOCK_CELLS] cells but only [len] cells will be used + debug_assert!( + record.inner.src_ptr as usize + num_blocks as usize * C::BLOCK_CELLS + <= (1 << self.pointer_max_bits) + ); + debug_assert!( + record.inner.dst_ptr as usize + C::WRITE_SIZE <= (1 << self.pointer_max_bits) + ); + // We don't support messages longer than 2^29 bytes + debug_assert!(record.inner.len < SHA_MAX_MESSAGE_LEN as u32); + + for block_idx in 0..num_blocks as usize { + // Reads happen on the first 4 rows of each block + for row in 0..C::NUM_READ_ROWS { + let read_idx = block_idx * C::NUM_READ_ROWS + row; + match C::VARIANT { + Sha2Variant::Sha256 => { + let row_input: [u8; Sha256Config::READ_SIZE] = tracing_read( + state.memory, + RV32_MEMORY_AS, + record.inner.src_ptr + (read_idx * C::READ_SIZE) as u32, + &mut record.read_aux[read_idx].prev_timestamp, + ); + record.input[read_idx * C::READ_SIZE..(read_idx + 1) * C::READ_SIZE] + .copy_from_slice(&row_input); + } + Sha2Variant::Sha512 => { + let row_input: [u8; Sha512Config::READ_SIZE] = tracing_read( + state.memory, + RV32_MEMORY_AS, + record.inner.src_ptr + (read_idx * C::READ_SIZE) as u32, + &mut record.read_aux[read_idx].prev_timestamp, + ); + record.input[read_idx * C::READ_SIZE..(read_idx + 1) * C::READ_SIZE] + .copy_from_slice(&row_input); + } + Sha2Variant::Sha384 => { + let row_input: [u8; Sha384Config::READ_SIZE] = tracing_read( + state.memory, + RV32_MEMORY_AS, + record.inner.src_ptr + (read_idx * C::READ_SIZE) as u32, + &mut record.read_aux[read_idx].prev_timestamp, + ); + record.input[read_idx * C::READ_SIZE..(read_idx + 1) * C::READ_SIZE] + .copy_from_slice(&row_input); + } + } + } + } + + let mut output = sha2_solve::(&record.input[..len as usize]); + match C::VARIANT { + Sha2Variant::Sha256 => { + tracing_write::( + state.memory, + RV32_MEMORY_AS, + record.inner.dst_ptr, + output.try_into().unwrap(), + &mut record.inner.writes_aux[0].prev_timestamp, + &mut record.inner.writes_aux[0].prev_data, + ); + } + Sha2Variant::Sha512 => { + debug_assert!(output.len() % Sha512Config::WRITE_SIZE == 0); + output + .chunks_exact(Sha512Config::WRITE_SIZE) + .enumerate() + .for_each(|(i, chunk)| { + tracing_write::( + state.memory, + RV32_MEMORY_AS, + record.inner.dst_ptr + (i * Sha512Config::WRITE_SIZE) as u32, + chunk.try_into().unwrap(), + &mut record.inner.writes_aux[i].prev_timestamp, + &mut record.inner.writes_aux[i].prev_data, + ); + }); + } + Sha2Variant::Sha384 => { + // output is a truncated 48-byte digest, so we will append 16 bytes of zeros + output.extend(iter::repeat(0).take(16)); + debug_assert!(output.len() % Sha384Config::WRITE_SIZE == 0); + output + .chunks_exact(Sha384Config::WRITE_SIZE) + .enumerate() + .for_each(|(i, chunk)| { + tracing_write::( + state.memory, + RV32_MEMORY_AS, + record.inner.dst_ptr + (i * Sha384Config::WRITE_SIZE) as u32, + chunk.try_into().unwrap(), + &mut record.inner.writes_aux[i].prev_timestamp, + &mut record.inner.writes_aux[i].prev_data, + ); + }); + } + } + + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + + Ok(()) + } +} + +impl TraceFiller for Sha2VmStep { + fn fill_trace( + &self, + mem_helper: &MemoryAuxColsFactory, + trace_matrix: &mut RowMajorMatrix, + rows_used: usize, + ) { + if rows_used == 0 { + return; + } + + let mut chunks = Vec::with_capacity(trace_matrix.height() / C::ROWS_PER_BLOCK); + let mut sizes = Vec::with_capacity(trace_matrix.height() / C::ROWS_PER_BLOCK); + let mut trace = &mut trace_matrix.values[..]; + let mut num_blocks_so_far = 0; + + // First pass over the trace to get the number of blocks for each instruction + // and divide the matrix into chunks of needed sizes + loop { + if num_blocks_so_far * C::ROWS_PER_BLOCK >= rows_used { + // Push all the padding rows as a single chunk and break + chunks.push(trace); + sizes.push((0, num_blocks_so_far)); + break; + } else { + let record: &Sha2VmRecordHeader = unsafe { get_record_from_slice(&mut trace, ()) }; + let num_blocks = get_sha2_num_blocks::(record.len) as usize; + let (chunk, rest) = + trace.split_at_mut(C::VM_WIDTH * C::ROWS_PER_BLOCK * num_blocks); + chunks.push(chunk); + sizes.push((num_blocks, num_blocks_so_far)); + num_blocks_so_far += num_blocks; + trace = rest; + } + } + + // During the first pass we will fill out most of the matrix + // But there are some cells that can't be generated by the first pass so we will do a second + // pass over the matrix later + chunks.par_iter_mut().zip(sizes.par_iter()).for_each( + |(slice, (num_blocks, global_block_offset))| { + if global_block_offset * C::ROWS_PER_BLOCK >= rows_used { + // Fill in the invalid rows + slice.par_chunks_mut(C::VM_WIDTH).for_each(|row| { + // Need to get rid of the accidental garbage data that might overflow the + // F's prime field. Unfortunately, there is no good way around this + unsafe { + std::ptr::write_bytes( + row.as_mut_ptr() as *mut u8, + 0, + C::VM_WIDTH * size_of::(), + ); + } + let cols = Sha2VmRoundColsRefMut::::from::( + row[..C::VM_ROUND_WIDTH].borrow_mut(), + ); + self.inner.generate_default_row(cols.inner); + }); + return; + } + + let record: Sha2VmRecordMut = unsafe { + get_record_from_slice( + slice, + Sha2VmRecordLayout { + metadata: Sha2VmMetadata { + num_blocks: *num_blocks as u32, + _phantom: PhantomData::, + }, + }, + ) + }; + + let mut input: Vec = Vec::with_capacity(C::BLOCK_CELLS * num_blocks); + input.extend_from_slice(record.input); + let mut padded_input = input.clone(); + let len = record.inner.len as usize; + let padded_input_len = padded_input.len(); + padded_input[len] = 1 << (RV32_CELL_BITS - 1); + padded_input[len + 1..padded_input_len - 4].fill(0); + padded_input[padded_input_len - 4..] + .copy_from_slice(&((len as u32) << 3).to_be_bytes()); + + let mut prev_hashes = Vec::with_capacity(*num_blocks); + prev_hashes.push(C::get_h().to_vec()); + for i in 0..*num_blocks - 1 { + prev_hashes.push(Sha2StepHelper::::get_block_hash( + &prev_hashes[i], + padded_input[i * C::BLOCK_CELLS..(i + 1) * C::BLOCK_CELLS].into(), + )); + } + // Copy the read aux records and input to another place to safely fill in the trace + // matrix without overwriting the record + let mut read_aux_records = Vec::with_capacity(C::NUM_READ_ROWS * num_blocks); + read_aux_records.extend_from_slice(record.read_aux); + let vm_record = record.inner.clone(); + + slice + .par_chunks_exact_mut(C::VM_WIDTH * C::ROWS_PER_BLOCK) + .enumerate() + .for_each(|(block_idx, block_slice)| { + // Need to get rid of the accidental garbage data that might overflow the + // F's prime field. Unfortunately, there is no good way around this + unsafe { + std::ptr::write_bytes( + block_slice.as_mut_ptr() as *mut u8, + 0, + C::ROWS_PER_BLOCK * C::VM_WIDTH * size_of::(), + ); + } + self.fill_block_trace::( + block_slice, + &vm_record, + &read_aux_records + [block_idx * C::NUM_READ_ROWS..(block_idx + 1) * C::NUM_READ_ROWS], + &input[block_idx * C::BLOCK_CELLS..(block_idx + 1) * C::BLOCK_CELLS], + &padded_input + [block_idx * C::BLOCK_CELLS..(block_idx + 1) * C::BLOCK_CELLS], + block_idx == *num_blocks - 1, + *global_block_offset + block_idx, + block_idx, + prev_hashes[block_idx].as_slice(), + mem_helper, + ); + }); + }, + ); + + // Do a second pass over the trace to fill in the missing values + // Note, we need to skip the very first row + trace_matrix.values[C::VM_WIDTH..] + .par_chunks_mut(C::VM_WIDTH * C::ROWS_PER_BLOCK) + .take(rows_used / C::ROWS_PER_BLOCK) + .for_each(|chunk| { + self.inner + .generate_missing_cells(chunk, C::VM_WIDTH, C::VM_CONTROL_WIDTH); + }); + } +} + +impl Sha2VmStep { + #[allow(clippy::too_many_arguments)] + fn fill_block_trace( + &self, + block_slice: &mut [F], + record: &Sha2VmRecordHeader, + read_aux_records: &[MemoryReadAuxRecord], + input: &[u8], + padded_input: &[u8], + is_last_block: bool, + global_block_idx: usize, + local_block_idx: usize, + prev_hash: &[C::Word], + mem_helper: &MemoryAuxColsFactory, + ) { + debug_assert_eq!(input.len(), C::BLOCK_CELLS); + debug_assert_eq!(padded_input.len(), C::BLOCK_CELLS); + debug_assert_eq!(read_aux_records.len(), C::NUM_READ_ROWS); + debug_assert_eq!(prev_hash.len(), C::HASH_WORDS); + + let padded_input = (0..C::BLOCK_WORDS) + .map(|i| { + be_limbs_into_word::( + &padded_input[i * C::WORD_U8S..(i + 1) * C::WORD_U8S] + .iter() + .map(|x| *x as u32) + .collect::>(), + ) + }) + .collect::>(); + + let block_start_timestamp = + record.timestamp + (SHA_REGISTER_READS + C::NUM_READ_ROWS * local_block_idx) as u32; + + let read_cells = (C::BLOCK_CELLS * local_block_idx) as u32; + let block_start_read_ptr = record.src_ptr + read_cells; + + let message_left = if record.len <= read_cells { + 0 + } else { + (record.len - read_cells) as usize + }; + + // -1 means that padding occurred before the start of the block + // C::ROWS_PER_BLOCK + 1 means that no padding occurred on this block + let first_padding_row = if record.len < read_cells { + -1 + } else if message_left < C::BLOCK_CELLS { + (message_left / C::READ_SIZE) as i32 + } else { + (C::ROWS_PER_BLOCK + 1) as i32 + }; + + // Fill in the VM columns first because the inner `carry_or_buffer` needs to be filled in + block_slice + .par_chunks_exact_mut(C::VM_WIDTH) + .enumerate() + .for_each(|(row_idx, row_slice)| { + // Handle round rows and digest row separately + if row_idx == C::ROWS_PER_BLOCK - 1 { + // This is a digest row + let mut digest_cols = Sha2VmDigestColsRefMut::::from::( + row_slice[..C::VM_DIGEST_WIDTH].borrow_mut(), + ); + digest_cols.from_state.timestamp = F::from_canonical_u32(record.timestamp); + digest_cols.from_state.pc = F::from_canonical_u32(record.from_pc); + *digest_cols.rd_ptr = F::from_canonical_u32(record.rd_ptr); + *digest_cols.rs1_ptr = F::from_canonical_u32(record.rs1_ptr); + *digest_cols.rs2_ptr = F::from_canonical_u32(record.rs2_ptr); + digest_cols + .dst_ptr + .iter_mut() + .zip(record.dst_ptr.to_le_bytes().map(F::from_canonical_u8)) + .for_each(|(x, y)| *x = y); + digest_cols + .src_ptr + .iter_mut() + .zip(record.src_ptr.to_le_bytes().map(F::from_canonical_u8)) + .for_each(|(x, y)| *x = y); + digest_cols + .len_data + .iter_mut() + .zip(record.len.to_le_bytes().map(F::from_canonical_u8)) + .for_each(|(x, y)| *x = y); + if is_last_block { + digest_cols + .register_reads_aux + .iter_mut() + .zip(record.register_reads_aux.iter()) + .enumerate() + .for_each(|(idx, (cols_read, record_read))| { + mem_helper.fill( + record_read.prev_timestamp, + record.timestamp + idx as u32, + cols_read.as_mut(), + ); + }); + for i in 0..C::NUM_WRITES { + digest_cols + .writes_aux_prev_data + .row_mut(i) + .iter_mut() + .zip(record.writes_aux[i].prev_data.map(F::from_canonical_u8)) + .for_each(|(x, y)| *x = y); + + // In the last block we do `C::NUM_READ_ROWS` reads and then write the + // result thus the timestamp of the write is + // `block_start_timestamp + C::NUM_READ_ROWS` + mem_helper.fill( + record.writes_aux[i].prev_timestamp, + block_start_timestamp + C::NUM_READ_ROWS as u32 + i as u32, + &mut digest_cols.writes_aux_base[i], + ); + } + // Need to range check the destination and source pointers + let msl_rshift: u32 = + ((RV32_REGISTER_NUM_LIMBS - 1) * RV32_CELL_BITS) as u32; + let msl_lshift: u32 = (RV32_REGISTER_NUM_LIMBS * RV32_CELL_BITS + - self.pointer_max_bits) + as u32; + self.bitwise_lookup_chip.request_range( + (record.dst_ptr >> msl_rshift) << msl_lshift, + (record.src_ptr >> msl_rshift) << msl_lshift, + ); + } else { + // Filling in zeros to make sure the accidental garbage data doesn't + // overflow the prime + digest_cols.register_reads_aux.iter_mut().for_each(|aux| { + mem_helper.fill_zero(aux.as_mut()); + }); + for i in 0..C::NUM_WRITES { + digest_cols.writes_aux_prev_data.row_mut(i).fill(F::ZERO); + mem_helper.fill_zero(&mut digest_cols.writes_aux_base[i]); + } + } + *digest_cols.inner.flags.is_last_block = F::from_bool(is_last_block); + *digest_cols.inner.flags.is_digest_row = F::from_bool(true); + } else { + // This is a round row + let mut round_cols = Sha2VmRoundColsRefMut::::from::( + row_slice[..C::VM_ROUND_WIDTH].borrow_mut(), + ); + // Take care of the first 4 round rows (aka read rows) + if row_idx < C::NUM_READ_ROWS { + round_cols + .inner + .message_schedule + .carry_or_buffer + .iter_mut() + .zip(input[row_idx * C::READ_SIZE..(row_idx + 1) * C::READ_SIZE].iter()) + .for_each(|(cell, data)| { + *cell = F::from_canonical_u8(*data); + }); + mem_helper.fill( + read_aux_records[row_idx].prev_timestamp, + block_start_timestamp + row_idx as u32, + round_cols.read_aux.as_mut(), + ); + } else { + mem_helper.fill_zero(round_cols.read_aux.as_mut()); + } + } + // Fill in the control cols, doesn't matter if it is a round or digest row + let mut control_cols = Sha2VmControlColsRefMut::::from::( + row_slice[..C::VM_CONTROL_WIDTH].borrow_mut(), + ); + *control_cols.len = F::from_canonical_u32(record.len); + // Only the first `SHA256_NUM_READ_ROWS` rows increment the timestamp and read ptr + *control_cols.cur_timestamp = F::from_canonical_u32( + block_start_timestamp + min(row_idx, C::NUM_READ_ROWS) as u32, + ); + *control_cols.read_ptr = F::from_canonical_u32( + block_start_read_ptr + (C::READ_SIZE * min(row_idx, C::NUM_READ_ROWS)) as u32, + ); + + // Fill in the padding flags + if row_idx < C::NUM_READ_ROWS { + #[allow(clippy::comparison_chain)] + if (row_idx as i32) < first_padding_row { + control_cols + .pad_flags + .iter_mut() + .zip( + get_flag_pt_array( + &self.padding_encoder, + PaddingFlags::NotPadding as usize, + ) + .into_iter() + .map(F::from_canonical_u32), + ) + .for_each(|(x, y)| *x = y); + } else if row_idx as i32 == first_padding_row { + let len = message_left - row_idx * C::READ_SIZE; + control_cols + .pad_flags + .iter_mut() + .zip( + get_flag_pt_array( + &self.padding_encoder, + if row_idx == 3 && is_last_block { + PaddingFlags::FirstPadding0_LastRow + } else { + PaddingFlags::FirstPadding0 + } as usize + + len, + ) + .into_iter() + .map(F::from_canonical_u32), + ) + .for_each(|(x, y)| *x = y); + } else { + control_cols + .pad_flags + .iter_mut() + .zip( + get_flag_pt_array( + &self.padding_encoder, + if row_idx == 3 && is_last_block { + PaddingFlags::EntirePaddingLastRow + } else { + PaddingFlags::EntirePadding + } as usize, + ) + .into_iter() + .map(F::from_canonical_u32), + ) + .for_each(|(x, y)| *x = y); + } + } else { + control_cols + .pad_flags + .iter_mut() + .zip( + get_flag_pt_array( + &self.padding_encoder, + PaddingFlags::NotConsidered as usize, + ) + .into_iter() + .map(F::from_canonical_u32), + ) + .for_each(|(x, y)| *x = y); + } + if is_last_block && row_idx == C::ROWS_PER_BLOCK - 1 { + // If last digest row, then we set padding_occurred = 0 + *control_cols.padding_occurred = F::ZERO; + } else { + *control_cols.padding_occurred = + F::from_bool((row_idx as i32) >= first_padding_row); + } + }); + + // Fill in the inner trace when the `carry_or_buffer` is filled in + self.inner.generate_block_trace::( + block_slice, + C::VM_WIDTH, + C::VM_CONTROL_WIDTH, + &padded_input, + self.bitwise_lookup_chip.clone(), + prev_hash, + is_last_block, + global_block_idx as u32 + 1, // global block index is 1-indexed + local_block_idx as u32, + ); + } +} diff --git a/extensions/sha2/circuit/src/sha2_chip/utils.rs b/extensions/sha2/circuit/src/sha2_chip/utils.rs new file mode 100644 index 0000000000..d3c78345ad --- /dev/null +++ b/extensions/sha2/circuit/src/sha2_chip/utils.rs @@ -0,0 +1,8 @@ +use crate::Sha2ChipConfig; + +/// Returns the number of blocks required to hash a message of length `len` +pub fn get_sha2_num_blocks(len: u32) -> u32 { + // need to pad with one 1 bit, 64 bits for the message length and then pad until the length + // is divisible by [C::BLOCK_BITS] + ((len << 3) as usize + 1 + C::MESSAGE_LENGTH_BITS).div_ceil(C::BLOCK_BITS) as u32 +} diff --git a/extensions/sha256/guest/Cargo.toml b/extensions/sha2/guest/Cargo.toml similarity index 69% rename from extensions/sha256/guest/Cargo.toml rename to extensions/sha2/guest/Cargo.toml index e9d28292b8..1c6503002e 100644 --- a/extensions/sha256/guest/Cargo.toml +++ b/extensions/sha2/guest/Cargo.toml @@ -1,9 +1,9 @@ [package] -name = "openvm-sha256-guest" +name = "openvm-sha2-guest" version.workspace = true authors.workspace = true edition.workspace = true -description = "Guest extension for Sha256" +description = "Guest extension for SHA-2" [dependencies] openvm-platform = { workspace = true } diff --git a/extensions/sha2/guest/src/lib.rs b/extensions/sha2/guest/src/lib.rs new file mode 100644 index 0000000000..567f60d4da --- /dev/null +++ b/extensions/sha2/guest/src/lib.rs @@ -0,0 +1,193 @@ +#![no_std] + +#[cfg(target_os = "zkvm")] +use openvm_platform::alloc::AlignedBuf; + +/// This is custom-0 defined in RISC-V spec document +pub const OPCODE: u8 = 0x0b; +pub const SHA2_FUNCT3: u8 = 0b100; + +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[repr(u8)] +pub enum Sha2BaseFunct7 { + Sha256 = 0x1, + Sha512 = 0x2, + Sha384 = 0x3, +} + +/// zkvm native implementation of sha256 +/// # Safety +/// +/// The VM accepts the preimage by pointer and length, and writes the +/// 32-byte hash. +/// - `bytes` must point to an input buffer at least `len` long. +/// - `output` must point to a buffer that is at least 32-bytes long. +/// +/// [`sha2-256`]: https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.180-4.pdf +#[cfg(target_os = "zkvm")] +#[inline(always)] +#[no_mangle] +pub extern "C" fn zkvm_sha256_impl(bytes: *const u8, len: usize, output: *mut u8) { + // SAFETY: assuming safety assumptions of the inputs, we handle all cases where `bytes` or + // `output` are not aligned to 4 bytes. + // The minimum alignment required for the input and output buffers + const MIN_ALIGN: usize = 4; + // The preferred alignment for the input buffer, since the input is read in chunks of 16 bytes + const INPUT_ALIGN: usize = 16; + // The preferred alignment for the output buffer, since the output is written in chunks of 32 + // bytes + const OUTPUT_ALIGN: usize = 32; + unsafe { + if bytes as usize % MIN_ALIGN != 0 { + let aligned_buff = AlignedBuf::new(bytes, len, INPUT_ALIGN); + if output as usize % MIN_ALIGN != 0 { + let aligned_out = AlignedBuf::uninit(32, OUTPUT_ALIGN); + __native_sha256(aligned_buff.ptr, len, aligned_out.ptr); + core::ptr::copy_nonoverlapping(aligned_out.ptr as *const u8, output, 32); + } else { + __native_sha256(aligned_buff.ptr, len, output); + } + } else { + if output as usize % MIN_ALIGN != 0 { + let aligned_out = AlignedBuf::uninit(32, OUTPUT_ALIGN); + __native_sha256(bytes, len, aligned_out.ptr); + core::ptr::copy_nonoverlapping(aligned_out.ptr as *const u8, output, 32); + } else { + __native_sha256(bytes, len, output); + } + }; + } +} + +/// zkvm native implementation of sha512 +/// # Safety +/// +/// The VM accepts the preimage by pointer and length, and writes the +/// 64-byte hash. +/// - `bytes` must point to an input buffer at least `len` long. +/// - `output` must point to a buffer that is at least 64-bytes long. +/// +/// [`sha2-512`]: https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.180-4.pdf +#[cfg(target_os = "zkvm")] +#[inline(always)] +#[no_mangle] +pub extern "C" fn zkvm_sha512_impl(bytes: *const u8, len: usize, output: *mut u8) { + // SAFETY: assuming safety assumptions of the inputs, we handle all cases where `bytes` or + // `output` are not aligned to 4 bytes. + // The minimum alignment required for the input and output buffers + const MIN_ALIGN: usize = 4; + // The preferred alignment for the input buffer, since the input is read in chunks of 32 bytes + const INPUT_ALIGN: usize = 32; + // The preferred alignment for the output buffer, since the output is written in chunks of 32 + // bytes + const OUTPUT_ALIGN: usize = 32; + unsafe { + if bytes as usize % MIN_ALIGN != 0 { + let aligned_buff = AlignedBuf::new(bytes, len, INPUT_ALIGN); + if output as usize % MIN_ALIGN != 0 { + let aligned_out = AlignedBuf::uninit(64, OUTPUT_ALIGN); + __native_sha512(aligned_buff.ptr, len, aligned_out.ptr); + core::ptr::copy_nonoverlapping(aligned_out.ptr as *const u8, output, 64); + } else { + __native_sha512(aligned_buff.ptr, len, output); + } + } else { + if output as usize % MIN_ALIGN != 0 { + let aligned_out = AlignedBuf::uninit(64, OUTPUT_ALIGN); + __native_sha512(bytes, len, aligned_out.ptr); + core::ptr::copy_nonoverlapping(aligned_out.ptr as *const u8, output, 64); + } else { + __native_sha512(bytes, len, output); + } + }; + } +} + +/// zkvm native implementation of sha384 +/// # Safety +/// +/// The VM accepts the preimage by pointer and length, and writes the +/// 48-byte hash followed by 16-bytes of zeros. +/// - `bytes` must point to an input buffer at least `len` long. +/// - `output` must point to a buffer that is at least 64-bytes long. +/// +/// [`sha2-512`]: https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.180-4.pdf +#[cfg(target_os = "zkvm")] +#[inline(always)] +#[no_mangle] +pub extern "C" fn zkvm_sha384_impl(bytes: *const u8, len: usize, output: *mut u8) { + // SAFETY: assuming safety assumptions of the inputs, we handle all cases where `bytes` or + // `output` are not aligned to 4 bytes. + // The minimum alignment required for the input and output buffers + const MIN_ALIGN: usize = 4; + // The preferred alignment for the input buffer, since the input is read in chunks of 32 bytes + const INPUT_ALIGN: usize = 32; + // The preferred alignment for the output buffer, since the output is written in chunks of 32 + // bytes + const OUTPUT_ALIGN: usize = 32; + unsafe { + if bytes as usize % MIN_ALIGN != 0 { + let aligned_buff = AlignedBuf::new(bytes, len, INPUT_ALIGN); + if output as usize % MIN_ALIGN != 0 { + let aligned_out = AlignedBuf::uninit(64, OUTPUT_ALIGN); + __native_sha384(aligned_buff.ptr, len, aligned_out.ptr); + core::ptr::copy_nonoverlapping(aligned_out.ptr as *const u8, output, 64); + } else { + __native_sha384(aligned_buff.ptr, len, output); + } + } else { + if output as usize % MIN_ALIGN != 0 { + let aligned_out = AlignedBuf::uninit(64, OUTPUT_ALIGN); + __native_sha384(bytes, len, aligned_out.ptr); + core::ptr::copy_nonoverlapping(aligned_out.ptr as *const u8, output, 64); + } else { + __native_sha384(bytes, len, output); + } + }; + } +} + +/// sha256 intrinsic binding +/// +/// # Safety +/// +/// The VM accepts the preimage by pointer and length, and writes the +/// 32-byte hash. +/// - `bytes` must point to an input buffer at least `len` long. +/// - `output` must point to a buffer that is at least 32-bytes long. +/// - `bytes` and `output` must be 4-byte aligned. +#[cfg(target_os = "zkvm")] +#[inline(always)] +fn __native_sha256(bytes: *const u8, len: usize, output: *mut u8) { + openvm_platform::custom_insn_r!(opcode = OPCODE, funct3 = SHA2_FUNCT3, funct7 = Sha2BaseFunct7::Sha256 as u8, rd = In output, rs1 = In bytes, rs2 = In len); +} + +/// sha512 intrinsic binding +/// +/// # Safety +/// +/// The VM accepts the preimage by pointer and length, and writes the +/// 64-byte hash. +/// - `bytes` must point to an input buffer at least `len` long. +/// - `output` must point to a buffer that is at least 64-bytes long. +/// - `bytes` and `output` must be 4-byte aligned. +#[cfg(target_os = "zkvm")] +#[inline(always)] +fn __native_sha512(bytes: *const u8, len: usize, output: *mut u8) { + openvm_platform::custom_insn_r!(opcode = OPCODE, funct3 = SHA2_FUNCT3, funct7 = Sha2BaseFunct7::Sha512 as u8, rd = In output, rs1 = In bytes, rs2 = In len); +} + +/// sha384 intrinsic binding +/// +/// # Safety +/// +/// The VM accepts the preimage by pointer and length, and writes the +/// 48-byte hash followed by 16-bytes of zeros. +/// - `bytes` must point to an input buffer at least `len` long. +/// - `output` must point to a buffer that is at least 64-bytes long. +/// - `bytes` and `output` must be 4-byte aligned. +#[cfg(target_os = "zkvm")] +#[inline(always)] +fn __native_sha384(bytes: *const u8, len: usize, output: *mut u8) { + openvm_platform::custom_insn_r!(opcode = OPCODE, funct3 = SHA2_FUNCT3, funct7 = Sha2BaseFunct7::Sha384 as u8, rd = In output, rs1 = In bytes, rs2 = In len); +} diff --git a/extensions/sha256/transpiler/Cargo.toml b/extensions/sha2/transpiler/Cargo.toml similarity index 73% rename from extensions/sha256/transpiler/Cargo.toml rename to extensions/sha2/transpiler/Cargo.toml index 933859f3a8..9eff76a3db 100644 --- a/extensions/sha256/transpiler/Cargo.toml +++ b/extensions/sha2/transpiler/Cargo.toml @@ -1,15 +1,15 @@ [package] -name = "openvm-sha256-transpiler" +name = "openvm-sha2-transpiler" version.workspace = true authors.workspace = true edition.workspace = true -description = "Transpiler extension for sha256" +description = "Transpiler extension for SHA-2" [dependencies] openvm-stark-backend = { workspace = true } openvm-instructions = { workspace = true } openvm-transpiler = { workspace = true } rrs-lib = { workspace = true } -openvm-sha256-guest = { workspace = true } +openvm-sha2-guest = { workspace = true } openvm-instructions-derive = { workspace = true } strum = { workspace = true } diff --git a/extensions/sha2/transpiler/src/lib.rs b/extensions/sha2/transpiler/src/lib.rs new file mode 100644 index 0000000000..89249ee026 --- /dev/null +++ b/extensions/sha2/transpiler/src/lib.rs @@ -0,0 +1,65 @@ +use openvm_instructions::{riscv::RV32_MEMORY_AS, LocalOpcode}; +use openvm_instructions_derive::LocalOpcode; +use openvm_sha2_guest::{Sha2BaseFunct7, OPCODE, SHA2_FUNCT3}; +use openvm_stark_backend::p3_field::PrimeField32; +use openvm_transpiler::{util::from_r_type, TranspilerExtension, TranspilerOutput}; +use rrs_lib::instruction_formats::RType; +use strum::{EnumCount, EnumIter, FromRepr}; + +#[derive( + Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, EnumCount, EnumIter, FromRepr, LocalOpcode, +)] +#[opcode_offset = 0x320] +#[repr(usize)] +pub enum Rv32Sha2Opcode { + SHA256, + SHA512, + SHA384, +} + +#[derive(Default)] +pub struct Sha2TranspilerExtension; + +impl TranspilerExtension for Sha2TranspilerExtension { + fn process_custom(&self, instruction_stream: &[u32]) -> Option> { + if instruction_stream.is_empty() { + return None; + } + let instruction_u32 = instruction_stream[0]; + let opcode = (instruction_u32 & 0x7f) as u8; + let funct3 = ((instruction_u32 >> 12) & 0b111) as u8; + + if (opcode, funct3) != (OPCODE, SHA2_FUNCT3) { + return None; + } + let dec_insn = RType::new(instruction_u32); + + if dec_insn.funct7 == Sha2BaseFunct7::Sha256 as u32 { + let instruction = from_r_type( + Rv32Sha2Opcode::SHA256.global_opcode().as_usize(), + RV32_MEMORY_AS as usize, + &dec_insn, + true, + ); + Some(TranspilerOutput::one_to_one(instruction)) + } else if dec_insn.funct7 == Sha2BaseFunct7::Sha512 as u32 { + let instruction = from_r_type( + Rv32Sha2Opcode::SHA512.global_opcode().as_usize(), + RV32_MEMORY_AS as usize, + &dec_insn, + true, + ); + Some(TranspilerOutput::one_to_one(instruction)) + } else if dec_insn.funct7 == Sha2BaseFunct7::Sha384 as u32 { + let instruction = from_r_type( + Rv32Sha2Opcode::SHA384.global_opcode().as_usize(), + RV32_MEMORY_AS as usize, + &dec_insn, + true, + ); + Some(TranspilerOutput::one_to_one(instruction)) + } else { + None + } + } +} diff --git a/extensions/sha256/circuit/src/lib.rs b/extensions/sha256/circuit/src/lib.rs deleted file mode 100644 index fe0844f902..0000000000 --- a/extensions/sha256/circuit/src/lib.rs +++ /dev/null @@ -1,5 +0,0 @@ -mod sha256_chip; -pub use sha256_chip::*; - -mod extension; -pub use extension::*; diff --git a/extensions/sha256/circuit/src/sha256_chip/air.rs b/extensions/sha256/circuit/src/sha256_chip/air.rs deleted file mode 100644 index a191157ff0..0000000000 --- a/extensions/sha256/circuit/src/sha256_chip/air.rs +++ /dev/null @@ -1,621 +0,0 @@ -use std::{array, borrow::Borrow, cmp::min}; - -use openvm_circuit::{ - arch::{ExecutionBridge, SystemPort}, - system::memory::{offline_checker::MemoryBridge, MemoryAddress}, -}; -use openvm_circuit_primitives::{ - bitwise_op_lookup::BitwiseOperationLookupBus, encoder::Encoder, utils::not, SubAir, -}; -use openvm_instructions::{ - riscv::{RV32_CELL_BITS, RV32_MEMORY_AS, RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS}, - LocalOpcode, -}; -use openvm_sha256_air::{ - compose, Sha256Air, SHA256_BLOCK_U8S, SHA256_HASH_WORDS, SHA256_ROUNDS_PER_ROW, - SHA256_WORD_U16S, SHA256_WORD_U8S, -}; -use openvm_sha256_transpiler::Rv32Sha256Opcode; -use openvm_stark_backend::{ - interaction::{BusIndex, InteractionBuilder}, - p3_air::{Air, AirBuilder, BaseAir}, - p3_field::{Field, FieldAlgebra}, - p3_matrix::Matrix, - rap::{BaseAirWithPublicValues, PartitionedBaseAir}, -}; - -use super::{ - Sha256VmDigestCols, Sha256VmRoundCols, SHA256VM_CONTROL_WIDTH, SHA256VM_DIGEST_WIDTH, - SHA256VM_ROUND_WIDTH, SHA256VM_WIDTH, SHA256_READ_SIZE, -}; - -/// Sha256VmAir does all constraints related to message padding and -/// the Sha256Air subair constrains the actual hash -#[derive(Clone, Debug)] -pub struct Sha256VmAir { - pub execution_bridge: ExecutionBridge, - pub memory_bridge: MemoryBridge, - /// Bus to send byte checks to - pub bitwise_lookup_bus: BitwiseOperationLookupBus, - /// Maximum number of bits allowed for an address pointer - /// Must be at least 24 - pub ptr_max_bits: usize, - pub(super) sha256_subair: Sha256Air, - pub(super) padding_encoder: Encoder, -} - -impl Sha256VmAir { - pub fn new( - SystemPort { - execution_bus, - program_bus, - memory_bridge, - }: SystemPort, - bitwise_lookup_bus: BitwiseOperationLookupBus, - ptr_max_bits: usize, - self_bus_idx: BusIndex, - ) -> Self { - Self { - execution_bridge: ExecutionBridge::new(execution_bus, program_bus), - memory_bridge, - bitwise_lookup_bus, - ptr_max_bits, - sha256_subair: Sha256Air::new(bitwise_lookup_bus, self_bus_idx), - padding_encoder: Encoder::new(PaddingFlags::COUNT, 2, false), - } - } -} - -impl BaseAirWithPublicValues for Sha256VmAir {} -impl PartitionedBaseAir for Sha256VmAir {} -impl BaseAir for Sha256VmAir { - fn width(&self) -> usize { - SHA256VM_WIDTH - } -} - -impl Air for Sha256VmAir { - fn eval(&self, builder: &mut AB) { - self.eval_padding(builder); - self.eval_transitions(builder); - self.eval_reads(builder); - self.eval_last_row(builder); - - self.sha256_subair.eval(builder, SHA256VM_CONTROL_WIDTH); - } -} - -#[allow(dead_code, non_camel_case_types)] -pub(super) enum PaddingFlags { - /// Not considered for padding - W's are not constrained - NotConsidered, - /// Not padding - W's should be equal to the message - NotPadding, - /// FIRST_PADDING_i: it is the first row with padding and there are i cells of non-padding - FirstPadding0, - FirstPadding1, - FirstPadding2, - FirstPadding3, - FirstPadding4, - FirstPadding5, - FirstPadding6, - FirstPadding7, - FirstPadding8, - FirstPadding9, - FirstPadding10, - FirstPadding11, - FirstPadding12, - FirstPadding13, - FirstPadding14, - FirstPadding15, - /// FIRST_PADDING_i_LastRow: it is the first row with padding and there are i cells of - /// non-padding AND it is the last reading row of the message - /// NOTE: if the Last row has padding it has to be at least 9 cells since the last 8 cells are - /// padded with the message length - FirstPadding0_LastRow, - FirstPadding1_LastRow, - FirstPadding2_LastRow, - FirstPadding3_LastRow, - FirstPadding4_LastRow, - FirstPadding5_LastRow, - FirstPadding6_LastRow, - FirstPadding7_LastRow, - /// The entire row is padding AND it is not the first row with padding - /// AND it is the 4th row of the last block of the message - EntirePaddingLastRow, - /// The entire row is padding AND it is not the first row with padding - EntirePadding, -} - -impl PaddingFlags { - /// The number of padding flags (including NotConsidered) - pub const COUNT: usize = EntirePadding as usize + 1; -} - -use PaddingFlags::*; -impl Sha256VmAir { - /// Implement all necessary constraints for the padding - fn eval_padding(&self, builder: &mut AB) { - let main = builder.main(); - let (local, next) = (main.row_slice(0), main.row_slice(1)); - let local_cols: &Sha256VmRoundCols = local[..SHA256VM_ROUND_WIDTH].borrow(); - let next_cols: &Sha256VmRoundCols = next[..SHA256VM_ROUND_WIDTH].borrow(); - - // Constrain the sanity of the padding flags - self.padding_encoder - .eval(builder, &local_cols.control.pad_flags); - - builder.assert_one(self.padding_encoder.contains_flag_range::( - &local_cols.control.pad_flags, - NotConsidered as usize..=EntirePadding as usize, - )); - - Self::eval_padding_transitions(self, builder, local_cols, next_cols); - Self::eval_padding_row(self, builder, local_cols); - } - - fn eval_padding_transitions( - &self, - builder: &mut AB, - local: &Sha256VmRoundCols, - next: &Sha256VmRoundCols, - ) { - let next_is_last_row = next.inner.flags.is_digest_row * next.inner.flags.is_last_block; - - // Constrain that `padding_occured` is 1 on a suffix of rows in each message, excluding the - // last digest row, and 0 everywhere else. Furthermore, the suffix starts in the - // first 4 rows of some block. - - builder.assert_bool(local.control.padding_occurred); - // Last round row in the last block has padding_occurred = 1 - // This is the end of the suffix - builder - .when(next_is_last_row.clone()) - .assert_one(local.control.padding_occurred); - - // Digest row in the last block has padding_occurred = 0 - builder - .when(next_is_last_row.clone()) - .assert_zero(next.control.padding_occurred); - - // If padding_occurred = 1 in the current row, then padding_occurred = 1 in the next row, - // unless next is the last digest row - builder - .when(local.control.padding_occurred - next_is_last_row.clone()) - .assert_one(next.control.padding_occurred); - - // If next row is not first 4 rows of a block, then next.padding_occurred = - // local.padding_occurred. So padding_occurred only changes in the first 4 rows of a - // block. - builder - .when_transition() - .when(not(next.inner.flags.is_first_4_rows) - next_is_last_row) - .assert_eq( - next.control.padding_occurred, - local.control.padding_occurred, - ); - - // Constrain the that the start of the padding is correct - let next_is_first_padding_row = - next.control.padding_occurred - local.control.padding_occurred; - // Row index if its between 0..4, else 0 - let next_row_idx = self.sha256_subair.row_idx_encoder.flag_with_val::( - &next.inner.flags.row_idx, - &(0..4).map(|x| (x, x)).collect::>(), - ); - // How many non-padding cells there are in the next row. - // Will be 0 on non-padding rows. - let next_padding_offset = self.padding_encoder.flag_with_val::( - &next.control.pad_flags, - &(0..16) - .map(|i| (FirstPadding0 as usize + i, i)) - .collect::>(), - ) + self.padding_encoder.flag_with_val::( - &next.control.pad_flags, - &(0..8) - .map(|i| (FirstPadding0_LastRow as usize + i, i)) - .collect::>(), - ); - - // Will be 0 on last digest row since: - // - padding_occurred = 0 is constrained above - // - next_row_idx = 0 since row_idx is not in 0..4 - // - and next_padding_offset = 0 since `pad_flags = NotConsidered` - let expected_len = next.inner.flags.local_block_idx - * next.control.padding_occurred - * AB::Expr::from_canonical_usize(SHA256_BLOCK_U8S) - + next_row_idx * AB::Expr::from_canonical_usize(SHA256_READ_SIZE) - + next_padding_offset; - - // Note: `next_is_first_padding_row` is either -1,0,1 - // If 1, then this constrains the length of message - // If -1, then `next` must be the last digest row and so this constraint will be 0 == 0 - builder.when(next_is_first_padding_row).assert_eq( - expected_len, - next.control.len * next.control.padding_occurred, - ); - - // Constrain the padding flags are of correct type (eg is not padding or first padding) - let is_next_first_padding = self.padding_encoder.contains_flag_range::( - &next.control.pad_flags, - FirstPadding0 as usize..=FirstPadding7_LastRow as usize, - ); - - let is_next_last_padding = self.padding_encoder.contains_flag_range::( - &next.control.pad_flags, - FirstPadding0_LastRow as usize..=EntirePaddingLastRow as usize, - ); - - let is_next_entire_padding = self.padding_encoder.contains_flag_range::( - &next.control.pad_flags, - EntirePaddingLastRow as usize..=EntirePadding as usize, - ); - - let is_next_not_considered = self - .padding_encoder - .contains_flag::(&next.control.pad_flags, &[NotConsidered as usize]); - - let is_next_not_padding = self - .padding_encoder - .contains_flag::(&next.control.pad_flags, &[NotPadding as usize]); - - let is_next_4th_row = self - .sha256_subair - .row_idx_encoder - .contains_flag::(&next.inner.flags.row_idx, &[3]); - - // `pad_flags` is `NotConsidered` on all rows except the first 4 rows of a block - builder.assert_eq( - not(next.inner.flags.is_first_4_rows), - is_next_not_considered, - ); - - // `pad_flags` is `EntirePadding` if the previous row is padding - builder.when(next.inner.flags.is_first_4_rows).assert_eq( - local.control.padding_occurred * next.control.padding_occurred, - is_next_entire_padding, - ); - - // `pad_flags` is `FirstPadding*` if current row is padding and the previous row is not - // padding - builder.when(next.inner.flags.is_first_4_rows).assert_eq( - not(local.control.padding_occurred) * next.control.padding_occurred, - is_next_first_padding, - ); - - // `pad_flags` is `NotPadding` if current row is not padding - builder - .when(next.inner.flags.is_first_4_rows) - .assert_eq(not(next.control.padding_occurred), is_next_not_padding); - - // `pad_flags` is `*LastRow` on the row that contains the last four words of the message - builder - .when(next.inner.flags.is_last_block) - .assert_eq(is_next_4th_row, is_next_last_padding); - } - - fn eval_padding_row( - &self, - builder: &mut AB, - local: &Sha256VmRoundCols, - ) { - let message: [AB::Var; SHA256_READ_SIZE] = array::from_fn(|i| { - local.inner.message_schedule.carry_or_buffer[i / (SHA256_WORD_U8S)] - [i % (SHA256_WORD_U8S)] - }); - - let get_ith_byte = |i: usize| { - let word_idx = i / SHA256_ROUNDS_PER_ROW; - let word = local.inner.message_schedule.w[word_idx].map(|x| x.into()); - // Need to reverse the byte order to match the endianness of the memory - let byte_idx = 4 - i % 4 - 1; - compose::(&word[byte_idx * 8..(byte_idx + 1) * 8], 1) - }; - - let is_not_padding = self - .padding_encoder - .contains_flag::(&local.control.pad_flags, &[NotPadding as usize]); - - // Check the `w`s on case by case basis - for (i, message_byte) in message.iter().enumerate() { - let w = get_ith_byte(i); - let should_be_message = is_not_padding.clone() - + if i < 15 { - self.padding_encoder.contains_flag_range::( - &local.control.pad_flags, - FirstPadding0 as usize + i + 1..=FirstPadding15 as usize, - ) - } else { - AB::Expr::ZERO - } - + if i < 7 { - self.padding_encoder.contains_flag_range::( - &local.control.pad_flags, - FirstPadding0_LastRow as usize + i + 1..=FirstPadding7_LastRow as usize, - ) - } else { - AB::Expr::ZERO - }; - builder - .when(should_be_message) - .assert_eq(w.clone(), *message_byte); - - let should_be_zero = self - .padding_encoder - .contains_flag::(&local.control.pad_flags, &[EntirePadding as usize]) - + if i < 12 { - self.padding_encoder.contains_flag::( - &local.control.pad_flags, - &[EntirePaddingLastRow as usize], - ) + if i > 0 { - self.padding_encoder.contains_flag_range::( - &local.control.pad_flags, - FirstPadding0_LastRow as usize - ..=min( - FirstPadding0_LastRow as usize + i - 1, - FirstPadding7_LastRow as usize, - ), - ) - } else { - AB::Expr::ZERO - } - } else { - AB::Expr::ZERO - } - + if i > 0 { - self.padding_encoder.contains_flag_range::( - &local.control.pad_flags, - FirstPadding0 as usize..=FirstPadding0 as usize + i - 1, - ) - } else { - AB::Expr::ZERO - }; - builder.when(should_be_zero).assert_zero(w.clone()); - - // Assumes bit-length of message is a multiple of 8 (message is bytes) - // This is true because the message is given as &[u8] - let should_be_128 = self - .padding_encoder - .contains_flag::(&local.control.pad_flags, &[FirstPadding0 as usize + i]) - + if i < 8 { - self.padding_encoder.contains_flag::( - &local.control.pad_flags, - &[FirstPadding0_LastRow as usize + i], - ) - } else { - AB::Expr::ZERO - }; - - builder - .when(should_be_128) - .assert_eq(AB::Expr::from_canonical_u32(1 << 7), w); - - // should be len is handled outside of the loop - } - let appended_len = compose::( - &[ - get_ith_byte(15), - get_ith_byte(14), - get_ith_byte(13), - get_ith_byte(12), - ], - RV32_CELL_BITS, - ); - - let actual_len = local.control.len; - - let is_last_padding_row = self.padding_encoder.contains_flag_range::( - &local.control.pad_flags, - FirstPadding0_LastRow as usize..=EntirePaddingLastRow as usize, - ); - - builder.when(is_last_padding_row.clone()).assert_eq( - appended_len * AB::F::from_canonical_usize(RV32_CELL_BITS).inverse(), // bit to byte conversion - actual_len, - ); - - // We constrain that the appended length is in bytes - builder.when(is_last_padding_row.clone()).assert_zero( - local.inner.message_schedule.w[3][0] - + local.inner.message_schedule.w[3][1] - + local.inner.message_schedule.w[3][2], - ); - - // We can't support messages longer than 2^30 bytes because the length has to fit in a field - // element. So, constrain that the first 4 bytes of the length are 0. - // Thus, the bit-length is < 2^32 so the message is < 2^29 bytes. - for i in 8..12 { - builder - .when(is_last_padding_row.clone()) - .assert_zero(get_ith_byte(i)); - } - } - /// Implement constraints on `len`, `read_ptr` and `cur_timestamp` - fn eval_transitions(&self, builder: &mut AB) { - let main = builder.main(); - let (local, next) = (main.row_slice(0), main.row_slice(1)); - let local_cols: &Sha256VmRoundCols = local[..SHA256VM_ROUND_WIDTH].borrow(); - let next_cols: &Sha256VmRoundCols = next[..SHA256VM_ROUND_WIDTH].borrow(); - - let is_last_row = - local_cols.inner.flags.is_last_block * local_cols.inner.flags.is_digest_row; - - // Len should be the same for the entire message - builder - .when_transition() - .when(not::(is_last_row.clone())) - .assert_eq(next_cols.control.len, local_cols.control.len); - - // Read ptr should increment by [SHA256_READ_SIZE] for the first 4 rows and stay the same - // otherwise - let read_ptr_delta = local_cols.inner.flags.is_first_4_rows - * AB::Expr::from_canonical_usize(SHA256_READ_SIZE); - builder - .when_transition() - .when(not::(is_last_row.clone())) - .assert_eq( - next_cols.control.read_ptr, - local_cols.control.read_ptr + read_ptr_delta, - ); - - // Timestamp should increment by 1 for the first 4 rows and stay the same otherwise - let timestamp_delta = local_cols.inner.flags.is_first_4_rows * AB::Expr::ONE; - builder - .when_transition() - .when(not::(is_last_row.clone())) - .assert_eq( - next_cols.control.cur_timestamp, - local_cols.control.cur_timestamp + timestamp_delta, - ); - } - - /// Implement the reads for the first 4 rows of a block - fn eval_reads(&self, builder: &mut AB) { - let main = builder.main(); - let local = main.row_slice(0); - let local_cols: &Sha256VmRoundCols = local[..SHA256VM_ROUND_WIDTH].borrow(); - - let message: [AB::Var; SHA256_READ_SIZE] = array::from_fn(|i| { - local_cols.inner.message_schedule.carry_or_buffer[i / (SHA256_WORD_U16S * 2)] - [i % (SHA256_WORD_U16S * 2)] - }); - - self.memory_bridge - .read( - MemoryAddress::new( - AB::Expr::from_canonical_u32(RV32_MEMORY_AS), - local_cols.control.read_ptr, - ), - message, - local_cols.control.cur_timestamp, - &local_cols.read_aux, - ) - .eval(builder, local_cols.inner.flags.is_first_4_rows); - } - /// Implement the constraints for the last row of a message - fn eval_last_row(&self, builder: &mut AB) { - let main = builder.main(); - let local = main.row_slice(0); - let local_cols: &Sha256VmDigestCols = local[..SHA256VM_DIGEST_WIDTH].borrow(); - - let timestamp: AB::Var = local_cols.from_state.timestamp; - let mut timestamp_delta: usize = 0; - let mut timestamp_pp = || { - timestamp_delta += 1; - timestamp + AB::Expr::from_canonical_usize(timestamp_delta - 1) - }; - - let is_last_row = - local_cols.inner.flags.is_last_block * local_cols.inner.flags.is_digest_row; - - self.memory_bridge - .read( - MemoryAddress::new( - AB::Expr::from_canonical_u32(RV32_REGISTER_AS), - local_cols.rd_ptr, - ), - local_cols.dst_ptr, - timestamp_pp(), - &local_cols.register_reads_aux[0], - ) - .eval(builder, is_last_row.clone()); - - self.memory_bridge - .read( - MemoryAddress::new( - AB::Expr::from_canonical_u32(RV32_REGISTER_AS), - local_cols.rs1_ptr, - ), - local_cols.src_ptr, - timestamp_pp(), - &local_cols.register_reads_aux[1], - ) - .eval(builder, is_last_row.clone()); - - self.memory_bridge - .read( - MemoryAddress::new( - AB::Expr::from_canonical_u32(RV32_REGISTER_AS), - local_cols.rs2_ptr, - ), - local_cols.len_data, - timestamp_pp(), - &local_cols.register_reads_aux[2], - ) - .eval(builder, is_last_row.clone()); - - // range check that the memory pointers don't overflow - // Note: no need to range check the length since we read from memory step by step and - // the memory bus will catch any memory accesses beyond ptr_max_bits - let shift = AB::Expr::from_canonical_usize( - 1 << (RV32_REGISTER_NUM_LIMBS * RV32_CELL_BITS - self.ptr_max_bits), - ); - // This only works if self.ptr_max_bits >= 24 which is typically the case - self.bitwise_lookup_bus - .send_range( - // It is fine to shift like this since we already know that dst_ptr and src_ptr - // have [RV32_CELL_BITS] bits - local_cols.dst_ptr[RV32_REGISTER_NUM_LIMBS - 1] * shift.clone(), - local_cols.src_ptr[RV32_REGISTER_NUM_LIMBS - 1] * shift.clone(), - ) - .eval(builder, is_last_row.clone()); - - // the number of reads that happened to read the entire message: we do 4 reads per block - let time_delta = (local_cols.inner.flags.local_block_idx + AB::Expr::ONE) - * AB::Expr::from_canonical_usize(4); - // Every time we read the message we increment the read pointer by SHA256_READ_SIZE - let read_ptr_delta = time_delta.clone() * AB::Expr::from_canonical_usize(SHA256_READ_SIZE); - - let result: [AB::Var; SHA256_WORD_U8S * SHA256_HASH_WORDS] = array::from_fn(|i| { - // The limbs are written in big endian order to the memory so need to be reversed - local_cols.inner.final_hash[i / SHA256_WORD_U8S] - [SHA256_WORD_U8S - i % SHA256_WORD_U8S - 1] - }); - - let dst_ptr_val = - compose::(&local_cols.dst_ptr.map(|x| x.into()), RV32_CELL_BITS); - - // Note: revisit in the future to do 2 block writes of 16 cells instead of 1 block write of - // 32 cells This could be beneficial as the output is often an input for - // another hash - self.memory_bridge - .write( - MemoryAddress::new(AB::Expr::from_canonical_u32(RV32_MEMORY_AS), dst_ptr_val), - result, - timestamp_pp() + time_delta.clone(), - &local_cols.writes_aux, - ) - .eval(builder, is_last_row.clone()); - - self.execution_bridge - .execute_and_increment_pc( - AB::Expr::from_canonical_usize(Rv32Sha256Opcode::SHA256.global_opcode().as_usize()), - [ - local_cols.rd_ptr.into(), - local_cols.rs1_ptr.into(), - local_cols.rs2_ptr.into(), - AB::Expr::from_canonical_u32(RV32_REGISTER_AS), - AB::Expr::from_canonical_u32(RV32_MEMORY_AS), - ], - local_cols.from_state, - AB::Expr::from_canonical_usize(timestamp_delta) + time_delta.clone(), - ) - .eval(builder, is_last_row.clone()); - - // Assert that we read the correct length of the message - let len_val = compose::(&local_cols.len_data.map(|x| x.into()), RV32_CELL_BITS); - builder - .when(is_last_row.clone()) - .assert_eq(local_cols.control.len, len_val); - // Assert that we started reading from the correct pointer initially - let src_val = compose::(&local_cols.src_ptr.map(|x| x.into()), RV32_CELL_BITS); - builder - .when(is_last_row.clone()) - .assert_eq(local_cols.control.read_ptr, src_val + read_ptr_delta); - // Assert that we started reading from the correct timestamp - builder.when(is_last_row.clone()).assert_eq( - local_cols.control.cur_timestamp, - local_cols.from_state.timestamp + AB::Expr::from_canonical_u32(3) + time_delta, - ); - } -} diff --git a/extensions/sha256/circuit/src/sha256_chip/columns.rs b/extensions/sha256/circuit/src/sha256_chip/columns.rs deleted file mode 100644 index 38c13a0f73..0000000000 --- a/extensions/sha256/circuit/src/sha256_chip/columns.rs +++ /dev/null @@ -1,70 +0,0 @@ -//! WARNING: the order of fields in the structs is important, do not change it - -use openvm_circuit::{ - arch::ExecutionState, - system::memory::offline_checker::{MemoryReadAuxCols, MemoryWriteAuxCols}, -}; -use openvm_circuit_primitives::AlignedBorrow; -use openvm_instructions::riscv::RV32_REGISTER_NUM_LIMBS; -use openvm_sha256_air::{Sha256DigestCols, Sha256RoundCols}; - -use super::{SHA256_REGISTER_READS, SHA256_WRITE_SIZE}; - -/// the first 16 rows of every SHA256 block will be of type Sha256VmRoundCols and the last row will -/// be of type Sha256VmDigestCols -#[repr(C)] -#[derive(Clone, Copy, Debug, AlignedBorrow)] -pub struct Sha256VmRoundCols { - pub control: Sha256VmControlCols, - pub inner: Sha256RoundCols, - pub read_aux: MemoryReadAuxCols, -} - -#[repr(C)] -#[derive(Clone, Copy, Debug, AlignedBorrow)] -pub struct Sha256VmDigestCols { - pub control: Sha256VmControlCols, - pub inner: Sha256DigestCols, - - pub from_state: ExecutionState, - /// It is counter intuitive, but we will constrain the register reads on the very last row of - /// every message - pub rd_ptr: T, - pub rs1_ptr: T, - pub rs2_ptr: T, - pub dst_ptr: [T; RV32_REGISTER_NUM_LIMBS], - pub src_ptr: [T; RV32_REGISTER_NUM_LIMBS], - pub len_data: [T; RV32_REGISTER_NUM_LIMBS], - pub register_reads_aux: [MemoryReadAuxCols; SHA256_REGISTER_READS], - pub writes_aux: MemoryWriteAuxCols, -} - -/// These are the columns that are used on both round and digest rows -#[repr(C)] -#[derive(Clone, Copy, Debug, AlignedBorrow)] -pub struct Sha256VmControlCols { - /// Note: We will use the buffer in `inner.message_schedule` as the message data - /// This is the length of the entire message in bytes - pub len: T, - /// Need to keep timestamp and read_ptr since block reads don't have the necessary information - pub cur_timestamp: T, - pub read_ptr: T, - /// Padding flags which will be used to encode the the number of non-padding cells in the - /// current row - pub pad_flags: [T; 6], - /// A boolean flag that indicates whether a padding already occurred - pub padding_occurred: T, -} - -/// Width of the Sha256VmControlCols -pub const SHA256VM_CONTROL_WIDTH: usize = Sha256VmControlCols::::width(); -/// Width of the Sha256VmRoundCols -pub const SHA256VM_ROUND_WIDTH: usize = Sha256VmRoundCols::::width(); -/// Width of the Sha256VmDigestCols -pub const SHA256VM_DIGEST_WIDTH: usize = Sha256VmDigestCols::::width(); -/// Width of the Sha256Cols -pub const SHA256VM_WIDTH: usize = if SHA256VM_ROUND_WIDTH > SHA256VM_DIGEST_WIDTH { - SHA256VM_ROUND_WIDTH -} else { - SHA256VM_DIGEST_WIDTH -}; diff --git a/extensions/sha256/circuit/src/sha256_chip/mod.rs b/extensions/sha256/circuit/src/sha256_chip/mod.rs deleted file mode 100644 index eb09967b22..0000000000 --- a/extensions/sha256/circuit/src/sha256_chip/mod.rs +++ /dev/null @@ -1,225 +0,0 @@ -//! Sha256 hasher. Handles full sha256 hashing with padding. -//! variable length inputs read from VM memory. - -use std::borrow::{Borrow, BorrowMut}; - -use openvm_circuit::arch::{ - execution_mode::E1ExecutionCtx, E2PreCompute, MatrixRecordArena, NewVmChipWrapper, Result, - StepExecutorE1, StepExecutorE2, -}; -use openvm_circuit_primitives::{ - bitwise_op_lookup::SharedBitwiseOperationLookupChip, encoder::Encoder, -}; -use openvm_instructions::{ - instruction::Instruction, - program::DEFAULT_PC_STEP, - riscv::{RV32_CELL_BITS, RV32_MEMORY_AS, RV32_REGISTER_AS}, - LocalOpcode, -}; -use openvm_sha256_air::{ - get_sha256_num_blocks, Sha256StepHelper, SHA256_BLOCK_BITS, SHA256_ROWS_PER_BLOCK, -}; -use openvm_sha256_transpiler::Rv32Sha256Opcode; -use openvm_stark_backend::p3_field::PrimeField32; -use sha2::{Digest, Sha256}; - -mod air; -mod columns; -mod trace; - -pub use air::*; -pub use columns::*; -use openvm_circuit::arch::{ - execution_mode::E2ExecutionCtx, ExecuteFunc, ExecutionError::InvalidInstruction, VmSegmentState, -}; -use openvm_circuit_primitives_derive::AlignedBytesBorrow; - -#[cfg(test)] -mod tests; - -// ==== Constants for register/memory adapter ==== -/// Register reads to get dst, src, len -const SHA256_REGISTER_READS: usize = 3; -/// Number of cells to read in a single memory access -const SHA256_READ_SIZE: usize = 16; -/// Number of cells to write in a single memory access -const SHA256_WRITE_SIZE: usize = 32; -/// Number of rv32 cells read in a SHA256 block -pub const SHA256_BLOCK_CELLS: usize = SHA256_BLOCK_BITS / RV32_CELL_BITS; -/// Number of rows we will do a read on for each SHA256 block -pub const SHA256_NUM_READ_ROWS: usize = SHA256_BLOCK_CELLS / SHA256_READ_SIZE; -/// Maximum message length that this chip supports in bytes -pub const SHA256_MAX_MESSAGE_LEN: usize = 1 << 29; - -pub type Sha256VmChip = NewVmChipWrapper>; - -pub struct Sha256VmStep { - pub inner: Sha256StepHelper, - pub padding_encoder: Encoder, - pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, - pub offset: usize, - pub pointer_max_bits: usize, -} - -impl Sha256VmStep { - pub fn new( - bitwise_lookup_chip: SharedBitwiseOperationLookupChip, - offset: usize, - pointer_max_bits: usize, - ) -> Self { - Self { - inner: Sha256StepHelper::new(), - padding_encoder: Encoder::new(PaddingFlags::COUNT, 2, false), - bitwise_lookup_chip, - offset, - pointer_max_bits, - } - } -} - -#[derive(AlignedBytesBorrow, Clone)] -#[repr(C)] -struct ShaPreCompute { - a: u8, - b: u8, - c: u8, -} - -impl StepExecutorE1 for Sha256VmStep { - fn pre_compute_size(&self) -> usize { - size_of::() - } - - fn pre_compute_e1( - &self, - pc: u32, - inst: &Instruction, - data: &mut [u8], - ) -> Result> - where - Ctx: E1ExecutionCtx, - { - let data: &mut ShaPreCompute = data.borrow_mut(); - self.pre_compute_impl(pc, inst, data)?; - Ok(execute_e1_impl::<_, _>) - } -} -impl StepExecutorE2 for Sha256VmStep { - fn e2_pre_compute_size(&self) -> usize { - size_of::>() - } - - fn pre_compute_e2( - &self, - chip_idx: usize, - pc: u32, - inst: &Instruction, - data: &mut [u8], - ) -> Result> - where - Ctx: E2ExecutionCtx, - { - let data: &mut E2PreCompute = data.borrow_mut(); - data.chip_idx = chip_idx as u32; - self.pre_compute_impl(pc, inst, &mut data.data)?; - Ok(execute_e2_impl::<_, _>) - } -} - -unsafe fn execute_e12_impl( - pre_compute: &ShaPreCompute, - vm_state: &mut VmSegmentState, -) -> u32 { - let dst = vm_state.vm_read(RV32_REGISTER_AS, pre_compute.a as u32); - let src = vm_state.vm_read(RV32_REGISTER_AS, pre_compute.b as u32); - let len = vm_state.vm_read(RV32_REGISTER_AS, pre_compute.c as u32); - let dst_u32 = u32::from_le_bytes(dst); - let src_u32 = u32::from_le_bytes(src); - let len_u32 = u32::from_le_bytes(len); - - let (output, height) = if IS_E1 { - // SAFETY: RV32_MEMORY_AS is memory address space of type u8 - let message = vm_state.vm_read_slice(RV32_MEMORY_AS, src_u32, len_u32 as usize); - let output = sha256_solve(message); - (output, 0) - } else { - let num_blocks = get_sha256_num_blocks(len_u32); - let mut message = Vec::with_capacity(len_u32 as usize); - for block_idx in 0..num_blocks as usize { - // Reads happen on the first 4 rows of each block - for row in 0..SHA256_NUM_READ_ROWS { - let read_idx = block_idx * SHA256_NUM_READ_ROWS + row; - let row_input: [u8; SHA256_READ_SIZE] = vm_state.vm_read( - RV32_MEMORY_AS, - src_u32 + (read_idx * SHA256_READ_SIZE) as u32, - ); - message.extend_from_slice(&row_input); - } - } - let output = sha256_solve(&message[..len_u32 as usize]); - let height = num_blocks * SHA256_ROWS_PER_BLOCK as u32; - (output, height) - }; - vm_state.vm_write(RV32_MEMORY_AS, dst_u32, &output); - - vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP); - vm_state.instret += 1; - - height -} - -unsafe fn execute_e1_impl( - pre_compute: &[u8], - vm_state: &mut VmSegmentState, -) { - let pre_compute: &ShaPreCompute = pre_compute.borrow(); - execute_e12_impl::(pre_compute, vm_state); -} -unsafe fn execute_e2_impl( - pre_compute: &[u8], - vm_state: &mut VmSegmentState, -) { - let pre_compute: &E2PreCompute = pre_compute.borrow(); - let height = execute_e12_impl::(&pre_compute.data, vm_state); - vm_state - .ctx - .on_height_change(pre_compute.chip_idx as usize, height); -} - -impl Sha256VmStep { - fn pre_compute_impl( - &self, - pc: u32, - inst: &Instruction, - data: &mut ShaPreCompute, - ) -> Result<()> { - let Instruction { - opcode, - a, - b, - c, - d, - e, - .. - } = inst; - let e_u32 = e.as_canonical_u32(); - if d.as_canonical_u32() != RV32_REGISTER_AS || e_u32 != RV32_MEMORY_AS { - return Err(InvalidInstruction(pc)); - } - *data = ShaPreCompute { - a: a.as_canonical_u32() as u8, - b: b.as_canonical_u32() as u8, - c: c.as_canonical_u32() as u8, - }; - assert_eq!(&Rv32Sha256Opcode::SHA256.global_opcode(), opcode); - Ok(()) - } -} - -pub fn sha256_solve(input_message: &[u8]) -> [u8; SHA256_WRITE_SIZE] { - let mut hasher = Sha256::new(); - hasher.update(input_message); - let mut output = [0u8; SHA256_WRITE_SIZE]; - output.copy_from_slice(hasher.finalize().as_ref()); - output -} diff --git a/extensions/sha256/circuit/src/sha256_chip/trace.rs b/extensions/sha256/circuit/src/sha256_chip/trace.rs deleted file mode 100644 index 4b0f4bb85d..0000000000 --- a/extensions/sha256/circuit/src/sha256_chip/trace.rs +++ /dev/null @@ -1,600 +0,0 @@ -use std::{ - array, - borrow::{Borrow, BorrowMut}, - cmp::min, -}; - -use openvm_circuit::{ - arch::{ - get_record_from_slice, CustomBorrow, MultiRowLayout, MultiRowMetadata, RecordArena, Result, - SizedRecord, TraceFiller, TraceStep, VmStateMut, - }, - system::memory::{ - offline_checker::{MemoryReadAuxRecord, MemoryWriteBytesAuxRecord}, - online::TracingMemory, - MemoryAuxColsFactory, - }, -}; -use openvm_circuit_primitives::AlignedBytesBorrow; -use openvm_instructions::{ - instruction::Instruction, - program::DEFAULT_PC_STEP, - riscv::{RV32_CELL_BITS, RV32_MEMORY_AS, RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS}, - LocalOpcode, -}; -use openvm_rv32im_circuit::adapters::{read_rv32_register, tracing_read, tracing_write}; -use openvm_sha256_air::{ - get_flag_pt_array, get_sha256_num_blocks, Sha256StepHelper, SHA256_BLOCK_BITS, SHA256_H, - SHA256_ROWS_PER_BLOCK, -}; -use openvm_sha256_transpiler::Rv32Sha256Opcode; -use openvm_stark_backend::{ - p3_field::PrimeField32, - p3_matrix::{dense::RowMajorMatrix, Matrix}, - p3_maybe_rayon::prelude::*, -}; - -use super::{ - Sha256VmDigestCols, Sha256VmRoundCols, Sha256VmStep, SHA256VM_CONTROL_WIDTH, - SHA256VM_DIGEST_WIDTH, -}; -use crate::{ - sha256_chip::{PaddingFlags, SHA256_READ_SIZE, SHA256_REGISTER_READS, SHA256_WRITE_SIZE}, - sha256_solve, Sha256VmControlCols, SHA256VM_ROUND_WIDTH, SHA256VM_WIDTH, SHA256_BLOCK_CELLS, - SHA256_MAX_MESSAGE_LEN, SHA256_NUM_READ_ROWS, -}; - -#[derive(Clone, Copy)] -pub struct Sha256VmMetadata { - pub num_blocks: u32, -} - -impl MultiRowMetadata for Sha256VmMetadata { - #[inline(always)] - fn get_num_rows(&self) -> usize { - self.num_blocks as usize * SHA256_ROWS_PER_BLOCK - } -} - -pub(crate) type Sha256VmRecordLayout = MultiRowLayout; - -#[repr(C)] -#[derive(AlignedBytesBorrow, Debug, Clone)] -pub struct Sha256VmRecordHeader { - pub from_pc: u32, - pub timestamp: u32, - pub rd_ptr: u32, - pub rs1_ptr: u32, - pub rs2_ptr: u32, - pub dst_ptr: u32, - pub src_ptr: u32, - pub len: u32, - - pub register_reads_aux: [MemoryReadAuxRecord; SHA256_REGISTER_READS], - pub write_aux: MemoryWriteBytesAuxRecord, -} - -pub struct Sha256VmRecordMut<'a> { - pub inner: &'a mut Sha256VmRecordHeader, - // Having a continuous slice of the input is useful for fast hashing in `execute` - pub input: &'a mut [u8], - pub read_aux: &'a mut [MemoryReadAuxRecord], -} - -/// Custom borrowing that splits the buffer into a fixed `Sha256VmRecord` header -/// followed by a slice of `u8`'s of length `SHA256_BLOCK_CELLS * num_blocks` where `num_blocks` is -/// provided at runtime, followed by a slice of `MemoryReadAuxRecord`'s of length -/// `SHA256_NUM_READ_ROWS * num_blocks`. Uses `align_to_mut()` to make sure the slice is properly -/// aligned to `MemoryReadAuxRecord`. Has debug assertions that check the size and alignment of the -/// slices. -impl<'a> CustomBorrow<'a, Sha256VmRecordMut<'a>, Sha256VmRecordLayout> for [u8] { - fn custom_borrow(&'a mut self, layout: Sha256VmRecordLayout) -> Sha256VmRecordMut<'a> { - let (header_buf, rest) = - unsafe { self.split_at_mut_unchecked(size_of::()) }; - - // Using `split_at_mut_unchecked` for perf reasons - // input is a slice of `u8`'s of length `SHA256_BLOCK_CELLS * num_blocks`, so the alignment - // is always satisfied - let (input, rest) = unsafe { - rest.split_at_mut_unchecked((layout.metadata.num_blocks as usize) * SHA256_BLOCK_CELLS) - }; - - // Using `align_to_mut` to make sure the returned slice is properly aligned to - // `MemoryReadAuxRecord` Additionally, Rust's subslice operation (a few lines below) - // will verify that the buffer has enough capacity - let (_, read_aux_buf, _) = unsafe { rest.align_to_mut::() }; - Sha256VmRecordMut { - inner: header_buf.borrow_mut(), - input, - read_aux: &mut read_aux_buf - [..(layout.metadata.num_blocks as usize) * SHA256_NUM_READ_ROWS], - } - } - - unsafe fn extract_layout(&self) -> Sha256VmRecordLayout { - let header: &Sha256VmRecordHeader = self.borrow(); - Sha256VmRecordLayout { - metadata: Sha256VmMetadata { - num_blocks: get_sha256_num_blocks(header.len), - }, - } - } -} - -impl SizedRecord for Sha256VmRecordMut<'_> { - fn size(layout: &Sha256VmRecordLayout) -> usize { - let mut total_len = size_of::(); - total_len += layout.metadata.num_blocks as usize * SHA256_BLOCK_CELLS; - // Align the pointer to the alignment of `MemoryReadAuxRecord` - total_len = total_len.next_multiple_of(align_of::()); - total_len += layout.metadata.num_blocks as usize - * SHA256_NUM_READ_ROWS - * size_of::(); - total_len - } - - fn alignment(_layout: &Sha256VmRecordLayout) -> usize { - align_of::() - } -} - -impl TraceStep for Sha256VmStep { - type RecordLayout = Sha256VmRecordLayout; - type RecordMut<'a> = Sha256VmRecordMut<'a>; - - fn get_opcode_name(&self, _: usize) -> String { - format!("{:?}", Rv32Sha256Opcode::SHA256) - } - - fn execute<'buf, RA>( - &mut self, - state: VmStateMut, CTX>, - instruction: &Instruction, - arena: &'buf mut RA, - ) -> Result<()> - where - RA: RecordArena<'buf, Self::RecordLayout, Self::RecordMut<'buf>>, - { - let Instruction { - opcode, - a, - b, - c, - d, - e, - .. - } = instruction; - debug_assert_eq!(*opcode, Rv32Sha256Opcode::SHA256.global_opcode()); - debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); - debug_assert_eq!(e.as_canonical_u32(), RV32_MEMORY_AS); - - // Reading the length first to allocate a record of correct size - let len = read_rv32_register(state.memory.data(), c.as_canonical_u32()); - - let num_blocks = get_sha256_num_blocks(len); - let record = arena.alloc(MultiRowLayout { - metadata: Sha256VmMetadata { num_blocks }, - }); - - record.inner.from_pc = *state.pc; - record.inner.timestamp = state.memory.timestamp(); - record.inner.rd_ptr = a.as_canonical_u32(); - record.inner.rs1_ptr = b.as_canonical_u32(); - record.inner.rs2_ptr = c.as_canonical_u32(); - - record.inner.dst_ptr = u32::from_le_bytes(tracing_read( - state.memory, - RV32_REGISTER_AS, - record.inner.rd_ptr, - &mut record.inner.register_reads_aux[0].prev_timestamp, - )); - record.inner.src_ptr = u32::from_le_bytes(tracing_read( - state.memory, - RV32_REGISTER_AS, - record.inner.rs1_ptr, - &mut record.inner.register_reads_aux[1].prev_timestamp, - )); - record.inner.len = u32::from_le_bytes(tracing_read( - state.memory, - RV32_REGISTER_AS, - record.inner.rs2_ptr, - &mut record.inner.register_reads_aux[2].prev_timestamp, - )); - - // we will read [num_blocks] * [SHA256_BLOCK_CELLS] cells but only [len] cells will be used - debug_assert!( - record.inner.src_ptr as usize + num_blocks as usize * SHA256_BLOCK_CELLS - <= (1 << self.pointer_max_bits) - ); - debug_assert!( - record.inner.dst_ptr as usize + SHA256_WRITE_SIZE <= (1 << self.pointer_max_bits) - ); - // We don't support messages longer than 2^29 bytes - debug_assert!(record.inner.len < SHA256_MAX_MESSAGE_LEN as u32); - - for block_idx in 0..num_blocks as usize { - // Reads happen on the first 4 rows of each block - for row in 0..SHA256_NUM_READ_ROWS { - let read_idx = block_idx * SHA256_NUM_READ_ROWS + row; - let row_input: [u8; SHA256_READ_SIZE] = tracing_read( - state.memory, - RV32_MEMORY_AS, - record.inner.src_ptr + (read_idx * SHA256_READ_SIZE) as u32, - &mut record.read_aux[read_idx].prev_timestamp, - ); - record.input[read_idx * SHA256_READ_SIZE..(read_idx + 1) * SHA256_READ_SIZE] - .copy_from_slice(&row_input); - } - } - - let output = sha256_solve(&record.input[..len as usize]); - tracing_write( - state.memory, - RV32_MEMORY_AS, - record.inner.dst_ptr, - output, - &mut record.inner.write_aux.prev_timestamp, - &mut record.inner.write_aux.prev_data, - ); - - *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); - - Ok(()) - } -} - -impl TraceFiller for Sha256VmStep { - fn fill_trace( - &self, - mem_helper: &MemoryAuxColsFactory, - trace_matrix: &mut RowMajorMatrix, - rows_used: usize, - ) { - if rows_used == 0 { - return; - } - - let mut chunks = Vec::with_capacity(trace_matrix.height() / SHA256_ROWS_PER_BLOCK); - let mut sizes = Vec::with_capacity(trace_matrix.height() / SHA256_ROWS_PER_BLOCK); - let mut trace = &mut trace_matrix.values[..]; - let mut num_blocks_so_far = 0; - - // First pass over the trace to get the number of blocks for each instruction - // and divide the matrix into chunks of needed sizes - loop { - if num_blocks_so_far * SHA256_ROWS_PER_BLOCK >= rows_used { - // Push all the padding rows as a single chunk and break - chunks.push(trace); - sizes.push((0, num_blocks_so_far)); - break; - } else { - let record: &Sha256VmRecordHeader = - unsafe { get_record_from_slice(&mut trace, ()) }; - let num_blocks = ((record.len << 3) as usize + 1 + 64).div_ceil(SHA256_BLOCK_BITS); - let (chunk, rest) = - trace.split_at_mut(SHA256VM_WIDTH * SHA256_ROWS_PER_BLOCK * num_blocks); - chunks.push(chunk); - sizes.push((num_blocks, num_blocks_so_far)); - num_blocks_so_far += num_blocks; - trace = rest; - } - } - - // During the first pass we will fill out most of the matrix - // But there are some cells that can't be generated by the first pass so we will do a second - // pass over the matrix later - chunks.par_iter_mut().zip(sizes.par_iter()).for_each( - |(slice, (num_blocks, global_block_offset))| { - if global_block_offset * SHA256_ROWS_PER_BLOCK >= rows_used { - // Fill in the invalid rows - slice.par_chunks_mut(SHA256VM_WIDTH).for_each(|row| { - // Need to get rid of the accidental garbage data that might overflow the - // F's prime field. Unfortunately, there is no good way around this - unsafe { - std::ptr::write_bytes( - row.as_mut_ptr() as *mut u8, - 0, - SHA256VM_WIDTH * size_of::(), - ); - } - let cols: &mut Sha256VmRoundCols = - row[..SHA256VM_ROUND_WIDTH].borrow_mut(); - self.inner.generate_default_row(&mut cols.inner); - }); - return; - } - - let record: Sha256VmRecordMut = unsafe { - get_record_from_slice( - slice, - Sha256VmRecordLayout { - metadata: Sha256VmMetadata { - num_blocks: *num_blocks as u32, - }, - }, - ) - }; - - let mut input: Vec = Vec::with_capacity(SHA256_BLOCK_CELLS * num_blocks); - input.extend_from_slice(record.input); - let mut padded_input = input.clone(); - let len = record.inner.len as usize; - let padded_input_len = padded_input.len(); - padded_input[len] = 1 << (RV32_CELL_BITS - 1); - padded_input[len + 1..padded_input_len - 4].fill(0); - padded_input[padded_input_len - 4..] - .copy_from_slice(&((len as u32) << 3).to_be_bytes()); - - let mut prev_hashes = Vec::with_capacity(*num_blocks); - prev_hashes.push(SHA256_H); - for i in 0..*num_blocks - 1 { - prev_hashes.push(Sha256StepHelper::get_block_hash( - &prev_hashes[i], - padded_input[i * SHA256_BLOCK_CELLS..(i + 1) * SHA256_BLOCK_CELLS] - .try_into() - .unwrap(), - )); - } - // Copy the read aux records and input to another place to safely fill in the trace - // matrix without overwriting the record - let mut read_aux_records = Vec::with_capacity(SHA256_NUM_READ_ROWS * num_blocks); - read_aux_records.extend_from_slice(record.read_aux); - let vm_record = record.inner.clone(); - - slice - .par_chunks_exact_mut(SHA256VM_WIDTH * SHA256_ROWS_PER_BLOCK) - .enumerate() - .for_each(|(block_idx, block_slice)| { - // Need to get rid of the accidental garbage data that might overflow the - // F's prime field. Unfortunately, there is no good way around this - unsafe { - std::ptr::write_bytes( - block_slice.as_mut_ptr() as *mut u8, - 0, - SHA256_ROWS_PER_BLOCK * SHA256VM_WIDTH * size_of::(), - ); - } - self.fill_block_trace::( - block_slice, - &vm_record, - &read_aux_records[block_idx * SHA256_NUM_READ_ROWS - ..(block_idx + 1) * SHA256_NUM_READ_ROWS], - &input[block_idx * SHA256_BLOCK_CELLS - ..(block_idx + 1) * SHA256_BLOCK_CELLS], - &padded_input[block_idx * SHA256_BLOCK_CELLS - ..(block_idx + 1) * SHA256_BLOCK_CELLS], - block_idx == *num_blocks - 1, - *global_block_offset + block_idx, - block_idx, - prev_hashes[block_idx], - mem_helper, - ); - }); - }, - ); - - // Do a second pass over the trace to fill in the missing values - // Note, we need to skip the very first row - trace_matrix.values[SHA256VM_WIDTH..] - .par_chunks_mut(SHA256VM_WIDTH * SHA256_ROWS_PER_BLOCK) - .take(rows_used / SHA256_ROWS_PER_BLOCK) - .for_each(|chunk| { - self.inner - .generate_missing_cells(chunk, SHA256VM_WIDTH, SHA256VM_CONTROL_WIDTH); - }); - } -} - -impl Sha256VmStep { - #[allow(clippy::too_many_arguments)] - fn fill_block_trace( - &self, - block_slice: &mut [F], - record: &Sha256VmRecordHeader, - read_aux_records: &[MemoryReadAuxRecord], - input: &[u8], - padded_input: &[u8], - is_last_block: bool, - global_block_idx: usize, - local_block_idx: usize, - prev_hash: [u32; 8], - mem_helper: &MemoryAuxColsFactory, - ) { - debug_assert_eq!(input.len(), SHA256_BLOCK_CELLS); - debug_assert_eq!(padded_input.len(), SHA256_BLOCK_CELLS); - debug_assert_eq!(read_aux_records.len(), SHA256_NUM_READ_ROWS); - - let padded_input = array::from_fn(|i| { - u32::from_be_bytes(padded_input[i * 4..(i + 1) * 4].try_into().unwrap()) - }); - - let block_start_timestamp = record.timestamp - + (SHA256_REGISTER_READS + SHA256_NUM_READ_ROWS * local_block_idx) as u32; - - let read_cells = (SHA256_BLOCK_CELLS * local_block_idx) as u32; - let block_start_read_ptr = record.src_ptr + read_cells; - - let message_left = if record.len <= read_cells { - 0 - } else { - (record.len - read_cells) as usize - }; - - // -1 means that padding occurred before the start of the block - // 18 means that no padding occurred on this block - let first_padding_row = if record.len < read_cells { - -1 - } else if message_left < SHA256_BLOCK_CELLS { - (message_left / SHA256_READ_SIZE) as i32 - } else { - 18 - }; - - // Fill in the VM columns first because the inner `carry_or_buffer` needs to be filled in - block_slice - .par_chunks_exact_mut(SHA256VM_WIDTH) - .enumerate() - .for_each(|(row_idx, row_slice)| { - // Handle round rows and digest row separately - if row_idx == SHA256_ROWS_PER_BLOCK - 1 { - // This is a digest row - let digest_cols: &mut Sha256VmDigestCols = - row_slice[..SHA256VM_DIGEST_WIDTH].borrow_mut(); - digest_cols.from_state.timestamp = F::from_canonical_u32(record.timestamp); - digest_cols.from_state.pc = F::from_canonical_u32(record.from_pc); - digest_cols.rd_ptr = F::from_canonical_u32(record.rd_ptr); - digest_cols.rs1_ptr = F::from_canonical_u32(record.rs1_ptr); - digest_cols.rs2_ptr = F::from_canonical_u32(record.rs2_ptr); - digest_cols.dst_ptr = record.dst_ptr.to_le_bytes().map(F::from_canonical_u8); - digest_cols.src_ptr = record.src_ptr.to_le_bytes().map(F::from_canonical_u8); - digest_cols.len_data = record.len.to_le_bytes().map(F::from_canonical_u8); - if is_last_block { - digest_cols - .register_reads_aux - .iter_mut() - .zip(record.register_reads_aux.iter()) - .enumerate() - .for_each(|(idx, (cols_read, record_read))| { - mem_helper.fill( - record_read.prev_timestamp, - record.timestamp + idx as u32, - cols_read.as_mut(), - ); - }); - digest_cols - .writes_aux - .set_prev_data(record.write_aux.prev_data.map(F::from_canonical_u8)); - // In the last block we do `SHA256_NUM_READ_ROWS` reads and then write the - // result thus the timestamp of the write is - // `block_start_timestamp + SHA256_NUM_READ_ROWS` - mem_helper.fill( - record.write_aux.prev_timestamp, - block_start_timestamp + SHA256_NUM_READ_ROWS as u32, - digest_cols.writes_aux.as_mut(), - ); - // Need to range check the destination and source pointers - let msl_rshift: u32 = - ((RV32_REGISTER_NUM_LIMBS - 1) * RV32_CELL_BITS) as u32; - let msl_lshift: u32 = (RV32_REGISTER_NUM_LIMBS * RV32_CELL_BITS - - self.pointer_max_bits) - as u32; - self.bitwise_lookup_chip.request_range( - (record.dst_ptr >> msl_rshift) << msl_lshift, - (record.src_ptr >> msl_rshift) << msl_lshift, - ); - } else { - // Filling in zeros to make sure the accidental garbage data doesn't - // overflow the prime - digest_cols.register_reads_aux.iter_mut().for_each(|aux| { - mem_helper.fill_zero(aux.as_mut()); - }); - digest_cols - .writes_aux - .set_prev_data([F::ZERO; SHA256_WRITE_SIZE]); - mem_helper.fill_zero(digest_cols.writes_aux.as_mut()); - } - digest_cols.inner.flags.is_last_block = F::from_bool(is_last_block); - digest_cols.inner.flags.is_digest_row = F::from_bool(true); - } else { - // This is a round row - let round_cols: &mut Sha256VmRoundCols = - row_slice[..SHA256VM_ROUND_WIDTH].borrow_mut(); - // Take care of the first 4 round rows (aka read rows) - if row_idx < SHA256_NUM_READ_ROWS { - round_cols - .inner - .message_schedule - .carry_or_buffer - .as_flattened_mut() - .iter_mut() - .zip( - input[row_idx * SHA256_READ_SIZE..(row_idx + 1) * SHA256_READ_SIZE] - .iter(), - ) - .for_each(|(cell, data)| { - *cell = F::from_canonical_u8(*data); - }); - mem_helper.fill( - read_aux_records[row_idx].prev_timestamp, - block_start_timestamp + row_idx as u32, - round_cols.read_aux.as_mut(), - ); - } else { - mem_helper.fill_zero(round_cols.read_aux.as_mut()); - } - } - // Fill in the control cols, doesn't matter if it is a round or digest row - let control_cols: &mut Sha256VmControlCols = - row_slice[..SHA256VM_CONTROL_WIDTH].borrow_mut(); - control_cols.len = F::from_canonical_u32(record.len); - // Only the first `SHA256_NUM_READ_ROWS` rows increment the timestamp and read ptr - control_cols.cur_timestamp = F::from_canonical_u32( - block_start_timestamp + min(row_idx, SHA256_NUM_READ_ROWS) as u32, - ); - control_cols.read_ptr = F::from_canonical_u32( - block_start_read_ptr - + (SHA256_READ_SIZE * min(row_idx, SHA256_NUM_READ_ROWS)) as u32, - ); - - // Fill in the padding flags - if row_idx < SHA256_NUM_READ_ROWS { - #[allow(clippy::comparison_chain)] - if (row_idx as i32) < first_padding_row { - control_cols.pad_flags = get_flag_pt_array( - &self.padding_encoder, - PaddingFlags::NotPadding as usize, - ) - .map(F::from_canonical_u32); - } else if row_idx as i32 == first_padding_row { - let len = message_left - row_idx * SHA256_READ_SIZE; - control_cols.pad_flags = get_flag_pt_array( - &self.padding_encoder, - if row_idx == 3 && is_last_block { - PaddingFlags::FirstPadding0_LastRow - } else { - PaddingFlags::FirstPadding0 - } as usize - + len, - ) - .map(F::from_canonical_u32); - } else { - control_cols.pad_flags = get_flag_pt_array( - &self.padding_encoder, - if row_idx == 3 && is_last_block { - PaddingFlags::EntirePaddingLastRow - } else { - PaddingFlags::EntirePadding - } as usize, - ) - .map(F::from_canonical_u32); - } - } else { - control_cols.pad_flags = get_flag_pt_array( - &self.padding_encoder, - PaddingFlags::NotConsidered as usize, - ) - .map(F::from_canonical_u32); - } - if is_last_block && row_idx == SHA256_ROWS_PER_BLOCK - 1 { - // If last digest row, then we set padding_occurred = 0 - control_cols.padding_occurred = F::ZERO; - } else { - control_cols.padding_occurred = - F::from_bool((row_idx as i32) >= first_padding_row); - } - }); - - // Fill in the inner trace when the `buffer_or_carry` is filled in - self.inner.generate_block_trace::( - block_slice, - SHA256VM_WIDTH, - SHA256VM_CONTROL_WIDTH, - &padded_input, - self.bitwise_lookup_chip.as_ref(), - &prev_hash, - is_last_block, - global_block_idx as u32 + 1, // global block index is 1-indexed - local_block_idx as u32, - ); - } -} diff --git a/extensions/sha256/guest/src/lib.rs b/extensions/sha256/guest/src/lib.rs deleted file mode 100644 index 8f7c072f4a..0000000000 --- a/extensions/sha256/guest/src/lib.rs +++ /dev/null @@ -1,69 +0,0 @@ -#![no_std] - -#[cfg(target_os = "zkvm")] -use openvm_platform::alloc::AlignedBuf; - -/// This is custom-0 defined in RISC-V spec document -pub const OPCODE: u8 = 0x0b; -pub const SHA256_FUNCT3: u8 = 0b100; -pub const SHA256_FUNCT7: u8 = 0x1; - -/// Native hook for sha256 -/// -/// # Safety -/// -/// The VM accepts the preimage by pointer and length, and writes the -/// 32-byte hash. -/// - `bytes` must point to an input buffer at least `len` long. -/// - `output` must point to a buffer that is at least 32-bytes long. -/// -/// [`sha2`]: https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.180-4.pdf -#[cfg(target_os = "zkvm")] -#[inline(always)] -#[no_mangle] -pub extern "C" fn zkvm_sha256_impl(bytes: *const u8, len: usize, output: *mut u8) { - // SAFETY: assuming safety assumptions of the inputs, we handle all cases where `bytes` or - // `output` are not aligned to 4 bytes. - // The minimum alignment required for the input and output buffers - const MIN_ALIGN: usize = 4; - // The preferred alignment for the input buffer, since the input is read in chunks of 16 bytes - const INPUT_ALIGN: usize = 16; - // The preferred alignment for the output buffer, since the output is written in chunks of 32 - // bytes - const OUTPUT_ALIGN: usize = 32; - unsafe { - if bytes as usize % MIN_ALIGN != 0 { - let aligned_buff = AlignedBuf::new(bytes, len, INPUT_ALIGN); - if output as usize % MIN_ALIGN != 0 { - let aligned_out = AlignedBuf::uninit(32, OUTPUT_ALIGN); - __native_sha256(aligned_buff.ptr, len, aligned_out.ptr); - core::ptr::copy_nonoverlapping(aligned_out.ptr as *const u8, output, 32); - } else { - __native_sha256(aligned_buff.ptr, len, output); - } - } else { - if output as usize % MIN_ALIGN != 0 { - let aligned_out = AlignedBuf::uninit(32, OUTPUT_ALIGN); - __native_sha256(bytes, len, aligned_out.ptr); - core::ptr::copy_nonoverlapping(aligned_out.ptr as *const u8, output, 32); - } else { - __native_sha256(bytes, len, output); - } - }; - } -} - -/// sha256 intrinsic binding -/// -/// # Safety -/// -/// The VM accepts the preimage by pointer and length, and writes the -/// 32-byte hash. -/// - `bytes` must point to an input buffer at least `len` long. -/// - `output` must point to a buffer that is at least 32-bytes long. -/// - `bytes` and `output` must be 4-byte aligned. -#[cfg(target_os = "zkvm")] -#[inline(always)] -fn __native_sha256(bytes: *const u8, len: usize, output: *mut u8) { - openvm_platform::custom_insn_r!(opcode = OPCODE, funct3 = SHA256_FUNCT3, funct7 = SHA256_FUNCT7, rd = In output, rs1 = In bytes, rs2 = In len); -} diff --git a/extensions/sha256/transpiler/src/lib.rs b/extensions/sha256/transpiler/src/lib.rs deleted file mode 100644 index 6b13efe055..0000000000 --- a/extensions/sha256/transpiler/src/lib.rs +++ /dev/null @@ -1,46 +0,0 @@ -use openvm_instructions::{riscv::RV32_MEMORY_AS, LocalOpcode}; -use openvm_instructions_derive::LocalOpcode; -use openvm_sha256_guest::{OPCODE, SHA256_FUNCT3, SHA256_FUNCT7}; -use openvm_stark_backend::p3_field::PrimeField32; -use openvm_transpiler::{util::from_r_type, TranspilerExtension, TranspilerOutput}; -use rrs_lib::instruction_formats::RType; -use strum::{EnumCount, EnumIter, FromRepr}; - -#[derive( - Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, EnumCount, EnumIter, FromRepr, LocalOpcode, -)] -#[opcode_offset = 0x320] -#[repr(usize)] -pub enum Rv32Sha256Opcode { - SHA256, -} - -#[derive(Default)] -pub struct Sha256TranspilerExtension; - -impl TranspilerExtension for Sha256TranspilerExtension { - fn process_custom(&self, instruction_stream: &[u32]) -> Option> { - if instruction_stream.is_empty() { - return None; - } - let instruction_u32 = instruction_stream[0]; - let opcode = (instruction_u32 & 0x7f) as u8; - let funct3 = ((instruction_u32 >> 12) & 0b111) as u8; - - if (opcode, funct3) != (OPCODE, SHA256_FUNCT3) { - return None; - } - let dec_insn = RType::new(instruction_u32); - - if dec_insn.funct7 != SHA256_FUNCT7 as u32 { - return None; - } - let instruction = from_r_type( - Rv32Sha256Opcode::SHA256.global_opcode().as_usize(), - RV32_MEMORY_AS as usize, - &dec_insn, - true, - ); - Some(TranspilerOutput::one_to_one(instruction)) - } -} diff --git a/guest-libs/k256/Cargo.toml b/guest-libs/k256/Cargo.toml index 8bd07badf0..19f15b8be4 100644 --- a/guest-libs/k256/Cargo.toml +++ b/guest-libs/k256/Cargo.toml @@ -36,8 +36,8 @@ openvm-algebra-circuit.workspace = true openvm-algebra-transpiler.workspace = true openvm-ecc-transpiler.workspace = true openvm-ecc-circuit.workspace = true -openvm-sha256-circuit.workspace = true -openvm-sha256-transpiler.workspace = true +openvm-sha2-circuit.workspace = true +openvm-sha2-transpiler.workspace = true openvm-rv32im-circuit.workspace = true openvm-rv32im-transpiler.workspace = true openvm-toolchain-tests.workspace = true diff --git a/guest-libs/k256/src/internal.rs b/guest-libs/k256/src/internal.rs index b8f8857dc9..868bce2cd5 100644 --- a/guest-libs/k256/src/internal.rs +++ b/guest-libs/k256/src/internal.rs @@ -4,8 +4,8 @@ use hex_literal::hex; use openvm_algebra_guest::IntMod; use openvm_algebra_moduli_macros::moduli_declare; use openvm_ecc_guest::{ - weierstrass::{CachedMulTable, IntrinsicCurve, WeierstrassPoint}, - CyclicGroup, Group, + weierstrass::{CachedMulTable, WeierstrassPoint}, + CyclicGroup, Group, IntrinsicCurve, }; use openvm_ecc_sw_macros::sw_declare; diff --git a/guest-libs/k256/src/point.rs b/guest-libs/k256/src/point.rs index b854ef582b..5e66303284 100644 --- a/guest-libs/k256/src/point.rs +++ b/guest-libs/k256/src/point.rs @@ -14,10 +14,7 @@ use elliptic_curve::{ FieldBytesEncoding, }; use openvm_algebra_guest::IntMod; -use openvm_ecc_guest::{ - weierstrass::{IntrinsicCurve, WeierstrassPoint}, - CyclicGroup, -}; +use openvm_ecc_guest::{weierstrass::WeierstrassPoint, CyclicGroup, IntrinsicCurve}; use crate::{ internal::{Secp256k1Coord, Secp256k1Point, Secp256k1Scalar}, @@ -181,7 +178,7 @@ impl MulByGenerator for Secp256k1Point {} impl DecompressPoint for Secp256k1Point { /// Note that this is not constant time fn decompress(x_bytes: &FieldBytes, y_is_odd: Choice) -> CtOption { - use openvm_ecc_guest::weierstrass::FromCompressed; + use openvm_ecc_guest::FromCompressed; let x = Secp256k1Coord::from_be_bytes_unchecked(x_bytes.as_slice()); let rec_id = y_is_odd.unwrap_u8(); diff --git a/guest-libs/k256/tests/lib.rs b/guest-libs/k256/tests/lib.rs index 8af3a65477..f9cf08e6ed 100644 --- a/guest-libs/k256/tests/lib.rs +++ b/guest-libs/k256/tests/lib.rs @@ -6,12 +6,14 @@ mod guest_tests { arch::instructions::exe::VmExe, utils::{air_test, test_system_config_with_continuations}, }; - use openvm_ecc_circuit::{CurveConfig, Rv32WeierstrassConfig, SECP256K1_CONFIG}; + #[cfg(test)] + use openvm_ecc_circuit::SwCurveCoeffs; + use openvm_ecc_circuit::{CurveConfig, Rv32EccConfig, SECP256K1_CONFIG}; use openvm_ecc_transpiler::EccTranspilerExtension; use openvm_rv32im_transpiler::{ Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, }; - use openvm_sha256_transpiler::Sha256TranspilerExtension; + use openvm_sha2_transpiler::Sha2TranspilerExtension; use openvm_stark_sdk::p3_baby_bear::BabyBear; use openvm_toolchain_tests::{build_example_program_at_path, get_programs_dir}; use openvm_transpiler::{transpiler::Transpiler, FromElf}; @@ -19,15 +21,15 @@ mod guest_tests { type F = BabyBear; #[cfg(test)] - fn test_rv32weierstrass_config(curves: Vec) -> Rv32WeierstrassConfig { - let mut config = Rv32WeierstrassConfig::new(curves); + fn test_rv32ecc_config(sw_curves: Vec>) -> Rv32EccConfig { + let mut config = Rv32EccConfig::new(sw_curves, vec![]); config.system = test_system_config_with_continuations(); config } #[test] fn test_add() -> Result<()> { - let config = test_rv32weierstrass_config(vec![SECP256K1_CONFIG.clone()]); + let config = test_rv32ecc_config(vec![SECP256K1_CONFIG.clone()]); let elf = build_example_program_at_path(get_programs_dir!("tests/programs"), "add", &config)?; let openvm_exe = VmExe::from_elf( @@ -45,7 +47,7 @@ mod guest_tests { #[test] fn test_mul() -> Result<()> { - let config = test_rv32weierstrass_config(vec![SECP256K1_CONFIG.clone()]); + let config = test_rv32ecc_config(vec![SECP256K1_CONFIG.clone()]); let elf = build_example_program_at_path(get_programs_dir!("tests/programs"), "mul", &config)?; let openvm_exe = VmExe::from_elf( @@ -63,7 +65,7 @@ mod guest_tests { #[test] fn test_linear_combination() -> Result<()> { - let config = test_rv32weierstrass_config(vec![SECP256K1_CONFIG.clone()]); + let config = test_rv32ecc_config(vec![SECP256K1_CONFIG.clone()]); let elf = build_example_program_at_path( get_programs_dir!("tests/programs"), "linear_combination", @@ -93,14 +95,13 @@ mod guest_tests { utils::test_system_config_with_continuations, }; use openvm_ecc_circuit::{ - CurveConfig, WeierstrassExtension, WeierstrassExtensionExecutor, - WeierstrassExtensionPeriphery, + CurveConfig, EccExtension, EccExtensionExecutor, EccExtensionPeriphery, SwCurveCoeffs, }; use openvm_rv32im_circuit::{ Rv32I, Rv32IExecutor, Rv32IPeriphery, Rv32Io, Rv32IoExecutor, Rv32IoPeriphery, Rv32M, Rv32MExecutor, Rv32MPeriphery, }; - use openvm_sha256_circuit::{Sha256, Sha256Executor, Sha256Periphery}; + use openvm_sha2_circuit::{Sha2, Sha2Executor, Sha2Periphery}; use openvm_stark_backend::p3_field::PrimeField32; use serde::{Deserialize, Serialize}; @@ -117,13 +118,13 @@ mod guest_tests { #[extension] pub modular: ModularExtension, #[extension] - pub weierstrass: WeierstrassExtension, + pub ecc: EccExtension, #[extension] - pub sha256: Sha256, + pub sha2: Sha2, } impl EcdsaConfig { - pub fn new(curves: Vec) -> Self { + pub fn new(curves: Vec>) -> Self { let primes: Vec<_> = curves .iter() .flat_map(|c| [c.modulus.clone(), c.scalar.clone()]) @@ -134,8 +135,8 @@ mod guest_tests { mul: Default::default(), io: Default::default(), modular: ModularExtension::new(primes), - weierstrass: WeierstrassExtension::new(curves), - sha256: Default::default(), + ecc: EccExtension::new(curves, vec![]), + sha2: Default::default(), } } } @@ -145,7 +146,7 @@ mod guest_tests { Some(format!( "// This file is automatically generated by cargo openvm. Do not rename or edit.\n{}\n{}\n", self.modular.generate_moduli_init(), - self.weierstrass.generate_sw_init() + self.ecc.generate_ecc_init() )) } } @@ -165,7 +166,7 @@ mod guest_tests { .with_extension(Rv32IoTranspilerExtension) .with_extension(EccTranspilerExtension) .with_extension(ModularTranspilerExtension) - .with_extension(Sha256TranspilerExtension), + .with_extension(Sha2TranspilerExtension), )?; air_test(config, openvm_exe); Ok(()) @@ -173,7 +174,7 @@ mod guest_tests { #[test] fn test_scalar_sqrt() -> Result<()> { - let config = test_rv32weierstrass_config(vec![SECP256K1_CONFIG.clone()]); + let config = test_rv32ecc_config(vec![SECP256K1_CONFIG.clone()]); let elf = build_example_program_at_path( get_programs_dir!("tests/programs"), "scalar_sqrt", diff --git a/guest-libs/k256/tests/programs/openvm_init_add.rs b/guest-libs/k256/tests/programs/openvm_init_add.rs index bec9f527e9..f0855c9497 100644 --- a/guest-libs/k256/tests/programs/openvm_init_add.rs +++ b/guest-libs/k256/tests/programs/openvm_init_add.rs @@ -1,3 +1,4 @@ // This file is automatically generated by cargo openvm. Do not rename or edit. openvm_algebra_guest::moduli_macros::moduli_init! { "115792089237316195423570985008687907853269984665640564039457584007908834671663", "115792089237316195423570985008687907852837564279074904382605163141518161494337" } openvm_ecc_guest::sw_macros::sw_init! { Secp256k1Point } +openvm_ecc_guest::te_macros::te_init! { } diff --git a/guest-libs/k256/tests/programs/openvm_init_ecdsa.rs b/guest-libs/k256/tests/programs/openvm_init_ecdsa.rs index bec9f527e9..f0855c9497 100644 --- a/guest-libs/k256/tests/programs/openvm_init_ecdsa.rs +++ b/guest-libs/k256/tests/programs/openvm_init_ecdsa.rs @@ -1,3 +1,4 @@ // This file is automatically generated by cargo openvm. Do not rename or edit. openvm_algebra_guest::moduli_macros::moduli_init! { "115792089237316195423570985008687907853269984665640564039457584007908834671663", "115792089237316195423570985008687907852837564279074904382605163141518161494337" } openvm_ecc_guest::sw_macros::sw_init! { Secp256k1Point } +openvm_ecc_guest::te_macros::te_init! { } diff --git a/guest-libs/k256/tests/programs/openvm_init_linear_combination.rs b/guest-libs/k256/tests/programs/openvm_init_linear_combination.rs index bec9f527e9..f0855c9497 100644 --- a/guest-libs/k256/tests/programs/openvm_init_linear_combination.rs +++ b/guest-libs/k256/tests/programs/openvm_init_linear_combination.rs @@ -1,3 +1,4 @@ // This file is automatically generated by cargo openvm. Do not rename or edit. openvm_algebra_guest::moduli_macros::moduli_init! { "115792089237316195423570985008687907853269984665640564039457584007908834671663", "115792089237316195423570985008687907852837564279074904382605163141518161494337" } openvm_ecc_guest::sw_macros::sw_init! { Secp256k1Point } +openvm_ecc_guest::te_macros::te_init! { } diff --git a/guest-libs/k256/tests/programs/openvm_init_mul.rs b/guest-libs/k256/tests/programs/openvm_init_mul.rs index bec9f527e9..f0855c9497 100644 --- a/guest-libs/k256/tests/programs/openvm_init_mul.rs +++ b/guest-libs/k256/tests/programs/openvm_init_mul.rs @@ -1,3 +1,4 @@ // This file is automatically generated by cargo openvm. Do not rename or edit. openvm_algebra_guest::moduli_macros::moduli_init! { "115792089237316195423570985008687907853269984665640564039457584007908834671663", "115792089237316195423570985008687907852837564279074904382605163141518161494337" } openvm_ecc_guest::sw_macros::sw_init! { Secp256k1Point } +openvm_ecc_guest::te_macros::te_init! { } diff --git a/guest-libs/k256/tests/programs/openvm_init_scalar_sqrt.rs b/guest-libs/k256/tests/programs/openvm_init_scalar_sqrt.rs index bec9f527e9..f0855c9497 100644 --- a/guest-libs/k256/tests/programs/openvm_init_scalar_sqrt.rs +++ b/guest-libs/k256/tests/programs/openvm_init_scalar_sqrt.rs @@ -1,3 +1,4 @@ // This file is automatically generated by cargo openvm. Do not rename or edit. openvm_algebra_guest::moduli_macros::moduli_init! { "115792089237316195423570985008687907853269984665640564039457584007908834671663", "115792089237316195423570985008687907852837564279074904382605163141518161494337" } openvm_ecc_guest::sw_macros::sw_init! { Secp256k1Point } +openvm_ecc_guest::te_macros::te_init! { } diff --git a/guest-libs/p256/Cargo.toml b/guest-libs/p256/Cargo.toml index 4a382a7aed..dad42296e3 100644 --- a/guest-libs/p256/Cargo.toml +++ b/guest-libs/p256/Cargo.toml @@ -33,8 +33,8 @@ openvm-algebra-circuit.workspace = true openvm-algebra-transpiler.workspace = true openvm-ecc-transpiler.workspace = true openvm-ecc-circuit.workspace = true -openvm-sha256-circuit.workspace = true -openvm-sha256-transpiler.workspace = true +openvm-sha2-circuit.workspace = true +openvm-sha2-transpiler.workspace = true openvm-rv32im-circuit.workspace = true openvm-rv32im-transpiler.workspace = true openvm-toolchain-tests.workspace = true diff --git a/guest-libs/p256/src/internal.rs b/guest-libs/p256/src/internal.rs index b98c401c8c..7db8f868c6 100644 --- a/guest-libs/p256/src/internal.rs +++ b/guest-libs/p256/src/internal.rs @@ -4,8 +4,8 @@ use hex_literal::hex; use openvm_algebra_guest::IntMod; use openvm_algebra_moduli_macros::moduli_declare; use openvm_ecc_guest::{ - weierstrass::{CachedMulTable, IntrinsicCurve, WeierstrassPoint}, - CyclicGroup, Group, + weierstrass::{CachedMulTable, WeierstrassPoint}, + CyclicGroup, Group, IntrinsicCurve, }; use openvm_ecc_sw_macros::sw_declare; diff --git a/guest-libs/p256/src/point.rs b/guest-libs/p256/src/point.rs index ee87396c74..3d4030d807 100644 --- a/guest-libs/p256/src/point.rs +++ b/guest-libs/p256/src/point.rs @@ -14,10 +14,7 @@ use elliptic_curve::{ FieldBytesEncoding, }; use openvm_algebra_guest::IntMod; -use openvm_ecc_guest::{ - weierstrass::{IntrinsicCurve, WeierstrassPoint}, - CyclicGroup, -}; +use openvm_ecc_guest::{weierstrass::WeierstrassPoint, CyclicGroup, IntrinsicCurve}; use crate::{ internal::{P256Coord, P256Point, P256Scalar}, @@ -177,7 +174,7 @@ impl MulByGenerator for P256Point {} impl DecompressPoint for P256Point { /// Note that this is not constant time fn decompress(x_bytes: &FieldBytes, y_is_odd: Choice) -> CtOption { - use openvm_ecc_guest::weierstrass::FromCompressed; + use openvm_ecc_guest::FromCompressed; let x = P256Coord::from_be_bytes_unchecked(x_bytes.as_slice()); let rec_id = y_is_odd.unwrap_u8(); diff --git a/guest-libs/p256/tests/lib.rs b/guest-libs/p256/tests/lib.rs index 6be5424512..9b51362653 100644 --- a/guest-libs/p256/tests/lib.rs +++ b/guest-libs/p256/tests/lib.rs @@ -6,12 +6,14 @@ mod guest_tests { arch::instructions::exe::VmExe, utils::{air_test, test_system_config_with_continuations}, }; - use openvm_ecc_circuit::{CurveConfig, Rv32WeierstrassConfig, P256_CONFIG}; + #[cfg(test)] + use openvm_ecc_circuit::SwCurveCoeffs; + use openvm_ecc_circuit::{CurveConfig, Rv32EccConfig, P256_CONFIG}; use openvm_ecc_transpiler::EccTranspilerExtension; use openvm_rv32im_transpiler::{ Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, }; - use openvm_sha256_transpiler::Sha256TranspilerExtension; + use openvm_sha2_transpiler::Sha2TranspilerExtension; use openvm_stark_sdk::p3_baby_bear::BabyBear; use openvm_toolchain_tests::{build_example_program_at_path, get_programs_dir}; use openvm_transpiler::{transpiler::Transpiler, FromElf}; @@ -19,15 +21,15 @@ mod guest_tests { type F = BabyBear; #[cfg(test)] - fn test_rv32weierstrass_config(curves: Vec) -> Rv32WeierstrassConfig { - let mut config = Rv32WeierstrassConfig::new(curves); + fn test_rv32ecc_config(sw_curves: Vec>) -> Rv32EccConfig { + let mut config = Rv32EccConfig::new(sw_curves, vec![]); config.system = test_system_config_with_continuations(); config } #[test] fn test_add() -> Result<()> { - let config = test_rv32weierstrass_config(vec![P256_CONFIG.clone()]); + let config = test_rv32ecc_config(vec![P256_CONFIG.clone()]); let elf = build_example_program_at_path(get_programs_dir!("tests/programs"), "add", &config)?; let openvm_exe = VmExe::from_elf( @@ -45,7 +47,7 @@ mod guest_tests { #[test] fn test_mul() -> Result<()> { - let config = test_rv32weierstrass_config(vec![P256_CONFIG.clone()]); + let config = test_rv32ecc_config(vec![P256_CONFIG.clone()]); let elf = build_example_program_at_path(get_programs_dir!("tests/programs"), "mul", &config)?; let openvm_exe = VmExe::from_elf( @@ -63,7 +65,7 @@ mod guest_tests { #[test] fn test_linear_combination() -> Result<()> { - let config = test_rv32weierstrass_config(vec![P256_CONFIG.clone()]); + let config = test_rv32ecc_config(vec![P256_CONFIG.clone()]); let elf = build_example_program_at_path( get_programs_dir!("tests/programs"), "linear_combination", @@ -93,14 +95,13 @@ mod guest_tests { utils::test_system_config_with_continuations, }; use openvm_ecc_circuit::{ - CurveConfig, WeierstrassExtension, WeierstrassExtensionExecutor, - WeierstrassExtensionPeriphery, + CurveConfig, EccExtension, EccExtensionExecutor, EccExtensionPeriphery, SwCurveCoeffs, }; use openvm_rv32im_circuit::{ Rv32I, Rv32IExecutor, Rv32IPeriphery, Rv32Io, Rv32IoExecutor, Rv32IoPeriphery, Rv32M, Rv32MExecutor, Rv32MPeriphery, }; - use openvm_sha256_circuit::{Sha256, Sha256Executor, Sha256Periphery}; + use openvm_sha2_circuit::{Sha2, Sha2Executor, Sha2Periphery}; use openvm_stark_backend::p3_field::PrimeField32; use serde::{Deserialize, Serialize}; @@ -117,13 +118,13 @@ mod guest_tests { #[extension] pub modular: ModularExtension, #[extension] - pub weierstrass: WeierstrassExtension, + pub ecc: EccExtension, #[extension] - pub sha256: Sha256, + pub sha2: Sha2, } impl EcdsaConfig { - pub fn new(curves: Vec) -> Self { + pub fn new(curves: Vec>) -> Self { let primes: Vec<_> = curves .iter() .flat_map(|c| [c.modulus.clone(), c.scalar.clone()]) @@ -134,8 +135,8 @@ mod guest_tests { mul: Default::default(), io: Default::default(), modular: ModularExtension::new(primes), - weierstrass: WeierstrassExtension::new(curves), - sha256: Default::default(), + ecc: EccExtension::new(curves, vec![]), + sha2: Default::default(), } } } @@ -145,7 +146,7 @@ mod guest_tests { Some(format!( "// This file is automatically generated by cargo openvm. Do not rename or edit.\n{}\n{}\n", self.modular.generate_moduli_init(), - self.weierstrass.generate_sw_init() + self.ecc.generate_ecc_init() )) } } @@ -165,7 +166,7 @@ mod guest_tests { .with_extension(Rv32IoTranspilerExtension) .with_extension(EccTranspilerExtension) .with_extension(ModularTranspilerExtension) - .with_extension(Sha256TranspilerExtension), + .with_extension(Sha2TranspilerExtension), )?; air_test(config, openvm_exe); Ok(()) @@ -173,7 +174,7 @@ mod guest_tests { #[test] fn test_scalar_sqrt() -> Result<()> { - let config = test_rv32weierstrass_config(vec![P256_CONFIG.clone()]); + let config = test_rv32ecc_config(vec![P256_CONFIG.clone()]); let elf = build_example_program_at_path( get_programs_dir!("tests/programs"), "scalar_sqrt", diff --git a/guest-libs/pairing/src/bls12_381/mod.rs b/guest-libs/pairing/src/bls12_381/mod.rs index 0a7c150e1c..d3557ba61a 100644 --- a/guest-libs/pairing/src/bls12_381/mod.rs +++ b/guest-libs/pairing/src/bls12_381/mod.rs @@ -4,7 +4,7 @@ use core::ops::Neg; use openvm_algebra_guest::IntMod; use openvm_algebra_moduli_macros::moduli_declare; -use openvm_ecc_guest::{weierstrass::IntrinsicCurve, CyclicGroup, Group}; +use openvm_ecc_guest::{CyclicGroup, Group, IntrinsicCurve}; mod fp12; mod fp2; diff --git a/guest-libs/pairing/src/bn254/mod.rs b/guest-libs/pairing/src/bn254/mod.rs index 8384b8b3e8..a8d3f99f68 100644 --- a/guest-libs/pairing/src/bn254/mod.rs +++ b/guest-libs/pairing/src/bn254/mod.rs @@ -5,10 +5,7 @@ use core::ops::{Add, Neg}; use hex_literal::hex; use openvm_algebra_guest::IntMod; use openvm_algebra_moduli_macros::moduli_declare; -use openvm_ecc_guest::{ - weierstrass::{CachedMulTable, IntrinsicCurve}, - CyclicGroup, Group, -}; +use openvm_ecc_guest::{weierstrass::CachedMulTable, CyclicGroup, Group, IntrinsicCurve}; use openvm_ecc_sw_macros::sw_declare; use openvm_pairing_guest::pairing::PairingIntrinsics; diff --git a/guest-libs/pairing/tests/lib.rs b/guest-libs/pairing/tests/lib.rs index 1d738e8701..4d1deaf9b3 100644 --- a/guest-libs/pairing/tests/lib.rs +++ b/guest-libs/pairing/tests/lib.rs @@ -14,7 +14,7 @@ mod bn254 { use openvm_circuit::utils::{ air_test, air_test_impl, air_test_with_min_segments, test_system_config_with_continuations, }; - use openvm_ecc_circuit::{CurveConfig, Rv32WeierstrassConfig, WeierstrassExtension}; + use openvm_ecc_circuit::{CurveConfig, EccExtension, Rv32EccConfig, SwCurveCoeffs}; use openvm_ecc_guest::{ algebra::{field::FieldExtension, IntMod}, AffinePoint, @@ -53,14 +53,14 @@ mod bn254 { io: Default::default(), modular: ModularExtension::new(primes.to_vec()), fp2: Fp2Extension::new(primes_with_names), - weierstrass: WeierstrassExtension::new(vec![]), + ecc: EccExtension::new(vec![], vec![]), pairing: PairingExtension::new(vec![PairingCurve::Bn254]), } } #[cfg(test)] - fn test_rv32weierstrass_config(curves: Vec) -> Rv32WeierstrassConfig { - let mut config = Rv32WeierstrassConfig::new(curves); + fn test_rv32ecc_config(sw_curves: Vec>) -> Rv32EccConfig { + let mut config = Rv32EccConfig::new(sw_curves, vec![]); config.system = test_system_config_with_continuations(); config } @@ -68,7 +68,7 @@ mod bn254 { #[test] fn test_bn_ec() -> Result<()> { let curve = PairingCurve::Bn254.curve_config(); - let config = test_rv32weierstrass_config(vec![curve]); + let config = test_rv32ecc_config(vec![curve]); let elf = build_example_program_at_path_with_features( get_programs_dir!("tests/programs"), "bn_ec", @@ -471,7 +471,7 @@ mod bls12_381 { test_system_config_with_continuations, }, }; - use openvm_ecc_circuit::{CurveConfig, Rv32WeierstrassConfig, WeierstrassExtension}; + use openvm_ecc_circuit::{CurveConfig, EccExtension, Rv32EccConfig, SwCurveCoeffs}; use openvm_ecc_guest::{ algebra::{field::FieldExtension, IntMod}, AffinePoint, @@ -512,14 +512,14 @@ mod bls12_381 { io: Default::default(), modular: ModularExtension::new(primes.to_vec()), fp2: Fp2Extension::new(primes_with_names), - weierstrass: WeierstrassExtension::new(vec![]), + ecc: EccExtension::new(vec![], vec![]), pairing: PairingExtension::new(vec![PairingCurve::Bls12_381]), } } #[cfg(test)] - fn test_rv32weierstrass_config(curves: Vec) -> Rv32WeierstrassConfig { - let mut config = Rv32WeierstrassConfig::new(curves); + fn test_rv32ecc_config(sw_curves: Vec>) -> Rv32EccConfig { + let mut config = Rv32EccConfig::new(sw_curves, vec![]); config.system = test_system_config_with_continuations(); config } @@ -530,10 +530,12 @@ mod bls12_381 { struct_name: BLS12_381_ECC_STRUCT_NAME.to_string(), modulus: BLS12_381_MODULUS.clone(), scalar: BLS12_381_ORDER.clone(), - a: BigUint::ZERO, - b: BigUint::from_u8(4).unwrap(), + coeffs: SwCurveCoeffs { + a: BigUint::ZERO, + b: BigUint::from_u8(4).unwrap(), + }, }; - let config = test_rv32weierstrass_config(vec![curve]); + let config = test_rv32ecc_config(vec![curve]); let elf = build_example_program_at_path_with_features( get_programs_dir!("tests/programs"), "bls_ec", diff --git a/guest-libs/sha2/Cargo.toml b/guest-libs/sha2/Cargo.toml index f8bf7b545e..9e13e85ce8 100644 --- a/guest-libs/sha2/Cargo.toml +++ b/guest-libs/sha2/Cargo.toml @@ -10,15 +10,15 @@ repository.workspace = true license.workspace = true [dependencies] -openvm-sha256-guest = { workspace = true } +openvm-sha2-guest = { workspace = true } [dev-dependencies] openvm-instructions = { workspace = true } openvm-stark-sdk = { workspace = true } openvm-circuit = { workspace = true, features = ["test-utils", "parallel"] } openvm-transpiler = { workspace = true } -openvm-sha256-transpiler = { workspace = true } -openvm-sha256-circuit = { workspace = true } +openvm-sha2-transpiler = { workspace = true } +openvm-sha2-circuit = { workspace = true } openvm-rv32im-transpiler = { workspace = true } openvm-toolchain-tests = { workspace = true } eyre = { workspace = true } diff --git a/guest-libs/sha2/src/lib.rs b/guest-libs/sha2/src/lib.rs index 43d90ba822..dfeddf70a1 100644 --- a/guest-libs/sha2/src/lib.rs +++ b/guest-libs/sha2/src/lib.rs @@ -8,6 +8,22 @@ pub fn sha256(input: &[u8]) -> [u8; 32] { output } +/// The sha512 cryptographic hash function. +#[inline(always)] +pub fn sha512(input: &[u8]) -> [u8; 64] { + let mut output = [0u8; 64]; + set_sha512(input, &mut output); + output +} + +/// The sha384 cryptographic hash function. +#[inline(always)] +pub fn sha384(input: &[u8]) -> [u8; 48] { + let mut output = [0u8; 48]; + set_sha384(input, &mut output); + output +} + /// Sets `output` to the sha256 hash of `input`. pub fn set_sha256(input: &[u8], output: &mut [u8; 32]) { #[cfg(not(target_os = "zkvm"))] @@ -19,10 +35,51 @@ pub fn set_sha256(input: &[u8], output: &mut [u8; 32]) { } #[cfg(target_os = "zkvm")] { - openvm_sha256_guest::zkvm_sha256_impl( + openvm_sha2_guest::zkvm_sha256_impl( input.as_ptr(), input.len(), output.as_mut_ptr() as *mut u8, ); } } + +/// Sets `output` to the sha512 hash of `input`. +pub fn set_sha512(input: &[u8], output: &mut [u8; 64]) { + #[cfg(not(target_os = "zkvm"))] + { + use sha2::{Digest, Sha512}; + let mut hasher = Sha512::new(); + hasher.update(input); + output.copy_from_slice(hasher.finalize().as_ref()); + } + #[cfg(target_os = "zkvm")] + { + openvm_sha2_guest::zkvm_sha512_impl( + input.as_ptr(), + input.len(), + output.as_mut_ptr() as *mut u8, + ); + } +} + +/// Sets the first 48 bytes of `output` to the sha384 hash of `input`. +/// Sets the last 16 bytes to zeros. +pub fn set_sha384(input: &[u8], output: &mut [u8; 48]) { + #[cfg(not(target_os = "zkvm"))] + { + use sha2::{Digest, Sha384}; + let mut hasher = Sha384::new(); + hasher.update(input); + output.copy_from_slice(hasher.finalize().as_ref()); + } + #[cfg(target_os = "zkvm")] + { + let mut output_64: [u8; 64] = [0; 64]; + openvm_sha2_guest::zkvm_sha384_impl( + input.as_ptr(), + input.len(), + output_64.as_mut_ptr() as *mut u8, + ); + output.copy_from_slice(&output_64[..48]); + } +} diff --git a/guest-libs/sha2/tests/lib.rs b/guest-libs/sha2/tests/lib.rs index 9ebab5ac02..3a8c92194c 100644 --- a/guest-libs/sha2/tests/lib.rs +++ b/guest-libs/sha2/tests/lib.rs @@ -6,8 +6,8 @@ mod tests { use openvm_rv32im_transpiler::{ Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, }; - use openvm_sha256_circuit::Sha256Rv32Config; - use openvm_sha256_transpiler::Sha256TranspilerExtension; + use openvm_sha2_circuit::Sha2Rv32Config; + use openvm_sha2_transpiler::Sha2TranspilerExtension; use openvm_stark_sdk::p3_baby_bear::BabyBear; use openvm_toolchain_tests::{build_example_program_at_path, get_programs_dir}; use openvm_transpiler::{transpiler::Transpiler, FromElf}; @@ -15,17 +15,17 @@ mod tests { type F = BabyBear; #[test] - fn test_sha256() -> Result<()> { - let config = Sha256Rv32Config::default(); + fn test_sha2() -> Result<()> { + let config = Sha2Rv32Config::default(); let elf = - build_example_program_at_path(get_programs_dir!("tests/programs"), "sha", &config)?; + build_example_program_at_path(get_programs_dir!("tests/programs"), "sha2", &config)?; let openvm_exe = VmExe::from_elf( elf, Transpiler::::default() .with_extension(Rv32ITranspilerExtension) .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension) - .with_extension(Sha256TranspilerExtension), + .with_extension(Sha2TranspilerExtension), )?; air_test(config, openvm_exe); Ok(()) diff --git a/guest-libs/sha2/tests/programs/Cargo.toml b/guest-libs/sha2/tests/programs/Cargo.toml index df13f8dfc7..c197564ec0 100644 --- a/guest-libs/sha2/tests/programs/Cargo.toml +++ b/guest-libs/sha2/tests/programs/Cargo.toml @@ -8,12 +8,12 @@ edition = "2021" openvm = { path = "../../../../crates/toolchain/openvm" } openvm-platform = { path = "../../../../crates/toolchain/platform" } openvm-sha2 = { path = "../../" } - hex = { version = "0.4.3", default-features = false, features = ["alloc"] } serde = { version = "1.0", default-features = false, features = [ "alloc", "derive", ] } +hex-literal = { version = "1.0.0" } [features] default = [] diff --git a/guest-libs/sha2/tests/programs/examples/sha.rs b/guest-libs/sha2/tests/programs/examples/sha.rs deleted file mode 100644 index ebfd50cbee..0000000000 --- a/guest-libs/sha2/tests/programs/examples/sha.rs +++ /dev/null @@ -1,29 +0,0 @@ -#![cfg_attr(not(feature = "std"), no_main)] -#![cfg_attr(not(feature = "std"), no_std)] - -extern crate alloc; - -use alloc::vec::Vec; -use core::hint::black_box; - -use hex::FromHex; -use openvm_sha2::sha256; - -openvm::entry!(main); - -pub fn main() { - let test_vectors = [ - ("", "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"), - ("98c1c0bdb7d5fea9a88859f06c6c439f", "b6b2c9c9b6f30e5c66c977f1bd7ad97071bee739524aecf793384890619f2b05"), - ("5b58f4163e248467cc1cd3eecafe749e8e2baaf82c0f63af06df0526347d7a11327463c115210a46b6740244eddf370be89c", "ac0e25049870b91d78ef6807bb87fce4603c81abd3c097fba2403fd18b6ce0b7"), - ("9ad198539e3160194f38ac076a782bd5210a007560d1fce9ef78f8a4a5e4d78c6b96c250cff3520009036e9c6087d5dab587394edda862862013de49a12072485a6c01165ec0f28ffddf1873fbd53e47fcd02fb6a5ccc9622d5588a92429c663ce298cb71b50022fc2ec4ba9f5bbd250974e1a607b165fee16e8f3f2be20d7348b91a2f518ce928491900d56d9f86970611580350cee08daea7717fe28a73b8dcfdea22a65ed9f5a09198de38e4e4f2cc05b0ba3dd787a5363ab6c9f39dcb66c1a29209b1d6b1152769395df8150b4316658ea6ab19af94903d643fcb0ae4d598035ebe73c8b1b687df1ab16504f633c929569c6d0e5fae6eea43838fbc8ce2c2b43161d0addc8ccf945a9c4e06294e56a67df0000f561f61b630b1983ba403e775aaeefa8d339f669d1e09ead7eae979383eda983321e1743e5404b4b328da656de79ff52d179833a6bd5129f49432d74d001996c37c68d9ab49fcff8061d193576f396c20e1f0d9ee83a51290ba60efa9c3cb2e15b756321a7ca668cdbf63f95ec33b1c450aa100101be059dc00077245b25a6a66698dee81953ed4a606944076e2858b1420de0095a7f60b08194d6d9a997009d345c71f63a7034b976e409af8a9a040ac7113664609a7adedb76b2fadf04b0348392a1650526eb2a4d6ed5e4bbcda8aabc8488b38f4f5d9a398103536bb8250ed82a9b9825f7703c263f9e", "080ad71239852124fc26758982090611b9b19abf22d22db3a57f67a06e984a23") - ]; - for (input, expected_output) in test_vectors.iter() { - let input = Vec::from_hex(input).unwrap(); - let expected_output = Vec::from_hex(expected_output).unwrap(); - let output = sha256(&black_box(input)); - if output != *expected_output { - panic!(); - } - } -} diff --git a/guest-libs/sha2/tests/programs/examples/sha2.rs b/guest-libs/sha2/tests/programs/examples/sha2.rs new file mode 100644 index 0000000000..7f28152b42 --- /dev/null +++ b/guest-libs/sha2/tests/programs/examples/sha2.rs @@ -0,0 +1,85 @@ +#![cfg_attr(not(feature = "std"), no_main)] +#![cfg_attr(not(feature = "std"), no_std)] + +extern crate alloc; + +use alloc::vec::Vec; +use core::hint::black_box; + +use hex::FromHex; +use openvm_sha2::{sha256, sha384, sha512}; + +openvm::entry!(main); + +struct ShaTestVector { + input: &'static str, + expected_output_sha256: &'static str, + expected_output_sha512: &'static str, + expected_output_sha384: &'static str, +} + +pub fn main() { + let test_vectors = [ + ShaTestVector { + input: "", + expected_output_sha256: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", + expected_output_sha512: "cf83e1357eefb8bdf1542850d66d8007d620e4050b5715dc83f4a921d36ce9ce47d0d13c5d85f2b0ff8318d2877eec2f63b931bd47417a81a538327af927da3e", + expected_output_sha384: "38b060a751ac96384cd9327eb1b1e36a21fdb71114be07434c0cc7bf63f6e1da274edebfe76f65fbd51ad2f14898b95b", + }, + ShaTestVector { + input: "98c1c0bdb7d5fea9a88859f06c6c439f", + expected_output_sha256: "b6b2c9c9b6f30e5c66c977f1bd7ad97071bee739524aecf793384890619f2b05", + expected_output_sha512: "eb576959c531f116842c0cc915a29c8f71d7a285c894c349b83469002ef093d51f9f14ce4248488bff143025e47ed27c12badb9cd43779cb147408eea062d583", + expected_output_sha384: "63e3061aab01f335ea3a4e617b9d14af9b63a5240229164ee962f6d5335ff25f0f0bf8e46723e83c41b9d17413b6a3c7", + }, + ShaTestVector { + input: "5b58f4163e248467cc1cd3eecafe749e8e2baaf82c0f63af06df0526347d7a11327463c115210a46b6740244eddf370be89c", + expected_output_sha256: "ac0e25049870b91d78ef6807bb87fce4603c81abd3c097fba2403fd18b6ce0b7", + expected_output_sha512: "a20d5fb14814d045a7d2861e80d2b688f1cd1daaba69e6bb1cc5233f514141ea4623b3373af702e78e3ec5dc8c1b716a37a9a2f5fbc9493b9df7043f5e99a8da", + expected_output_sha384: "eac4b72b0540486bc088834860873338e31e9e4062532bf509191ef63b9298c67db5654a28fe6f07e4cc6ff466d1be24", + }, + ShaTestVector { + input: "9ad198539e3160194f38ac076a782bd5210a007560d1fce9ef78f8a4a5e4d78c6b96c250cff3520009036e9c6087d5dab587394edda862862013de49a12072485a6c01165ec0f28ffddf1873fbd53e47fcd02fb6a5ccc9622d5588a92429c663ce298cb71b50022fc2ec4ba9f5bbd250974e1a607b165fee16e8f3f2be20d7348b91a2f518ce928491900d56d9f86970611580350cee08daea7717fe28a73b8dcfdea22a65ed9f5a09198de38e4e4f2cc05b0ba3dd787a5363ab6c9f39dcb66c1a29209b1d6b1152769395df8150b4316658ea6ab19af94903d643fcb0ae4d598035ebe73c8b1b687df1ab16504f633c929569c6d0e5fae6eea43838fbc8ce2c2b43161d0addc8ccf945a9c4e06294e56a67df0000f561f61b630b1983ba403e775aaeefa8d339f669d1e09ead7eae979383eda983321e1743e5404b4b328da656de79ff52d179833a6bd5129f49432d74d001996c37c68d9ab49fcff8061d193576f396c20e1f0d9ee83a51290ba60efa9c3cb2e15b756321a7ca668cdbf63f95ec33b1c450aa100101be059dc00077245b25a6a66698dee81953ed4a606944076e2858b1420de0095a7f60b08194d6d9a997009d345c71f63a7034b976e409af8a9a040ac7113664609a7adedb76b2fadf04b0348392a1650526eb2a4d6ed5e4bbcda8aabc8488b38f4f5d9a398103536bb8250ed82a9b9825f7703c263f9e", + expected_output_sha256: "080ad71239852124fc26758982090611b9b19abf22d22db3a57f67a06e984a23", + expected_output_sha512: "8d215ee6dc26757c210db0dd00c1c6ed16cc34dbd4bb0fa10c1edb6b62d5ab16aea88c881001b173d270676daf2d6381b5eab8711fa2f5589c477c1d4b84774f", + expected_output_sha384: "904a90010d772a904a35572fdd4bdf1dd253742e47872c8a18e2255f66fa889e44781e65487a043f435daa53c496a53e", + } + ]; + + for ( + i, + ShaTestVector { + input, + expected_output_sha256, + expected_output_sha512, + expected_output_sha384, + }, + ) in test_vectors.iter().enumerate() + { + let input = Vec::from_hex(input).unwrap(); + let expected_output_sha256 = Vec::from_hex(expected_output_sha256).unwrap(); + let output = sha256(black_box(&input)); + if output != *expected_output_sha256 { + panic!( + "sha256 test {i} failed on input: {:?}.\nexpected: {:?},\ngot: {:?}", + input, expected_output_sha256, output + ); + } + let expected_output_sha512 = Vec::from_hex(expected_output_sha512).unwrap(); + let output = sha512(black_box(&input)); + if output != *expected_output_sha512 { + panic!( + "sha512 test {i} failed on input: {:?}.\nexpected: {:?},\ngot: {:?}", + input, expected_output_sha512, output + ); + } + let expected_output_sha384 = Vec::from_hex(expected_output_sha384).unwrap(); + let output = sha384(black_box(&input)); + if output != *expected_output_sha384 { + panic!( + "sha384 test {i} failed on input: {:?}.\nexpected: {:?},\ngot: {:?}", + input, expected_output_sha384, output + ); + } + } +}