diff --git a/.config/nextest.toml b/.config/nextest.toml new file mode 100644 index 0000000000..db7e0b6a84 --- /dev/null +++ b/.config/nextest.toml @@ -0,0 +1,13 @@ +# Nextest configuration for OpenVM project + +# Define test groups with different weights +[[profile.default.overrides]] +# Match all tests with "persistent" in their name +filter = 'test(~persistent)' +# Give these tests 5x the default weight because they use more memory +threads-required = 16 + +# custom profile for heavy tests +[profile.heavy] +# Run fewer tests in parallel for heavy workloads +test-threads = 2 diff --git a/.github/workflows/benchmark-call.yml b/.github/workflows/benchmark-call.yml index 737e1c81ed..f2d1228cc1 100644 --- a/.github/workflows/benchmark-call.yml +++ b/.github/workflows/benchmark-call.yml @@ -250,9 +250,15 @@ jobs: fi ########################################################################## - # Update s3 for latest main metrics upon a push event # + # Update s3 for latest branch metrics upon a push event # ########################################################################## - - name: Update latest main result in s3 - if: github.event_name == 'push' && github.ref == 'refs/heads/main' + - name: Update latest branch result in s3 + if: github.event_name == 'push' run: | - s5cmd cp $METRIC_PATH "${{ env.S3_METRICS_PATH }}/main-${METRIC_NAME}.json" + if [[ "${{ github.ref }}" == "refs/heads/main" ]]; then + # for backwards compatibility + REF_HASH="main" + else + REF_HASH=$(echo "${{ github.ref }}" | sha256sum | cut -d' ' -f1) + fi + s5cmd cp $METRIC_PATH "${{ env.S3_METRICS_PATH }}/${REF_HASH}-${METRIC_NAME}.json" diff --git a/.github/workflows/benchmarks-execute.yml b/.github/workflows/benchmarks-execute.yml index 741ccdb0f1..337df60d1a 100644 --- a/.github/workflows/benchmarks-execute.yml +++ b/.github/workflows/benchmarks-execute.yml @@ -1,8 +1,9 @@ -name: "benchmarks-execute" +name: "Execution benchmarks" on: push: - branches: ["main"] + # TODO(ayush): remove after feat/new-execution is merged + branches: ["main", "feat/new-execution"] pull_request: types: [opened, synchronize, reopened, labeled] branches: ["**"] @@ -18,95 +19,83 @@ on: - ".github/workflows/benchmarks-execute.yml" workflow_dispatch: +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }} + cancel-in-progress: true + env: CARGO_TERM_COLOR: always jobs: - execute-benchmarks: + codspeed-walltime-benchmarks: + name: Run codspeed walltime benchmarks runs-on: - runs-on=${{ github.run_id }} - - runner=8cpu-linux-x64 + - family=m5a.xlarge # 2.5Ghz clock speed + - image=ubuntu24-full-x64 + - extras=s3-cache + + env: + CODSPEED_RUNNER_MODE: walltime + steps: + - uses: runs-on/action@v1 - uses: actions/checkout@v4 - - - name: Set up Rust - uses: actions-rs/toolchain@v1 + - uses: dtolnay/rust-toolchain@stable + - uses: Swatinem/rust-cache@v2 with: - profile: minimal - toolchain: stable - override: true + cache-on-failure: true - - name: Run execution benchmarks - working-directory: benchmarks/execute - run: cargo run | tee benchmark_output.log + - name: Install cargo-binstall + uses: cargo-bins/cargo-binstall@main + - name: Install codspeed + run: cargo binstall --no-confirm --force cargo-codspeed - - name: Parse benchmark results + - name: Build benchmarks working-directory: benchmarks/execute - run: | - # Determine if running in GitHub Actions environment - if [ -n "$GITHUB_STEP_SUMMARY" ]; then - SUMMARY_FILE="$GITHUB_STEP_SUMMARY" - echo "### Benchmark Results Summary" >> "$SUMMARY_FILE" - else - SUMMARY_FILE="benchmark_summary.md" - echo "### Benchmark Results Summary" > "$SUMMARY_FILE" - echo "Saving summary to $SUMMARY_FILE" - fi - - # Set up summary table header - echo "| Program | Total Time (ms) |" >> "$SUMMARY_FILE" - echo "| ------- | --------------- |" >> "$SUMMARY_FILE" - - # Variables to track current program and total time - current_program="" - total_time=0 - - # Process the output file line by line - while IFS= read -r line; do - # Check if line contains "Running program" message - if [[ $line =~ i\ \[info\]:\ Running\ program:\ ([a-zA-Z0-9_-]+) ]]; then - # If we were processing a program, output its results - if [[ -n "$current_program" ]]; then - echo "| $current_program | $total_time |" >> "$SUMMARY_FILE" - fi - - # Start tracking new program - current_program="${BASH_REMATCH[1]}" - total_time=0 - fi - - # Check for program completion to catch programs that might have no execution segments - if [[ $line =~ i\ \[info\]:\ Completed\ program:\ ([a-zA-Z0-9_-]+) ]]; then - completed_program="${BASH_REMATCH[1]}" - # If no segments were found for this program, ensure it's still in the output - if [[ "$current_program" == "$completed_program" && $total_time == 0 ]]; then - echo "| $current_program | 0 |" >> "$SUMMARY_FILE" - current_program="" - fi - fi - - # Check if line contains execution time (looking for the format with ms or s) - if [[ $line =~ execute_segment\ \[\ ([0-9.]+)(ms|s)\ \|\ [0-9.]+%\ \]\ segment ]]; then - segment_time="${BASH_REMATCH[1]}" - unit="${BASH_REMATCH[2]}" + run: cargo codspeed build --profile maxperf + - name: Run benchmarks + uses: CodSpeedHQ/action@v3 + with: + working-directory: benchmarks/execute + run: cargo codspeed run + token: ${{ secrets.CODSPEED_TOKEN }} + env: + CODSPEED_RUNNER_MODE: walltime + + codspeed-instrumentation-benchmarks: + name: Run codspeed instrumentation benchmarks + runs-on: + - runs-on=${{ github.run_id }} + - family=m5a.xlarge + - image=ubuntu24-full-x64 + - extras=s3-cache + if: github.event_name != 'pull_request' - # Convert to milliseconds if in seconds - if [[ "$unit" == "s" ]]; then - segment_time=$(echo "scale=6; $segment_time * 1000" | bc) - fi + env: + CODSPEED_RUNNER_MODE: instrumentation - # Add segment time to total - total_time=$(echo "scale=6; $total_time + $segment_time" | bc) - fi - done < benchmark_output.log + steps: + - uses: runs-on/action@v1 + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@stable + - uses: Swatinem/rust-cache@v2 + with: + cache-on-failure: true - # Output the last program result if there was one - if [[ -n "$current_program" ]]; then - echo "| $current_program | $total_time |" >> "$SUMMARY_FILE" - fi + - name: Install cargo-binstall + uses: cargo-bins/cargo-binstall@main + - name: Install codspeed + run: cargo binstall --no-confirm --force cargo-codspeed - # If not in GitHub Actions, print the summary to the terminal - if [ -z "$GITHUB_STEP_SUMMARY" ]; then - echo -e "\nBenchmark Summary:" - cat "$SUMMARY_FILE" - fi + - name: Build benchmarks + working-directory: benchmarks/execute + run: cargo codspeed build + - name: Run benchmarks + uses: CodSpeedHQ/action@v3 + with: + working-directory: benchmarks/execute + run: cargo codspeed run + token: ${{ secrets.CODSPEED_TOKEN }} + env: + CODSPEED_RUNNER_MODE: instrumentation diff --git a/.github/workflows/benchmarks.yml b/.github/workflows/benchmarks.yml index 3c2b02c574..7b41d09883 100644 --- a/.github/workflows/benchmarks.yml +++ b/.github/workflows/benchmarks.yml @@ -2,7 +2,7 @@ name: "OpenVM Benchmarks: Coordinate Runner & Reporting" on: push: - branches: ["main"] + branches: ["main", "feat/new-execution"] pull_request: types: [opened, synchronize, reopened, labeled] branches: ["**"] @@ -211,9 +211,21 @@ jobs: json_file_list=$(echo -n "$json_files" | paste -sd "," -) echo $json_file_list - prev_json_files=$(echo $matrix | jq -r ' + # For PRs, get the latest commit from the target branch + if [[ "${{ github.event_name }}" == "pull_request" ]]; then + if [[ "${{ github.base_ref }}" == "main" ]]; then + REF_HASH="main" + else + REF_HASH=$(echo "refs/heads/${{ github.base_ref }}" | sha256sum | cut -d' ' -f1) + fi + echo "Target branch REF_HASH: $REF_HASH" + else + REF_HASH="main" + fi + + prev_json_files=$(echo $matrix | jq -r --arg target "$REF_HASH" ' .[] | - "main-\(.id).json"') + "\($target)-\(.id).json"') prev_json_file_list=$(echo -n "$prev_json_files" | paste -sd "," -) echo $prev_json_file_list diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 574a49be15..7f3df63f6f 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -16,6 +16,7 @@ jobs: runs-on: - runs-on=${{ github.run_id }} - runner=64cpu-linux-arm64 + - image=ubuntu24-full-arm64 - extras=s3-cache steps: - uses: runs-on/action@v1 diff --git a/.github/workflows/cli.yml b/.github/workflows/cli.yml index 510a124092..d0816f6731 100644 --- a/.github/workflows/cli.yml +++ b/.github/workflows/cli.yml @@ -36,7 +36,8 @@ jobs: runs-on: - runs-on=${{ github.run_id }} - disk=large - - runner=32cpu-linux-arm64 + - runner=64cpu-linux-arm64 + - image=ubuntu24-full-arm64 - extras=s3-cache steps: @@ -47,7 +48,8 @@ jobs: cache-on-failure: true - uses: taiki-e/install-action@nextest - name: Install solc # svm should support arm64 linux - run: (hash svm 2>/dev/null || cargo install --version 0.2.23 svm-rs) && svm install 0.8.19 && solc --version + run: | + (hash svm 2>/dev/null || cargo install --version 0.2.23 svm-rs) && svm install 0.8.19 && solc --version - name: Install cargo-openvm working-directory: crates/cli @@ -80,8 +82,7 @@ jobs: working-directory: crates/cli run: | export RUST_BACKTRACE=1 - cargo build - cargo run --bin cargo-openvm -- openvm keygen --config ./example/app_config.toml --output-dir . + cargo openvm keygen --config ./example/app_config.toml --output-dir . - name: Set USE_LOCAL_OPENVM environment variable run: | @@ -94,4 +95,5 @@ jobs: - name: Run CLI tests working-directory: crates/cli run: | - cargo nextest run --cargo-profile=fast + export SKIP_INSTALL=1 + cargo nextest run --cargo-profile=fast --test-threads=1 diff --git a/.github/workflows/extension-tests.yml b/.github/workflows/extension-tests.yml index ef13b840c6..2ac189374a 100644 --- a/.github/workflows/extension-tests.yml +++ b/.github/workflows/extension-tests.yml @@ -29,7 +29,7 @@ jobs: - { name: "rv32im", path: "rv32im" } - { name: "native", path: "native" } - { name: "keccak256", path: "keccak256" } - - { name: "sha256", path: "sha256" } + - { name: "sha2", path: "sha2" } - { name: "bigint", path: "bigint" } - { name: "algebra", path: "algebra" } - { name: "ecc", path: "ecc" } @@ -40,6 +40,7 @@ jobs: runs-on: - runs-on=${{ github.run_id }} - runner=64cpu-linux-arm64 + - image=ubuntu24-full-arm64 - tag=extension-${{ matrix.extension.name }} - extras=s3-cache @@ -69,7 +70,7 @@ jobs: - name: Run ${{ matrix.extension.name }} circuit crate tests working-directory: extensions/${{ matrix.extension.path }}/circuit - run: cargo nextest run --cargo-profile=fast + run: cargo nextest run --cargo-profile=fast --test-threads=32 - name: Run ${{ matrix.extension.name }} guest crate tests if: hashFiles(format('extensions/{0}/guest', matrix.extension.path)) != '' @@ -86,4 +87,4 @@ jobs: working-directory: extensions/${{ matrix.extension.path }}/tests run: | rustup component add rust-src --toolchain nightly-2025-02-14 - cargo nextest run --cargo-profile=fast --no-tests=pass + cargo nextest run --cargo-profile=fast --profile=heavy --no-tests=pass diff --git a/.github/workflows/guest-lib-tests.yml b/.github/workflows/guest-lib-tests.yml index 1b87b600e2..5c0e3deab3 100644 --- a/.github/workflows/guest-lib-tests.yml +++ b/.github/workflows/guest-lib-tests.yml @@ -41,6 +41,7 @@ jobs: runs-on: - runs-on=${{ github.run_id }} - runner=64cpu-linux-arm64 + - image=ubuntu24-full-arm64 - tag=crate-${{ matrix.crate.name }} - extras=s3-cache diff --git a/.github/workflows/native-compiler.yml b/.github/workflows/native-compiler.yml index af4f39ddff..b79a3cb1c9 100644 --- a/.github/workflows/native-compiler.yml +++ b/.github/workflows/native-compiler.yml @@ -25,6 +25,7 @@ jobs: runs-on: - runs-on=${{ github.run_id }} - runner=64cpu-linux-arm64 + - image=ubuntu24-full-arm64 - extras=s3-cache steps: diff --git a/.github/workflows/primitives.yml b/.github/workflows/primitives.yml index 714230b8cd..4385b1ba5a 100644 --- a/.github/workflows/primitives.yml +++ b/.github/workflows/primitives.yml @@ -8,7 +8,7 @@ on: paths: - "crates/circuits/primitives/**" - "crates/circuits/poseidon2-air/**" - - "crates/circuits/sha256-air/**" + - "crates/circuits/sha2-air/**" - "crates/circuits/mod-builder/**" - "Cargo.toml" - ".github/workflows/primitives.yml" @@ -26,6 +26,7 @@ jobs: runs-on: - runs-on=${{ github.run_id }} - runner=32cpu-linux-arm64 + - image=ubuntu24-full-arm64 - extras=s3-cache steps: @@ -47,8 +48,8 @@ jobs: run: | cargo nextest run --cargo-profile fast --features parallel - - name: Run tests for sha256-air - working-directory: crates/circuits/sha256-air + - name: Run tests for sha2-air + working-directory: crates/circuits/sha2-air run: | cargo nextest run --cargo-profile fast --features parallel diff --git a/.github/workflows/recursion.yml b/.github/workflows/recursion.yml index 814c1fa44a..64538c18c1 100644 --- a/.github/workflows/recursion.yml +++ b/.github/workflows/recursion.yml @@ -26,6 +26,7 @@ jobs: runs-on: - runs-on=${{ github.run_id }} - runner=64cpu-linux-arm64 + - image=ubuntu24-full-arm64 - extras=s3-cache steps: diff --git a/.github/workflows/sdk.yml b/.github/workflows/sdk.yml index e24df21ffe..4d194a03df 100644 --- a/.github/workflows/sdk.yml +++ b/.github/workflows/sdk.yml @@ -26,6 +26,7 @@ jobs: runs-on: - runs-on=${{ github.run_id }} - family=m7a.24xlarge + - image=ubuntu24-full-x64 - disk=large - extras=s3-cache @@ -97,4 +98,10 @@ jobs: working-directory: crates/sdk run: | export RUST_BACKTRACE=1 - cargo nextest run --cargo-profile=fast --test-threads=2 --features parallel,evm-verify + cargo nextest run --cargo-profile=fast --features parallel,evm-verify + + - name: Run ignored tests + working-directory: crates/sdk + if: ${{ github.event_name == 'push' }} + run: | + cargo nextest run --cargo-profile=fast --features parallel,evm-verify --ignored test_static_verifier_custom_pv_handler diff --git a/.github/workflows/vm.yml b/.github/workflows/vm.yml index cb7f2284ca..c8c03dc931 100644 --- a/.github/workflows/vm.yml +++ b/.github/workflows/vm.yml @@ -25,6 +25,7 @@ jobs: runs-on: - runs-on=${{ github.run_id }} - runner=64cpu-linux-arm64 + - image=ubuntu24-full-arm64 - extras=s3-cache steps: @@ -40,3 +41,8 @@ jobs: working-directory: crates/vm run: | cargo nextest run --cargo-profile=fast --features parallel + + - name: Run vm crate tests with basic memory + working-directory: crates/vm + run: | + cargo nextest run --cargo-profile=fast --features parallel,basic-memory diff --git a/.gitignore b/.gitignore index d794a5dc57..15fa79c61f 100644 --- a/.gitignore +++ b/.gitignore @@ -37,3 +37,6 @@ guest.syms # openvm generated files crates/cli/openvm/ + +# samply profile +profile.json.gz diff --git a/Cargo.lock b/Cargo.lock index b1d036d006..68a5d7172b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -34,9 +34,9 @@ dependencies = [ [[package]] name = "adler2" -version = "2.0.0" +version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" +checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa" [[package]] name = "aes" @@ -51,9 +51,9 @@ dependencies = [ [[package]] name = "ahash" -version = "0.8.11" +version = "0.8.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" +checksum = "5a15f179cd60c4584b8a8c596927aadc462e27f2ca70c04e0071964a73ba7a75" dependencies = [ "cfg-if", "once_cell", @@ -76,28 +76,61 @@ version = "0.2.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" +[[package]] +name = "alloy-eip2124" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "741bdd7499908b3aa0b159bba11e71c8cddd009a2c2eb7a06e825f1ec87900a5" +dependencies = [ + "alloy-primitives 1.2.1", + "alloy-rlp", + "crc", + "serde", + "thiserror 2.0.12", +] + [[package]] name = "alloy-eip2930" -version = "0.1.0" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0069cf0642457f87a01a014f6dc29d5d893cd4fd8fddf0c3cdfad1bb3ebafc41" +checksum = "7b82752a889170df67bbb36d42ca63c531eb16274f0d7299ae2a680facba17bd" dependencies = [ - "alloy-primitives 0.8.25", + "alloy-primitives 1.2.1", "alloy-rlp", "serde", ] [[package]] name = "alloy-eip7702" -version = "0.4.2" +version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c986539255fb839d1533c128e190e557e52ff652c9ef62939e233a81dd93f7e" +checksum = "9d4769c6ffddca380b0070d71c8b7f30bed375543fe76bb2f74ec0acf4b7cd16" dependencies = [ - "alloy-primitives 0.8.25", + "alloy-primitives 1.2.1", "alloy-rlp", - "derive_more 1.0.0", "k256 0.13.4 (registry+https://github.com/rust-lang/crates.io-index)", "serde", + "thiserror 2.0.12", +] + +[[package]] +name = "alloy-eips" +version = "1.0.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4134375e533d095e045982cd7684a29c37089ab7a605ecf2b4aa17a5e61d72d3" +dependencies = [ + "alloy-eip2124", + "alloy-eip2930", + "alloy-eip7702", + "alloy-primitives 1.2.1", + "alloy-rlp", + "alloy-serde", + "auto_impl", + "c-kzg", + "derive_more 2.0.1", + "either", + "serde", + "sha2 0.10.9", ] [[package]] @@ -121,10 +154,10 @@ dependencies = [ "bytes", "cfg-if", "const-hex", - "derive_more 0.99.19", + "derive_more 0.99.20", "hex-literal 0.4.1", "itoa", - "ruint 1.12.3", + "ruint 1.15.0", "tiny-keccak", ] @@ -138,10 +171,10 @@ dependencies = [ "bytes", "cfg-if", "const-hex", - "derive_more 0.99.19", + "derive_more 0.99.20", "hex-literal 0.4.1", "itoa", - "ruint 1.12.3", + "ruint 1.15.0", "tiny-keccak", ] @@ -157,15 +190,42 @@ dependencies = [ "const-hex", "derive_more 2.0.1", "foldhash", - "hashbrown 0.15.2", - "indexmap 2.7.1", + "hashbrown 0.15.4", + "indexmap 2.10.0", "itoa", "k256 0.13.4 (registry+https://github.com/rust-lang/crates.io-index)", "keccak-asm", "paste", "proptest", "rand 0.8.5", - "ruint 1.12.3", + "ruint 1.15.0", + "rustc-hash 2.1.1", + "serde", + "sha3", + "tiny-keccak", +] + +[[package]] +name = "alloy-primitives" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6177ed26655d4e84e00b65cb494d4e0b8830e7cae7ef5d63087d445a2600fb55" +dependencies = [ + "alloy-rlp", + "bytes", + "cfg-if", + "const-hex", + "derive_more 2.0.1", + "foldhash", + "hashbrown 0.15.4", + "indexmap 2.10.0", + "itoa", + "k256 0.13.4 (registry+https://github.com/rust-lang/crates.io-index)", + "keccak-asm", + "paste", + "proptest", + "rand 0.9.1", + "ruint 1.15.0", "rustc-hash 2.1.1", "serde", "sha3", @@ -174,9 +234,9 @@ dependencies = [ [[package]] name = "alloy-rlp" -version = "0.3.11" +version = "0.3.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d6c1d995bff8d011f7cd6c81820d51825e6e06d6db73914c1630ecf544d83d6" +checksum = "5f70d83b765fdc080dbcd4f4db70d8d23fe4761f2f02ebfa9146b833900634b4" dependencies = [ "alloy-rlp-derive", "arrayvec", @@ -185,13 +245,24 @@ dependencies = [ [[package]] name = "alloy-rlp-derive" -version = "0.3.11" +version = "0.3.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a40e1ef334153322fd878d07e86af7a529bcb86b2439525920a88eba87bcf943" +checksum = "64b728d511962dda67c1bc7ea7c03736ec275ed2cf4c35d9585298ac9ccf3b73" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", +] + +[[package]] +name = "alloy-serde" +version = "1.0.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06c02a06ae34d2354398dc9d2de0503129c3f0904a3eb791b5d0149f267c2688" +dependencies = [ + "alloy-primitives 1.2.1", + "serde", + "serde_json", ] [[package]] @@ -205,7 +276,7 @@ dependencies = [ "proc-macro-error2", "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] @@ -218,11 +289,11 @@ dependencies = [ "alloy-sol-macro-input", "const-hex", "heck", - "indexmap 2.7.1", + "indexmap 2.10.0", "proc-macro-error2", "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", "syn-solidity", "tiny-keccak", ] @@ -241,7 +312,7 @@ dependencies = [ "proc-macro2", "quote", "serde_json", - "syn 2.0.98", + "syn 2.0.104", "syn-solidity", ] @@ -252,7 +323,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6d162f8524adfdfb0e4bd0505c734c985f3e2474eb022af32eef0d52a4f3935c" dependencies = [ "serde", - "winnow 0.7.3", + "winnow 0.7.12", ] [[package]] @@ -300,9 +371,9 @@ dependencies = [ [[package]] name = "anstream" -version = "0.6.18" +version = "0.6.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8acc5369981196006228e28809f761875c0327210a891e941f4c683b3a99529b" +checksum = "301af1932e46185686725e0fad2f8f2aa7da69dd70bf6ecc44d6b703844a3933" dependencies = [ "anstyle", "anstyle-parse", @@ -315,44 +386,44 @@ dependencies = [ [[package]] name = "anstyle" -version = "1.0.10" +version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "55cc3b69f167a1ef2e161439aa98aed94e6028e5f9a59be9a6ffb47aef1651f9" +checksum = "862ed96ca487e809f1c8e5a8447f6ee2cf102f846893800b20cebdf541fc6bbd" [[package]] name = "anstyle-parse" -version = "0.2.6" +version = "0.2.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3b2d16507662817a6a20a9ea92df6652ee4f94f914589377d69f3b21bc5798a9" +checksum = "4e7644824f0aa2c7b9384579234ef10eb7efb6a0deb83f9630a49594dd9c15c2" dependencies = [ "utf8parse", ] [[package]] name = "anstyle-query" -version = "1.1.2" +version = "1.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "79947af37f4177cfead1110013d678905c37501914fba0efea834c3fe9a8d60c" +checksum = "6c8bdeb6047d8983be085bab0ba1472e6dc604e7041dbf6fcd5e71523014fae9" dependencies = [ "windows-sys 0.59.0", ] [[package]] name = "anstyle-wincon" -version = "3.0.7" +version = "3.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca3534e77181a9cc07539ad51f2141fe32f6c3ffd4df76db8ad92346b003ae4e" +checksum = "403f75924867bb1033c59fbf0797484329750cfbe3c4325cd33127941fabc882" dependencies = [ "anstyle", - "once_cell", + "once_cell_polyfill", "windows-sys 0.59.0", ] [[package]] name = "anyhow" -version = "1.0.96" +version = "1.0.98" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6b964d184e89d9b6b67dd2715bc8e74cf3107fb2b529990c90cf517326150bf4" +checksum = "e16d2d3311acee920a9eb8d33b8cbc1787ce4a264e85f964c2404b969bdcd487" [[package]] name = "approx" @@ -379,6 +450,18 @@ dependencies = [ "yansi 0.5.1", ] +[[package]] +name = "ark-bls12-381" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3df4dcc01ff89867cd86b0da835f23c3f02738353aaee7dde7495af71363b8d5" +dependencies = [ + "ark-ec 0.5.0", + "ark-ff 0.5.0", + "ark-serialize 0.5.0", + "ark-std 0.5.0", +] + [[package]] name = "ark-bn254" version = "0.3.0" @@ -401,6 +484,18 @@ dependencies = [ "ark-std 0.4.0", ] +[[package]] +name = "ark-bn254" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d69eab57e8d2663efa5c63135b2af4f396d66424f88954c21104125ab6b3e6bc" +dependencies = [ + "ark-ec 0.5.0", + "ark-ff 0.5.0", + "ark-r1cs-std", + "ark-std 0.5.0", +] + [[package]] name = "ark-ec" version = "0.3.0" @@ -422,7 +517,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "defd9a439d56ac24968cca0571f598a61bc8c55f71d50a89cda591cb750670ba" dependencies = [ "ark-ff 0.4.2", - "ark-poly", + "ark-poly 0.4.2", "ark-serialize 0.4.2", "ark-std 0.4.0", "derivative", @@ -432,6 +527,27 @@ dependencies = [ "zeroize", ] +[[package]] +name = "ark-ec" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43d68f2d516162846c1238e755a7c4d131b892b70cc70c471a8e3ca3ed818fce" +dependencies = [ + "ahash", + "ark-ff 0.5.0", + "ark-poly 0.5.0", + "ark-serialize 0.5.0", + "ark-std 0.5.0", + "educe", + "fnv", + "hashbrown 0.15.4", + "itertools 0.13.0", + "num-bigint 0.4.6", + "num-integer", + "num-traits", + "zeroize", +] + [[package]] name = "ark-ff" version = "0.3.0" @@ -470,6 +586,26 @@ dependencies = [ "zeroize", ] +[[package]] +name = "ark-ff" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a177aba0ed1e0fbb62aa9f6d0502e9b46dad8c2eab04c14258a1212d2557ea70" +dependencies = [ + "ark-ff-asm 0.5.0", + "ark-ff-macros 0.5.0", + "ark-serialize 0.5.0", + "ark-std 0.5.0", + "arrayvec", + "digest 0.10.7", + "educe", + "itertools 0.13.0", + "num-bigint 0.4.6", + "num-traits", + "paste", + "zeroize", +] + [[package]] name = "ark-ff-asm" version = "0.3.0" @@ -490,6 +626,16 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "ark-ff-asm" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62945a2f7e6de02a31fe400aa489f0e0f5b2502e69f95f853adb82a96c7a6b60" +dependencies = [ + "quote", + "syn 2.0.104", +] + [[package]] name = "ark-ff-macros" version = "0.3.0" @@ -515,6 +661,19 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "ark-ff-macros" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09be120733ee33f7693ceaa202ca41accd5653b779563608f1234f78ae07c4b3" +dependencies = [ + "num-bigint 0.4.6", + "num-traits", + "proc-macro2", + "quote", + "syn 2.0.104", +] + [[package]] name = "ark-poly" version = "0.4.2" @@ -528,6 +687,50 @@ dependencies = [ "hashbrown 0.13.2", ] +[[package]] +name = "ark-poly" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "579305839da207f02b89cd1679e50e67b4331e2f9294a57693e5051b7703fe27" +dependencies = [ + "ahash", + "ark-ff 0.5.0", + "ark-serialize 0.5.0", + "ark-std 0.5.0", + "educe", + "fnv", + "hashbrown 0.15.4", +] + +[[package]] +name = "ark-r1cs-std" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "941551ef1df4c7a401de7068758db6503598e6f01850bdb2cfdb614a1f9dbea1" +dependencies = [ + "ark-ec 0.5.0", + "ark-ff 0.5.0", + "ark-relations", + "ark-std 0.5.0", + "educe", + "num-bigint 0.4.6", + "num-integer", + "num-traits", + "tracing", +] + +[[package]] +name = "ark-relations" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec46ddc93e7af44bcab5230937635b06fb5744464dd6a7e7b083e80ebd274384" +dependencies = [ + "ark-ff 0.5.0", + "ark-std 0.5.0", + "tracing", + "tracing-subscriber 0.2.25", +] + [[package]] name = "ark-serialize" version = "0.3.0" @@ -544,12 +747,25 @@ version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "adb7b85a02b83d2f22f89bd5cac66c9c89474240cb6207cb1efc16d098e822a5" dependencies = [ - "ark-serialize-derive", + "ark-serialize-derive 0.4.2", "ark-std 0.4.0", "digest 0.10.7", "num-bigint 0.4.6", ] +[[package]] +name = "ark-serialize" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f4d068aaf107ebcd7dfb52bc748f8030e0fc930ac8e360146ca54c1203088f7" +dependencies = [ + "ark-serialize-derive 0.5.0", + "ark-std 0.5.0", + "arrayvec", + "digest 0.10.7", + "num-bigint 0.4.6", +] + [[package]] name = "ark-serialize-derive" version = "0.4.2" @@ -561,6 +777,17 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "ark-serialize-derive" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "213888f660fddcca0d257e88e54ac05bca01885f258ccdf695bafd77031bb69d" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.104", +] + [[package]] name = "ark-std" version = "0.3.0" @@ -581,6 +808,16 @@ dependencies = [ "rand 0.8.5", ] +[[package]] +name = "ark-std" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "246a225cc6131e9ee4f24619af0f19d67761fff15d7ccc22e42b80846e69449a" +dependencies = [ + "num-traits", + "rand 0.8.5", +] + [[package]] name = "arrayref" version = "0.3.9" @@ -604,24 +841,30 @@ dependencies = [ [[package]] name = "async-trait" -version = "0.1.86" +version = "0.1.88" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "644dd749086bf3771a2fbc5f256fdb982d53f011c7d5d560304eafeecebce79d" +checksum = "e539d3fca749fcee5236ab05e93a52867dd549cc157c8cb7f99595f3cedffdb5" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] name = "atomic" -version = "0.6.0" +version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d818003e740b63afc82337e3160717f4f63078720a810b7b903e70a5d1d2994" +checksum = "a89cbf775b137e9b968e67227ef7f775587cde3fd31b0d8599dbd0f598a48340" dependencies = [ "bytemuck", ] +[[package]] +name = "atomic-waker" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" + [[package]] name = "aurora-engine-modexp" version = "1.2.0" @@ -634,26 +877,26 @@ dependencies = [ [[package]] name = "auto_impl" -version = "1.2.1" +version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e12882f59de5360c748c4cbf569a042d5fb0eb515f7bea9c1f470b47f6ffbd73" +checksum = "ffdcb70bdbc4d478427380519163274ac86e52916e10f0a8889adf0f96d3fee7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] name = "autocfg" -version = "1.4.0" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" +checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" [[package]] name = "aws-config" -version = "1.5.18" +version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90aff65e86db5fe300752551c1b015ef72b708ac54bded8ef43d0d53cb7cb0b1" +checksum = "c18d005c70d2b9c0c1ea8876c039db0ec7fb71164d25c73ccea21bf41fd02171" dependencies = [ "aws-credential-types", "aws-runtime", @@ -661,7 +904,7 @@ dependencies = [ "aws-sdk-ssooidc", "aws-sdk-sts", "aws-smithy-async", - "aws-smithy-http 0.61.1", + "aws-smithy-http", "aws-smithy-json", "aws-smithy-runtime", "aws-smithy-runtime-api", @@ -670,7 +913,7 @@ dependencies = [ "bytes", "fastrand", "hex", - "http 0.2.12", + "http 1.3.1", "ring", "time", "tokio", @@ -681,9 +924,9 @@ dependencies = [ [[package]] name = "aws-credential-types" -version = "1.2.1" +version = "1.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60e8f6b615cb5fc60a98132268508ad104310f0cfb25a1c22eee76efdf9154da" +checksum = "687bc16bc431a8533fe0097c7f0182874767f920989d7260950172ae8e3c4465" dependencies = [ "aws-smithy-async", "aws-smithy-runtime-api", @@ -691,17 +934,40 @@ dependencies = [ "zeroize", ] +[[package]] +name = "aws-lc-rs" +version = "1.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08b5d4e069cbc868041a64bd68dc8cb39a0d79585cd6c5a24caa8c2d622121be" +dependencies = [ + "aws-lc-sys", + "zeroize", +] + +[[package]] +name = "aws-lc-sys" +version = "0.30.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dbfd150b5dbdb988bcc8fb1fe787eb6b7ee6180ca24da683b61ea5405f3d43ff" +dependencies = [ + "bindgen", + "cc", + "cmake", + "dunce", + "fs_extra", +] + [[package]] name = "aws-runtime" -version = "1.5.5" +version = "1.5.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76dd04d39cc12844c0994f2c9c5a6f5184c22e9188ec1ff723de41910a21dcad" +checksum = "4f6c68419d8ba16d9a7463671593c54f81ba58cab466e9b759418da606dcc2e2" dependencies = [ "aws-credential-types", "aws-sigv4", "aws-smithy-async", "aws-smithy-eventstream", - "aws-smithy-http 0.60.12", + "aws-smithy-http", "aws-smithy-runtime", "aws-smithy-runtime-api", "aws-smithy-types", @@ -710,7 +976,6 @@ dependencies = [ "fastrand", "http 0.2.12", "http-body 0.4.6", - "once_cell", "percent-encoding", "pin-project-lite", "tracing", @@ -719,9 +984,9 @@ dependencies = [ [[package]] name = "aws-sdk-s3" -version = "1.78.0" +version = "1.96.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3038614b6cf7dd68d9a7b5b39563d04337eb3678d1d4173e356e927b0356158a" +checksum = "6e25d24de44b34dcdd5182ac4e4c6f07bcec2661c505acef94c0d293b65505fe" dependencies = [ "aws-credential-types", "aws-runtime", @@ -729,7 +994,7 @@ dependencies = [ "aws-smithy-async", "aws-smithy-checksums", "aws-smithy-eventstream", - "aws-smithy-http 0.61.1", + "aws-smithy-http", "aws-smithy-json", "aws-smithy-runtime", "aws-smithy-runtime-api", @@ -741,70 +1006,70 @@ dependencies = [ "hex", "hmac", "http 0.2.12", + "http 1.3.1", "http-body 0.4.6", "lru", - "once_cell", "percent-encoding", "regex-lite", - "sha2", + "sha2 0.10.9", "tracing", "url", ] [[package]] name = "aws-sdk-sso" -version = "1.61.0" +version = "1.74.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e65ff295979977039a25f5a0bf067a64bc5e6aa38f3cef4037cf42516265553c" +checksum = "e0a69de9c1b9272da2872af60c7402683e7f45c06267735b4332deacb203239b" dependencies = [ "aws-credential-types", "aws-runtime", "aws-smithy-async", - "aws-smithy-http 0.61.1", + "aws-smithy-http", "aws-smithy-json", "aws-smithy-runtime", "aws-smithy-runtime-api", "aws-smithy-types", "aws-types", "bytes", + "fastrand", "http 0.2.12", - "once_cell", "regex-lite", "tracing", ] [[package]] name = "aws-sdk-ssooidc" -version = "1.62.0" +version = "1.75.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91430a60f754f235688387b75ee798ef00cfd09709a582be2b7525ebb5306d4f" +checksum = "f0b161d836fac72bdd5ac1a4cd1cdc38ab888c7af26cfd95f661be4409505e63" dependencies = [ "aws-credential-types", "aws-runtime", "aws-smithy-async", - "aws-smithy-http 0.61.1", + "aws-smithy-http", "aws-smithy-json", "aws-smithy-runtime", "aws-smithy-runtime-api", "aws-smithy-types", "aws-types", "bytes", + "fastrand", "http 0.2.12", - "once_cell", "regex-lite", "tracing", ] [[package]] name = "aws-sdk-sts" -version = "1.62.0" +version = "1.76.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9276e139d39fff5a0b0c984fc2d30f970f9a202da67234f948fda02e5bea1dbe" +checksum = "cb1cd79a3412751a341a28e2cd0d6fa4345241976da427b075a0c0cd5409f886" dependencies = [ "aws-credential-types", "aws-runtime", "aws-smithy-async", - "aws-smithy-http 0.61.1", + "aws-smithy-http", "aws-smithy-json", "aws-smithy-query", "aws-smithy-runtime", @@ -812,21 +1077,21 @@ dependencies = [ "aws-smithy-types", "aws-smithy-xml", "aws-types", + "fastrand", "http 0.2.12", - "once_cell", "regex-lite", "tracing", ] [[package]] name = "aws-sigv4" -version = "1.2.9" +version = "1.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9bfe75fad52793ce6dec0dc3d4b1f388f038b5eb866c8d4d7f3a8e21b5ea5051" +checksum = "ddfb9021f581b71870a17eac25b52335b82211cdc092e02b6876b2bcefa61666" dependencies = [ "aws-credential-types", "aws-smithy-eventstream", - "aws-smithy-http 0.60.12", + "aws-smithy-http", "aws-smithy-runtime-api", "aws-smithy-types", "bytes", @@ -835,12 +1100,11 @@ dependencies = [ "hex", "hmac", "http 0.2.12", - "http 1.2.0", - "once_cell", + "http 1.3.1", "p256 0.11.1", "percent-encoding", "ring", - "sha2", + "sha2 0.10.9", "subtle", "time", "tracing", @@ -849,9 +1113,9 @@ dependencies = [ [[package]] name = "aws-smithy-async" -version = "1.2.4" +version = "1.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa59d1327d8b5053c54bf2eaae63bf629ba9e904434d0835a28ed3c0ed0a614e" +checksum = "1e190749ea56f8c42bf15dd76c65e14f8f765233e6df9b0506d9d934ebef867c" dependencies = [ "futures-util", "pin-project-lite", @@ -860,31 +1124,29 @@ dependencies = [ [[package]] name = "aws-smithy-checksums" -version = "0.63.0" +version = "0.63.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db2dc8d842d872529355c72632de49ef8c5a2949a4472f10e802f28cf925770c" +checksum = "244f00666380d35c1c76b90f7b88a11935d11b84076ac22a4c014ea0939627af" dependencies = [ - "aws-smithy-http 0.60.12", + "aws-smithy-http", "aws-smithy-types", "bytes", - "crc32c", - "crc32fast", - "crc64fast-nvme", + "crc-fast", "hex", "http 0.2.12", "http-body 0.4.6", "md-5", "pin-project-lite", "sha1", - "sha2", + "sha2 0.10.9", "tracing", ] [[package]] name = "aws-smithy-eventstream" -version = "0.60.7" +version = "0.60.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "461e5e02f9864cba17cff30f007c2e37ade94d01e87cdb5204e44a84e6d38c17" +checksum = "338a3642c399c0a5d157648426110e199ca7fd1c689cc395676b81aa563700c4" dependencies = [ "aws-smithy-types", "bytes", @@ -893,18 +1155,19 @@ dependencies = [ [[package]] name = "aws-smithy-http" -version = "0.60.12" +version = "0.62.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7809c27ad8da6a6a68c454e651d4962479e81472aa19ae99e59f9aba1f9713cc" +checksum = "99335bec6cdc50a346fda1437f9fefe33abf8c99060739a546a16457f2862ca9" dependencies = [ + "aws-smithy-eventstream", "aws-smithy-runtime-api", "aws-smithy-types", "bytes", "bytes-utils", "futures-core", "http 0.2.12", + "http 1.3.1", "http-body 0.4.6", - "once_cell", "percent-encoding", "pin-project-lite", "pin-utils", @@ -912,35 +1175,52 @@ dependencies = [ ] [[package]] -name = "aws-smithy-http" -version = "0.61.1" +name = "aws-smithy-http-client" +version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6f276f21c7921fe902826618d1423ae5bf74cf8c1b8472aee8434f3dfd31824" +checksum = "f108f1ca850f3feef3009bdcc977be201bca9a91058864d9de0684e64514bee0" dependencies = [ - "aws-smithy-eventstream", + "aws-smithy-async", "aws-smithy-runtime-api", "aws-smithy-types", - "bytes", - "bytes-utils", - "futures-core", + "h2 0.3.27", + "h2 0.4.11", "http 0.2.12", + "http 1.3.1", "http-body 0.4.6", - "once_cell", - "percent-encoding", + "hyper 0.14.32", + "hyper 1.6.0", + "hyper-rustls 0.24.2", + "hyper-rustls 0.27.7", + "hyper-util", "pin-project-lite", - "pin-utils", + "rustls 0.21.12", + "rustls 0.23.29", + "rustls-native-certs 0.8.1", + "rustls-pki-types", + "tokio", + "tower", "tracing", ] [[package]] name = "aws-smithy-json" -version = "0.61.2" +version = "0.61.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "623a51127f24c30776c8b374295f2df78d92517386f77ba30773f15a30ce1422" +checksum = "a16e040799d29c17412943bdbf488fd75db04112d0c0d4b9290bacf5ae0014b9" dependencies = [ "aws-smithy-types", ] +[[package]] +name = "aws-smithy-observability" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9364d5989ac4dd918e5cc4c4bdcc61c9be17dcd2586ea7f69e348fc7c6cab393" +dependencies = [ + "aws-smithy-runtime-api", +] + [[package]] name = "aws-smithy-query" version = "0.60.7" @@ -953,42 +1233,39 @@ dependencies = [ [[package]] name = "aws-smithy-runtime" -version = "1.7.8" +version = "1.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d526a12d9ed61fadefda24abe2e682892ba288c2018bcb38b1b4c111d13f6d92" +checksum = "c3aaec682eb189e43c8a19c3dab2fe54590ad5f2cc2d26ab27608a20f2acf81c" dependencies = [ "aws-smithy-async", - "aws-smithy-http 0.60.12", + "aws-smithy-http", + "aws-smithy-http-client", + "aws-smithy-observability", "aws-smithy-runtime-api", "aws-smithy-types", "bytes", "fastrand", - "h2", "http 0.2.12", + "http 1.3.1", "http-body 0.4.6", "http-body 1.0.1", - "httparse", - "hyper", - "hyper-rustls", - "once_cell", "pin-project-lite", "pin-utils", - "rustls", "tokio", "tracing", ] [[package]] name = "aws-smithy-runtime-api" -version = "1.7.3" +version = "1.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "92165296a47a812b267b4f41032ff8069ab7ff783696d217f0994a0d7ab585cd" +checksum = "9852b9226cb60b78ce9369022c0df678af1cac231c882d5da97a0c4e03be6e67" dependencies = [ "aws-smithy-async", "aws-smithy-types", "bytes", "http 0.2.12", - "http 1.2.0", + "http 1.3.1", "pin-project-lite", "tokio", "tracing", @@ -997,16 +1274,16 @@ dependencies = [ [[package]] name = "aws-smithy-types" -version = "1.2.13" +version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7b8a53819e42f10d0821f56da995e1470b199686a1809168db6ca485665f042" +checksum = "d498595448e43de7f4296b7b7a18a8a02c61ec9349128c80a368f7c3b4ab11a8" dependencies = [ "base64-simd", "bytes", "bytes-utils", "futures-core", "http 0.2.12", - "http 1.2.0", + "http 1.3.1", "http-body 0.4.6", "http-body 1.0.1", "http-body-util", @@ -1023,18 +1300,18 @@ dependencies = [ [[package]] name = "aws-smithy-xml" -version = "0.60.9" +version = "0.60.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab0b0166827aa700d3dc519f72f8b3a91c35d0b8d042dc5d643a91e6f80648fc" +checksum = "3db87b96cb1b16c024980f133968d52882ca0daaee3a086c6decc500f6c99728" dependencies = [ "xmlparser", ] [[package]] name = "aws-types" -version = "1.3.5" +version = "1.3.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dfbd0a668309ec1f66c0f6bda4840dd6d4796ae26d699ebc266d7cc95c6d040f" +checksum = "8a322fec39e4df22777ed3ad8ea868ac2f94cd15e1a55f6ee8d8d6305057689a" dependencies = [ "aws-credential-types", "aws-smithy-async", @@ -1046,9 +1323,9 @@ dependencies = [ [[package]] name = "backtrace" -version = "0.3.74" +version = "0.3.75" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d82cb332cdfaed17ae235a638438ac4d4839913cc2af585c3c6746e8f8bee1a" +checksum = "6806a6321ec58106fea15becdad98371e28d92ccbc7c8f1b3b6dd724fe8f1002" dependencies = [ "addr2line", "cfg-if", @@ -1102,9 +1379,9 @@ dependencies = [ [[package]] name = "base64ct" -version = "1.6.0" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" +checksum = "55248b47b0caf0546f7988906588779981c43bb1bc9d0c44087278f80cdb44ba" [[package]] name = "bincode" @@ -1115,6 +1392,29 @@ dependencies = [ "serde", ] +[[package]] +name = "bindgen" +version = "0.69.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "271383c67ccabffb7381723dea0672a673f292304fcb45c01cc648c7a8d58088" +dependencies = [ + "bitflags 2.9.1", + "cexpr", + "clang-sys", + "itertools 0.12.1", + "lazy_static", + "lazycell", + "log", + "prettyplease", + "proc-macro2", + "quote", + "regex", + "rustc-hash 1.1.0", + "shlex", + "syn 2.0.104", + "which", +] + [[package]] name = "bit-set" version = "0.5.3" @@ -1147,9 +1447,9 @@ checksum = "5e764a1d40d510daf35e07be9eb06e75770908c27d411ee6c92109c9840eaaf7" [[package]] name = "bitcode" -version = "0.6.5" +version = "0.6.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "18c1406a27371b2f76232a2259df6ab607b91b5a0a7476a7729ff590df5a969a" +checksum = "cf300f4aa6e66f3bdff11f1236a88c622fe47ea814524792240b4d554d9858ee" dependencies = [ "arrayvec", "bitcode_derive", @@ -1166,7 +1466,23 @@ checksum = "42b6b4cb608b8282dc3b53d0f4c9ab404655d562674c682db7e6c0458cc83c23" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", +] + +[[package]] +name = "bitcoin-io" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b47c4ab7a93edb0c7198c5535ed9b52b63095f4e9b45279c6736cec4b856baf" + +[[package]] +name = "bitcoin_hashes" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb18c03d0db0247e147a21a6faafd5a7eb851c743db062de72018b6b7e8e4d16" +dependencies = [ + "bitcoin-io", + "hex-conservative", ] [[package]] @@ -1177,9 +1493,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.8.0" +version = "2.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f68f53c83ab957f72c32642f3868eec03eb974d1fb82e453128456482613d36" +checksum = "1b8e56985ec62d17e9c1001dc89c88ecd7dc08e47eba5ec7c29c7b5eeecde967" [[package]] name = "bitvec" @@ -1215,16 +1531,24 @@ dependencies = [ [[package]] name = "blake3" -version = "1.6.0" +version = "1.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1230237285e3e10cde447185e8975408ae24deaa67205ce684805c25bc0c7937" +checksum = "3888aaa89e4b2a40fca9848e400f6a658a5a3978de7be858e209cafa8be9a4a0" dependencies = [ "arrayref", "arrayvec", "cc", "cfg-if", "constant_time_eq 0.3.1", - "memmap2", +] + +[[package]] +name = "block-buffer" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4152116fd6e9dadb291ae18fc1ec3575ed6d84c29642d97890f4b4a3417297e4" +dependencies = [ + "generic-array", ] [[package]] @@ -1251,9 +1575,9 @@ dependencies = [ [[package]] name = "blst" -version = "0.3.14" +version = "0.3.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "47c79a94619fade3c0b887670333513a67ac28a6a7e653eb260bf0d4103db38d" +checksum = "4fd49896f12ac9b6dcd7a5998466b9b58263a695a3dd1ecc1aaca2e12a90b080" dependencies = [ "cc", "glob", @@ -1267,7 +1591,7 @@ version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c34e20109dce74b02019885a01edc8ca485380a297ed8d6eb9e63e657774074b" dependencies = [ - "getrandom 0.2.15", + "getrandom 0.2.16", "js-sys", "primitive-types", "rustc-hex", @@ -1278,9 +1602,9 @@ dependencies = [ [[package]] name = "bon" -version = "3.3.2" +version = "3.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fe7acc34ff59877422326db7d6f2d845a582b16396b6b08194942bf34c6528ab" +checksum = "f61138465baf186c63e8d9b6b613b508cd832cba4ce93cf37ce5f096f91ac1a6" dependencies = [ "bon-macros", "rustversion", @@ -1288,9 +1612,9 @@ dependencies = [ [[package]] name = "bon-macros" -version = "3.3.2" +version = "3.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4159dd617a7fbc9be6a692fe69dc2954f8e6bb6bb5e4d7578467441390d77fd0" +checksum = "40d1dad34aa19bf02295382f08d9bc40651585bd497266831d40ee6296fb49ca" dependencies = [ "darling", "ident_case", @@ -1298,7 +1622,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] @@ -1321,7 +1645,7 @@ dependencies = [ "proc-macro-crate", "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] @@ -1342,21 +1666,21 @@ checksum = "b4ae4235e6dac0694637c763029ecea1a2ec9e4e06ec2729bd21ba4d9c863eb7" [[package]] name = "bumpalo" -version = "3.17.0" +version = "3.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1628fb46dfa0b37568d12e5edd512553eccf6a22a78e8bde00bb4aed84d5bdbf" +checksum = "46c5e41b57b8bba42a04676d81cb89e9ee8e859a1a66f80a5a72e1cb76b34d43" [[package]] name = "byte-slice-cast" -version = "1.2.2" +version = "1.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3ac9f8b63eca6fd385229b3675f6cc0dc5c8a5c8a54a59d4f52ffd670d87b0c" +checksum = "7575182f7272186991736b70173b0ea045398f984bf5ebbb3804736ce1330c9d" [[package]] name = "bytemuck" -version = "1.21.0" +version = "1.23.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef657dfab802224e671f5818e9a4935f9b1957ed18e58292690cc39e7a4092a3" +checksum = "5c76a5792e44e4abe34d3abf15636779261d45a7450612059293d1d2cfc63422" [[package]] name = "byteorder" @@ -1366,9 +1690,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "bytes" -version = "1.10.0" +version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f61dac84819c6588b558454b194026eb1f09c293b9036ae9b159e74e73ab6cf9" +checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a" dependencies = [ "serde", ] @@ -1405,9 +1729,9 @@ dependencies = [ [[package]] name = "c-kzg" -version = "1.0.3" +version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0307f72feab3300336fb803a57134159f6e20139af1357f36c54cb90d8e8928" +checksum = "7318cfa722931cb5fe0838b98d3ce5621e75f6a6408abc21721d80de9223f2e4" dependencies = [ "blst", "cc", @@ -1420,16 +1744,16 @@ dependencies = [ [[package]] name = "camino" -version = "1.1.9" +version = "1.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b96ec4966b5813e2c0507c1f86115c8c5abaadc3980879c3424042a02fd1ad3" +checksum = "0da45bc31171d8d6960122e222a67740df867c1dd53b4d51caa297084c185cab" dependencies = [ "serde", ] [[package]] name = "cargo-openvm" -version = "1.3.0-rc.1" +version = "1.4.0-rc.0" dependencies = [ "aws-config", "aws-sdk-s3", @@ -1450,8 +1774,8 @@ dependencies = [ "target-lexicon 0.12.16", "tempfile", "tokio", - "toml 0.8.20", - "toml_edit 0.22.24", + "toml 0.8.23", + "toml_edit 0.22.27", "tracing", "vergen", ] @@ -1473,7 +1797,7 @@ checksum = "2d886547e41f740c616ae73108f6eb70afe6d940c7bc697cb30f13daec073037" dependencies = [ "camino", "cargo-platform", - "semver 1.0.25", + "semver 1.0.26", "serde", "serde_json", "thiserror 1.0.69", @@ -1487,20 +1811,29 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" [[package]] name = "cc" -version = "1.2.14" +version = "1.2.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c3d1b2e905a3a7b00a6141adb0e4c0bb941d11caf55349d863942a1cc44e3c9" +checksum = "5c1599538de2394445747c8cf7935946e3cc27e9625f889d979bfb2aaf569362" dependencies = [ "jobserver", "libc", "shlex", ] +[[package]] +name = "cexpr" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766" +dependencies = [ + "nom", +] + [[package]] name = "cfg-if" -version = "1.0.0" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +checksum = "9555578bc9e57714c812a1f84e4fc5b4d21fcb063490c624de019f7464c91268" [[package]] name = "cfg_aliases" @@ -1510,15 +1843,15 @@ checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" [[package]] name = "chrono" -version = "0.4.39" +version = "0.4.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e36cc9d416881d2e24f9a963be5fb1cd90966419ac844274161d10488b3e825" +checksum = "c469d952047f47f91b68d1cba3f10d63c11d73e4636f24f08daf0278abf01c4d" dependencies = [ "android-tzdata", "iana-time-zone", "num-traits", "serde", - "windows-targets 0.52.6", + "windows-link", ] [[package]] @@ -1558,11 +1891,22 @@ dependencies = [ "inout", ] +[[package]] +name = "clang-sys" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4" +dependencies = [ + "glob", + "libc", + "libloading", +] + [[package]] name = "clap" -version = "4.5.30" +version = "4.5.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "92b7b18d71fad5313a1e320fa9897994228ce274b60faa4d694fe0ea89cd9e6d" +checksum = "be92d32e80243a54711e5d7ce823c35c41c9d929dc4ab58e1276f625841aadf9" dependencies = [ "clap_builder", "clap_derive", @@ -1570,39 +1914,117 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.30" +version = "4.5.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a35db2071778a7344791a4fb4f95308b5673d219dee3ae348b86642574ecc90c" +checksum = "707eab41e9622f9139419d573eca0900137718000c517d47da73045f54331c3d" dependencies = [ "anstream", "anstyle", "clap_lex", "strsim", + "terminal_size", ] [[package]] name = "clap_derive" -version = "4.5.28" +version = "4.5.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf4ced95c6f4a675af3da73304b9ac4ed991640c36374e4b46795c49e17cf1ed" +checksum = "ef4f52386a59ca4c860f7393bcf8abd8dfd91ecccc0f774635ff68e92eeef491" dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] name = "clap_lex" -version = "0.7.4" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b94f61472cee1439c0b966b47e3aca9ae07e45d070759512cd390ea2bebc6675" + +[[package]] +name = "cmake" +version = "0.1.54" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7caa3f9de89ddbe2c607f4101924c5abec803763ae9534e4f4d7d8f84aa81f0" +dependencies = [ + "cc", +] + +[[package]] +name = "codspeed" +version = "3.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7524e02ff6173bc143d9abc01b518711b77addb60de871bbe5686843f88fb48" +dependencies = [ + "anyhow", + "bincode", + "colored", + "glob", + "libc", + "nix", + "serde", + "serde_json", + "statrs", + "uuid", +] + +[[package]] +name = "codspeed-divan-compat" +version = "3.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "157f6307b7400d74f3e41bd429b751b53d05c138a6a0f35853055e2523440354" +dependencies = [ + "codspeed", + "codspeed-divan-compat-macros", + "codspeed-divan-compat-walltime", +] + +[[package]] +name = "codspeed-divan-compat-macros" +version = "3.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5e422ac666f5871ab86d17b0f7292696ef194138bab5b49f743d23799cd6c04" +dependencies = [ + "divan-macros", + "itertools 0.14.0", + "proc-macro-crate", + "proc-macro2", + "quote", + "syn 2.0.104", +] + +[[package]] +name = "codspeed-divan-compat-walltime" +version = "3.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f46ad14479a25103f283c0f10005961cf086d8dc42205bb44c46ac563475dca6" +checksum = "66715e496e52fe861695e2644577adc7573544a729585fba4737193a62fd5a8a" +dependencies = [ + "cfg-if", + "clap", + "codspeed", + "condtype", + "divan-macros", + "libc", + "regex-lite", +] [[package]] name = "colorchoice" -version = "1.0.3" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75" + +[[package]] +name = "colored" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b63caa9aa9397e2d9480a9b13673856c78d8ac123288526c37d7839f2a86990" +checksum = "117725a109d387c937a1533ce01b450cbde6b88abceea8473c4d7a85853cda3c" +dependencies = [ + "lazy_static", + "windows-sys 0.59.0", +] [[package]] name = "concurrent-queue" @@ -1613,6 +2035,12 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "condtype" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf0a07a401f374238ab8e2f11a104d2851bf9ce711ec69804834de8af45c7af" + [[package]] name = "const-default" version = "1.0.0" @@ -1621,9 +2049,9 @@ checksum = "0b396d1f76d455557e1218ec8066ae14bba60b4b36ecd55577ba979f5db7ecaa" [[package]] name = "const-hex" -version = "1.14.0" +version = "1.14.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4b0485bab839b018a8f1723fc5391819fea5f8f0f32288ef8a735fd096b6160c" +checksum = "83e22e0ed40b96a48d3db274f72fd365bd78f67af39b6bbd47e8a15e1c6207ff" dependencies = [ "cfg-if", "cpufeatures", @@ -1686,6 +2114,16 @@ dependencies = [ "libc", ] +[[package]] +name = "core-foundation" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2a6cd9ae233e7f62ba4e9353e81a88df7fc8a5987b8d445b4d90c879bd156f6" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "core-foundation-sys" version = "0.8.7" @@ -1703,9 +2141,9 @@ dependencies = [ [[package]] name = "crc" -version = "3.2.1" +version = "3.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69e6e4d7b33a94f0991c26729976b10ebde1d34c3ee82408fb536164fa10d636" +checksum = "9710d3b3739c2e349eb44fe848ad0b7c8cb1e42bd87ee49371df2f7acaf3e675" dependencies = [ "crc-catalog", ] @@ -1717,12 +2155,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "19d374276b40fb8bbdee95aef7c7fa6b5316ec764510eb64b8dd0e2ed0d7e7f5" [[package]] -name = "crc32c" -version = "0.6.8" +name = "crc-fast" +version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3a47af21622d091a8f0fb295b88bc886ac74efcc613efc19f5d0b21de5c89e47" +checksum = "6bf62af4cc77d8fe1c22dde4e721d87f2f54056139d8c412e1366b740305f56f" dependencies = [ - "rustc_version 0.4.1", + "crc", + "digest 0.10.7", + "libc", + "rand 0.9.1", + "regex", ] [[package]] @@ -1734,15 +2176,6 @@ dependencies = [ "cfg-if", ] -[[package]] -name = "crc64fast-nvme" -version = "1.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4955638f00a809894c947f85a024020a20815b65a5eea633798ea7924edab2b3" -dependencies = [ - "crc", -] - [[package]] name = "criterion" version = "0.5.1" @@ -1843,9 +2276,9 @@ checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" [[package]] name = "crunchy" -version = "0.2.3" +version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43da5946c66ffcc7745f48db692ffbb10a83bfe0afd96235c5c2a4fb23994929" +checksum = "460fbee9c2c2f33933d720630a6a0bac33ba7053db5344fac858d4b8952d77d5" [[package]] name = "crypto-bigint" @@ -1883,9 +2316,9 @@ dependencies = [ [[package]] name = "darling" -version = "0.20.10" +version = "0.20.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f63b86c8a8826a49b8c21f08a2d07338eec8d900540f8630dc76284be802989" +checksum = "fc7f46116c46ff9ab3eb1597a45688b6715c6e628b5c133e288e709a29bcb4ee" dependencies = [ "darling_core", "darling_macro", @@ -1893,27 +2326,42 @@ dependencies = [ [[package]] name = "darling_core" -version = "0.20.10" +version = "0.20.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95133861a8032aaea082871032f5815eb9e98cef03fa916ab4500513994df9e5" +checksum = "0d00b9596d185e565c2207a0b01f8bd1a135483d02d9b7b0a54b11da8d53412e" dependencies = [ "fnv", "ident_case", "proc-macro2", "quote", "strsim", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] name = "darling_macro" -version = "0.20.10" +version = "0.20.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d336a2a514f6ccccaa3e09b02d41d35330c07ddf03a62165fcec10bb561c7806" +checksum = "fc34b93ccb385b40dc71c6fceac4b2ad23662c7eeb248cf10d529b7e055b6ead" dependencies = [ "darling_core", "quote", - "syn 2.0.98", + "syn 2.0.104", +] + +[[package]] +name = "dashmap" +version = "6.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf" +dependencies = [ + "cfg-if", + "crossbeam-utils", + "hashbrown 0.14.5", + "lock_api", + "once_cell", + "parking_lot_core", + "rayon", ] [[package]] @@ -1928,9 +2376,9 @@ dependencies = [ [[package]] name = "der" -version = "0.7.9" +version = "0.7.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f55bf8e7b65898637379c1b74eb1551107c8294ed26d855ceb9fd1a09cfc9bc0" +checksum = "e7c1832837b905bbfb5101e07cc24c8deddf52f93225eee6ead5f4d63d53ddcb" dependencies = [ "const-oid", "pem-rfc7468", @@ -1939,9 +2387,9 @@ dependencies = [ [[package]] name = "deranged" -version = "0.3.11" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b42b6fa04a440b495c8b04d0e71b707c585f83cb9cb28cf8cd0d976c315e31b4" +checksum = "9c9e6a11ca8224451684bc0d7d5a7adbf8f2fd6887261a1cfc3c0432f9d4068e" dependencies = [ "powerfmt", "serde", @@ -1966,7 +2414,7 @@ checksum = "d150dea618e920167e5973d70ae6ece4385b7164e0d799fe7c122dd0a5d912ad" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] @@ -1977,20 +2425,31 @@ checksum = "2cdc8d50f426189eef89dac62fabfa0abb27d5cc008f25bf4156a0203325becc" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", +] + +[[package]] +name = "derive-where" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "510c292c8cf384b1a340b816a9a6cf2599eb8f566a44949024af88418000c50b" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.104", ] [[package]] name = "derive_more" -version = "0.99.19" +version = "0.99.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3da29a38df43d6f156149c9b43ded5e018ddff2a855cf2cfd62e8cd7d079c69f" +checksum = "6edb4b64a43d977b8e99788fe3a04d483834fba1215a7e02caa415b626497f7f" dependencies = [ "convert_case", "proc-macro2", "quote", "rustc_version 0.4.1", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] @@ -2019,7 +2478,7 @@ checksum = "cb7330aeadfbe296029522e6c40f315320aba36fc43a5b3632f3795348f3bd22" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", "unicode-xid", ] @@ -2031,30 +2490,30 @@ checksum = "bda628edc44c4bb645fbe0f758797143e4e07926f7ebf4e9bdfbd3d2ce621df3" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", "unicode-xid", ] [[package]] name = "diesel" -version = "2.2.10" +version = "2.2.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff3e1edb1f37b4953dd5176916347289ed43d7119cc2e6c7c3f7849ff44ea506" +checksum = "229850a212cd9b84d4f0290ad9d294afc0ae70fccaa8949dbe8b43ffafa1e20c" dependencies = [ "diesel_derives", ] [[package]] name = "diesel_derives" -version = "2.2.5" +version = "2.2.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "68d4216021b3ea446fd2047f5c8f8fe6e98af34508a254a01e4d6bc1e844f84d" +checksum = "1b96984c469425cb577bf6f17121ecb3e4fe1e81de5d8f780dd372802858d756" dependencies = [ "diesel_table_macro_syntax", "dsl_auto_type", "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] @@ -2063,7 +2522,7 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "209c735641a413bc68c4923a9d6ad4bcb3ca306b794edaa7eb0b3228a99ffb25" dependencies = [ - "syn 2.0.98", + "syn 2.0.104", ] [[package]] @@ -2081,7 +2540,7 @@ version = "0.10.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" dependencies = [ - "block-buffer", + "block-buffer 0.10.4", "const-oid", "crypto-common", "subtle", @@ -2137,7 +2596,18 @@ checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", +] + +[[package]] +name = "divan-macros" +version = "0.1.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8dc51d98e636f5e3b0759a39257458b22619cac7e96d932da6eeb052891bb67c" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.104", ] [[package]] @@ -2157,7 +2627,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] @@ -2168,9 +2638,9 @@ checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813" [[package]] name = "dyn-clone" -version = "1.0.18" +version = "1.0.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "feeef44e73baff3a26d371801df019877a9866a8c493d315ab00177843314f35" +checksum = "1c7a8fb8a9fbf66c1f703fe16184d10ca0ee9d23be5b4436400408ba54a95005" [[package]] name = "ecdsa" @@ -2190,7 +2660,7 @@ version = "0.16.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ee27f32b5c5292967d2d4a9d7f1e0b0aed2c15daded5a60300e4abb9d8020bca" dependencies = [ - "der 0.7.9", + "der 0.7.10", "digest 0.10.7", "elliptic-curve 0.13.8", "rfc6979 0.4.0", @@ -2199,11 +2669,23 @@ dependencies = [ "spki 0.7.3", ] +[[package]] +name = "educe" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d7bc049e1bd8cdeb31b68bbd586a9464ecf9f3944af3958a7a9d0f8b9799417" +dependencies = [ + "enum-ordinalize", + "proc-macro2", + "quote", + "syn 2.0.104", +] + [[package]] name = "either" -version = "1.13.0" +version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" +checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" [[package]] name = "elf" @@ -2292,6 +2774,26 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c34f04666d835ff5d62e058c3995147c06f42fe86ff053337632bca83e42702d" +[[package]] +name = "enum-ordinalize" +version = "4.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fea0dcfa4e54eeb516fe454635a95753ddd39acda650ce703031c6973e315dd5" +dependencies = [ + "enum-ordinalize-derive", +] + +[[package]] +name = "enum-ordinalize-derive" +version = "4.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d28318a75d4aead5c4db25382e8ef717932d0346600cacae6357eb5941bc5ff" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.104", +] + [[package]] name = "enum_dispatch" version = "0.3.13" @@ -2301,7 +2803,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] @@ -2312,7 +2814,7 @@ checksum = "2f9ed6b3789237c8a0c1c505af1c7eb2c560df6186f01b098c3a1064ea532f38" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] @@ -2326,9 +2828,9 @@ dependencies = [ [[package]] name = "env_logger" -version = "0.11.6" +version = "0.11.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dcaee3d8e3cfc3fd92428d477bc97fc29ec8716d180c0d74c643bb26166660e0" +checksum = "13c863f0904021b108aa8b2f55046443e6b1ebde8fd4a15c399893aae4fa069f" dependencies = [ "anstream", "anstyle", @@ -2344,12 +2846,12 @@ checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" [[package]] name = "errno" -version = "0.3.10" +version = "0.3.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33d852cb9b869c2a9b3df2f71a3074817f01e1844f839a144f5fcef059a4eb5d" +checksum = "778e2ac28f6c47af28e4907f13ffd1e1ddbd400980a9abd7c8df189bf578a5ad" dependencies = [ "libc", - "windows-sys 0.59.0", + "windows-sys 0.60.2", ] [[package]] @@ -2447,7 +2949,7 @@ dependencies = [ "chrono", "ethers-core", "reqwest", - "semver 1.0.25", + "semver 1.0.26", "serde", "serde_json", "thiserror 1.0.69", @@ -2474,10 +2976,10 @@ dependencies = [ "path-slash", "rayon", "regex", - "semver 1.0.25", + "semver 1.0.26", "serde", "serde_json", - "sha2", + "sha2 0.10.9", "solang-parser", "svm-rs", "svm-rs-builds", @@ -2592,7 +3094,7 @@ dependencies = [ "atomic", "pear", "serde", - "toml 0.8.20", + "toml 0.8.23", "uncased", "version_check", ] @@ -2617,9 +3119,9 @@ checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" [[package]] name = "flate2" -version = "1.1.0" +version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "11faaf5a5236997af9848be0bef4db95824b1d534ebc64d0f0c6cf3e67bd38dc" +checksum = "4a3d7db9596fecd151c5f638c0ee5d5bd487b6e0ea232e5dc96d5250f6f94b1d" dependencies = [ "crc32fast", "miniz_oxide", @@ -2633,9 +3135,9 @@ checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" [[package]] name = "foldhash" -version = "0.1.4" +version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a0d2fde1f7b3d48b8395d5f2de76c18a528bd6a9cdde438df747bfcba3e05d6f" +checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" [[package]] name = "forge-fmt" @@ -2682,7 +3184,7 @@ dependencies = [ "regex", "reqwest", "revm-primitives 1.3.0", - "semver 1.0.25", + "semver 1.0.26", "serde", "serde_json", "serde_regex", @@ -2703,6 +3205,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "fs_extra" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" + [[package]] name = "funty" version = "2.0.0" @@ -2750,7 +3258,7 @@ checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] @@ -2801,39 +3309,39 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.2.15" +version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" +checksum = "335ff9f135e4384c8150d6f27c6daed433577f86b4750418338c01a1a2528592" dependencies = [ "cfg-if", "js-sys", "libc", - "wasi 0.11.0+wasi-snapshot-preview1", + "wasi 0.11.1+wasi-snapshot-preview1", "wasm-bindgen", ] [[package]] name = "getrandom" -version = "0.3.1" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43a49c392881ce6d5c3b8cb70f98717b7c07aabbdff06687b9030dbfbe2725f8" +checksum = "26145e563e54f2cadc477553f1ec5ee650b00862f0a58bcd12cbdc5f0ea2d2f4" dependencies = [ "cfg-if", "libc", - "wasi 0.13.3+wasi-0.2.2", - "windows-targets 0.52.6", + "r-efi", + "wasi 0.14.2+wasi-0.2.4", ] [[package]] name = "getset" -version = "0.1.4" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eded738faa0e88d3abc9d1a13cb11adc2073c400969eeb8793cf7132589959fc" +checksum = "9cf0fc11e47561d47397154977bc219f4cf809b2974facc3ccb3b89e2436f912" dependencies = [ "proc-macro-error2", "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] @@ -2848,7 +3356,7 @@ version = "0.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b903b73e45dc0c6c596f2d37eccece7c1c8bb6e4407b001096387c63d0d93724" dependencies = [ - "bitflags 2.8.0", + "bitflags 2.9.1", "libc", "libgit2-sys", "log", @@ -2857,9 +3365,9 @@ dependencies = [ [[package]] name = "glam" -version = "0.30.0" +version = "0.30.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17fcdf9683c406c2fc4d124afd29c0d595e22210d633cbdb8695ba9935ab1dc6" +checksum = "50a99dbe56b72736564cfa4b85bf9a33079f16ae8b74983ab06af3b1a3696b11" [[package]] name = "glob" @@ -2905,9 +3413,9 @@ dependencies = [ [[package]] name = "h2" -version = "0.3.26" +version = "0.3.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81fe527a889e1532da5c525686d96d4c2e74cdd345badf8dfef9f6b39dd5f5e8" +checksum = "0beca50380b1fc32983fc1cb4587bfa4bb9e78fc259aad4a0032d2080309222d" dependencies = [ "bytes", "fnv", @@ -2915,7 +3423,26 @@ dependencies = [ "futures-sink", "futures-util", "http 0.2.12", - "indexmap 2.7.1", + "indexmap 2.10.0", + "slab", + "tokio", + "tokio-util", + "tracing", +] + +[[package]] +name = "h2" +version = "0.4.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "17da50a276f1e01e0ba6c029e47b7100754904ee8a278f886546e98575380785" +dependencies = [ + "atomic-waker", + "bytes", + "fnv", + "futures-core", + "futures-sink", + "http 1.3.1", + "indexmap 2.10.0", "slab", "tokio", "tokio-util", @@ -2924,9 +3451,9 @@ dependencies = [ [[package]] name = "half" -version = "2.4.1" +version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6dd08c532ae367adf81c312a4580bc67f1d0fe8bc9c460520283f4c0ff277888" +checksum = "459196ed295495a68f7d7fe1d84f6c4b7ff0e21fe3017b2f283c6fac3ad803c9" dependencies = [ "cfg-if", "crunchy", @@ -2943,9 +3470,9 @@ dependencies = [ [[package]] name = "halo2-axiom" -version = "0.5.0" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62f0ca78d12ac5c893f286d7cdfe3869290305ab8cac376e2592cdc8396da102" +checksum = "0aee3f8178b78275038e5ea0e2577140056d2c4c87fccaf6777dc0a8eebe455a" dependencies = [ "blake2b_simd", "crossbeam", @@ -3041,7 +3568,7 @@ dependencies = [ "rayon", "serde", "serde_arrays", - "sha2", + "sha2 0.10.9", "static_assertions", "subtle", "unroll", @@ -3069,7 +3596,7 @@ dependencies = [ "rayon", "serde", "serde_arrays", - "sha2", + "sha2 0.10.9", "static_assertions", "subtle", "unroll", @@ -3116,9 +3643,9 @@ dependencies = [ [[package]] name = "hashbrown" -version = "0.15.2" +version = "0.15.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" +checksum = "5971ac85611da7067dbfcabef3c70ebb5606018acd9e2a3903a0da507521e0d5" dependencies = [ "allocator-api2", "equivalent", @@ -3132,7 +3659,7 @@ version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7382cf6263419f2d8df38c55d7da83da5c18aef87fc7a7fc1fb1e344edfe14c1" dependencies = [ - "hashbrown 0.15.2", + "hashbrown 0.15.4", ] [[package]] @@ -3143,15 +3670,9 @@ checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" [[package]] name = "hermit-abi" -version = "0.3.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" - -[[package]] -name = "hermit-abi" -version = "0.4.0" +version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fbf6a919d6cf397374f7dfeeea91d974c7c0a7221d0d0f4f20d859d329e53fcc" +checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c" [[package]] name = "hex" @@ -3162,6 +3683,15 @@ dependencies = [ "serde", ] +[[package]] +name = "hex-conservative" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5313b072ce3c597065a808dbf612c4c8e8590bdbf8b579508bf7a762c5eae6cd" +dependencies = [ + "arrayvec", +] + [[package]] name = "hex-literal" version = "0.4.1" @@ -3214,9 +3744,9 @@ dependencies = [ [[package]] name = "http" -version = "1.2.0" +version = "1.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f16ca2af56261c99fba8bac40a10251ce8188205a4c448fbb745a2e4daa76fea" +checksum = "f4a85d31aea989eead29a3aaf9e1115a180df8282431156e533de47660892565" dependencies = [ "bytes", "fnv", @@ -3241,27 +3771,27 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" dependencies = [ "bytes", - "http 1.2.0", + "http 1.3.1", ] [[package]] name = "http-body-util" -version = "0.1.2" +version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "793429d76616a256bcb62c2a2ec2bed781c8307e797e2598c50010f2bee2544f" +checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a" dependencies = [ "bytes", - "futures-util", - "http 1.2.0", + "futures-core", + "http 1.3.1", "http-body 1.0.1", "pin-project-lite", ] [[package]] name = "httparse" -version = "1.10.0" +version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2d708df4e7140240a16cd6ab0ab65c972d7433ab77819ea693fde9c43811e2a" +checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" [[package]] name = "httpdate" @@ -3279,7 +3809,7 @@ dependencies = [ "futures-channel", "futures-core", "futures-util", - "h2", + "h2 0.3.27", "http 0.2.12", "http-body 0.4.6", "httparse", @@ -3293,6 +3823,26 @@ dependencies = [ "want", ] +[[package]] +name = "hyper" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc2b571658e38e0c01b1fdca3bbbe93c00d3d71693ff2770043f8c29bc7d6f80" +dependencies = [ + "bytes", + "futures-channel", + "futures-util", + "h2 0.4.11", + "http 1.3.1", + "http-body 1.0.1", + "httparse", + "itoa", + "pin-project-lite", + "smallvec", + "tokio", + "want", +] + [[package]] name = "hyper-rustls" version = "0.24.2" @@ -3301,24 +3851,63 @@ checksum = "ec3efd23720e2049821a693cbc7e65ea87c72f1c58ff2f9522ff332b1491e590" dependencies = [ "futures-util", "http 0.2.12", - "hyper", + "hyper 0.14.32", "log", - "rustls", - "rustls-native-certs", + "rustls 0.21.12", + "rustls-native-certs 0.6.3", + "tokio", + "tokio-rustls 0.24.1", +] + +[[package]] +name = "hyper-rustls" +version = "0.27.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3c93eb611681b207e1fe55d5a71ecf91572ec8a6705cdb6857f7d8d5242cf58" +dependencies = [ + "http 1.3.1", + "hyper 1.6.0", + "hyper-util", + "rustls 0.23.29", + "rustls-native-certs 0.8.1", + "rustls-pki-types", + "tokio", + "tokio-rustls 0.26.2", + "tower-service", +] + +[[package]] +name = "hyper-util" +version = "0.1.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f66d5bd4c6f02bf0542fad85d626775bab9258cf795a4256dcaf3161114d1df" +dependencies = [ + "bytes", + "futures-channel", + "futures-core", + "futures-util", + "http 1.3.1", + "http-body 1.0.1", + "hyper 1.6.0", + "libc", + "pin-project-lite", + "socket2", "tokio", - "tokio-rustls", + "tower-service", + "tracing", ] [[package]] name = "iana-time-zone" -version = "0.1.61" +version = "0.1.63" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "235e081f3925a06703c2d0117ea8b91f042756fd6e7a6e5d901e8ca1a996b220" +checksum = "b0c919e5debc312ad217002b8048a17b7d83f80703865bbfcfebb0458b0b27d8" dependencies = [ "android_system_properties", "core-foundation-sys", "iana-time-zone-haiku", "js-sys", + "log", "wasm-bindgen", "windows-core", ] @@ -3334,21 +3923,22 @@ dependencies = [ [[package]] name = "icu_collections" -version = "1.5.0" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db2fa452206ebee18c4b5c2274dbf1de17008e874b4dc4f0aea9d01ca79e4526" +checksum = "200072f5d0e3614556f94a9930d5dc3e0662a652823904c3a75dc3b0af7fee47" dependencies = [ "displaydoc", + "potential_utf", "yoke", "zerofrom", "zerovec", ] [[package]] -name = "icu_locid" -version = "1.5.0" +name = "icu_locale_core" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13acbb8371917fc971be86fc8057c41a64b521c184808a698c02acc242dbf637" +checksum = "0cde2700ccaed3872079a65fb1a78f6c0a36c91570f28755dda67bc8f7d9f00a" dependencies = [ "displaydoc", "litemap", @@ -3357,31 +3947,11 @@ dependencies = [ "zerovec", ] -[[package]] -name = "icu_locid_transform" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01d11ac35de8e40fdeda00d9e1e9d92525f3f9d887cdd7aa81d727596788b54e" -dependencies = [ - "displaydoc", - "icu_locid", - "icu_locid_transform_data", - "icu_provider", - "tinystr", - "zerovec", -] - -[[package]] -name = "icu_locid_transform_data" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fdc8ff3388f852bede6b579ad4e978ab004f139284d7b28715f773507b946f6e" - [[package]] name = "icu_normalizer" -version = "1.5.0" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "19ce3e0da2ec68599d193c93d088142efd7f9c5d6fc9b803774855747dc6a84f" +checksum = "436880e8e18df4d7bbc06d58432329d6458cc84531f7ac5f024e93deadb37979" dependencies = [ "displaydoc", "icu_collections", @@ -3389,67 +3959,54 @@ dependencies = [ "icu_properties", "icu_provider", "smallvec", - "utf16_iter", - "utf8_iter", - "write16", "zerovec", ] [[package]] name = "icu_normalizer_data" -version = "1.5.0" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8cafbf7aa791e9b22bec55a167906f9e1215fd475cd22adfcf660e03e989516" +checksum = "00210d6893afc98edb752b664b8890f0ef174c8adbb8d0be9710fa66fbbf72d3" [[package]] name = "icu_properties" -version = "1.5.1" +version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93d6020766cfc6302c15dbbc9c8778c37e62c14427cb7f6e601d849e092aeef5" +checksum = "016c619c1eeb94efb86809b015c58f479963de65bdb6253345c1a1276f22e32b" dependencies = [ "displaydoc", "icu_collections", - "icu_locid_transform", + "icu_locale_core", "icu_properties_data", "icu_provider", - "tinystr", + "potential_utf", + "zerotrie", "zerovec", ] [[package]] name = "icu_properties_data" -version = "1.5.0" +version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67a8effbc3dd3e4ba1afa8ad918d5684b8868b3b26500753effea8d2eed19569" +checksum = "298459143998310acd25ffe6810ed544932242d3f07083eee1084d83a71bd632" [[package]] name = "icu_provider" -version = "1.5.0" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ed421c8a8ef78d3e2dbc98a973be2f3770cb42b606e3ab18d6237c4dfde68d9" +checksum = "03c80da27b5f4187909049ee2d72f276f0d9f99a42c306bd0131ecfe04d8e5af" dependencies = [ "displaydoc", - "icu_locid", - "icu_provider_macros", + "icu_locale_core", "stable_deref_trait", "tinystr", "writeable", "yoke", "zerofrom", + "zerotrie", "zerovec", ] -[[package]] -name = "icu_provider_macros" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ec89e9337638ecdc08744df490b221a7399bf8d164eb52a665454e60e075ad6" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.98", -] - [[package]] name = "ident_case" version = "1.0.1" @@ -3469,9 +4026,9 @@ dependencies = [ [[package]] name = "idna_adapter" -version = "1.2.0" +version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "daca1df1c957320b2cf139ac61e7bd64fed304c5040df000a745aa1de3b4ef71" +checksum = "3acae9609540aa318d1bc588455225fb2085b9ed0c4f6bd0d9d5bcd86f1a0344" dependencies = [ "icu_normalizer", "icu_properties", @@ -3512,7 +4069,7 @@ checksum = "a0eb5a3343abf848c0984fe4604b2b105da9539376e24fc0a3b0007411ae4fd9" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] @@ -3553,12 +4110,12 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.7.1" +version = "2.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c9c992b02b5b4c94ea26e32fe5bccb7aa7d9f390ab5c1221ff895bc7ea8b652" +checksum = "fe4cd85333e22411419a0bcae1297d25e58c9443848b11dc6a86fefe8c78a661" dependencies = [ "equivalent", - "hashbrown 0.15.2", + "hashbrown 0.15.4", "serde", ] @@ -3577,6 +4134,17 @@ dependencies = [ "generic-array", ] +[[package]] +name = "io-uring" +version = "0.7.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b86e202f00093dcba4275d4636b93ef9dd75d025ae560d2521b45ea28ab49013" +dependencies = [ + "bitflags 2.9.1", + "cfg-if", + "libc", +] + [[package]] name = "ipnet" version = "2.11.0" @@ -3585,11 +4153,11 @@ checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130" [[package]] name = "is-terminal" -version = "0.4.15" +version = "0.4.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e19b23d53f35ce9f56aebc7d1bb4e6ac1e9c0db7ac85c8d1760c04379edced37" +checksum = "e04d7f318608d35d4b61ddd75cbdaee86b023ebe2bd5a66ee0915f0bf93095a9" dependencies = [ - "hermit-abi 0.4.0", + "hermit-abi", "libc", "windows-sys 0.59.0", ] @@ -3618,6 +4186,24 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" +dependencies = [ + "either", +] + +[[package]] +name = "itertools" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" +dependencies = [ + "either", +] + [[package]] name = "itertools" version = "0.14.0" @@ -3629,16 +4215,17 @@ dependencies = [ [[package]] name = "itoa" -version = "1.0.14" +version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d75a2a4b1b190afb6f5425f10f6a8f959d2ea0b9c2b1d79553551850539e4674" +checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" [[package]] name = "jobserver" -version = "0.1.32" +version = "0.1.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48d1dbcbbeb6a7fec7e059840aa538bd62aaccf972c7346c4d9d2059312853d0" +checksum = "38f262f097c174adebe41eb73d66ae9c06b2844fb0da69969647bbddd9b0538a" dependencies = [ + "getrandom 0.3.3", "libc", ] @@ -3690,12 +4277,13 @@ dependencies = [ "openvm-ecc-transpiler", "openvm-rv32im-circuit", "openvm-rv32im-transpiler", - "openvm-sha256-circuit", - "openvm-sha256-transpiler", + "openvm-sha2-circuit", + "openvm-sha2-transpiler", "openvm-stark-backend", "openvm-stark-sdk", "openvm-toolchain-tests", "openvm-transpiler", + "rand 0.8.5", "serde", "signature 2.2.0", ] @@ -3710,7 +4298,7 @@ dependencies = [ "ecdsa 0.16.9", "elliptic-curve 0.13.8", "once_cell", - "sha2", + "sha2 0.10.9", ] [[package]] @@ -3771,11 +4359,17 @@ dependencies = [ "spin", ] +[[package]] +name = "lazycell" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" + [[package]] name = "libc" -version = "0.2.169" +version = "0.2.174" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a" +checksum = "1171693293099992e19cddea4e8b849964e9846f4acee11b3948bcc337be8776" [[package]] name = "libgit2-sys" @@ -3789,17 +4383,27 @@ dependencies = [ "pkg-config", ] +[[package]] +name = "libloading" +version = "0.8.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07033963ba89ebaf1584d767badaa2e8fcec21aedea6b8c0346d487d49c28667" +dependencies = [ + "cfg-if", + "windows-targets 0.53.2", +] + [[package]] name = "libm" -version = "0.2.11" +version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8355be11b20d696c8f18f6cc018c4e372165b1fa8126cef092399c9951984ffa" +checksum = "f9fbbcab51052fe104eb5e5d351cf728d30a5be1fe14d9be8a3b097481fb97de" [[package]] name = "libmimalloc-sys" -version = "0.1.39" +version = "0.1.43" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23aa6811d3bd4deb8a84dde645f943476d13b248d818edcf8ce0b2f37f036b44" +checksum = "bf88cd67e9de251c1781dbe2f641a1a3ad66eaae831b8a2c38fbdc5ddae16d4d" dependencies = [ "cc", "libc", @@ -3807,19 +4411,65 @@ dependencies = [ [[package]] name = "libredox" -version = "0.1.3" +version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c0ff37bd590ca25063e35af745c343cb7a0271906fb7b37e4813e8f79f00268d" +checksum = "1580801010e535496706ba011c15f8532df6b42297d2e471fec38ceadd8c0638" dependencies = [ - "bitflags 2.8.0", + "bitflags 2.9.1", "libc", ] +[[package]] +name = "libsecp256k1" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e79019718125edc905a079a70cfa5f3820bc76139fc91d6f9abc27ea2a887139" +dependencies = [ + "arrayref", + "base64 0.22.1", + "digest 0.9.0", + "libsecp256k1-core", + "libsecp256k1-gen-ecmult", + "libsecp256k1-gen-genmult", + "rand 0.8.5", + "serde", + "sha2 0.9.9", +] + +[[package]] +name = "libsecp256k1-core" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5be9b9bb642d8522a44d533eab56c16c738301965504753b03ad1de3425d5451" +dependencies = [ + "crunchy", + "digest 0.9.0", + "subtle", +] + +[[package]] +name = "libsecp256k1-gen-ecmult" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3038c808c55c87e8a172643a7d87187fc6c4174468159cb3090659d55bcb4809" +dependencies = [ + "libsecp256k1-core", +] + +[[package]] +name = "libsecp256k1-gen-genmult" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3db8d6ba2cec9eacc40e6e8ccc98931840301f1006e95647ceb2dd5c3aa06f7c" +dependencies = [ + "libsecp256k1-core", +] + [[package]] name = "libz-sys" -version = "1.1.21" +version = "1.1.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df9b68e50e6e0b26f672573834882eb57759f6db9b3be2ea3c35c91188bb4eaa" +checksum = "8b70e7a7df205e92a1a4cd9aaae7898dac0aa555503cc0a649494d0d60e7651d" dependencies = [ "cc", "libc", @@ -3839,17 +4489,23 @@ version = "0.4.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab" +[[package]] +name = "linux-raw-sys" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd945864f07fe9f5371a27ad7b52a172b4b499999f1d97574c9fa68373937e12" + [[package]] name = "litemap" -version = "0.7.4" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ee93343901ab17bd981295f2cf0026d4ad018c7c31ba84549a4ddbb47a45104" +checksum = "241eaef5fd12c88705a01fc1066c48c4b36e0dd4377dcdc7ec3942cea7a69956" [[package]] name = "lock_api" -version = "0.4.12" +version = "0.4.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17" +checksum = "96936507f153605bddfcda068dd804796c84324ed2510809e5b2a624c81da765" dependencies = [ "autocfg", "scopeguard", @@ -3863,9 +4519,9 @@ checksum = "9374ef4228402d4b7e403e5838cb880d9ee663314b0a900d5a6aabf0c213552e" [[package]] name = "log" -version = "0.4.25" +version = "0.4.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04cbf5b083de1c7e0222a7a51dbfdba1cbe1c6ab0b15e29fff3f6c077fd9cd9f" +checksum = "13dc2df351e3202783a1fe0d44375f7295ffb4049267b0f3018346dc122a1d94" [[package]] name = "lru" @@ -3873,7 +4529,7 @@ version = "0.12.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "234cf4f4a04dc1f57e24b96cc0cd600cf2af460d4161ac5ecdd0af8e1f3b2a38" dependencies = [ - "hashbrown 0.15.2", + "hashbrown 0.15.4", ] [[package]] @@ -3884,7 +4540,7 @@ checksum = "1b27834086c65ec3f9387b096d66e99f221cf081c2b738042aa252bcd41204e3" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] @@ -3896,6 +4552,16 @@ dependencies = [ "regex-automata 0.1.10", ] +[[package]] +name = "matrixmultiply" +version = "0.3.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a06de3016e9fae57a36fd14dba131fccf49f74b40b7fbdb472f96e361ec71a08" +dependencies = [ + "autocfg", + "rawpointer", +] + [[package]] name = "maybe-rayon" version = "0.1.1" @@ -3918,9 +4584,9 @@ dependencies = [ [[package]] name = "memchr" -version = "2.7.4" +version = "2.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" +checksum = "32a282da65faaf38286cf3be983213fcf1d2e2a58700e808f83f4ea9a4804bc0" [[package]] name = "memmap2" @@ -3948,9 +4614,9 @@ checksum = "3d97bbf43eb4f088f8ca469930cde17fa036207c9a5e02ccc5107c4e8b17c964" [[package]] name = "metrics" -version = "0.23.0" +version = "0.23.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "884adb57038347dfbaf2d5065887b6cf4312330dc8e94bc30a1a839bd79d3261" +checksum = "3045b4193fbdc5b5681f32f11070da9be3609f189a79f3390706d42587f46bb5" dependencies = [ "ahash", "portable-atomic", @@ -3962,7 +4628,7 @@ version = "0.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "62a6a1f7141f1d9bc7a886b87536bbfc97752e08b369e1e0453a9acfab5f5da4" dependencies = [ - "indexmap 2.7.1", + "indexmap 2.10.0", "itoa", "lockfree-object-pool", "metrics", @@ -3970,7 +4636,7 @@ dependencies = [ "once_cell", "tracing", "tracing-core", - "tracing-subscriber", + "tracing-subscriber 0.3.19", ] [[package]] @@ -3983,7 +4649,7 @@ dependencies = [ "crossbeam-epoch", "crossbeam-utils", "hashbrown 0.14.5", - "indexmap 2.7.1", + "indexmap 2.10.0", "metrics", "num_cpus", "ordered-float", @@ -3994,9 +4660,9 @@ dependencies = [ [[package]] name = "mimalloc" -version = "0.1.43" +version = "0.1.47" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "68914350ae34959d83f732418d51e2427a794055d0b9529f48259ac07af65633" +checksum = "b1791cbe101e95af5764f06f20f6760521f7158f69dbf9d6baf941ee1bf6bc40" dependencies = [ "libmimalloc-sys", ] @@ -4007,24 +4673,45 @@ version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" +[[package]] +name = "minimal-lexical" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" + [[package]] name = "miniz_oxide" -version = "0.8.4" +version = "0.8.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3b1c9bd4fe1f0f8b387f6eb9eb3b4a1aa26185e5750efb9140301703f62cd1b" +checksum = "1fa76a2c86f704bdb222d66965fb3d63269ce38518b83cb0575fca855ebb6316" dependencies = [ "adler2", ] [[package]] name = "mio" -version = "1.0.3" +version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2886843bf800fba2e3377cff24abf6379b4c4d5c6681eaf9ea5b0d15090450bd" +checksum = "78bed444cc8a2160f01cbcf811ef18cac863ad68ae8ca62092e8db51d51c761c" dependencies = [ "libc", - "wasi 0.11.0+wasi-snapshot-preview1", - "windows-sys 0.52.0", + "wasi 0.11.1+wasi-snapshot-preview1", + "windows-sys 0.59.0", +] + +[[package]] +name = "ndarray" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "882ed72dce9365842bf196bdeedf5055305f11fc8c03dee7bb0194a6cad34841" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "portable-atomic", + "portable-atomic-util", + "rawpointer", ] [[package]] @@ -4042,6 +4729,28 @@ dependencies = [ "smallvec", ] +[[package]] +name = "nix" +version = "0.29.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "71e2746dc3a24dd78b3cfcb7be93368c6de9963d30f43a6a73998a9cf4b17b46" +dependencies = [ + "bitflags 2.9.1", + "cfg-if", + "cfg_aliases", + "libc", +] + +[[package]] +name = "nom" +version = "7.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" +dependencies = [ + "memchr", + "minimal-lexical", +] + [[package]] name = "nu-ansi-term" version = "0.46.0" @@ -4184,33 +4893,34 @@ dependencies = [ [[package]] name = "num_cpus" -version = "1.16.0" +version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" +checksum = "91df4bbde75afed763b708b7eee1e8e7651e02d97f6d5dd763e89367e957b23b" dependencies = [ - "hermit-abi 0.3.9", + "hermit-abi", "libc", ] [[package]] name = "num_enum" -version = "0.7.3" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e613fc340b2220f734a8595782c551f1250e969d87d3be1ae0579e8d4065179" +checksum = "a973b4e44ce6cad84ce69d797acf9a044532e4184c4f267913d1b546a0727b7a" dependencies = [ "num_enum_derive", + "rustversion", ] [[package]] name = "num_enum_derive" -version = "0.7.3" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af1844ef2428cc3e1cb900be36181049ef3d3193c63e43026cfe202983b27a56" +checksum = "77e878c846a8abae00dd069496dbe8751b16ac1c3d6bd2a7283a938e8228f90d" dependencies = [ "proc-macro-crate", "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] @@ -4251,19 +4961,31 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.20.3" +version = "1.21.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "945462a4b81e43c4e3ba96bd7b49d834c6f61198356aa858733bc4acf3cbe62e" +checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" dependencies = [ "critical-section", "portable-atomic", ] +[[package]] +name = "once_cell_polyfill" +version = "1.70.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4895175b425cb1f87721b59f0f286c2092bd4af812243672510e1ac53e2e0ad" + [[package]] name = "oorandom" -version = "11.1.4" +version = "11.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" + +[[package]] +name = "opaque-debug" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b410bbe7e14ab526a0e86877eb47c6996a2bd7746f027ba551028c925390e4e9" +checksum = "c08d65885ee38876c4f86fa503fb49d7b507c2b62552df7c70b2fce627e06381" [[package]] name = "open-fastrlp" @@ -4298,12 +5020,12 @@ checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e" [[package]] name = "openvm" -version = "1.3.0-rc.1" +version = "1.4.0-rc.0" dependencies = [ "bytemuck", "chrono", - "getrandom 0.2.15", - "getrandom 0.3.1", + "getrandom 0.2.16", + "getrandom 0.3.3", "num-bigint 0.4.6", "openvm-custom-insn", "openvm-platform", @@ -4313,7 +5035,7 @@ dependencies = [ [[package]] name = "openvm-algebra-circuit" -version = "1.3.0-rc.1" +version = "1.4.0-rc.0" dependencies = [ "derive-new 0.6.0", "derive_more 1.0.0", @@ -4336,23 +5058,23 @@ dependencies = [ "openvm-stark-sdk", "rand 0.8.5", "serde", - "serde-big-array", "serde_with", "strum", + "test-case", ] [[package]] name = "openvm-algebra-complex-macros" -version = "1.3.0-rc.1" +version = "1.4.0-rc.0" dependencies = [ "openvm-macros-common", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] name = "openvm-algebra-guest" -version = "1.3.0-rc.1" +version = "1.4.0-rc.0" dependencies = [ "halo2curves-axiom", "num-bigint 0.4.6", @@ -4367,18 +5089,18 @@ dependencies = [ [[package]] name = "openvm-algebra-moduli-macros" -version = "1.3.0-rc.1" +version = "1.4.0-rc.0" dependencies = [ "num-bigint 0.4.6", "num-prime", "openvm-macros-common", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] name = "openvm-algebra-tests" -version = "1.3.0-rc.1" +version = "1.4.0-rc.0" dependencies = [ "eyre", "num-bigint 0.4.6", @@ -4395,7 +5117,7 @@ dependencies = [ [[package]] name = "openvm-algebra-transpiler" -version = "1.3.0-rc.1" +version = "1.4.0-rc.0" dependencies = [ "openvm-algebra-guest", "openvm-instructions", @@ -4408,29 +5130,37 @@ dependencies = [ [[package]] name = "openvm-benchmarks-execute" -version = "1.3.0-rc.1" +version = "1.4.0-rc.0" dependencies = [ - "cargo-openvm", - "clap", - "criterion", + "codspeed-divan-compat", "derive_more 1.0.0", "eyre", + "openvm-algebra-circuit", + "openvm-algebra-transpiler", "openvm-benchmarks-utils", + "openvm-bigint-circuit", + "openvm-bigint-transpiler", "openvm-circuit", + "openvm-ecc-circuit", + "openvm-ecc-transpiler", "openvm-keccak256-circuit", "openvm-keccak256-transpiler", + "openvm-pairing-circuit", + "openvm-pairing-guest", + "openvm-pairing-transpiler", "openvm-rv32im-circuit", "openvm-rv32im-transpiler", - "openvm-sdk", + "openvm-sha2-circuit", + "openvm-sha2-transpiler", "openvm-stark-sdk", "openvm-transpiler", - "tracing", - "tracing-subscriber", + "rand 0.8.5", + "serde", ] [[package]] name = "openvm-benchmarks-prove" -version = "1.3.0-rc.1" +version = "1.4.0-rc.0" dependencies = [ "clap", "derive-new 0.6.0", @@ -4442,6 +5172,7 @@ dependencies = [ "openvm-algebra-transpiler", "openvm-benchmarks-utils", "openvm-circuit", + "openvm-continuations", "openvm-ecc-circuit", "openvm-ecc-transpiler", "openvm-keccak256-circuit", @@ -4457,6 +5188,7 @@ dependencies = [ "openvm-stark-backend", "openvm-stark-sdk", "openvm-transpiler", + "rand 0.8.5", "rand_chacha 0.3.1", "serde", "tiny-keccak", @@ -4466,7 +5198,7 @@ dependencies = [ [[package]] name = "openvm-benchmarks-utils" -version = "1.3.0-rc.1" +version = "1.4.0-rc.0" dependencies = [ "cargo_metadata", "clap", @@ -4475,13 +5207,14 @@ dependencies = [ "openvm-transpiler", "tempfile", "tracing", - "tracing-subscriber", + "tracing-subscriber 0.3.19", ] [[package]] name = "openvm-bigint-circuit" -version = "1.3.0-rc.1" +version = "1.4.0-rc.0" dependencies = [ + "alloy-primitives 1.2.1", "derive-new 0.6.0", "derive_more 1.0.0", "openvm-bigint-transpiler", @@ -4497,11 +5230,12 @@ dependencies = [ "openvm-stark-sdk", "rand 0.8.5", "serde", + "test-case", ] [[package]] name = "openvm-bigint-guest" -version = "1.3.0-rc.1" +version = "1.4.0-rc.0" dependencies = [ "openvm-platform", "strum_macros", @@ -4509,7 +5243,7 @@ dependencies = [ [[package]] name = "openvm-bigint-transpiler" -version = "1.3.0-rc.1" +version = "1.4.0-rc.0" dependencies = [ "openvm-bigint-guest", "openvm-instructions", @@ -4523,7 +5257,7 @@ dependencies = [ [[package]] name = "openvm-build" -version = "1.3.0-rc.1" +version = "1.4.0-rc.0" dependencies = [ "cargo_metadata", "eyre", @@ -4534,10 +5268,11 @@ dependencies = [ [[package]] name = "openvm-circuit" -version = "1.3.0-rc.1" +version = "1.4.0-rc.0" dependencies = [ "backtrace", "cfg-if", + "dashmap", "derivative", "derive-new 0.6.0", "derive_more 1.0.0", @@ -4545,6 +5280,7 @@ dependencies = [ "eyre", "getset", "itertools 0.14.0", + "memmap2", "metrics", "openvm-circuit", "openvm-circuit-derive", @@ -4570,16 +5306,16 @@ dependencies = [ [[package]] name = "openvm-circuit-derive" -version = "1.3.0-rc.1" +version = "1.4.0-rc.0" dependencies = [ "itertools 0.14.0", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] name = "openvm-circuit-primitives" -version = "1.3.0-rc.1" +version = "1.4.0-rc.0" dependencies = [ "derive-new 0.6.0", "itertools 0.14.0", @@ -4595,16 +5331,18 @@ dependencies = [ [[package]] name = "openvm-circuit-primitives-derive" -version = "1.3.0-rc.1" +version = "1.4.0-rc.0" dependencies = [ "itertools 0.14.0", + "ndarray", + "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] name = "openvm-continuations" -version = "1.3.0-rc.1" +version = "1.4.0-rc.0" dependencies = [ "derivative", "openvm-circuit", @@ -4622,12 +5360,12 @@ version = "0.1.0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] name = "openvm-ecc-circuit" -version = "1.3.0-rc.1" +version = "1.4.0-rc.0" dependencies = [ "derive-new 0.6.0", "derive_more 1.0.0", @@ -4648,6 +5386,7 @@ dependencies = [ "openvm-rv32im-circuit", "openvm-stark-backend", "openvm-stark-sdk", + "rand 0.8.5", "serde", "serde_with", "strum", @@ -4655,7 +5394,7 @@ dependencies = [ [[package]] name = "openvm-ecc-guest" -version = "1.3.0-rc.1" +version = "1.4.0-rc.0" dependencies = [ "ecdsa 0.16.9", "elliptic-curve 0.13.8", @@ -4673,7 +5412,7 @@ dependencies = [ [[package]] name = "openvm-ecc-integration-tests" -version = "1.3.0-rc.1" +version = "1.4.0-rc.0" dependencies = [ "eyre", "halo2curves-axiom", @@ -4690,21 +5429,21 @@ dependencies = [ "openvm-transpiler", "serde", "serde_with", - "toml 0.8.20", + "toml 0.8.23", ] [[package]] name = "openvm-ecc-sw-macros" -version = "1.3.0-rc.1" +version = "1.4.0-rc.0" dependencies = [ "openvm-macros-common", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] name = "openvm-ecc-transpiler" -version = "1.3.0-rc.1" +version = "1.4.0-rc.0" dependencies = [ "openvm-ecc-guest", "openvm-instructions", @@ -4717,7 +5456,7 @@ dependencies = [ [[package]] name = "openvm-ff-derive" -version = "1.3.0-rc.1" +version = "1.4.0-rc.0" dependencies = [ "addchain", "eyre", @@ -4740,7 +5479,7 @@ dependencies = [ [[package]] name = "openvm-instructions" -version = "1.3.0-rc.1" +version = "1.4.0-rc.0" dependencies = [ "backtrace", "bitcode", @@ -4760,18 +5499,18 @@ dependencies = [ [[package]] name = "openvm-instructions-derive" -version = "1.3.0-rc.1" +version = "1.4.0-rc.0" dependencies = [ "openvm-instructions", "quote", "strum", "strum_macros", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] name = "openvm-keccak256" -version = "1.3.0-rc.1" +version = "1.4.0-rc.0" dependencies = [ "eyre", "openvm-circuit", @@ -4788,7 +5527,7 @@ dependencies = [ [[package]] name = "openvm-keccak256-circuit" -version = "1.3.0-rc.1" +version = "1.4.0-rc.0" dependencies = [ "derive-new 0.6.0", "derive_more 1.0.0", @@ -4806,22 +5545,20 @@ dependencies = [ "p3-keccak-air", "rand 0.8.5", "serde", - "serde-big-array", "strum", "tiny-keccak", - "tracing", ] [[package]] name = "openvm-keccak256-guest" -version = "1.3.0-rc.1" +version = "1.4.0-rc.0" dependencies = [ "openvm-platform", ] [[package]] name = "openvm-keccak256-transpiler" -version = "1.3.0-rc.1" +version = "1.4.0-rc.0" dependencies = [ "openvm-instructions", "openvm-instructions-derive", @@ -4834,14 +5571,14 @@ dependencies = [ [[package]] name = "openvm-macros-common" -version = "1.3.0-rc.1" +version = "1.4.0-rc.0" dependencies = [ - "syn 2.0.98", + "syn 2.0.104", ] [[package]] name = "openvm-mod-circuit-builder" -version = "1.3.0-rc.1" +version = "1.4.0-rc.0" dependencies = [ "halo2curves-axiom", "itertools 0.14.0", @@ -4854,14 +5591,12 @@ dependencies = [ "openvm-stark-backend", "openvm-stark-sdk", "rand 0.8.5", - "serde", - "serde_with", "tracing", ] [[package]] name = "openvm-native-circuit" -version = "1.3.0-rc.1" +version = "1.4.0-rc.0" dependencies = [ "derive-new 0.6.0", "derive_more 1.0.0", @@ -4875,19 +5610,19 @@ dependencies = [ "openvm-native-compiler", "openvm-poseidon2-air", "openvm-rv32im-circuit", + "openvm-rv32im-transpiler", "openvm-stark-backend", "openvm-stark-sdk", "rand 0.8.5", "serde", - "serde-big-array", "static_assertions", "strum", - "tracing", + "test-case", ] [[package]] name = "openvm-native-compiler" -version = "1.3.0-rc.1" +version = "1.4.0-rc.0" dependencies = [ "backtrace", "itertools 0.14.0", @@ -4913,15 +5648,15 @@ dependencies = [ [[package]] name = "openvm-native-compiler-derive" -version = "1.3.0-rc.1" +version = "1.4.0-rc.0" dependencies = [ "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] name = "openvm-native-recursion" -version = "1.3.0-rc.1" +version = "1.4.0-rc.0" dependencies = [ "bitcode", "cfg-if", @@ -4951,7 +5686,7 @@ dependencies = [ [[package]] name = "openvm-native-transpiler" -version = "1.3.0-rc.1" +version = "1.4.0-rc.0" dependencies = [ "openvm-instructions", "openvm-transpiler", @@ -4960,7 +5695,7 @@ dependencies = [ [[package]] name = "openvm-pairing" -version = "1.3.0-rc.1" +version = "1.4.0-rc.0" dependencies = [ "eyre", "group 0.13.0", @@ -4997,13 +5732,12 @@ dependencies = [ [[package]] name = "openvm-pairing-circuit" -version = "1.3.0-rc.1" +version = "1.4.0-rc.0" dependencies = [ "derive-new 0.6.0", "derive_more 1.0.0", "eyre", "halo2curves-axiom", - "itertools 0.14.0", "num-bigint 0.4.6", "num-traits", "openvm-algebra-circuit", @@ -5017,7 +5751,6 @@ dependencies = [ "openvm-mod-circuit-builder", "openvm-pairing-guest", "openvm-pairing-transpiler", - "openvm-rv32-adapters", "openvm-rv32im-circuit", "openvm-stark-backend", "openvm-stark-sdk", @@ -5028,7 +5761,7 @@ dependencies = [ [[package]] name = "openvm-pairing-guest" -version = "1.3.0-rc.1" +version = "1.4.0-rc.0" dependencies = [ "halo2curves-axiom", "hex-literal 0.4.1", @@ -5049,10 +5782,9 @@ dependencies = [ [[package]] name = "openvm-pairing-transpiler" -version = "1.3.0-rc.1" +version = "1.4.0-rc.0" dependencies = [ "openvm-instructions", - "openvm-instructions-derive", "openvm-pairing-guest", "openvm-stark-backend", "openvm-transpiler", @@ -5062,7 +5794,7 @@ dependencies = [ [[package]] name = "openvm-platform" -version = "1.3.0-rc.1" +version = "1.4.0-rc.0" dependencies = [ "critical-section", "embedded-alloc", @@ -5073,7 +5805,7 @@ dependencies = [ [[package]] name = "openvm-poseidon2-air" -version = "1.3.0-rc.1" +version = "1.4.0-rc.0" dependencies = [ "derivative", "lazy_static", @@ -5089,7 +5821,7 @@ dependencies = [ [[package]] name = "openvm-prof" -version = "1.3.0-rc.1" +version = "1.4.0-rc.0" dependencies = [ "clap", "eyre", @@ -5102,7 +5834,7 @@ dependencies = [ [[package]] name = "openvm-rv32-adapters" -version = "1.3.0-rc.1" +version = "1.4.0-rc.0" dependencies = [ "derive-new 0.6.0", "itertools 0.14.0", @@ -5114,14 +5846,11 @@ dependencies = [ "openvm-stark-backend", "openvm-stark-sdk", "rand 0.8.5", - "serde", - "serde-big-array", - "serde_with", ] [[package]] name = "openvm-rv32im-circuit" -version = "1.3.0-rc.1" +version = "1.4.0-rc.0" dependencies = [ "derive-new 0.6.0", "derive_more 1.0.0", @@ -5138,13 +5867,13 @@ dependencies = [ "openvm-stark-sdk", "rand 0.8.5", "serde", - "serde-big-array", "strum", + "test-case", ] [[package]] name = "openvm-rv32im-guest" -version = "1.3.0-rc.1" +version = "1.4.0-rc.0" dependencies = [ "openvm-custom-insn", "p3-field", @@ -5153,7 +5882,7 @@ dependencies = [ [[package]] name = "openvm-rv32im-integration-tests" -version = "1.3.0-rc.1" +version = "1.4.0-rc.0" dependencies = [ "eyre", "openvm", @@ -5171,7 +5900,7 @@ dependencies = [ [[package]] name = "openvm-rv32im-transpiler" -version = "1.3.0-rc.1" +version = "1.4.0-rc.0" dependencies = [ "openvm-instructions", "openvm-instructions-derive", @@ -5186,7 +5915,7 @@ dependencies = [ [[package]] name = "openvm-sdk" -version = "1.3.0-rc.1" +version = "1.4.0-rc.0" dependencies = [ "alloy-sol-types", "async-trait", @@ -5222,12 +5951,13 @@ dependencies = [ "openvm-pairing-transpiler", "openvm-rv32im-circuit", "openvm-rv32im-transpiler", - "openvm-sha256-circuit", - "openvm-sha256-transpiler", + "openvm-sha2-circuit", + "openvm-sha2-transpiler", "openvm-stark-backend", "openvm-stark-sdk", "openvm-transpiler", "p3-fri", + "rand 0.8.5", "rrs-lib", "serde", "serde_json", @@ -5241,69 +5971,72 @@ dependencies = [ [[package]] name = "openvm-sha2" -version = "1.3.0-rc.1" +version = "1.4.0-rc.0" dependencies = [ "eyre", "openvm-circuit", "openvm-instructions", "openvm-rv32im-transpiler", - "openvm-sha256-circuit", - "openvm-sha256-guest", - "openvm-sha256-transpiler", + "openvm-sha2-circuit", + "openvm-sha2-guest", + "openvm-sha2-transpiler", "openvm-stark-sdk", "openvm-toolchain-tests", "openvm-transpiler", - "sha2", + "sha2 0.10.9", ] [[package]] -name = "openvm-sha256-air" -version = "1.3.0-rc.1" +name = "openvm-sha2-air" +version = "1.4.0-rc.0" dependencies = [ + "ndarray", + "num_enum", "openvm-circuit", "openvm-circuit-primitives", + "openvm-circuit-primitives-derive", "openvm-stark-backend", "openvm-stark-sdk", "rand 0.8.5", - "sha2", + "sha2 0.10.9", ] [[package]] -name = "openvm-sha256-circuit" -version = "1.3.0-rc.1" +name = "openvm-sha2-circuit" +version = "1.4.0-rc.0" dependencies = [ "derive-new 0.6.0", "derive_more 1.0.0", + "ndarray", "openvm-circuit", "openvm-circuit-derive", "openvm-circuit-primitives", "openvm-circuit-primitives-derive", "openvm-instructions", "openvm-rv32im-circuit", - "openvm-sha256-air", - "openvm-sha256-transpiler", + "openvm-sha2-air", + "openvm-sha2-transpiler", "openvm-stark-backend", "openvm-stark-sdk", "rand 0.8.5", "serde", - "sha2", - "strum", + "sha2 0.10.9", ] [[package]] -name = "openvm-sha256-guest" -version = "1.3.0-rc.1" +name = "openvm-sha2-guest" +version = "1.4.0-rc.0" dependencies = [ "openvm-platform", ] [[package]] -name = "openvm-sha256-transpiler" -version = "1.3.0-rc.1" +name = "openvm-sha2-transpiler" +version = "1.4.0-rc.0" dependencies = [ "openvm-instructions", "openvm-instructions-derive", - "openvm-sha256-guest", + "openvm-sha2-guest", "openvm-stark-backend", "openvm-transpiler", "rrs-lib", @@ -5312,8 +6045,8 @@ dependencies = [ [[package]] name = "openvm-stark-backend" -version = "1.1.1" -source = "git+https://github.com/openvm-org/stark-backend.git?tag=v1.1.1#0879de162658b797b8dd6b6ee4429cbb8dd78ba1" +version = "1.1.2" +source = "git+https://github.com/openvm-org/stark-backend.git?tag=v1.1.2#b0bec8739d249370f91862f99c2ecc2c03d33240" dependencies = [ "bitcode", "cfg-if", @@ -5340,11 +6073,12 @@ dependencies = [ [[package]] name = "openvm-stark-sdk" -version = "1.1.1" -source = "git+https://github.com/openvm-org/stark-backend.git?tag=v1.1.1#0879de162658b797b8dd6b6ee4429cbb8dd78ba1" +version = "1.1.2" +source = "git+https://github.com/openvm-org/stark-backend.git?tag=v1.1.2#b0bec8739d249370f91862f99c2ecc2c03d33240" dependencies = [ + "dashmap", "derivative", - "derive_more 0.99.19", + "derive_more 0.99.20", "ff 0.13.1", "itertools 0.14.0", "metrics", @@ -5367,16 +6101,16 @@ dependencies = [ "serde", "serde_json", "static_assertions", - "toml 0.8.20", + "toml 0.8.23", "tracing", "tracing-forest", - "tracing-subscriber", + "tracing-subscriber 0.3.19", "zkhash", ] [[package]] name = "openvm-toolchain-tests" -version = "1.3.0-rc.1" +version = "1.4.0-rc.0" dependencies = [ "derive_more 1.0.0", "eyre", @@ -5394,6 +6128,7 @@ dependencies = [ "openvm-stark-backend", "openvm-stark-sdk", "openvm-transpiler", + "rand 0.8.5", "serde", "tempfile", "test-case", @@ -5401,7 +6136,7 @@ dependencies = [ [[package]] name = "openvm-transpiler" -version = "1.3.0-rc.1" +version = "1.4.0-rc.0" dependencies = [ "elf", "eyre", @@ -5415,7 +6150,7 @@ dependencies = [ [[package]] name = "openvm-verify-stark" -version = "1.3.0-rc.1" +version = "1.4.0-rc.0" dependencies = [ "eyre", "openvm-circuit", @@ -5462,7 +6197,7 @@ checksum = "51f44edd08f51e2ade572f141051021c5af22677e42b7dd28a88155151c33594" dependencies = [ "ecdsa 0.14.8", "elliptic-curve 0.12.3", - "sha2", + "sha2 0.10.9", ] [[package]] @@ -5488,15 +6223,28 @@ dependencies = [ "openvm-ecc-transpiler", "openvm-rv32im-circuit", "openvm-rv32im-transpiler", - "openvm-sha256-circuit", - "openvm-sha256-transpiler", + "openvm-sha2-circuit", + "openvm-sha2-transpiler", "openvm-stark-backend", "openvm-stark-sdk", "openvm-toolchain-tests", "openvm-transpiler", + "rand 0.8.5", "serde", ] +[[package]] +name = "p256" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c9863ad85fa8f4460f9c48cb909d38a0d689dba1f6f6988a5e3e0d31071bcd4b" +dependencies = [ + "ecdsa 0.16.9", + "elliptic-curve 0.13.8", + "primeorder", + "sha2 0.10.9", +] + [[package]] name = "p3-air" version = "0.1.0" @@ -5858,9 +6606,9 @@ dependencies = [ [[package]] name = "parity-scale-codec" -version = "3.7.4" +version = "3.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c9fde3d0718baf5bc92f577d652001da0f8d54cd03a7974e118d04fc888dc23d" +checksum = "799781ae679d79a948e13d4824a40970bfa500058d245760dd857301059810fa" dependencies = [ "arrayvec", "bitvec", @@ -5874,14 +6622,14 @@ dependencies = [ [[package]] name = "parity-scale-codec-derive" -version = "3.7.4" +version = "3.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "581c837bb6b9541ce7faa9377c20616e4fb7650f6b0f68bc93c827ee504fb7b3" +checksum = "34b4653168b563151153c9e4c08ebed57fb8262bebfa79711552fa983c623e7a" dependencies = [ "proc-macro-crate", "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] @@ -5892,9 +6640,9 @@ checksum = "f38d5652c16fde515bb1ecef450ab0f6a219d619a7274976324d5e377f7dceba" [[package]] name = "parking_lot" -version = "0.12.3" +version = "0.12.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27" +checksum = "70d58bf43669b5795d1576d0641cfb6fbb2057bf629506267a92807158584a13" dependencies = [ "lock_api", "parking_lot_core", @@ -5902,9 +6650,9 @@ dependencies = [ [[package]] name = "parking_lot_core" -version = "0.9.10" +version = "0.9.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8" +checksum = "bc838d2a56b5b1a6c25f55575dfc605fabb63bb2365f6c2353ef9159aa69e4a5" dependencies = [ "cfg-if", "libc", @@ -5975,7 +6723,7 @@ dependencies = [ "digest 0.10.7", "hmac", "password-hash", - "sha2", + "sha2 0.10.9", ] [[package]] @@ -5998,7 +6746,7 @@ dependencies = [ "proc-macro2", "proc-macro2-diagnostics", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] @@ -6018,12 +6766,12 @@ checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" [[package]] name = "pest" -version = "2.7.15" +version = "2.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b7cafe60d6cf8e62e1b9b2ea516a089c008945bb5a275416789e7db0bc199dc" +checksum = "1db05f56d34358a8b1066f67cbb203ee3e7ed2ba674a6263a1d5ec6db2204323" dependencies = [ "memchr", - "thiserror 2.0.11", + "thiserror 2.0.12", "ucd-trie", ] @@ -6034,7 +6782,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b4c5cc86750666a3ed20bdaf5ca2a0344f9c67674cae0515bec2da16fbaa47db" dependencies = [ "fixedbitset", - "indexmap 2.7.1", + "indexmap 2.10.0", ] [[package]] @@ -6067,7 +6815,7 @@ dependencies = [ "phf_shared", "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] @@ -6107,15 +6855,15 @@ version = "0.10.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f950b2377845cebe5cf8b5165cb3cc1a5e0fa5cfa3e1f7f55707d8fd82e0a7b7" dependencies = [ - "der 0.7.9", + "der 0.7.10", "spki 0.7.3", ] [[package]] name = "pkg-config" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "953ec861398dccce10c670dfeaf3ec4911ca479e9c02154b3a215178c5f566f2" +checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" [[package]] name = "plotters" @@ -6147,9 +6895,18 @@ dependencies = [ [[package]] name = "portable-atomic" -version = "1.10.0" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f84267b20a16ea918e43c6a88433c2d54fa145c92a811b5b047ccbe153674483" + +[[package]] +name = "portable-atomic-util" +version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "280dc24453071f1b63954171985a0b0d30058d287960968b9b2aca264c8d4ee6" +checksum = "d8a2f0d8d040d7848a709caf78912debcc3f33ee4b3cac47d73d1e1069e83507" +dependencies = [ + "portable-atomic", +] [[package]] name = "poseidon-primitives" @@ -6162,7 +6919,7 @@ dependencies = [ "lazy_static", "log", "rand 0.8.5", - "rand_xorshift", + "rand_xorshift 0.3.0", "thiserror 1.0.69", ] @@ -6194,7 +6951,7 @@ dependencies = [ "md-5", "memchr", "rand 0.9.1", - "sha2", + "sha2 0.10.9", "stringprep", ] @@ -6209,6 +6966,15 @@ dependencies = [ "postgres-protocol", ] +[[package]] +name = "potential_utf" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5a7c30837279ca13e7c867e9e40053bc68740f988cb07f7ca6df43cc734b585" +dependencies = [ + "zerovec", +] + [[package]] name = "powerfmt" version = "0.2.0" @@ -6217,9 +6983,9 @@ checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" [[package]] name = "ppv-lite86" -version = "0.2.20" +version = "0.2.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04" +checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9" dependencies = [ "zerocopy", ] @@ -6232,12 +6998,21 @@ checksum = "925383efa346730478fb4838dbe9137d2a47675ad789c546d150a6e1dd4ab31c" [[package]] name = "prettyplease" -version = "0.2.29" +version = "0.2.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6924ced06e1f7dfe3fa48d57b9f74f55d8915f5036121bef647ef4b204895fac" +checksum = "061c1221631e079b26479d25bbf2275bfe5917ae8419cd7e34f13bfc2aa7539a" dependencies = [ "proc-macro2", - "syn 2.0.98", + "syn 2.0.104", +] + +[[package]] +name = "primeorder" +version = "0.13.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "353e1ca18966c16d9deb1c69278edbc5f194139612772bd9537af60ac231e1e6" +dependencies = [ + "elliptic-curve 0.13.8", ] [[package]] @@ -6256,11 +7031,11 @@ dependencies = [ [[package]] name = "proc-macro-crate" -version = "3.2.0" +version = "3.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ecf48c7ca261d60b74ab1a7b20da18bede46776b2e55535cb958eb595c5fa7b" +checksum = "edce586971a4dfaa28950c6f18ed55e0406c1ab88bbce2c6f6293a7aaba73d35" dependencies = [ - "toml_edit 0.22.24", + "toml_edit 0.22.27", ] [[package]] @@ -6282,14 +7057,14 @@ dependencies = [ "proc-macro-error-attr2", "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] name = "proc-macro2" -version = "1.0.93" +version = "1.0.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60946a68e5f9d28b0dc1c21bb8a97ee7d018a8b322fa57838ba31cc878e22d99" +checksum = "02b3e5e68a3a1a02aad3ec490a98007cbc13c37cbe84a3cd7b8e406d76e7f778" dependencies = [ "unicode-ident", ] @@ -6302,25 +7077,25 @@ checksum = "af066a9c399a26e020ada66a034357a868728e72cd426f3adcd35f80d88d88c8" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", "version_check", "yansi 1.0.1", ] [[package]] name = "proptest" -version = "1.6.0" +version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "14cae93065090804185d3b75f0bf93b8eeda30c7a9b4a33d3bdb3988d6229e50" +checksum = "6fcdab19deb5195a31cf7726a210015ff1496ba1464fd42cb4f537b8b01b471f" dependencies = [ "bit-set 0.8.0", "bit-vec 0.8.0", - "bitflags 2.8.0", + "bitflags 2.9.1", "lazy_static", "num-traits", - "rand 0.8.5", - "rand_chacha 0.3.1", - "rand_xorshift", + "rand 0.9.1", + "rand_chacha 0.9.0", + "rand_xorshift 0.4.0", "regex-syntax 0.8.5", "rusty-fork", "tempfile", @@ -6329,9 +7104,9 @@ dependencies = [ [[package]] name = "pyo3" -version = "0.25.0" +version = "0.25.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f239d656363bcee73afef85277f1b281e8ac6212a1d42aa90e55b90ed43c47a4" +checksum = "8970a78afe0628a3e3430376fc5fd76b6b45c4d43360ffd6cdd40bdde72b682a" dependencies = [ "libc", "memoffset", @@ -6343,9 +7118,9 @@ dependencies = [ [[package]] name = "pyo3-build-config" -version = "0.25.0" +version = "0.25.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "755ea671a1c34044fa165247aaf6f419ca39caa6003aee791a0df2713d8f1b6d" +checksum = "458eb0c55e7ece017adeba38f2248ff3ac615e53660d7c71a238d7d2a01c7598" dependencies = [ "once_cell", "target-lexicon 0.13.2", @@ -6353,9 +7128,9 @@ dependencies = [ [[package]] name = "pyo3-ffi" -version = "0.25.0" +version = "0.25.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc95a2e67091e44791d4ea300ff744be5293f394f1bafd9f78c080814d35956e" +checksum = "7114fe5457c61b276ab77c5055f206295b812608083644a5c5b2640c3102565c" dependencies = [ "libc", "pyo3-build-config", @@ -6363,15 +7138,15 @@ dependencies = [ [[package]] name = "quanta" -version = "0.12.5" +version = "0.12.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3bd1fe6824cea6538803de3ff1bc0cf3949024db3d43c9643024bfb33a807c0e" +checksum = "f3ab5a9d756f0d97bdc89019bd2e4ea098cf9cde50ee7564dde6b81ccc8f06c7" dependencies = [ "crossbeam-utils", "libc", "once_cell", "raw-cpuid", - "wasi 0.11.0+wasi-snapshot-preview1", + "wasi 0.11.1+wasi-snapshot-preview1", "web-sys", "winapi", ] @@ -6393,13 +7168,19 @@ dependencies = [ [[package]] name = "quote" -version = "1.0.38" +version = "1.0.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e4dccaaaf89514f546c693ddc140f729f958c247918a13380cccc6078391acc" +checksum = "1885c039570dc00dcb4ff087a89e185fd56bae234ddc7f056a945bf36467248d" dependencies = [ "proc-macro2", ] +[[package]] +name = "r-efi" +version = "5.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" + [[package]] name = "radium" version = "0.7.0" @@ -6436,6 +7217,7 @@ checksum = "9fbfd9d094a40bf3ae768db9361049ace4c0e04a4fd6b359518bd7b73a73dd97" dependencies = [ "rand_chacha 0.9.0", "rand_core 0.9.3", + "serde", ] [[package]] @@ -6464,7 +7246,7 @@ version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" dependencies = [ - "getrandom 0.2.15", + "getrandom 0.2.16", ] [[package]] @@ -6473,7 +7255,8 @@ version = "0.9.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38" dependencies = [ - "getrandom 0.3.1", + "getrandom 0.3.3", + "serde", ] [[package]] @@ -6485,15 +7268,30 @@ dependencies = [ "rand_core 0.6.4", ] +[[package]] +name = "rand_xorshift" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "513962919efc330f829edb2535844d1b912b0fbe2ca165d613e4e8788bb05a5a" +dependencies = [ + "rand_core 0.9.3", +] + [[package]] name = "raw-cpuid" -version = "11.4.0" +version = "11.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "529468c1335c1c03919960dfefdb1b3648858c20d7ec2d0663e728e4a717efbc" +checksum = "c6df7ab838ed27997ba19a4664507e6f82b41fe6e20be42929332156e5e85146" dependencies = [ - "bitflags 2.8.0", + "bitflags 2.9.1", ] +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + [[package]] name = "rayon" version = "1.10.0" @@ -6516,11 +7314,11 @@ dependencies = [ [[package]] name = "redox_syscall" -version = "0.5.11" +version = "0.5.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2f103c6d277498fbceb16e84d317e2a400f160f46904d5f5410848c829511a3" +checksum = "0d04b7d0ee6b4a0207a0a7adb104d23ecb0b47d6beae7152d0fa34b692b29fd6" dependencies = [ - "bitflags 2.8.0", + "bitflags 2.9.1", ] [[package]] @@ -6529,11 +7327,31 @@ version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ba009ff324d1fc1b900bd1fdb31564febe58a8ccc8a6fdbb93b543d33b13ca43" dependencies = [ - "getrandom 0.2.15", + "getrandom 0.2.16", "libredox", "thiserror 1.0.69", ] +[[package]] +name = "ref-cast" +version = "1.0.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4a0ae411dbe946a674d89546582cea4ba2bb8defac896622d6496f14c23ba5cf" +dependencies = [ + "ref-cast-impl", +] + +[[package]] +name = "ref-cast-impl" +version = "1.0.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1165225c21bff1f3bbce98f5a1f889949bc902d3575308cc7b0de30b4f6d27c7" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.104", +] + [[package]] name = "regex" version = "1.11.1" @@ -6595,11 +7413,11 @@ dependencies = [ "encoding_rs", "futures-core", "futures-util", - "h2", + "h2 0.3.27", "http 0.2.12", "http-body 0.4.6", - "hyper", - "hyper-rustls", + "hyper 0.14.32", + "hyper-rustls 0.24.2", "ipnet", "js-sys", "log", @@ -6607,7 +7425,7 @@ dependencies = [ "once_cell", "percent-encoding", "pin-project-lite", - "rustls", + "rustls 0.21.12", "rustls-pemfile", "serde", "serde_json", @@ -6615,7 +7433,7 @@ dependencies = [ "sync_wrapper", "system-configuration", "tokio", - "tokio-rustls", + "tokio-rustls 0.24.1", "tower-service", "url", "wasm-bindgen", @@ -6627,46 +7445,164 @@ dependencies = [ [[package]] name = "revm" -version = "18.0.0" +version = "24.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "15689a3c6a8d14b647b4666f2e236ef47b5a5133cdfd423f545947986fff7013" +checksum = "01d277408ff8d6f747665ad9e52150ab4caf8d5eaf0d787614cf84633c8337b4" +dependencies = [ + "revm-bytecode", + "revm-context", + "revm-context-interface", + "revm-database", + "revm-database-interface", + "revm-handler", + "revm-inspector", + "revm-interpreter", + "revm-precompile", + "revm-primitives 19.2.0", + "revm-state", +] + +[[package]] +name = "revm-bytecode" +version = "4.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "942fe4724cf552fd28db6b0a2ca5b79e884d40dd8288a4027ed1e9090e0c6f49" +dependencies = [ + "bitvec", + "once_cell", + "phf", + "revm-primitives 19.2.0", + "serde", +] + +[[package]] +name = "revm-context" +version = "5.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b01aad49e1233f94cebda48a4e5cef022f7c7ed29b4edf0d202b081af23435ef" dependencies = [ - "auto_impl", "cfg-if", - "dyn-clone", + "derive-where", + "revm-bytecode", + "revm-context-interface", + "revm-database-interface", + "revm-primitives 19.2.0", + "revm-state", + "serde", +] + +[[package]] +name = "revm-context-interface" +version = "5.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b844f48a411e62c7dde0f757bf5cce49c85b86d6fc1d3b2722c07f2bec4c3ce" +dependencies = [ + "alloy-eip2930", + "alloy-eip7702", + "auto_impl", + "either", + "revm-database-interface", + "revm-primitives 19.2.0", + "revm-state", + "serde", +] + +[[package]] +name = "revm-database" +version = "4.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad3fbe34f6bb00a9c3155723b3718b9cb9f17066ba38f9eb101b678cd3626775" +dependencies = [ + "alloy-eips", + "revm-bytecode", + "revm-database-interface", + "revm-primitives 19.2.0", + "revm-state", + "serde", +] + +[[package]] +name = "revm-database-interface" +version = "4.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b8acd36784a6d95d5b9e1b7be3ce014f1e759abb59df1fa08396b30f71adc2a" +dependencies = [ + "auto_impl", + "revm-primitives 19.2.0", + "revm-state", + "serde", +] + +[[package]] +name = "revm-handler" +version = "5.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "481e8c3290ff4fa1c066592fdfeb2b172edfd14d12e6cade6f6f5588cad9359a" +dependencies = [ + "auto_impl", + "revm-bytecode", + "revm-context", + "revm-context-interface", + "revm-database-interface", "revm-interpreter", "revm-precompile", + "revm-primitives 19.2.0", + "revm-state", + "serde", +] + +[[package]] +name = "revm-inspector" +version = "5.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fdc1167ef8937d8867888e63581d8ece729a72073d322119ef4627d813d99ecb" +dependencies = [ + "auto_impl", + "revm-context", + "revm-database-interface", + "revm-handler", + "revm-interpreter", + "revm-primitives 19.2.0", + "revm-state", "serde", "serde_json", ] [[package]] name = "revm-interpreter" -version = "14.0.0" +version = "20.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "74e3f11d0fed049a4a10f79820c59113a79b38aed4ebec786a79d5c667bfeb51" +checksum = "b5ee65e57375c6639b0f50555e92a4f1b2434349dd32f52e2176f5c711171697" dependencies = [ - "revm-primitives 14.0.0", + "revm-bytecode", + "revm-context-interface", + "revm-primitives 19.2.0", "serde", ] [[package]] name = "revm-precompile" -version = "15.0.0" +version = "21.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e381060af24b750069a2b2d2c54bba273d84e8f5f9e8026fc9262298e26cc336" +checksum = "0f9311e735123d8d53a02af2aa81877bba185be7c141be7f931bb3d2f3af449c" dependencies = [ + "ark-bls12-381", + "ark-bn254 0.5.0", + "ark-ec 0.5.0", + "ark-ff 0.5.0", + "ark-serialize 0.5.0", "aurora-engine-modexp", "blst", "c-kzg", "cfg-if", "k256 0.13.4 (registry+https://github.com/rust-lang/crates.io-index)", + "libsecp256k1", "once_cell", - "revm-primitives 14.0.0", + "p256 0.13.2 (registry+https://github.com/rust-lang/crates.io-index)", + "revm-primitives 19.2.0", "ripemd", "secp256k1", - "sha2", - "substrate-bn", + "sha2 0.10.9", ] [[package]] @@ -6678,7 +7614,7 @@ dependencies = [ "alloy-primitives 0.4.2", "alloy-rlp", "auto_impl", - "bitflags 2.8.0", + "bitflags 2.9.1", "bitvec", "enumn", "hashbrown 0.14.5", @@ -6687,21 +7623,24 @@ dependencies = [ [[package]] name = "revm-primitives" -version = "14.0.0" +version = "19.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3702f132bb484f4f0d0ca4f6fbde3c82cfd745041abbedd6eda67730e1868ef0" +checksum = "1c1588093530ec4442461163be49c433c07a3235d1ca6f6799fef338dacc50d3" dependencies = [ - "alloy-eip2930", - "alloy-eip7702", - "alloy-primitives 0.8.25", - "auto_impl", - "bitflags 2.8.0", - "bitvec", - "c-kzg", - "cfg-if", - "dyn-clone", - "enumn", - "hex", + "alloy-primitives 1.2.1", + "num_enum", + "serde", +] + +[[package]] +name = "revm-state" +version = "4.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0040c61c30319254b34507383ba33d85f92949933adf6525a2cede05d165e1fa" +dependencies = [ + "bitflags 2.9.1", + "revm-bytecode", + "revm-primitives 19.2.0", "serde", ] @@ -6728,13 +7667,13 @@ dependencies = [ [[package]] name = "ring" -version = "0.17.13" +version = "0.17.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "70ac5d832aa16abd7d1def883a8545280c20a60f523a370aa3a9617c2b8550ee" +checksum = "a4689e6c2294d81e88dc6261c768b63bc4fcdb852be6d1352498b114f61383b7" dependencies = [ "cc", "cfg-if", - "getrandom 0.2.15", + "getrandom 0.2.16", "libc", "untrusted", "windows-sys 0.52.0", @@ -6793,30 +7732,6 @@ dependencies = [ "paste", ] -[[package]] -name = "ruint" -version = "1.12.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2c3cc4c2511671f327125da14133d0c5c5d137f006a1017a16f557bc85b16286" -dependencies = [ - "alloy-rlp", - "ark-ff 0.3.0", - "ark-ff 0.4.2", - "bytes", - "fastrlp 0.3.1", - "num-bigint 0.4.6", - "num-traits", - "parity-scale-codec", - "primitive-types", - "proptest", - "rand 0.8.5", - "rlp", - "ruint-macro 1.2.1 (registry+https://github.com/rust-lang/crates.io-index)", - "serde", - "valuable", - "zeroize", -] - [[package]] name = "ruint" version = "1.14.0" @@ -6834,7 +7749,7 @@ dependencies = [ "bytemuck", "bytes", "criterion", - "der 0.7.9", + "der 0.7.10", "diesel", "ethereum_ssz", "eyre", @@ -6870,7 +7785,34 @@ dependencies = [ "serde_json", "sqlx-core", "subtle", - "thiserror 2.0.11", + "thiserror 2.0.12", + "valuable", + "zeroize", +] + +[[package]] +name = "ruint" +version = "1.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11256b5fe8c68f56ac6f39ef0720e592f33d2367a4782740d9c9142e889c7fb4" +dependencies = [ + "alloy-rlp", + "ark-ff 0.3.0", + "ark-ff 0.4.2", + "bytes", + "fastrlp 0.3.1", + "fastrlp 0.4.0", + "num-bigint 0.4.6", + "num-integer", + "num-traits", + "parity-scale-codec", + "primitive-types", + "proptest", + "rand 0.8.5", + "rand 0.9.1", + "rlp", + "ruint-macro 1.2.1 (registry+https://github.com/rust-lang/crates.io-index)", + "serde", "valuable", "zeroize", ] @@ -6890,9 +7832,9 @@ checksum = "48fd7bd8a6377e15ad9d42a8ec25371b94ddc67abe7c8b9127bec79bebaaae18" [[package]] name = "rustc-demangle" -version = "0.1.24" +version = "0.1.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" +checksum = "989e6739f80c4ad5b13e0fd7fe89531180375b18520cc8c82080e4dc4035b84f" [[package]] name = "rustc-hash" @@ -6927,7 +7869,7 @@ version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92" dependencies = [ - "semver 1.0.25", + "semver 1.0.26", ] [[package]] @@ -6936,10 +7878,23 @@ version = "0.38.44" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154" dependencies = [ - "bitflags 2.8.0", + "bitflags 2.9.1", + "errno", + "libc", + "linux-raw-sys 0.4.15", + "windows-sys 0.59.0", +] + +[[package]] +name = "rustix" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c71e83d6afe7ff64890ec6b71d6a69bb8a610ab78ce364b3352876bb4c801266" +dependencies = [ + "bitflags 2.9.1", "errno", "libc", - "linux-raw-sys", + "linux-raw-sys 0.9.4", "windows-sys 0.59.0", ] @@ -6951,10 +7906,24 @@ checksum = "3f56a14d1f48b391359b22f731fd4bd7e43c97f3c50eee276f3aa09c94784d3e" dependencies = [ "log", "ring", - "rustls-webpki", + "rustls-webpki 0.101.7", "sct", ] +[[package]] +name = "rustls" +version = "0.23.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2491382039b29b9b11ff08b76ff6c97cf287671dbb74f0be44bda389fffe9bd1" +dependencies = [ + "aws-lc-rs", + "once_cell", + "rustls-pki-types", + "rustls-webpki 0.103.4", + "subtle", + "zeroize", +] + [[package]] name = "rustls-native-certs" version = "0.6.3" @@ -6964,33 +7933,66 @@ dependencies = [ "openssl-probe", "rustls-pemfile", "schannel", - "security-framework", + "security-framework 2.11.1", +] + +[[package]] +name = "rustls-native-certs" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7fcff2dd52b58a8d98a70243663a0d234c4e2b79235637849d15913394a247d3" +dependencies = [ + "openssl-probe", + "rustls-pki-types", + "schannel", + "security-framework 3.2.0", +] + +[[package]] +name = "rustls-pemfile" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1c74cae0a4cf6ccbbf5f359f08efdf8ee7e1dc532573bf0db71968cb56b1448c" +dependencies = [ + "base64 0.21.7", +] + +[[package]] +name = "rustls-pki-types" +version = "1.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "229a4a4c221013e7e1f1a043678c5cc39fe5171437c88fb47151a21e6f5b5c79" +dependencies = [ + "zeroize", ] [[package]] -name = "rustls-pemfile" -version = "1.0.4" +name = "rustls-webpki" +version = "0.101.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1c74cae0a4cf6ccbbf5f359f08efdf8ee7e1dc532573bf0db71968cb56b1448c" +checksum = "8b6275d1ee7a1cd780b64aca7726599a1dbc893b1e64144529e55c3c2f745765" dependencies = [ - "base64 0.21.7", + "ring", + "untrusted", ] [[package]] name = "rustls-webpki" -version = "0.101.7" +version = "0.103.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b6275d1ee7a1cd780b64aca7726599a1dbc893b1e64144529e55c3c2f745765" +checksum = "0a17884ae0c1b773f1ccd2bd4a8c72f16da897310a98b0e84bf349ad5ead92fc" dependencies = [ + "aws-lc-rs", "ring", + "rustls-pki-types", "untrusted", ] [[package]] name = "rustversion" -version = "1.0.19" +version = "1.0.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f7c45b9784283f1b2e7fb61b42047c2fd678ef0960d4f6f1eba131594cc369d4" +checksum = "8a0d197bd2c9dc6e53b84da9556a69ba4cdfab8619eb41a8bd1cc2027a0f6b1d" [[package]] name = "rusty-fork" @@ -7006,9 +8008,9 @@ dependencies = [ [[package]] name = "ryu" -version = "1.0.19" +version = "1.0.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ea1a2d0a644769cc99faa24c3ad26b379b786fe7c36fd3c546254801650e6dd" +checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" [[package]] name = "same-file" @@ -7040,7 +8042,7 @@ dependencies = [ "proc-macro-crate", "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] @@ -7052,6 +8054,30 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "schemars" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4cd191f9397d57d581cddd31014772520aa448f65ef991055d7f61582c65165f" +dependencies = [ + "dyn-clone", + "ref-cast", + "serde", + "serde_json", +] + +[[package]] +name = "schemars" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "82d20c4491bc164fa2f6c5d44565947a52ad80b9505d8e36f8d54c27c739fcd0" +dependencies = [ + "dyn-clone", + "ref-cast", + "serde", + "serde_json", +] + [[package]] name = "scopeguard" version = "1.2.0" @@ -7089,7 +8115,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d3e97a565f76233a6003f9f5c54be1d9c5bdfa3eccfb189469f11ec4901c47dc" dependencies = [ "base16ct 0.2.0", - "der 0.7.9", + "der 0.7.10", "generic-array", "pkcs8 0.10.2", "serdect", @@ -7099,10 +8125,11 @@ dependencies = [ [[package]] name = "secp256k1" -version = "0.29.1" +version = "0.30.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9465315bc9d4566e1724f0fffcbcc446268cb522e60f9a27bcded6b19c108113" +checksum = "b50c5943d326858130af85e049f2661ba3c78b26589b8ab98e65e80ae44a1252" dependencies = [ + "bitcoin_hashes", "rand 0.8.5", "secp256k1-sys", ] @@ -7122,8 +8149,21 @@ version = "2.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" dependencies = [ - "bitflags 2.8.0", - "core-foundation", + "bitflags 2.9.1", + "core-foundation 0.9.4", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework" +version = "3.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "271720403f46ca04f7ba6f55d438f8bd878d6b8ca0a1046e8228c4145bcbb316" +dependencies = [ + "bitflags 2.9.1", + "core-foundation 0.10.1", "core-foundation-sys", "libc", "security-framework-sys", @@ -7150,9 +8190,9 @@ dependencies = [ [[package]] name = "semver" -version = "1.0.25" +version = "1.0.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f79dfe2d285b0488816f30e700a7438c5a73d816b5b7d3ac72fbc48b0d185e03" +checksum = "56e6fa9c48d24d85fb3de5ad847117517440f6beceb7798af16b4a87d616b8d0" dependencies = [ "serde", ] @@ -7168,9 +8208,9 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.218" +version = "1.0.219" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8dfc9d19bdbf6d17e22319da49161d5d0108e4188e8b680aef6299eed22df60" +checksum = "5f0e2c6ed6606019b4e29e69dbaba95b11854410e5347d525002456dbbb786b6" dependencies = [ "serde_derive", ] @@ -7195,22 +8235,22 @@ dependencies = [ [[package]] name = "serde_derive" -version = "1.0.218" +version = "1.0.219" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f09503e191f4e797cb8aac08e9a4a4695c5edf6a2e70e376d961ddd5c969f82b" +checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] name = "serde_json" -version = "1.0.139" +version = "1.0.140" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "44f86c3acccc9c65b153fe1b85a3be07fe5515274ec9f0653b4a0875731c72a6" +checksum = "20068b6e96dc6c9bd23e01df8827e6c7e1f2fddd43c21810382803c136b99373" dependencies = [ - "indexmap 2.7.1", + "indexmap 2.10.0", "itoa", "memchr", "ryu", @@ -7229,9 +8269,9 @@ dependencies = [ [[package]] name = "serde_spanned" -version = "0.6.8" +version = "0.6.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87607cb1398ed59d48732e575a4c28a7a8ebf2454b964fe3f224f2afc07909e1" +checksum = "bf41e0cfaf7226dca15e8197172c295a782857fcb97fad1808a166870dee75a3" dependencies = [ "serde", ] @@ -7250,15 +8290,17 @@ dependencies = [ [[package]] name = "serde_with" -version = "3.12.0" +version = "3.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d6b6f7f2fcb69f747921f79f3926bd1e203fce4fef62c268dd3abfb6d86029aa" +checksum = "f2c45cd61fefa9db6f254525d46e392b852e0e61d9a1fd36e5bd183450a556d5" dependencies = [ "base64 0.22.1", "chrono", "hex", "indexmap 1.9.3", - "indexmap 2.7.1", + "indexmap 2.10.0", + "schemars 0.9.0", + "schemars 1.0.4", "serde", "serde_derive", "serde_json", @@ -7268,14 +8310,14 @@ dependencies = [ [[package]] name = "serde_with_macros" -version = "3.12.0" +version = "3.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d00caa5193a3c8362ac2b73be6b9e768aa5a4b2f721d8f4b339600c3cb51f8e" +checksum = "de90945e6565ce0d9a25098082ed4ee4002e047cb59892c318d66821e14bb30f" dependencies = [ "darling", "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] @@ -7301,9 +8343,22 @@ dependencies = [ [[package]] name = "sha2" -version = "0.10.8" +version = "0.9.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4d58a1e1bf39749807d89cf2d98ac2dfa0ff1cb3faa38fbb64dd88ac8013d800" +dependencies = [ + "block-buffer 0.9.0", + "cfg-if", + "cpufeatures", + "digest 0.9.0", + "opaque-debug", +] + +[[package]] +name = "sha2" +version = "0.10.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8" +checksum = "a7507d819769d01a365ab707794a4084392c824f54a7a6a7862f8c3d0892b283" dependencies = [ "cfg-if", "cpufeatures", @@ -7347,9 +8402,9 @@ checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" [[package]] name = "signal-hook-registry" -version = "1.4.2" +version = "1.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a9e9e0b4211b72e7b8b6e85c807d36c212bdb33ea8587f7569562a84df5465b1" +checksum = "9203b8055f63a2a00e2f593bb0510367fe707d7ff1e5c872de2f537b339e5410" dependencies = [ "libc", ] @@ -7388,24 +8443,21 @@ checksum = "85636c14b73d81f541e525f585c0a2109e6744e1565b5c1668e31c70c10ed65c" [[package]] name = "slab" -version = "0.4.9" +version = "0.4.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f92a496fb766b417c996b9c5e57daf2f7ad3b0bebe1ccfca4856390e3d3bb67" -dependencies = [ - "autocfg", -] +checksum = "04dc19736151f35336d325007ac991178d504a119863a2fcb3758cdb5e52c50d" [[package]] name = "smallvec" -version = "1.14.0" +version = "1.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7fcf8323ef1faaee30a44a340193b1ac6814fd9b7b4e88e9d4519a3e4abe1cfd" +checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" [[package]] name = "snark-verifier" -version = "0.2.0" +version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28e4c4ed1edca41687fe2d8a09ba30badb0a5cc7fa56dd1159d62aeab7c99ace" +checksum = "c9203c416ff9de0762667270b21573ba5e6edaeda08743b3ca37dc8a5e0a4480" dependencies = [ "halo2-base", "halo2-ecc", @@ -7418,16 +8470,16 @@ dependencies = [ "pairing 0.23.0", "rand 0.8.5", "revm", - "ruint 1.12.3", + "ruint 1.15.0", "serde", "sha3", ] [[package]] name = "snark-verifier-sdk" -version = "0.2.0" +version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "babff70ce6292fce03f692d68569f76b8f6710dbac7be7fe5f32c915909c9065" +checksum = "290ae6e750d9d5fdf05393bbcae6bf7a63e3408eab023abf7d466156a234ac85" dependencies = [ "bincode", "ethereum-types", @@ -7448,9 +8500,9 @@ dependencies = [ [[package]] name = "socket2" -version = "0.5.8" +version = "0.5.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c970269d99b64e60ec3bd6ad27270092a5394c4e309314b18ae3fe575695fbe8" +checksum = "e22376abed350d73dd1cd119b57ffccad95b4e585a7cda43e286245ce23c0678" dependencies = [ "libc", "windows-sys 0.52.0", @@ -7493,7 +8545,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d91ed6c858b01f942cd56b37a94b3e0a1798290327d1236e4d9cf4eaca44d29d" dependencies = [ "base64ct", - "der 0.7.9", + "der 0.7.10", ] [[package]] @@ -7511,15 +8563,15 @@ dependencies = [ "futures-intrusive", "futures-io", "futures-util", - "hashbrown 0.15.2", + "hashbrown 0.15.4", "hashlink", - "indexmap 2.7.1", + "indexmap 2.10.0", "log", "memchr", "once_cell", "percent-encoding", "smallvec", - "thiserror 2.0.11", + "thiserror 2.0.12", "tracing", "url", ] @@ -7536,6 +8588,16 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" +[[package]] +name = "statrs" +version = "0.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a3fe7c28c6512e766b0874335db33c94ad7b8f9054228ae1c2abd47ce7d335e" +dependencies = [ + "approx", + "num-traits", +] + [[package]] name = "strength_reduce" version = "0.2.4" @@ -7590,20 +8652,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.98", -] - -[[package]] -name = "substrate-bn" -version = "0.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72b5bbfa79abbae15dd642ea8176a21a635ff3c00059961d1ea27ad04e5b441c" -dependencies = [ - "byteorder", - "crunchy", - "lazy_static", - "rand 0.8.5", - "rustc-hex", + "syn 2.0.104", ] [[package]] @@ -7636,10 +8685,10 @@ dependencies = [ "hex", "once_cell", "reqwest", - "semver 1.0.25", + "semver 1.0.26", "serde", "serde_json", - "sha2", + "sha2 0.10.9", "thiserror 1.0.69", "url", "zip", @@ -7653,7 +8702,7 @@ checksum = "aa64b5e8eecd3a8af7cfc311e29db31a268a62d5953233d3e8243ec77a71c4e3" dependencies = [ "build_const", "hex", - "semver 1.0.25", + "semver 1.0.26", "serde_json", "svm-rs", ] @@ -7671,9 +8720,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.98" +version = "2.0.104" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "36147f1a48ae0ec2b5b3bc5b537d267457555a10dc06f3dbc8cb11ba3006d3b1" +checksum = "17b6f705963418cdb9927482fa304bc562ece2fdd4f616084c50b7023b435a40" dependencies = [ "proc-macro2", "quote", @@ -7689,7 +8738,7 @@ dependencies = [ "paste", "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] @@ -7700,13 +8749,13 @@ checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" [[package]] name = "synstructure" -version = "0.13.1" +version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971" +checksum = "728a70f3dbaf5bab7f0c4b1ac8d7ae5ea60a4b5549c8a5914361c99147a709d2" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] @@ -7716,7 +8765,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7" dependencies = [ "bitflags 1.3.2", - "core-foundation", + "core-foundation 0.9.4", "system-configuration-sys", ] @@ -7750,15 +8799,14 @@ checksum = "e502f78cdbb8ba4718f566c418c52bc729126ffd16baee5baa718cf25dd5a69a" [[package]] name = "tempfile" -version = "3.17.1" +version = "3.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22e5a0acb1f3f55f65cc4a866c361b2fb2a0ff6366785ae6fbb5f85df07ba230" +checksum = "e8a64e3985349f2441a1a9ef0b853f869006c3855f2cda6862a94d26ebb9d6a1" dependencies = [ - "cfg-if", "fastrand", - "getrandom 0.3.1", + "getrandom 0.3.3", "once_cell", - "rustix", + "rustix 1.0.7", "windows-sys 0.59.0", ] @@ -7773,6 +8821,16 @@ dependencies = [ "winapi", ] +[[package]] +name = "terminal_size" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45c6481c4829e4cc63825e62c49186a34538b7b2750b73b266581ffb612fb5ed" +dependencies = [ + "rustix 1.0.7", + "windows-sys 0.59.0", +] + [[package]] name = "test-case" version = "3.3.1" @@ -7791,7 +8849,7 @@ dependencies = [ "cfg-if", "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] @@ -7802,30 +8860,30 @@ checksum = "5c89e72a01ed4c579669add59014b9a524d609c0c88c6a585ce37485879f6ffb" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", "test-case-core", ] [[package]] name = "test-log" -version = "0.2.17" +version = "0.2.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7f46083d221181166e5b6f6b1e5f1d499f3a76888826e6cb1d057554157cd0f" +checksum = "1e33b98a582ea0be1168eba097538ee8dd4bbe0f2b01b22ac92ea30054e5be7b" dependencies = [ "env_logger", "test-log-macros", - "tracing-subscriber", + "tracing-subscriber 0.3.19", ] [[package]] name = "test-log-macros" -version = "0.2.17" +version = "0.2.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "888d0c3c6db53c0fdab160d2ed5e12ba745383d3e85813f2ea0f2b1475ab553f" +checksum = "451b374529930d7601b1eef8d32bc79ae870b6079b069401709c2a8bf9e75f36" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] @@ -7839,11 +8897,11 @@ dependencies = [ [[package]] name = "thiserror" -version = "2.0.11" +version = "2.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d452f284b73e6d76dd36758a0c8684b1d5be31f92b89d07fd5822175732206fc" +checksum = "567b8a2dae586314f7be2a752ec7474332959c6460e02bde30d702a66d488708" dependencies = [ - "thiserror-impl 2.0.11", + "thiserror-impl 2.0.12", ] [[package]] @@ -7854,28 +8912,27 @@ checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] name = "thiserror-impl" -version = "2.0.11" +version = "2.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26afc1baea8a989337eeb52b6e72a039780ce45c3edfcc9c5b9d112feeb173c2" +checksum = "7f7cf42b4507d8ea322120659672cf1b9dbb93f8f2d4ecfd6e51350ff5b17a1d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] name = "thread_local" -version = "1.1.8" +version = "1.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b9ef9bad013ada3808854ceac7b46812a6465ba368859a37e2100283d2d719c" +checksum = "f60246a4944f24f6e018aa17cdeffb7818b76356965d03b07d6a9886e8962185" dependencies = [ "cfg-if", - "once_cell", ] [[package]] @@ -7909,9 +8966,9 @@ dependencies = [ [[package]] name = "time" -version = "0.3.37" +version = "0.3.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "35e7868883861bd0e56d9ac6efcaaca0d6d5d82a2a7ec8209ff492c07cf37b21" +checksum = "8a7619e19bc266e0f9c5e6686659d394bc57973859340060a69221e57dbc0c40" dependencies = [ "deranged", "itoa", @@ -7926,15 +8983,15 @@ dependencies = [ [[package]] name = "time-core" -version = "0.1.2" +version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3" +checksum = "c9e9a38711f559d9e3ce1cdb06dd7c5b8ea546bc90052da6d06bb76da74bb07c" [[package]] name = "time-macros" -version = "0.2.19" +version = "0.2.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2834e6017e3e5e4b9834939793b282bc03b37a3336245fa820e35e233e2a85de" +checksum = "3526739392ec93fd8b359c8e98514cb3e8e021beb4e5f597b00a0221f8ed8a49" dependencies = [ "num-conv", "time-core", @@ -7951,9 +9008,9 @@ dependencies = [ [[package]] name = "tinystr" -version = "0.7.6" +version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9117f5d4db391c1cf6927e7bea3db74b9a1c1add8f7eda9ffd5364f40f57b82f" +checksum = "5d4f6d1145dcb577acf783d4e601bc1d76a13337bb54e6233add580b07344c8b" dependencies = [ "displaydoc", "zerovec", @@ -7986,16 +9043,18 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.44.2" +version = "1.46.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6b88822cbe49de4185e3a4cbf8321dd487cf5fe0c5c65695fef6346371e9c48" +checksum = "0cc3a2344dafbe23a245241fe8b09735b521110d30fcefbbd5feb1797ca35d17" dependencies = [ "backtrace", "bytes", + "io-uring", "libc", "mio", "pin-project-lite", "signal-hook-registry", + "slab", "socket2", "tokio-macros", "windows-sys 0.52.0", @@ -8009,7 +9068,7 @@ checksum = "6e06d43f1345a3bcd39f6a56dbb7dcab2ba47e68e8ac134855e7e2bdbaf8cab8" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] @@ -8044,15 +9103,25 @@ version = "0.24.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c28327cf380ac148141087fbfb9de9d7bd4e84ab5d2c28fbc911d753de8a7081" dependencies = [ - "rustls", + "rustls 0.21.12", + "tokio", +] + +[[package]] +name = "tokio-rustls" +version = "0.26.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e727b36a1a0e8b74c376ac2211e40c2c8af09fb4013c60d910495810f008e9b" +dependencies = [ + "rustls 0.23.29", "tokio", ] [[package]] name = "tokio-util" -version = "0.7.13" +version = "0.7.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7fcaa8d55a2bdd6b83ace262b016eca0d79ee02818c5c1bcdf0305114081078" +checksum = "66a539a9ad6d5d281510d5bd368c973d636c02dbf8a67300bfb6b950696ad7df" dependencies = [ "bytes", "futures-core", @@ -8067,7 +9136,7 @@ version = "0.7.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dd79e69d3b627db300ff956027cc6c3798cef26d22526befdfcd12feeb6d2257" dependencies = [ - "indexmap 2.7.1", + "indexmap 2.10.0", "serde", "serde_spanned", "toml_datetime", @@ -8076,21 +9145,21 @@ dependencies = [ [[package]] name = "toml" -version = "0.8.20" +version = "0.8.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cd87a5cdd6ffab733b2f74bc4fd7ee5fff6634124999ac278c35fc78c6120148" +checksum = "dc1beb996b9d83529a9e75c17a1686767d148d70663143c7854d8b4a09ced362" dependencies = [ "serde", "serde_spanned", "toml_datetime", - "toml_edit 0.22.24", + "toml_edit 0.22.27", ] [[package]] name = "toml_datetime" -version = "0.6.8" +version = "0.6.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0dd7358ecb8fc2f8d014bf86f6f638ce72ba252a2c3a2572f2a795f1d23efb41" +checksum = "22cddaf88f4fbc13c51aebbf5f8eceb5c7c5a9da2ac40a13519eb5b0a0e8f11c" dependencies = [ "serde", ] @@ -8101,7 +9170,7 @@ version = "0.19.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1b5bb770da30e5cbfde35a2d7b9b8a2c4b8ef89548a7a6aeab5c9a576e3e7421" dependencies = [ - "indexmap 2.7.1", + "indexmap 2.10.0", "serde", "serde_spanned", "toml_datetime", @@ -8110,17 +9179,40 @@ dependencies = [ [[package]] name = "toml_edit" -version = "0.22.24" +version = "0.22.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17b4795ff5edd201c7cd6dca065ae59972ce77d1b80fa0a84d94950ece7d1474" +checksum = "41fe8c660ae4257887cf66394862d21dbca4a6ddd26f04a3560410406a2f819a" dependencies = [ - "indexmap 2.7.1", + "indexmap 2.10.0", "serde", "serde_spanned", "toml_datetime", - "winnow 0.7.3", + "toml_write", + "winnow 0.7.12", +] + +[[package]] +name = "toml_write" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d99f8c9a7727884afe522e9bd5edbfc91a3312b36a77b5fb8926e4c31a41801" + +[[package]] +name = "tower" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d039ad9159c98b70ecfd540b2573b97f7f52c3e8d9f8ad57a24b916a536975f9" +dependencies = [ + "tower-layer", + "tower-service", ] +[[package]] +name = "tower-layer" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e" + [[package]] name = "tower-service" version = "0.3.3" @@ -8141,20 +9233,20 @@ dependencies = [ [[package]] name = "tracing-attributes" -version = "0.1.28" +version = "0.1.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "395ae124c09f9e6918a2310af6038fba074bcf474ac352496d5910dd59a2226d" +checksum = "81383ab64e72a7a8b8e13130c49e3dab29def6d0c7d76a03087b3cf71c5c6903" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] name = "tracing-core" -version = "0.1.33" +version = "0.1.34" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e672c95779cf947c5311f83787af4fa8fffd12fb27e4993211a84bdfd9610f9c" +checksum = "b9d12581f227e93f094d3af2ae690a574abb8a2b9b7a96e7cfe9647b2b617678" dependencies = [ "once_cell", "valuable", @@ -8170,7 +9262,7 @@ dependencies = [ "smallvec", "thiserror 1.0.69", "tracing", - "tracing-subscriber", + "tracing-subscriber 0.3.19", ] [[package]] @@ -8184,6 +9276,15 @@ dependencies = [ "tracing-core", ] +[[package]] +name = "tracing-subscriber" +version = "0.2.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e0d2eaa99c3c2e41547cfa109e910a68ea03823cccad4a0525dcbc9b01e8c71" +dependencies = [ + "tracing-core", +] + [[package]] name = "tracing-subscriber" version = "0.3.19" @@ -8265,9 +9366,9 @@ checksum = "5c1cb5db39152898a79168971543b1cb5020dff7fe43c8dc468b0885f5e29df5" [[package]] name = "unicode-ident" -version = "1.0.17" +version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "00e2473a93778eb0bad35909dff6a10d28e63f792f16ed15e404fca9d5eeedbe" +checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512" [[package]] name = "unicode-normalization" @@ -8329,12 +9430,6 @@ version = "2.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da" -[[package]] -name = "utf16_iter" -version = "1.0.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8232dd3cdaed5356e0f716d285e4b40b932ac434100fe9b7e0e8e935b9e6246" - [[package]] name = "utf8_iter" version = "1.0.4" @@ -8349,9 +9444,14 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "uuid" -version = "1.13.2" +version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c1f41ffb7cf259f1ecc2876861a17e7142e63ead296f671f81f6ae85903e0d6" +checksum = "3cf4199d1e5d15ddd86a694e4d0dffa9c323ce759fea589f00fef9d81cc1931d" +dependencies = [ + "getrandom 0.3.3", + "js-sys", + "wasm-bindgen", +] [[package]] name = "valuable" @@ -8420,15 +9520,15 @@ dependencies = [ [[package]] name = "wasi" -version = "0.11.0+wasi-snapshot-preview1" +version = "0.11.1+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" +checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" [[package]] name = "wasi" -version = "0.13.3+wasi-0.2.2" +version = "0.14.2+wasi-0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26816d2e1a4a36a2940b96c5296ce403917633dff8f3440e9b236ed6f6bacad2" +checksum = "9683f9a5a998d873c0d21fcbe3c083009670149a8fab228644b8bd36b2c48cb3" dependencies = [ "wit-bindgen-rt", ] @@ -8461,7 +9561,7 @@ dependencies = [ "log", "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", "wasm-bindgen-shared", ] @@ -8496,7 +9596,7 @@ checksum = "8ae87ea40c9f689fc23f209965b6fb8a99ad69aeeb0231408be24920604395de" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -8526,6 +9626,18 @@ version = "0.25.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5f20c57d8d7db6d3b86154206ae5d8fba62dd39573114de97c2cb0578251f8e1" +[[package]] +name = "which" +version = "4.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87ba24419a2078cd2b0f2ede2691b6c66d8e47836da3b6db8265ebad47afbfc7" +dependencies = [ + "either", + "home", + "once_cell", + "rustix 0.38.44", +] + [[package]] name = "whoami" version = "1.6.0" @@ -8570,11 +9682,61 @@ checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" [[package]] name = "windows-core" -version = "0.52.0" +version = "0.61.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" +checksum = "c0fdd3ddb90610c7638aa2b3a3ab2904fb9e5cdbecc643ddb3647212781c4ae3" dependencies = [ - "windows-targets 0.52.6", + "windows-implement", + "windows-interface", + "windows-link", + "windows-result", + "windows-strings", +] + +[[package]] +name = "windows-implement" +version = "0.60.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a47fddd13af08290e67f4acabf4b459f647552718f683a7b415d290ac744a836" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.104", +] + +[[package]] +name = "windows-interface" +version = "0.59.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd9211b69f8dcdfa817bfd14bf1c97c9188afa36f4750130fcdf3f400eca9fa8" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.104", +] + +[[package]] +name = "windows-link" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e6ad25900d524eaabdbbb96d20b4311e1e7ae1699af4fb28c17ae66c80d798a" + +[[package]] +name = "windows-result" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56f42bd332cc6c8eac5af113fc0c1fd6a8fd2aa08a0119358686e5160d0586c6" +dependencies = [ + "windows-link", +] + +[[package]] +name = "windows-strings" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56e6c93f3a0c3b36176cb1327a4958a0353d5d166c2a35cb268ace15e91d3b57" +dependencies = [ + "windows-link", ] [[package]] @@ -8604,6 +9766,15 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "windows-sys" +version = "0.60.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2f500e4d28234f72040990ec9d39e3a6b950f9f22d3dba18416c35882612bcb" +dependencies = [ + "windows-targets 0.53.2", +] + [[package]] name = "windows-targets" version = "0.48.5" @@ -8628,13 +9799,29 @@ dependencies = [ "windows_aarch64_gnullvm 0.52.6", "windows_aarch64_msvc 0.52.6", "windows_i686_gnu 0.52.6", - "windows_i686_gnullvm", + "windows_i686_gnullvm 0.52.6", "windows_i686_msvc 0.52.6", "windows_x86_64_gnu 0.52.6", "windows_x86_64_gnullvm 0.52.6", "windows_x86_64_msvc 0.52.6", ] +[[package]] +name = "windows-targets" +version = "0.53.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c66f69fcc9ce11da9966ddb31a40968cad001c5bedeb5c2b82ede4253ab48aef" +dependencies = [ + "windows_aarch64_gnullvm 0.53.0", + "windows_aarch64_msvc 0.53.0", + "windows_i686_gnu 0.53.0", + "windows_i686_gnullvm 0.53.0", + "windows_i686_msvc 0.53.0", + "windows_x86_64_gnu 0.53.0", + "windows_x86_64_gnullvm 0.53.0", + "windows_x86_64_msvc 0.53.0", +] + [[package]] name = "windows_aarch64_gnullvm" version = "0.48.5" @@ -8647,6 +9834,12 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "86b8d5f90ddd19cb4a147a5fa63ca848db3df085e25fee3cc10b39b6eebae764" + [[package]] name = "windows_aarch64_msvc" version = "0.48.5" @@ -8659,6 +9852,12 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" +[[package]] +name = "windows_aarch64_msvc" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7651a1f62a11b8cbd5e0d42526e55f2c99886c77e007179efff86c2b137e66c" + [[package]] name = "windows_i686_gnu" version = "0.48.5" @@ -8671,12 +9870,24 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" +[[package]] +name = "windows_i686_gnu" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1dc67659d35f387f5f6c479dc4e28f1d4bb90ddd1a5d3da2e5d97b42d6272c3" + [[package]] name = "windows_i686_gnullvm" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" +[[package]] +name = "windows_i686_gnullvm" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ce6ccbdedbf6d6354471319e781c0dfef054c81fbc7cf83f338a4296c0cae11" + [[package]] name = "windows_i686_msvc" version = "0.48.5" @@ -8689,6 +9900,12 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" +[[package]] +name = "windows_i686_msvc" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "581fee95406bb13382d2f65cd4a908ca7b1e4c2f1917f143ba16efe98a589b5d" + [[package]] name = "windows_x86_64_gnu" version = "0.48.5" @@ -8701,6 +9918,12 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" +[[package]] +name = "windows_x86_64_gnu" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e55b5ac9ea33f2fc1716d1742db15574fd6fc8dadc51caab1c16a3d3b4190ba" + [[package]] name = "windows_x86_64_gnullvm" version = "0.48.5" @@ -8713,6 +9936,12 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0a6e035dd0599267ce1ee132e51c27dd29437f63325753051e71dd9e42406c57" + [[package]] name = "windows_x86_64_msvc" version = "0.48.5" @@ -8725,6 +9954,12 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" +[[package]] +name = "windows_x86_64_msvc" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "271414315aff87387382ec3d271b52d7ae78726f5d44ac98b4f4030c91880486" + [[package]] name = "winnow" version = "0.5.40" @@ -8736,9 +9971,9 @@ dependencies = [ [[package]] name = "winnow" -version = "0.7.3" +version = "0.7.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e7f4ea97f6f78012141bcdb6a216b2609f0979ada50b20ca5b52dde2eac2bb1" +checksum = "f3edebf492c8125044983378ecb5766203ad3b4c2f7a922bd7dd207f6d443e95" dependencies = [ "memchr", ] @@ -8755,24 +9990,18 @@ dependencies = [ [[package]] name = "wit-bindgen-rt" -version = "0.33.0" +version = "0.39.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3268f3d866458b787f390cf61f4bbb563b922d091359f9608842999eaee3943c" +checksum = "6f42320e61fe2cfd34354ecb597f86f413484a798ba44a8ca1165c58d42da6c1" dependencies = [ - "bitflags 2.8.0", + "bitflags 2.9.1", ] -[[package]] -name = "write16" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d1890f4022759daae28ed4fe62859b1236caebfc61ede2f63ed4e695f3f6d936" - [[package]] name = "writeable" -version = "0.5.5" +version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e9df38ee2d2c3c5948ea468a8406ff0db0b29ae1ffde1bcf20ef305bcc95c51" +checksum = "ea2f10b9bb0928dfb1b42b65e1f9e36f7f54dbdf08457afefb38afcdec4fa2bb" [[package]] name = "wyz" @@ -8803,9 +10032,9 @@ checksum = "cfe53a6657fd280eaa890a3bc59152892ffa3e30101319d168b781ed6529b049" [[package]] name = "yoke" -version = "0.7.5" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "120e6aef9aa629e3d4f52dc8cc43a015c7724194c97dfaf45180d2daf2b77f40" +checksum = "5f41bb01b8226ef4bfd589436a297c53d118f65921786300e427be8d487695cc" dependencies = [ "serde", "stable_deref_trait", @@ -8815,55 +10044,54 @@ dependencies = [ [[package]] name = "yoke-derive" -version = "0.7.5" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2380878cad4ac9aac1e2435f3eb4020e8374b5f13c296cb75b4620ff8e229154" +checksum = "38da3c9736e16c5d3c8c597a9aaa5d1fa565d0532ae05e27c24aa62fb32c0ab6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", "synstructure", ] [[package]] name = "zerocopy" -version = "0.7.35" +version = "0.8.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" +checksum = "1039dd0d3c310cf05de012d8a39ff557cb0d23087fd44cad61df08fc31907a2f" dependencies = [ - "byteorder", "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.7.35" +version = "0.8.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" +checksum = "9ecf5b4cc5364572d7f4c329661bcc82724222973f2cab6f050a4e5c22f75181" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] name = "zerofrom" -version = "0.1.5" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cff3ee08c995dee1859d998dea82f7374f2826091dd9cd47def953cae446cd2e" +checksum = "50cc42e0333e05660c3587f3bf9d0478688e15d870fab3346451ce7f8c9fbea5" dependencies = [ "zerofrom-derive", ] [[package]] name = "zerofrom-derive" -version = "0.1.5" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "595eed982f7d355beb85837f651fa22e90b3c044842dc7f2c2842c086f295808" +checksum = "d71e5d6e06ab090c67b5e44993ec16b72dcbaabc526db883a360057678b48502" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", "synstructure", ] @@ -8884,14 +10112,25 @@ checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", +] + +[[package]] +name = "zerotrie" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "36f0bbd478583f79edad978b407914f61b2972f5af6fa089686016be8f9af595" +dependencies = [ + "displaydoc", + "yoke", + "zerofrom", ] [[package]] name = "zerovec" -version = "0.10.4" +version = "0.11.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aa2b893d79df23bfb12d5461018d408ea19dfafe76c2c7ef6d4eba614f8ff079" +checksum = "4a05eb080e015ba39cc9e23bbe5e7fb04d5fb040350f99f34e338d5fdd294428" dependencies = [ "yoke", "zerofrom", @@ -8900,13 +10139,13 @@ dependencies = [ [[package]] name = "zerovec-derive" -version = "0.10.3" +version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6eafa6dfb17584ea3e2bd6e76e0cc15ad7af12b09abdd1ca55961bed9b1063c6" +checksum = "5b96237efa0c878c64bd89c436f661be4e46b2f3eff1ebb976f7ef2321d2f58f" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.104", ] [[package]] @@ -8950,7 +10189,7 @@ dependencies = [ "pasta_curves 0.5.1", "rand 0.8.5", "serde", - "sha2", + "sha2 0.10.9", "sha3", "subtle", ] diff --git a/Cargo.toml b/Cargo.toml index 3548b66419..a7402e6e1a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [workspace.package] -version = "1.3.0-rc.1" +version = "1.4.0-rc.0" edition = "2021" rust-version = "1.82" authors = ["OpenVM Authors"] @@ -51,9 +51,9 @@ members = [ "extensions/keccak256/circuit", "extensions/keccak256/transpiler", "extensions/keccak256/guest", - "extensions/sha256/circuit", - "extensions/sha256/transpiler", - "extensions/sha256/guest", + "extensions/sha2/circuit", + "extensions/sha2/transpiler", + "extensions/sha2/guest", "extensions/ecc/circuit", "extensions/ecc/transpiler", "extensions/ecc/guest", @@ -85,6 +85,7 @@ codegen-units = 16 [profile.profiling] inherits = "release" debug = 2 +debug-assertions = false strip = false # Make sure debug symbols are in the bench profile for flamegraphs @@ -99,6 +100,7 @@ codegen-units = 1 [profile.dev] opt-level = 1 +debug = 2 # For O1 optimization but still fast(ish) compile times [profile.fast] @@ -110,14 +112,14 @@ lto = "thin" [workspace.dependencies] # Stark Backend -openvm-stark-backend = { git = "https://github.com/openvm-org/stark-backend.git", tag = "v1.1.1", default-features = false } -openvm-stark-sdk = { git = "https://github.com/openvm-org/stark-backend.git", tag = "v1.1.1", default-features = false } +openvm-stark-backend = { git = "https://github.com/openvm-org/stark-backend.git", tag = "v1.1.2", default-features = false } +openvm-stark-sdk = { git = "https://github.com/openvm-org/stark-backend.git", tag = "v1.1.2", default-features = false } # OpenVM openvm-sdk = { path = "crates/sdk", default-features = false } openvm-mod-circuit-builder = { path = "crates/circuits/mod-builder", default-features = false } openvm-poseidon2-air = { path = "crates/circuits/poseidon2-air", default-features = false } -openvm-sha256-air = { path = "crates/circuits/sha256-air", default-features = false } +openvm-sha2-air = { path = "crates/circuits/sha2-air", default-features = false } openvm-circuit-primitives = { path = "crates/circuits/primitives", default-features = false } openvm-circuit-primitives-derive = { path = "crates/circuits/primitives/derive", default-features = false } openvm = { path = "crates/toolchain/openvm", default-features = false } @@ -147,9 +149,9 @@ openvm-native-transpiler = { path = "extensions/native/transpiler", default-feat openvm-keccak256-circuit = { path = "extensions/keccak256/circuit", default-features = false } openvm-keccak256-transpiler = { path = "extensions/keccak256/transpiler", default-features = false } openvm-keccak256-guest = { path = "extensions/keccak256/guest", default-features = false } -openvm-sha256-circuit = { path = "extensions/sha256/circuit", default-features = false } -openvm-sha256-transpiler = { path = "extensions/sha256/transpiler", default-features = false } -openvm-sha256-guest = { path = "extensions/sha256/guest", default-features = false } +openvm-sha2-circuit = { path = "extensions/sha2/circuit", default-features = false } +openvm-sha2-transpiler = { path = "extensions/sha2/transpiler", default-features = false } +openvm-sha2-guest = { path = "extensions/sha2/guest", default-features = false } openvm-bigint-circuit = { path = "extensions/bigint/circuit", default-features = false } openvm-bigint-transpiler = { path = "extensions/bigint/transpiler", default-features = false } openvm-bigint-guest = { path = "extensions/bigint/guest", default-features = false } @@ -171,18 +173,16 @@ openvm-verify-stark = { path = "guest-libs/verify_stark", default-features = fal openvm-benchmarks-utils = { path = "benchmarks/utils", default-features = false } # Plonky3 -p3-field = { git = "https://github.com/Plonky3/Plonky3.git", rev = "539bbc84085efb609f4f62cb03cf49588388abdb" } -p3-baby-bear = { git = "https://github.com/Plonky3/Plonky3.git", features = [ - "nightly-features", -], rev = "539bbc84085efb609f4f62cb03cf49588388abdb" } -p3-dft = { git = "https://github.com/Plonky3/Plonky3.git", rev = "539bbc84085efb609f4f62cb03cf49588388abdb" } -p3-fri = { git = "https://github.com/Plonky3/Plonky3.git", rev = "539bbc84085efb609f4f62cb03cf49588388abdb" } -p3-keccak-air = { git = "https://github.com/Plonky3/Plonky3.git", rev = "539bbc84085efb609f4f62cb03cf49588388abdb" } -p3-merkle-tree = { git = "https://github.com/Plonky3/Plonky3.git", rev = "539bbc84085efb609f4f62cb03cf49588388abdb" } -p3-monty-31 = { git = "https://github.com/Plonky3/Plonky3.git", rev = "539bbc84085efb609f4f62cb03cf49588388abdb" } -p3-poseidon2 = { git = "https://github.com/Plonky3/Plonky3.git", rev = "539bbc84085efb609f4f62cb03cf49588388abdb" } -p3-poseidon2-air = { git = "https://github.com/Plonky3/Plonky3.git", rev = "539bbc84085efb609f4f62cb03cf49588388abdb" } -p3-symmetric = { git = "https://github.com/Plonky3/Plonky3.git", rev = "539bbc84085efb609f4f62cb03cf49588388abdb" } +p3-field = { git = "https://github.com/Plonky3/Plonky3.git", rev = "539bbc84085efb609f4f62cb03cf49588388abdb", default-features = false } +p3-baby-bear = { git = "https://github.com/Plonky3/Plonky3.git", rev = "539bbc84085efb609f4f62cb03cf49588388abdb", default-features = false } +p3-dft = { git = "https://github.com/Plonky3/Plonky3.git", rev = "539bbc84085efb609f4f62cb03cf49588388abdb", default-features = false } +p3-fri = { git = "https://github.com/Plonky3/Plonky3.git", rev = "539bbc84085efb609f4f62cb03cf49588388abdb", default-features = false } +p3-keccak-air = { git = "https://github.com/Plonky3/Plonky3.git", rev = "539bbc84085efb609f4f62cb03cf49588388abdb", default-features = false } +p3-merkle-tree = { git = "https://github.com/Plonky3/Plonky3.git", rev = "539bbc84085efb609f4f62cb03cf49588388abdb", default-features = false } +p3-monty-31 = { git = "https://github.com/Plonky3/Plonky3.git", rev = "539bbc84085efb609f4f62cb03cf49588388abdb", default-features = false } +p3-poseidon2 = { git = "https://github.com/Plonky3/Plonky3.git", rev = "539bbc84085efb609f4f62cb03cf49588388abdb", default-features = false } +p3-poseidon2-air = { git = "https://github.com/Plonky3/Plonky3.git", rev = "539bbc84085efb609f4f62cb03cf49588388abdb", default-features = false } +p3-symmetric = { git = "https://github.com/Plonky3/Plonky3.git", rev = "539bbc84085efb609f4f62cb03cf49588388abdb", default-features = false } zkhash = { git = "https://github.com/HorizenLabs/poseidon2.git", rev = "bb476b9" } snark-verifier-sdk = { version = "0.2.0", default-features = false, features = [ @@ -226,6 +226,10 @@ rrs-lib = "0.1.0" rand = { version = "0.8.5", default-features = false } hex = { version = "0.4.3", default-features = false } serde-big-array = "0.5.1" +dashmap = "6.1.0" +memmap2 = "0.9.5" +ndarray = { version = "0.16.1", default-features = false } +num_enum = { version = "0.7.4", default-features = false } # default-features = false for no_std for use in guest programs itertools = { version = "0.14.0", default-features = false } @@ -258,3 +262,6 @@ sha2 = { version = "0.10", default-features = false } # p3-poseidon2 = { path = "../Plonky3/poseidon2" } # p3-poseidon2-air = { path = "../Plonky3/poseidon2-air" } # p3-symmetric = { path = "../Plonky3/symmetric" } + +[workspace.metadata.cargo-shear] +ignored = ["cargo-openvm"] diff --git a/benchmarks/execute/Cargo.toml b/benchmarks/execute/Cargo.toml index 319490220a..f943f73dcb 100644 --- a/benchmarks/execute/Cargo.toml +++ b/benchmarks/execute/Cargo.toml @@ -9,41 +9,51 @@ license.workspace = true [dependencies] openvm-benchmarks-utils.workspace = true -cargo-openvm.workspace = true openvm-circuit.workspace = true -openvm-sdk.workspace = true openvm-stark-sdk.workspace = true openvm-transpiler.workspace = true -openvm-rv32im-circuit.workspace = true -openvm-rv32im-transpiler.workspace = true +openvm-algebra-circuit.workspace = true +openvm-algebra-transpiler.workspace = true +openvm-bigint-circuit.workspace = true +openvm-bigint-transpiler.workspace = true +openvm-ecc-circuit.workspace = true +openvm-ecc-transpiler.workspace = true +openvm-pairing-circuit.workspace = true +openvm-pairing-guest.workspace = true +openvm-pairing-transpiler.workspace = true openvm-keccak256-circuit.workspace = true openvm-keccak256-transpiler.workspace = true - -clap = { version = "4.5.9", features = ["derive", "env"] } +openvm-rv32im-circuit.workspace = true +openvm-rv32im-transpiler.workspace = true +openvm-sha2-circuit.workspace = true +openvm-sha2-transpiler.workspace = true eyre.workspace = true -tracing.workspace = true derive_more = { workspace = true, features = ["from"] } - -tracing-subscriber = { version = "0.3.17", features = ["std", "env-filter"] } +rand.workspace = true +serde = { workspace = true, features = ["derive"] } [dev-dependencies] -criterion = { version = "0.5", features = ["html_reports"] } +divan = { package = "codspeed-divan-compat", version = "3.0.2" } [features] default = ["jemalloc"] -profiling = ["openvm-sdk/profiling"] mimalloc = ["openvm-circuit/mimalloc"] jemalloc = ["openvm-circuit/jemalloc"] jemalloc-prof = ["openvm-circuit/jemalloc-prof"] nightly-features = ["openvm-circuit/nightly-features"] +profiling = ["openvm-circuit/function-span", "openvm-transpiler/function-span"] -[[bench]] -name = "fibonacci_execute" -harness = false +# [[bench]] +# name = "fibonacci_execute" +# harness = false + +# [[bench]] +# name = "regex_execute" +# harness = false [[bench]] -name = "regex_execute" +name = "execute" harness = false [package.metadata.cargo-shear] -ignored = ["derive_more"] +ignored = ["derive_more", "rand"] diff --git a/benchmarks/execute/benches/execute.rs b/benchmarks/execute/benches/execute.rs new file mode 100644 index 0000000000..eb91bc54f1 --- /dev/null +++ b/benchmarks/execute/benches/execute.rs @@ -0,0 +1,240 @@ +use std::{path::Path, sync::OnceLock}; + +use divan::Bencher; +use eyre::Result; +use openvm_algebra_circuit::{ + Fp2Extension, Fp2ExtensionExecutor, Fp2ExtensionPeriphery, ModularExtension, + ModularExtensionExecutor, ModularExtensionPeriphery, +}; +use openvm_algebra_transpiler::{Fp2TranspilerExtension, ModularTranspilerExtension}; +use openvm_benchmarks_utils::{get_elf_path, get_programs_dir, read_elf_file}; +use openvm_bigint_circuit::{Int256, Int256Executor, Int256Periphery}; +use openvm_bigint_transpiler::Int256TranspilerExtension; +use openvm_circuit::{ + arch::{ + execution_mode::{ + e1::E1Ctx, + metered::{ctx::DEFAULT_PAGE_BITS, MeteredCtx}, + }, + instructions::exe::VmExe, + interpreter::InterpretedInstance, + InitFileGenerator, SystemConfig, VirtualMachine, VmChipComplex, VmConfig, + }, + derive::VmConfig, +}; +use openvm_ecc_circuit::{ + WeierstrassExtension, WeierstrassExtensionExecutor, WeierstrassExtensionPeriphery, +}; +use openvm_ecc_transpiler::EccTranspilerExtension; +use openvm_keccak256_circuit::{Keccak256, Keccak256Executor, Keccak256Periphery}; +use openvm_keccak256_transpiler::Keccak256TranspilerExtension; +use openvm_pairing_circuit::{ + PairingCurve, PairingExtension, PairingExtensionExecutor, PairingExtensionPeriphery, +}; +use openvm_pairing_guest::bn254::BN254_COMPLEX_STRUCT_NAME; +use openvm_pairing_transpiler::PairingTranspilerExtension; +use openvm_rv32im_circuit::{ + Rv32I, Rv32IExecutor, Rv32IPeriphery, Rv32Io, Rv32IoExecutor, Rv32IoPeriphery, Rv32M, + Rv32MExecutor, Rv32MPeriphery, +}; +use openvm_rv32im_transpiler::{ + Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, +}; +use openvm_sha2_circuit::{Sha2, Sha2Executor, Sha2Periphery}; +use openvm_sha2_transpiler::Sha2TranspilerExtension; +use openvm_stark_sdk::{ + config::baby_bear_poseidon2::{ + default_engine, BabyBearPoseidon2Config, BabyBearPoseidon2Engine, + }, + openvm_stark_backend::{self, p3_field::PrimeField32}, + p3_baby_bear::BabyBear, +}; +use openvm_transpiler::{transpiler::Transpiler, FromElf}; +use serde::{Deserialize, Serialize}; + +static AVAILABLE_PROGRAMS: &[&str] = &[ + "fibonacci_recursive", + "fibonacci_iterative", + "quicksort", + "bubblesort", + "factorial_iterative_u256", + "revm_snailtracer", + "keccak256", + "keccak256_iter", + "sha256", + "sha256_iter", + "revm_transfer", + "pairing", +]; + +static SHARED_INTERACTIONS: OnceLock> = OnceLock::new(); + +#[derive(Clone, Debug, VmConfig, Serialize, Deserialize)] +pub struct ExecuteConfig { + #[system] + pub system: SystemConfig, + #[extension] + pub rv32i: Rv32I, + #[extension] + pub rv32m: Rv32M, + #[extension] + pub io: Rv32Io, + #[extension] + pub bigint: Int256, + #[extension] + pub keccak: Keccak256, + #[extension] + pub sha2: Sha2, + #[extension] + pub modular: ModularExtension, + #[extension] + pub fp2: Fp2Extension, + #[extension] + pub weierstrass: WeierstrassExtension, + #[extension] + pub pairing: PairingExtension, +} + +impl Default for ExecuteConfig { + fn default() -> Self { + let bn_config = PairingCurve::Bn254.curve_config(); + Self { + system: SystemConfig::default().with_continuations(), + rv32i: Rv32I, + rv32m: Rv32M::default(), + io: Rv32Io, + bigint: Int256::default(), + keccak: Keccak256, + sha2: Sha2, + modular: ModularExtension::new(vec![ + bn_config.modulus.clone(), + bn_config.scalar.clone(), + ]), + fp2: Fp2Extension::new(vec![( + BN254_COMPLEX_STRUCT_NAME.to_string(), + bn_config.modulus.clone(), + )]), + weierstrass: WeierstrassExtension::new(vec![bn_config.clone()]), + pairing: PairingExtension::new(vec![PairingCurve::Bn254]), + } + } +} + +impl InitFileGenerator for ExecuteConfig { + fn write_to_init_file( + &self, + _manifest_dir: &Path, + _init_file_name: Option<&str>, + ) -> eyre::Result<()> { + Ok(()) + } +} + +fn main() { + divan::main(); +} + +fn create_default_vm( +) -> VirtualMachine { + let vm_config = ExecuteConfig::default(); + VirtualMachine::new(default_engine(), vm_config) +} + +fn create_default_transpiler() -> Transpiler { + Transpiler::::default() + .with_extension(Rv32ITranspilerExtension) + .with_extension(Rv32IoTranspilerExtension) + .with_extension(Rv32MTranspilerExtension) + .with_extension(Int256TranspilerExtension) + .with_extension(Keccak256TranspilerExtension) + .with_extension(Sha2TranspilerExtension) + .with_extension(ModularTranspilerExtension) + .with_extension(Fp2TranspilerExtension) + .with_extension(EccTranspilerExtension) + .with_extension(PairingTranspilerExtension) +} + +fn load_program_executable(program: &str) -> Result> { + let transpiler = create_default_transpiler(); + let program_dir = get_programs_dir().join(program); + let elf_path = get_elf_path(&program_dir); + let elf = read_elf_file(&elf_path)?; + Ok(VmExe::from_elf(elf, transpiler)?) +} + +fn shared_interactions() -> &'static Vec { + SHARED_INTERACTIONS.get_or_init(|| { + let vm = create_default_vm(); + let pk = vm.keygen(); + let vk = pk.get_vk(); + vk.num_interactions() + }) +} + +#[divan::bench(args = AVAILABLE_PROGRAMS, sample_count=10)] +fn benchmark_execute(bencher: Bencher, program: &str) { + bencher + .with_inputs(|| { + let vm_config = ExecuteConfig::default(); + let exe = load_program_executable(program).expect("Failed to load program executable"); + let interpreter = InterpretedInstance::new(vm_config, exe); + (interpreter, vec![]) + }) + .bench_values(|(interpreter, input)| { + interpreter + .execute(E1Ctx::new(None), input) + .expect("Failed to execute program in interpreted mode"); + }); +} + +#[divan::bench(args = AVAILABLE_PROGRAMS, sample_count=5)] +fn benchmark_execute_metered(bencher: Bencher, program: &str) { + bencher + .with_inputs(|| { + let vm_config = ExecuteConfig::default(); + let exe = load_program_executable(program).expect("Failed to load program executable"); + + let chip_complex: VmChipComplex = + vm_config.create_chip_complex().unwrap(); + let interactions = shared_interactions(); + let segmentation_strategy = + &>::system(&vm_config).segmentation_strategy; + + let ctx: MeteredCtx = + MeteredCtx::new(&chip_complex, interactions.to_vec()) + .with_max_trace_height(segmentation_strategy.max_trace_height() as u32) + .with_max_cells(segmentation_strategy.max_cells()); + let interpreter = InterpretedInstance::new(vm_config, exe); + + (interpreter, vec![], ctx) + }) + .bench_values(|(interpreter, input, ctx)| { + interpreter + .execute_e2(ctx, input) + .expect("Failed to execute program"); + }); +} + +// #[divan::bench(args = AVAILABLE_PROGRAMS, sample_count=3)] +// fn benchmark_execute_e3(bencher: Bencher, program: &str) { +// bencher +// .with_inputs(|| { +// let vm = create_default_vm(); +// let exe = load_program_executable(program).expect("Failed to load program +// executable"); let state = create_initial_state(&vm.config().system.memory_config, +// &exe, vec![], 0); + +// let (widths, interactions) = shared_widths_and_interactions(); +// let segments = vm +// .executor +// .execute_metered(exe.clone(), vec![], interactions) +// .expect("Failed to execute program"); + +// (vm.executor, exe, state, segments) +// }) +// .bench_values(|(executor, exe, state, segments)| { +// executor +// .execute_from_state(exe, state, &segments) +// .expect("Failed to execute program"); +// }); +// } diff --git a/benchmarks/execute/benches/fibonacci_execute.rs b/benchmarks/execute/benches/fibonacci_execute.rs index 70952b53c9..49b453d028 100644 --- a/benchmarks/execute/benches/fibonacci_execute.rs +++ b/benchmarks/execute/benches/fibonacci_execute.rs @@ -1,42 +1,44 @@ -use criterion::{criterion_group, criterion_main, Criterion}; -use openvm_benchmarks_utils::{build_elf, get_programs_dir}; -use openvm_circuit::arch::{instructions::exe::VmExe, VmExecutor}; -use openvm_rv32im_circuit::Rv32ImConfig; -use openvm_rv32im_transpiler::{ - Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, -}; -use openvm_sdk::StdIn; -use openvm_stark_sdk::p3_baby_bear::BabyBear; -use openvm_transpiler::{transpiler::Transpiler, FromElf}; +// use criterion::{criterion_group, criterion_main, Criterion}; +// use openvm_benchmarks_utils::{build_elf, get_programs_dir}; +// use openvm_circuit::arch::{instructions::exe::VmExe, VmExecutor}; +// use openvm_rv32im_circuit::Rv32ImConfig; +// use openvm_rv32im_transpiler::{ +// Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, +// }; +// // TODO(ayush): add this back +// // use openvm_sdk::StdIn; +// use openvm_stark_sdk::p3_baby_bear::BabyBear; +// use openvm_transpiler::{transpiler::Transpiler, FromElf}; -fn benchmark_function(c: &mut Criterion) { - let program_dir = get_programs_dir().join("fibonacci"); - let elf = build_elf(&program_dir, "release").unwrap(); +// fn benchmark_function(c: &mut Criterion) { +// let program_dir = get_programs_dir().join("fibonacci"); +// let elf = build_elf(&program_dir, "release").unwrap(); - let exe = VmExe::from_elf( - elf, - Transpiler::::default() - .with_extension(Rv32ITranspilerExtension) - .with_extension(Rv32MTranspilerExtension) - .with_extension(Rv32IoTranspilerExtension), - ) - .unwrap(); +// let exe = VmExe::from_elf( +// elf, +// Transpiler::::default() +// .with_extension(Rv32ITranspilerExtension) +// .with_extension(Rv32MTranspilerExtension) +// .with_extension(Rv32IoTranspilerExtension), +// ) +// .unwrap(); - let mut group = c.benchmark_group("fibonacci"); - let config = Rv32ImConfig::default(); - let executor = VmExecutor::::new(config); +// let mut group = c.benchmark_group("fibonacci"); +// let config = Rv32ImConfig::default(); +// let executor = VmExecutor::::new(config); - group.bench_function("execute", |b| { - b.iter(|| { - let n = 100_000u64; - let mut stdin = StdIn::default(); - stdin.write(&n); - executor.execute(exe.clone(), stdin).unwrap(); - }) - }); +// group.bench_function("execute", |b| { +// b.iter(|| { +// // TODO(ayush): add this back +// // let n = 100_000u64; +// // let mut stdin = StdIn::default(); +// // stdin.write(&n); +// executor.execute(exe.clone(), vec![]).unwrap(); +// }) +// }); - group.finish(); -} +// group.finish(); +// } -criterion_group!(benches, benchmark_function); -criterion_main!(benches); +// criterion_group!(benches, benchmark_function); +// criterion_main!(benches); diff --git a/benchmarks/execute/benches/regex_execute.rs b/benchmarks/execute/benches/regex_execute.rs index a3a110e344..d4116b5aab 100644 --- a/benchmarks/execute/benches/regex_execute.rs +++ b/benchmarks/execute/benches/regex_execute.rs @@ -1,47 +1,47 @@ -use criterion::{black_box, criterion_group, criterion_main, Criterion}; -use openvm_benchmarks_utils::{build_elf, get_programs_dir}; -use openvm_circuit::arch::{instructions::exe::VmExe, VmExecutor}; -use openvm_keccak256_circuit::Keccak256Rv32Config; -use openvm_keccak256_transpiler::Keccak256TranspilerExtension; -use openvm_rv32im_transpiler::{ - Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, -}; -use openvm_sdk::StdIn; -use openvm_stark_sdk::p3_baby_bear::BabyBear; -use openvm_transpiler::{transpiler::Transpiler, FromElf}; +// TODO(ayush): add this back +// use criterion::{black_box, criterion_group, criterion_main, Criterion}; +// use openvm_benchmarks_utils::{build_elf, get_programs_dir}; +// use openvm_circuit::arch::{instructions::exe::VmExe, VmExecutor}; +// use openvm_keccak256_circuit::Keccak256Rv32Config; +// use openvm_keccak256_transpiler::Keccak256TranspilerExtension; +// use openvm_rv32im_transpiler::{ +// Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, +// }; +// use openvm_sdk::StdIn; +// use openvm_stark_sdk::p3_baby_bear::BabyBear; +// use openvm_transpiler::{transpiler::Transpiler, FromElf}; -fn benchmark_function(c: &mut Criterion) { - let program_dir = get_programs_dir().join("regex"); - let elf = build_elf(&program_dir, "release").unwrap(); +// fn benchmark_function(c: &mut Criterion) { +// let program_dir = get_programs_dir().join("regex"); +// let elf = build_elf(&program_dir, "release").unwrap(); - let exe = VmExe::from_elf( - elf, - Transpiler::::default() - .with_extension(Rv32ITranspilerExtension) - .with_extension(Rv32MTranspilerExtension) - .with_extension(Rv32IoTranspilerExtension) - .with_extension(Keccak256TranspilerExtension), - ) - .unwrap(); +// let exe = VmExe::from_elf( +// elf, +// Transpiler::::default() +// .with_extension(Rv32ITranspilerExtension) +// .with_extension(Rv32MTranspilerExtension) +// .with_extension(Rv32IoTranspilerExtension) +// .with_extension(Keccak256TranspilerExtension), +// ) +// .unwrap(); - let mut group = c.benchmark_group("regex"); - group.sample_size(10); - let config = Keccak256Rv32Config::default(); - let executor = VmExecutor::::new(config); +// let mut group = c.benchmark_group("regex"); +// group.sample_size(10); +// let config = Keccak256Rv32Config::default(); +// let executor = VmExecutor::::new(config); - let data = include_str!("../../guest/regex/regex_email.txt"); +// let data = include_str!("../../guest/regex/regex_email.txt"); - let fe_bytes = data.to_owned().into_bytes(); - group.bench_function("execute", |b| { - b.iter(|| { - executor - .execute(exe.clone(), black_box(StdIn::from_bytes(&fe_bytes))) - .unwrap(); - }) - }); +// let fe_bytes = data.to_owned().into_bytes(); +// group.bench_function("execute", |b| { +// b.iter(|| { +// let input = black_box(Stdin::from_bytes(&fe_bytes)); +// executor.execute(exe.clone(), input).unwrap(); +// }) +// }); - group.finish(); -} +// group.finish(); +// } -criterion_group!(benches, benchmark_function); -criterion_main!(benches); +// criterion_group!(benches, benchmark_function); +// criterion_main!(benches); diff --git a/benchmarks/execute/examples/regex_execute.rs b/benchmarks/execute/examples/regex_execute.rs index 59705a19fd..3a6fd4162f 100644 --- a/benchmarks/execute/examples/regex_execute.rs +++ b/benchmarks/execute/examples/regex_execute.rs @@ -1,35 +1,35 @@ -use openvm_circuit::arch::{instructions::exe::VmExe, VmExecutor}; -use openvm_keccak256_circuit::Keccak256Rv32Config; -use openvm_keccak256_transpiler::Keccak256TranspilerExtension; -use openvm_rv32im_transpiler::{ - Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, -}; -use openvm_sdk::StdIn; -use openvm_stark_sdk::p3_baby_bear::BabyBear; -use openvm_transpiler::{ - elf::Elf, openvm_platform::memory::MEM_SIZE, transpiler::Transpiler, FromElf, -}; +// use openvm_circuit::arch::{instructions::exe::VmExe, VmExecutor}; +// use openvm_keccak256_circuit::Keccak256Rv32Config; +// use openvm_keccak256_transpiler::Keccak256TranspilerExtension; +// use openvm_rv32im_transpiler::{ +// Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, +// }; +// use openvm_sdk::StdIn; +// use openvm_stark_sdk::p3_baby_bear::BabyBear; +// use openvm_transpiler::{ +// elf::Elf, openvm_platform::memory::MEM_SIZE, transpiler::Transpiler, FromElf, +// }; fn main() { - let elf = Elf::decode(include_bytes!("regex-elf"), MEM_SIZE as u32).unwrap(); - let exe = VmExe::from_elf( - elf, - Transpiler::::default() - .with_extension(Rv32ITranspilerExtension) - .with_extension(Rv32MTranspilerExtension) - .with_extension(Rv32IoTranspilerExtension) - .with_extension(Keccak256TranspilerExtension), - ) - .unwrap(); + // let elf = Elf::decode(include_bytes!("regex-elf"), MEM_SIZE as u32).unwrap(); + // let exe = VmExe::from_elf( + // elf, + // Transpiler::::default() + // .with_extension(Rv32ITranspilerExtension) + // .with_extension(Rv32MTranspilerExtension) + // .with_extension(Rv32IoTranspilerExtension) + // .with_extension(Keccak256TranspilerExtension), + // ) + // .unwrap(); - let config = Keccak256Rv32Config::default(); - let executor = VmExecutor::::new(config); + // let config = Keccak256Rv32Config::default(); + // let executor = VmExecutor::::new(config); - let data = include_str!("../../guest/regex/regex_email.txt"); + // let data = include_str!("../../guest/regex/regex_email.txt"); - let timer = std::time::Instant::now(); - executor - .execute(exe.clone(), StdIn::from_bytes(data.as_bytes())) - .unwrap(); - println!("execute_time: {:?}", timer.elapsed()); + // let timer = std::time::Instant::now(); + // executor + // .execute(exe.clone(), StdIn::from_bytes(data.as_bytes())) + // .unwrap(); + // println!("execute_time: {:?}", timer.elapsed()); } diff --git a/benchmarks/execute/src/main.rs b/benchmarks/execute/src/main.rs deleted file mode 100644 index a05baeea44..0000000000 --- a/benchmarks/execute/src/main.rs +++ /dev/null @@ -1,121 +0,0 @@ -use cargo_openvm::util::read_config_toml_or_default; -use clap::{Parser, ValueEnum}; -use eyre::Result; -use openvm_benchmarks_utils::{get_elf_path, get_programs_dir, read_elf_file}; -use openvm_circuit::arch::{instructions::exe::VmExe, VmExecutor}; -use openvm_sdk::StdIn; -use openvm_stark_sdk::bench::run_with_metric_collection; -use openvm_transpiler::FromElf; - -#[derive(Debug, Clone, ValueEnum)] -enum BuildProfile { - Debug, - Release, -} - -static AVAILABLE_PROGRAMS: &[&str] = &[ - "fibonacci_recursive", - "fibonacci_iterative", - "quicksort", - "bubblesort", - "pairing", - "keccak256", - "keccak256_iter", - "sha256", - "sha256_iter", - "revm_transfer", - "revm_snailtracer", -]; - -#[derive(Parser)] -#[command(author, version, about = "OpenVM Benchmark CLI", long_about = None)] -struct Cli { - /// Programs to benchmark (if not specified, all programs will be run) - #[arg(short, long)] - programs: Vec, - - /// Programs to skip from benchmarking - #[arg(short, long)] - skip: Vec, - - /// Output path for benchmark results - #[arg(short, long, default_value = "OUTPUT_PATH")] - output: String, - - /// List available benchmark programs and exit - #[arg(short, long)] - list: bool, - - /// Verbose output - #[arg(short, long)] - verbose: bool, -} - -fn main() -> Result<()> { - let cli = Cli::parse(); - - if cli.list { - println!("Available benchmark programs:"); - for program in AVAILABLE_PROGRAMS { - println!(" {}", program); - } - return Ok(()); - } - - // Set up logging based on verbosity - if cli.verbose { - tracing_subscriber::fmt::init(); - } - - let mut programs_to_run = if cli.programs.is_empty() { - AVAILABLE_PROGRAMS.to_vec() - } else { - // Validate provided programs - for program in &cli.programs { - if !AVAILABLE_PROGRAMS.contains(&program.as_str()) { - eprintln!("Unknown program: {}", program); - eprintln!("Use --list to see available programs"); - std::process::exit(1); - } - } - cli.programs.iter().map(|s| s.as_str()).collect() - }; - - // Remove programs that should be skipped - if !cli.skip.is_empty() { - // Validate skipped programs - for program in &cli.skip { - if !AVAILABLE_PROGRAMS.contains(&program.as_str()) { - eprintln!("Unknown program to skip: {}", program); - eprintln!("Use --list to see available programs"); - std::process::exit(1); - } - } - - let skip_set: Vec<&str> = cli.skip.iter().map(|s| s.as_str()).collect(); - programs_to_run.retain(|&program| !skip_set.contains(&program)); - } - - tracing::info!("Starting benchmarks with metric collection"); - - run_with_metric_collection(&cli.output, || -> Result<()> { - for program in &programs_to_run { - tracing::info!("Running program: {}", program); - - let program_dir = get_programs_dir().join(program); - let elf_path = get_elf_path(&program_dir); - let elf = read_elf_file(&elf_path)?; - - let config_path = program_dir.join("openvm.toml"); - let vm_config = read_config_toml_or_default(&config_path)?.app_vm_config; - - let exe = VmExe::from_elf(elf, vm_config.transpiler())?; - - let executor = VmExecutor::new(vm_config); - executor.execute(exe, StdIn::default())?; - tracing::info!("Completed program: {}", program); - } - tracing::info!("All programs executed successfully"); - Ok(()) - }) -} diff --git a/benchmarks/guest/base64_json/elf/openvm-json-program.elf b/benchmarks/guest/base64_json/elf/openvm-json-program.elf index 29e6cac131..55335dca15 100755 Binary files a/benchmarks/guest/base64_json/elf/openvm-json-program.elf and b/benchmarks/guest/base64_json/elf/openvm-json-program.elf differ diff --git a/benchmarks/guest/bincode/elf/openvm-bincode-program.elf b/benchmarks/guest/bincode/elf/openvm-bincode-program.elf index 085eb7ee4f..2d4b2ae67a 100755 Binary files a/benchmarks/guest/bincode/elf/openvm-bincode-program.elf and b/benchmarks/guest/bincode/elf/openvm-bincode-program.elf differ diff --git a/benchmarks/guest/bubblesort/elf/openvm-bubblesort-program.elf b/benchmarks/guest/bubblesort/elf/openvm-bubblesort-program.elf index 0f81a3926f..cec789e279 100755 Binary files a/benchmarks/guest/bubblesort/elf/openvm-bubblesort-program.elf and b/benchmarks/guest/bubblesort/elf/openvm-bubblesort-program.elf differ diff --git a/benchmarks/guest/bubblesort/src/main.rs b/benchmarks/guest/bubblesort/src/main.rs index 0dd7e51146..d859641504 100644 --- a/benchmarks/guest/bubblesort/src/main.rs +++ b/benchmarks/guest/bubblesort/src/main.rs @@ -1,7 +1,7 @@ use core::hint::black_box; use openvm as _; -const ARRAY_SIZE: usize = 100; +const ARRAY_SIZE: usize = 1_000; fn bubblesort(arr: &mut [T]) { let len = arr.len(); diff --git a/benchmarks/guest/ecrecover/elf/openvm-ecdsa-recover-key-program.elf b/benchmarks/guest/ecrecover/elf/openvm-ecdsa-recover-key-program.elf index 4e54268ea4..88c87c6abc 100755 Binary files a/benchmarks/guest/ecrecover/elf/openvm-ecdsa-recover-key-program.elf and b/benchmarks/guest/ecrecover/elf/openvm-ecdsa-recover-key-program.elf differ diff --git a/benchmarks/guest/factorial_iterative_u256/Cargo.toml b/benchmarks/guest/factorial_iterative_u256/Cargo.toml new file mode 100644 index 0000000000..260d22985b --- /dev/null +++ b/benchmarks/guest/factorial_iterative_u256/Cargo.toml @@ -0,0 +1,17 @@ +[workspace] +[package] +name = "openvm-factorial-iterative-u256-program" +version = "0.0.0" +edition = "2021" + +[dependencies] +openvm = { path = "../../../crates/toolchain/openvm", features = ["std"] } +openvm-ruint = { path = "../../../guest-libs/ruint/", package = "ruint", default-features = false } + +[features] +default = [] + +[profile.profiling] +inherits = "release" +debug = 2 +strip = false diff --git a/benchmarks/guest/factorial_iterative_u256/elf/openvm-factorial-iterative-u256-program.elf b/benchmarks/guest/factorial_iterative_u256/elf/openvm-factorial-iterative-u256-program.elf new file mode 100755 index 0000000000..572f71b182 Binary files /dev/null and b/benchmarks/guest/factorial_iterative_u256/elf/openvm-factorial-iterative-u256-program.elf differ diff --git a/examples/sha256/openvm.toml b/benchmarks/guest/factorial_iterative_u256/openvm.toml similarity index 73% rename from examples/sha256/openvm.toml rename to benchmarks/guest/factorial_iterative_u256/openvm.toml index 656bf52414..b226887890 100644 --- a/examples/sha256/openvm.toml +++ b/benchmarks/guest/factorial_iterative_u256/openvm.toml @@ -1,4 +1,4 @@ [app_vm_config.rv32i] [app_vm_config.rv32m] [app_vm_config.io] -[app_vm_config.sha256] +[app_vm_config.bigint] diff --git a/benchmarks/guest/factorial_iterative_u256/src/main.rs b/benchmarks/guest/factorial_iterative_u256/src/main.rs new file mode 100644 index 0000000000..c92491d2da --- /dev/null +++ b/benchmarks/guest/factorial_iterative_u256/src/main.rs @@ -0,0 +1,16 @@ +use core::hint::black_box; +use openvm as _; +use openvm_ruint::aliases::U256; + +// This will overflow but that is fine +const N: u32 = 65_000; + +pub fn main() { + let mut acc = U256::from(1u32); + let mut i = U256::from(N); + while i > black_box(U256::ZERO) { + acc *= i.clone(); + i -= U256::from(1u32); + } + black_box(acc); +} diff --git a/benchmarks/guest/fibonacci/elf/openvm-fibonacci-program.elf b/benchmarks/guest/fibonacci/elf/openvm-fibonacci-program.elf index 36ad8d359c..20335618e4 100755 Binary files a/benchmarks/guest/fibonacci/elf/openvm-fibonacci-program.elf and b/benchmarks/guest/fibonacci/elf/openvm-fibonacci-program.elf differ diff --git a/benchmarks/guest/fibonacci_iterative/elf/openvm-fibonacci-iterative-program.elf b/benchmarks/guest/fibonacci_iterative/elf/openvm-fibonacci-iterative-program.elf index ac9fbf3e89..7c681ee313 100755 Binary files a/benchmarks/guest/fibonacci_iterative/elf/openvm-fibonacci-iterative-program.elf and b/benchmarks/guest/fibonacci_iterative/elf/openvm-fibonacci-iterative-program.elf differ diff --git a/benchmarks/guest/fibonacci_iterative/src/main.rs b/benchmarks/guest/fibonacci_iterative/src/main.rs index 09ceb5df41..f7ab8ec0f6 100644 --- a/benchmarks/guest/fibonacci_iterative/src/main.rs +++ b/benchmarks/guest/fibonacci_iterative/src/main.rs @@ -1,15 +1,15 @@ use core::hint::black_box; -use openvm as _; +use openvm::io::reveal_u32; -const N: u64 = 100_000; +const N: u32 = 900_000; pub fn main() { - let mut a: u64 = 0; - let mut b: u64 = 1; + let mut a: u32 = 0; + let mut b: u32 = 1; for _ in 0..black_box(N) { - let c: u64 = a.wrapping_add(b); + let c: u32 = a.wrapping_add(b); a = b; b = c; } - black_box(a); + reveal_u32(a, 0); } diff --git a/benchmarks/guest/fibonacci_recursive/elf/openvm-fibonacci-recursive-program.elf b/benchmarks/guest/fibonacci_recursive/elf/openvm-fibonacci-recursive-program.elf index 7dee9d4286..d14372657c 100755 Binary files a/benchmarks/guest/fibonacci_recursive/elf/openvm-fibonacci-recursive-program.elf and b/benchmarks/guest/fibonacci_recursive/elf/openvm-fibonacci-recursive-program.elf differ diff --git a/benchmarks/guest/fibonacci_recursive/src/main.rs b/benchmarks/guest/fibonacci_recursive/src/main.rs index fae64a1b0f..9020bc91ef 100644 --- a/benchmarks/guest/fibonacci_recursive/src/main.rs +++ b/benchmarks/guest/fibonacci_recursive/src/main.rs @@ -1,14 +1,15 @@ use core::hint::black_box; -use openvm as _; +use openvm::io::reveal_u32; -const N: u64 = 25; +const N: u32 = 27; pub fn main() { let n = black_box(N); - black_box(fibonacci(n)); + let result = fibonacci(n); + reveal_u32(result, 0); } -fn fibonacci(n: u64) -> u64 { +fn fibonacci(n: u32) -> u32 { if n == 0 { 0 } else if n == 1 { diff --git a/benchmarks/guest/keccak256/Cargo.toml b/benchmarks/guest/keccak256/Cargo.toml index 35bc10320a..5ff9a6337e 100644 --- a/benchmarks/guest/keccak256/Cargo.toml +++ b/benchmarks/guest/keccak256/Cargo.toml @@ -6,7 +6,7 @@ edition = "2021" [dependencies] openvm = { path = "../../../crates/toolchain/openvm", features = ["std"] } -openvm-keccak256 = { path = "../../../guest-libs/keccak256" } +openvm-keccak256 = { path = "../../../guest-libs/keccak256/", default-features = false } [features] default = [] diff --git a/benchmarks/guest/keccak256/elf/openvm-keccak256-program.elf b/benchmarks/guest/keccak256/elf/openvm-keccak256-program.elf index 7425897f99..6e0fc26837 100755 Binary files a/benchmarks/guest/keccak256/elf/openvm-keccak256-program.elf and b/benchmarks/guest/keccak256/elf/openvm-keccak256-program.elf differ diff --git a/benchmarks/guest/keccak256/src/main.rs b/benchmarks/guest/keccak256/src/main.rs index 5a00ba4067..0d8c6d17b4 100644 --- a/benchmarks/guest/keccak256/src/main.rs +++ b/benchmarks/guest/keccak256/src/main.rs @@ -3,7 +3,7 @@ use openvm as _; use openvm_keccak256::keccak256; -const INPUT_LENGTH_BYTES: usize = 100 * 1024; // 100 KB +const INPUT_LENGTH_BYTES: usize = 384 * 1024; pub fn main() { let mut input = Vec::with_capacity(INPUT_LENGTH_BYTES); diff --git a/benchmarks/guest/keccak256_iter/Cargo.toml b/benchmarks/guest/keccak256_iter/Cargo.toml index 68c2cbb5dd..56884f4869 100644 --- a/benchmarks/guest/keccak256_iter/Cargo.toml +++ b/benchmarks/guest/keccak256_iter/Cargo.toml @@ -6,7 +6,7 @@ edition = "2021" [dependencies] openvm = { path = "../../../crates/toolchain/openvm", features = ["std"] } -openvm-keccak256 = { path = "../../../guest-libs/keccak256" } +openvm-keccak256 = { path = "../../../guest-libs/keccak256/", default-features = false } [features] default = [] diff --git a/benchmarks/guest/keccak256_iter/elf/openvm-keccak256-iter-program.elf b/benchmarks/guest/keccak256_iter/elf/openvm-keccak256-iter-program.elf index 0cf372eec3..7a267a02ab 100755 Binary files a/benchmarks/guest/keccak256_iter/elf/openvm-keccak256-iter-program.elf and b/benchmarks/guest/keccak256_iter/elf/openvm-keccak256-iter-program.elf differ diff --git a/benchmarks/guest/keccak256_iter/src/main.rs b/benchmarks/guest/keccak256_iter/src/main.rs index ef36ff1d64..554179819a 100644 --- a/benchmarks/guest/keccak256_iter/src/main.rs +++ b/benchmarks/guest/keccak256_iter/src/main.rs @@ -3,7 +3,7 @@ use openvm as _; use openvm_keccak256::keccak256; -const ITERATIONS: usize = 10_000; +const ITERATIONS: usize = 65_000; pub fn main() { // Initialize with hash of an empty vector diff --git a/benchmarks/guest/kitchen-sink/elf/openvm-kitchen-sink-program.elf b/benchmarks/guest/kitchen-sink/elf/openvm-kitchen-sink-program.elf index 85f3509fa5..fb59df5d0a 100755 Binary files a/benchmarks/guest/kitchen-sink/elf/openvm-kitchen-sink-program.elf and b/benchmarks/guest/kitchen-sink/elf/openvm-kitchen-sink-program.elf differ diff --git a/benchmarks/guest/kitchen-sink/openvm.toml b/benchmarks/guest/kitchen-sink/openvm.toml index 2d1b307eef..e6cafcf57f 100644 --- a/benchmarks/guest/kitchen-sink/openvm.toml +++ b/benchmarks/guest/kitchen-sink/openvm.toml @@ -2,7 +2,7 @@ [app_vm_config.rv32m] [app_vm_config.io] [app_vm_config.keccak] -[app_vm_config.sha256] +[app_vm_config.sha2] [app_vm_config.bigint] [app_vm_config.modular] diff --git a/benchmarks/guest/pairing/elf/openvm-pairing-program.elf b/benchmarks/guest/pairing/elf/openvm-pairing-program.elf index bf30d5a003..69c3cd0106 100755 Binary files a/benchmarks/guest/pairing/elf/openvm-pairing-program.elf and b/benchmarks/guest/pairing/elf/openvm-pairing-program.elf differ diff --git a/benchmarks/guest/quicksort/elf/openvm-quicksort-program.elf b/benchmarks/guest/quicksort/elf/openvm-quicksort-program.elf index 54af6272d6..0e7d6e6143 100755 Binary files a/benchmarks/guest/quicksort/elf/openvm-quicksort-program.elf and b/benchmarks/guest/quicksort/elf/openvm-quicksort-program.elf differ diff --git a/benchmarks/guest/quicksort/src/main.rs b/benchmarks/guest/quicksort/src/main.rs index 30218cf40e..a6579306c7 100644 --- a/benchmarks/guest/quicksort/src/main.rs +++ b/benchmarks/guest/quicksort/src/main.rs @@ -1,7 +1,7 @@ use core::hint::black_box; use openvm as _; -const ARRAY_SIZE: usize = 1_000; +const ARRAY_SIZE: usize = 3_500; fn quicksort(arr: &mut [T]) { if arr.len() <= 1 { diff --git a/benchmarks/guest/regex/elf/openvm-regex-program.elf b/benchmarks/guest/regex/elf/openvm-regex-program.elf index 6e6074e079..05388a8223 100755 Binary files a/benchmarks/guest/regex/elf/openvm-regex-program.elf and b/benchmarks/guest/regex/elf/openvm-regex-program.elf differ diff --git a/benchmarks/guest/revm_snailtracer/elf/openvm-revm-snailtracer.elf b/benchmarks/guest/revm_snailtracer/elf/openvm-revm-snailtracer.elf index 9255290412..26e1d4c515 100755 Binary files a/benchmarks/guest/revm_snailtracer/elf/openvm-revm-snailtracer.elf and b/benchmarks/guest/revm_snailtracer/elf/openvm-revm-snailtracer.elf differ diff --git a/benchmarks/guest/revm_transfer/elf/openvm-revm-transfer.elf b/benchmarks/guest/revm_transfer/elf/openvm-revm-transfer.elf index 0aa22396e6..96f7d328e9 100755 Binary files a/benchmarks/guest/revm_transfer/elf/openvm-revm-transfer.elf and b/benchmarks/guest/revm_transfer/elf/openvm-revm-transfer.elf differ diff --git a/benchmarks/guest/rkyv/elf/openvm-rkyv-program.elf b/benchmarks/guest/rkyv/elf/openvm-rkyv-program.elf index 528106e233..f2b7f8d95d 100755 Binary files a/benchmarks/guest/rkyv/elf/openvm-rkyv-program.elf and b/benchmarks/guest/rkyv/elf/openvm-rkyv-program.elf differ diff --git a/benchmarks/guest/sha256/Cargo.toml b/benchmarks/guest/sha256/Cargo.toml index 1d5491f35a..064923f3c8 100644 --- a/benchmarks/guest/sha256/Cargo.toml +++ b/benchmarks/guest/sha256/Cargo.toml @@ -6,7 +6,7 @@ edition = "2021" [dependencies] openvm = { path = "../../../crates/toolchain/openvm", features = ["std"] } -openvm-sha2 = { path = "../../../guest-libs/sha2" } +openvm-sha2 = { path = "../../../guest-libs/sha2/", default-features = false } [features] default = [] diff --git a/benchmarks/guest/sha256/elf/openvm-sha256-program.elf b/benchmarks/guest/sha256/elf/openvm-sha256-program.elf index 9524e8f552..2c03e2dad6 100755 Binary files a/benchmarks/guest/sha256/elf/openvm-sha256-program.elf and b/benchmarks/guest/sha256/elf/openvm-sha256-program.elf differ diff --git a/benchmarks/guest/sha256/openvm.toml b/benchmarks/guest/sha256/openvm.toml index 656bf52414..35f92b7195 100644 --- a/benchmarks/guest/sha256/openvm.toml +++ b/benchmarks/guest/sha256/openvm.toml @@ -1,4 +1,4 @@ [app_vm_config.rv32i] [app_vm_config.rv32m] [app_vm_config.io] -[app_vm_config.sha256] +[app_vm_config.sha2] diff --git a/benchmarks/guest/sha256/src/main.rs b/benchmarks/guest/sha256/src/main.rs index 0178771d09..fc0b3fab78 100644 --- a/benchmarks/guest/sha256/src/main.rs +++ b/benchmarks/guest/sha256/src/main.rs @@ -3,7 +3,7 @@ use openvm as _; use openvm_sha2::sha256; -const INPUT_LENGTH_BYTES: usize = 100 * 1024; // 100 KB +const INPUT_LENGTH_BYTES: usize = 384 * 1024; pub fn main() { let mut input = Vec::with_capacity(INPUT_LENGTH_BYTES); diff --git a/benchmarks/guest/sha256_iter/Cargo.toml b/benchmarks/guest/sha256_iter/Cargo.toml index 8e0273858a..ae10694c67 100644 --- a/benchmarks/guest/sha256_iter/Cargo.toml +++ b/benchmarks/guest/sha256_iter/Cargo.toml @@ -6,7 +6,7 @@ edition = "2021" [dependencies] openvm = { path = "../../../crates/toolchain/openvm", features = ["std"] } -openvm-sha2 = { path = "../../../guest-libs/sha2" } +openvm-sha2 = { path = "../../../guest-libs/sha2/", default-features = false } [features] default = [] diff --git a/benchmarks/guest/sha256_iter/elf/openvm-sha256-iter-program.elf b/benchmarks/guest/sha256_iter/elf/openvm-sha256-iter-program.elf index 95b469ece5..677d9a3b7a 100755 Binary files a/benchmarks/guest/sha256_iter/elf/openvm-sha256-iter-program.elf and b/benchmarks/guest/sha256_iter/elf/openvm-sha256-iter-program.elf differ diff --git a/benchmarks/guest/sha256_iter/openvm.toml b/benchmarks/guest/sha256_iter/openvm.toml index 656bf52414..35f92b7195 100644 --- a/benchmarks/guest/sha256_iter/openvm.toml +++ b/benchmarks/guest/sha256_iter/openvm.toml @@ -1,4 +1,4 @@ [app_vm_config.rv32i] [app_vm_config.rv32m] [app_vm_config.io] -[app_vm_config.sha256] +[app_vm_config.sha2] diff --git a/benchmarks/guest/sha256_iter/src/main.rs b/benchmarks/guest/sha256_iter/src/main.rs index 0b495a58a8..6ffafae9a3 100644 --- a/benchmarks/guest/sha256_iter/src/main.rs +++ b/benchmarks/guest/sha256_iter/src/main.rs @@ -3,7 +3,7 @@ use openvm as _; use openvm_sha2::sha256; -const ITERATIONS: usize = 20_000; +const ITERATIONS: usize = 150_000; pub fn main() { // Initialize with hash of an empty vector diff --git a/benchmarks/prove/Cargo.toml b/benchmarks/prove/Cargo.toml index 9e745d3d80..eabfcefcc9 100644 --- a/benchmarks/prove/Cargo.toml +++ b/benchmarks/prove/Cargo.toml @@ -10,6 +10,7 @@ license.workspace = true [dependencies] openvm-benchmarks-utils.workspace = true openvm-circuit.workspace = true +openvm-continuations.workspace = true openvm-sdk.workspace = true openvm-stark-backend.workspace = true openvm-stark-sdk.workspace = true @@ -37,6 +38,7 @@ tiny-keccak.workspace = true derive-new.workspace = true derive_more = { workspace = true, features = ["from"] } num-bigint = { workspace = true, features = ["std", "serde"] } +rand.workspace = true serde.workspace = true tracing.workspace = true @@ -55,7 +57,7 @@ jemalloc-prof = ["openvm-sdk/jemalloc-prof"] nightly-features = ["openvm-sdk/nightly-features"] [package.metadata.cargo-shear] -ignored = ["derive_more"] +ignored = ["derive_more", "rand"] [[bin]] name = "fib_e2e" diff --git a/benchmarks/prove/src/bin/kitchen_sink.rs b/benchmarks/prove/src/bin/kitchen_sink.rs index 3102c9e3fe..0b928fad17 100644 --- a/benchmarks/prove/src/bin/kitchen_sink.rs +++ b/benchmarks/prove/src/bin/kitchen_sink.rs @@ -5,22 +5,84 @@ use eyre::Result; use num_bigint::BigUint; use openvm_algebra_circuit::{Fp2Extension, ModularExtension}; use openvm_benchmarks_prove::util::BenchmarkCli; -use openvm_circuit::arch::{instructions::exe::VmExe, SystemConfig}; +use openvm_circuit::arch::{instructions::exe::VmExe, SingleSegmentVmExecutor, SystemConfig}; +use openvm_continuations::verifier::leaf::types::LeafVmVerifierInput; use openvm_ecc_circuit::{WeierstrassExtension, P256_CONFIG, SECP256K1_CONFIG}; +use openvm_native_circuit::{NativeConfig, NATIVE_MAX_TRACE_HEIGHTS}; use openvm_native_recursion::halo2::utils::{CacheHalo2ParamsReader, DEFAULT_PARAMS_DIR}; use openvm_pairing_circuit::{PairingCurve, PairingExtension}; use openvm_pairing_guest::{ bls12_381::BLS12_381_COMPLEX_STRUCT_NAME, bn254::BN254_COMPLEX_STRUCT_NAME, }; use openvm_sdk::{ - commit::commit_app_exe, config::SdkVmConfig, prover::EvmHalo2Prover, - DefaultStaticVerifierPvHandler, Sdk, StdIn, + commit::commit_app_exe, + config::SdkVmConfig, + keygen::AppProvingKey, + prover::{vm::types::VmProvingKey, EvmHalo2Prover}, + DefaultStaticVerifierPvHandler, NonRootCommittedExe, Sdk, StdIn, SC, }; use openvm_stark_sdk::{ bench::run_with_metric_collection, config::baby_bear_poseidon2::BabyBearPoseidon2Engine, }; use openvm_transpiler::FromElf; +fn verify_native_max_trace_heights( + sdk: &Sdk, + app_pk: Arc>, + app_committed_exe: Arc, + leaf_vm_pk: Arc>, + num_children_leaf: usize, +) -> Result<()> { + let app_proof = + sdk.generate_app_proof(app_pk.clone(), app_committed_exe.clone(), StdIn::default())?; + let leaf_inputs = + LeafVmVerifierInput::chunk_continuation_vm_proof(&app_proof, num_children_leaf); + let vm_vk = leaf_vm_pk.vm_pk.get_vk(); + + leaf_inputs.iter().for_each(|leaf_input| { + let executor = { + let mut executor = SingleSegmentVmExecutor::new(leaf_vm_pk.vm_config.clone()); + executor + .set_trace_height_constraints(leaf_vm_pk.vm_pk.trace_height_constraints.clone()); + executor + }; + let max_trace_heights = executor + .execute_metered( + app_pk.leaf_committed_exe.exe.clone(), + leaf_input.write_to_stream(), + &vm_vk.num_interactions(), + ) + .expect("execute_metered failed"); + println!("max_trace_heights: {:?}", max_trace_heights); + + let actual_trace_heights = executor + .execute_and_generate( + app_pk.leaf_committed_exe.clone(), + leaf_input.write_to_stream(), + &max_trace_heights, + ) + .expect("execute_and_generate failed") + .per_air + .iter() + .map(|(_, air)| air.raw.height()) + .collect::>(); + println!("actual_trace_heights: {:?}", actual_trace_heights); + + actual_trace_heights + .iter() + .zip(NATIVE_MAX_TRACE_HEIGHTS) + .for_each(|(&actual, &expected)| { + assert!( + actual <= (expected as usize), + "Actual trace height {} exceeds expected height {}", + actual, + expected + ); + }); + }); + Ok(()) +} + fn main() -> Result<()> { let args = BenchmarkCli::parse(); @@ -32,7 +94,7 @@ fn main() -> Result<()> { .rv32m(Default::default()) .io(Default::default()) .keccak(Default::default()) - .sha256(Default::default()) + .sha2(Default::default()) .bigint(Default::default()) .modular(ModularExtension::new(vec![ BigUint::from_str("1000000000000000003").unwrap(), @@ -88,6 +150,15 @@ fn main() -> Result<()> { &DefaultStaticVerifierPvHandler, )?; + // Verify that NATIVE_MAX_TRACE_HEIGHTS remains valid + verify_native_max_trace_heights( + &sdk, + app_pk.clone(), + app_committed_exe.clone(), + full_agg_pk.agg_stark_pk.leaf_vm_pk.clone(), + args.agg_tree_config.num_children_leaf, + )?; + run_with_metric_collection("OUTPUT_PATH", || -> Result<()> { let mut prover = EvmHalo2Prover::<_, BabyBearPoseidon2Engine>::new( &halo2_params_reader, diff --git a/benchmarks/prove/src/util.rs b/benchmarks/prove/src/util.rs index b3c17ead85..c26c76f1a7 100644 --- a/benchmarks/prove/src/util.rs +++ b/benchmarks/prove/src/util.rs @@ -3,7 +3,9 @@ use std::{path::PathBuf, sync::Arc}; use clap::{command, Parser}; use eyre::Result; use openvm_benchmarks_utils::{build_elf, get_programs_dir}; -use openvm_circuit::arch::{instructions::exe::VmExe, DefaultSegmentationStrategy, VmConfig}; +use openvm_circuit::arch::{ + instructions::exe::VmExe, DefaultSegmentationStrategy, InsExecutorE1, VmConfig, +}; use openvm_native_circuit::NativeConfig; use openvm_native_compiler::conversion::CompilerOptions; use openvm_sdk::{ @@ -17,7 +19,7 @@ use openvm_sdk::{ prover::{vm::local::VmLocalProver, AppProver, LeafProvingController}, Sdk, StdIn, }; -use openvm_stark_backend::utils::metrics_span; +use openvm_stark_backend::config::Val; use openvm_stark_sdk::{ config::{ baby_bear_poseidon2::{BabyBearPoseidon2Config, BabyBearPoseidon2Engine}, @@ -165,7 +167,7 @@ impl BenchmarkCli { ) -> Result<()> where VC: VmConfig, - VC::Executor: Chip, + VC::Executor: Chip + InsExecutorE1>, VC::Periphery: Chip, { let app_config = self.app_config(vm_config); @@ -199,22 +201,16 @@ pub fn bench_from_exe>( ) -> Result<()> where VC: VmConfig, - VC::Executor: Chip, + VC::Executor: Chip + InsExecutorE1>, VC::Periphery: Chip, { let bench_name = bench_name.to_string(); // 1. Generate proving key from config. - let app_pk = info_span!("keygen", group = &bench_name).in_scope(|| { - metrics_span("keygen_time_ms", || { - AppProvingKey::keygen(app_config.clone()) - }) - }); + let app_pk = info_span!("keygen", group = &bench_name) + .in_scope(|| AppProvingKey::keygen(app_config.clone())); // 2. Commit to the exe by generating cached trace for program. - let committed_exe = info_span!("commit_exe", group = &bench_name).in_scope(|| { - metrics_span("commit_exe_time_ms", || { - commit_app_exe(app_config.app_fri_params.fri_params, exe) - }) - }); + let committed_exe = info_span!("commit_exe", group = &bench_name) + .in_scope(|| commit_app_exe(app_config.app_fri_params.fri_params, exe)); // 3. Executes runtime // 4. Generate trace // 5. Generate STARK proofs for each segment (segmentation is determined by `config`), with diff --git a/benchmarks/utils/src/build-elfs.rs b/benchmarks/utils/src/build-elfs.rs index 3bed7cf6fd..3ce24c7c5c 100644 --- a/benchmarks/utils/src/build-elfs.rs +++ b/benchmarks/utils/src/build-elfs.rs @@ -63,6 +63,12 @@ fn main() -> Result<()> { let programs_to_build = if cli.programs.is_empty() { available_programs } else { + for prog in &cli.programs { + if !available_programs.iter().any(|(name, _)| name == prog) { + tracing::warn!("Program '{}' not found in available programs", prog); + } + } + available_programs .into_iter() .filter(|(name, _)| cli.programs.contains(name)) @@ -70,6 +76,12 @@ fn main() -> Result<()> { }; // Filter out skipped programs + for prog in &cli.skip { + if !programs_to_build.iter().any(|(name, _)| name == prog) { + tracing::warn!("Program '{}' not found in programs to skip", prog); + } + } + let programs_to_build = programs_to_build .into_iter() .filter(|(name, _)| !cli.skip.contains(name)) diff --git a/book/src/SUMMARY.md b/book/src/SUMMARY.md index 08c9faefb1..953d1fbe2d 100644 --- a/book/src/SUMMARY.md +++ b/book/src/SUMMARY.md @@ -21,7 +21,7 @@ - [Overview](./custom-extensions/overview.md) - [Keccak](./custom-extensions/keccak.md) -- [SHA-256](./custom-extensions/sha256.md) +- [SHA-2](./custom-extensions/sha2.md) - [Big Integer](./custom-extensions/bigint.md) - [Algebra (Modular Arithmetic)](./custom-extensions/algebra.md) - [Elliptic Curve Cryptography](./custom-extensions/ecc.md) diff --git a/book/src/custom-extensions/overview.md b/book/src/custom-extensions/overview.md index 2b07a73ec4..9ccfe35f3f 100644 --- a/book/src/custom-extensions/overview.md +++ b/book/src/custom-extensions/overview.md @@ -3,7 +3,7 @@ OpenVM ships with a set of pre-built extensions maintained by the OpenVM team. Below, we highlight six of these extensions designed to accelerate common arithmetic and cryptographic operations that are notoriously expensive to execute. Some of these extensions have corresponding guest libraries which provide convenient, high-level interfaces for your guest program to interact with the extension. - [`openvm-keccak-guest`](./keccak.md) - Keccak256 hash function. See the [Keccak256 guest library](../guest-libs/keccak256.md) for usage details. -- [`openvm-sha256-guest`](./sha256.md) - SHA-256 hash function. See the [SHA-2 guest library](../guest-libs/sha2.md) for usage details. +- [`openvm-sha2-guest`](./sha2.md) - SHA-2 family of hash functions. See the [SHA-2 guest library](../guest-libs/sha2.md) for usage details. - [`openvm-bigint-guest`](./bigint.md) - Big integer arithmetic for 256-bit signed and unsigned integers. See the [ruint guest library](../guest-libs/ruint.md) for using accelerated 256-bit integer ops in rust. - [`openvm-algebra-guest`](./algebra.md) - Modular arithmetic and complex field extensions. - [`openvm-ecc-guest`](./ecc.md) - Elliptic curve cryptography. See the [k256](../guest-libs/k256.md) and [p256](../guest-libs/p256.md) guest libraries for using this extension over the respective curves. @@ -43,9 +43,7 @@ range_tuple_checker_sizes = [256, 8192] [app_vm_config.io] [app_vm_config.keccak] - -[app_vm_config.sha256] - +[app_vm_config.sha2] [app_vm_config.native] [app_vm_config.bigint] diff --git a/book/src/custom-extensions/sha256.md b/book/src/custom-extensions/sha2.md similarity index 52% rename from book/src/custom-extensions/sha256.md rename to book/src/custom-extensions/sha2.md index a4a7f46261..de845fe25f 100644 --- a/book/src/custom-extensions/sha256.md +++ b/book/src/custom-extensions/sha2.md @@ -1,8 +1,8 @@ -# SHA-256 +# SHA-2 -The SHA-256 extension guest provides a function that is meant to be linked to other external libraries. The external libraries can use this function as a hook for the SHA-256 intrinsic. This is enabled only when the target is `zkvm`. +The SHA-2 extension guest provides functions that are meant to be linked to other external libraries. The external libraries can use these functions as a hook for SHA-2 intrinsics. This is enabled only when the target is `zkvm`. We support the SHA-256, SHA-512, and SHA-384 hash functions. -- `zkvm_sha256_impl(input: *const u8, len: usize, output: *mut u8)`: This function has `C` ABI. It takes in a pointer to the input, the length of the input, and a pointer to the output buffer. +- `zkvm_shaXXX_impl(input: *const u8, len: usize, output: *mut u8)` where XXX is one of `256`, `512`, or `384`. These functions have `C` ABI. They take in a pointer to the input, the length of the input, and a pointer to the output buffer. In the external library, you can do the following: @@ -31,5 +31,5 @@ fn sha256(input: &[u8]) -> [u8; 32] { For the guest program to build successfully add the following to your `.toml` file: ```toml -[app_vm_config.sha256] +[app_vm_config.sha2] ``` diff --git a/book/src/guest-libs/sha2.md b/book/src/guest-libs/sha2.md index cd35cf2e02..1ce0f46a89 100644 --- a/book/src/guest-libs/sha2.md +++ b/book/src/guest-libs/sha2.md @@ -3,26 +3,33 @@ The OpenVM SHA-2 guest library provides access to a set of accelerated SHA-2 family hash functions. Currently, it supports the following: - SHA-256 +- SHA-512 +- SHA-384 -## SHA-256 - -Refer [here](https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.180-4.pdf) for more details on SHA-256. +Refer [here](https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.180-4.pdf) for more details on the SHA-2 family of hash functions. For SHA-256, the SHA2 guest library provides two functions for use in your guest code: - - `sha256(input: &[u8]) -> [u8; 32]`: Computes the SHA-256 hash of the input data and returns it as an array of 32 bytes. - `set_sha256(input: &[u8], output: &mut [u8; 32])`: Sets the output to the SHA-256 hash of the input data into the provided output buffer. -See the full example [here](https://github.com/openvm-org/openvm/blob/main/examples/sha256/src/main.rs). +For SHA-512, we provide: +- `sha512(input: &[u8]) -> [u8; 46]`: Computes the SHA-512 hash of the input data and returns it as an array of 64 bytes. +- `set_sha512(input: &[u8], output: &mut [u8; 64])`: Sets the output to the SHA-512 hash of the input data into the provided output buffer. + +For SHA-384, we provide: +- `sha384(input: &[u8]) -> [u8; 48]`: Computes the SHA-384 hash of the input data and returns it as an array of 48 bytes. +- `set_sha384(input: &[u8], output: &mut [u8; 48])`: Sets the output to the SHA-384 hash of the input data into the provided output buffer. + +See the full example [here](https://github.com/openvm-org/openvm/blob/feat/sha-512-new-execution/examples/sha2/src/main.rs). ### Example ```rust,no_run,noplayground -{{ #include ../../../examples/sha256/src/main.rs:imports }} -{{ #include ../../../examples/sha256/src/main.rs:main }} +{{ #include ../../../examples/sha2/src/main.rs:imports }} +{{ #include ../../../examples/sha2/src/main.rs:main }} ``` -To be able to import the `sha256` function, add the following to your `Cargo.toml` file: +To be able to import the `shaXXX` functions and run the example, add the following to your `Cargo.toml` file: ```toml openvm-sha2 = { git = "https://github.com/openvm-org/openvm.git" } @@ -34,4 +41,4 @@ hex = { version = "0.4.3" } For the guest program to build successfully add the following to your `.toml` file: ```toml -[app_vm_config.sha256] \ No newline at end of file +[app_vm_config.sha2] \ No newline at end of file diff --git a/book/src/introduction.md b/book/src/introduction.md index ed39cbe33a..833151d66e 100644 --- a/book/src/introduction.md +++ b/book/src/introduction.md @@ -12,7 +12,7 @@ OpenVM is an open-source zero-knowledge virtual machine (zkVM) framework focused - RISC-V support via RV32IM - A native field arithmetic extension for proof recursion and aggregation - - The Keccak-256 and SHA2-256 hash functions + - The Keccak-256, SHA-256, SHA-512, and SHA-384 hash functions - Int256 arithmetic - Modular arithmetic over arbitrary fields - Elliptic curve operations, including multi-scalar multiplication and ECDSA signature verification, including for the secp256k1 and secp256r1 curves diff --git a/crates/circuits/mod-builder/Cargo.toml b/crates/circuits/mod-builder/Cargo.toml index d756db326b..cfd5434dde 100644 --- a/crates/circuits/mod-builder/Cargo.toml +++ b/crates/circuits/mod-builder/Cargo.toml @@ -23,8 +23,6 @@ num-traits.workspace = true tracing.workspace = true itertools.workspace = true -serde = { workspace = true, features = ["derive"] } -serde_with.workspace = true [dev-dependencies] openvm-circuit-primitives = { workspace = true } @@ -35,4 +33,8 @@ openvm-circuit = { workspace = true, features = ["test-utils"] } [features] default = [] parallel = ["openvm-stark-backend/parallel"] -test-utils = ["dep:halo2curves-axiom", "dep:openvm-pairing-guest"] +test-utils = [ + "dep:halo2curves-axiom", + "dep:openvm-pairing-guest", + "openvm-circuit/test-utils", +] diff --git a/crates/circuits/mod-builder/src/builder.rs b/crates/circuits/mod-builder/src/builder.rs index 6e1c22a009..5d337130bb 100644 --- a/crates/circuits/mod-builder/src/builder.rs +++ b/crates/circuits/mod-builder/src/builder.rs @@ -289,6 +289,22 @@ impl FieldExpr { ret.setup_values = setup_values; ret } + + pub fn num_inputs(&self) -> usize { + self.builder.num_input + } + + pub fn num_vars(&self) -> usize { + self.builder.num_variables + } + + pub fn num_flags(&self) -> usize { + self.builder.num_flags + } + + pub fn output_indices(&self) -> &[usize] { + &self.builder.output_indices + } } impl Deref for FieldExpr { diff --git a/crates/circuits/mod-builder/src/core_chip.rs b/crates/circuits/mod-builder/src/core_chip.rs index 30e9c65dbb..a8e65d7770 100644 --- a/crates/circuits/mod-builder/src/core_chip.rs +++ b/crates/circuits/mod-builder/src/core_chip.rs @@ -1,28 +1,35 @@ +use std::{ + marker::PhantomData, + mem::{align_of, size_of}, +}; + use itertools::Itertools; use num_bigint::BigUint; use num_traits::Zero; -use openvm_circuit::arch::{ - AdapterAirContext, AdapterRuntimeContext, DynAdapterInterface, DynArray, MinimalInstruction, - Result, VmAdapterInterface, VmCoreAir, VmCoreChip, +use openvm_circuit::{ + arch::{ + get_record_from_slice, AdapterAirContext, AdapterCoreLayout, AdapterCoreMetadata, + AdapterTraceFiller, AdapterTraceStep, CustomBorrow, DynAdapterInterface, DynArray, + MinimalInstruction, RecordArena, Result, SizedRecord, TraceFiller, TraceStep, + VmAdapterInterface, VmCoreAir, VmStateMut, + }, + system::memory::{online::TracingMemory, MemoryAuxColsFactory}, }; use openvm_circuit_primitives::{ var_range::SharedVariableRangeCheckerChip, SubAir, TraceSubRowGenerator, }; -use openvm_instructions::instruction::Instruction; +use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP}; use openvm_stark_backend::{ interaction::InteractionBuilder, p3_air::BaseAir, p3_field::{Field, FieldAlgebra, PrimeField32}, - p3_matrix::{dense::RowMajorMatrix, Matrix}, rap::BaseAirWithPublicValues, }; use openvm_stark_sdk::p3_baby_bear::BabyBear; -use serde::{Deserialize, Serialize}; -use serde_with::{serde_as, DisplayFromStr}; use crate::{ - utils::{biguint_to_limbs_vec, limbs_to_biguint}, - FieldExpr, FieldExprCols, + builder::{FieldExpr, FieldExprCols}, + utils::biguint_to_limbs_vec, }; #[derive(Clone)] @@ -165,27 +172,125 @@ where } } -#[serde_as] -#[derive(Serialize, Deserialize, Clone, PartialEq, Debug)] -pub struct FieldExpressionRecord { - #[serde_as(as = "Vec")] - pub inputs: Vec, - pub flags: Vec, +pub struct FieldExpressionMetadata { + pub total_input_limbs: usize, // num_inputs * limbs_per_input + _phantom: PhantomData<(F, A)>, } -pub struct FieldExpressionCoreChip { - pub air: FieldExpressionCoreAir, - pub range_checker: SharedVariableRangeCheckerChip, +impl Clone for FieldExpressionMetadata { + fn clone(&self) -> Self { + Self { + total_input_limbs: self.total_input_limbs, + _phantom: PhantomData, + } + } +} - pub name: String, +impl Default for FieldExpressionMetadata { + fn default() -> Self { + Self { + total_input_limbs: 0, + _phantom: PhantomData, + } + } +} + +impl FieldExpressionMetadata { + pub fn new(total_input_limbs: usize) -> Self { + Self { + total_input_limbs, + _phantom: PhantomData, + } + } +} + +impl AdapterCoreMetadata for FieldExpressionMetadata +where + A: AdapterTraceStep, +{ + #[inline(always)] + fn get_adapter_width() -> usize { + A::WIDTH * size_of::() + } +} + +pub type FieldExpressionRecordLayout = AdapterCoreLayout>; + +pub struct FieldExpressionCoreRecordMut<'a> { + pub opcode: &'a mut u8, + pub input_limbs: &'a mut [u8], +} + +impl<'a, F, A> CustomBorrow<'a, FieldExpressionCoreRecordMut<'a>, FieldExpressionRecordLayout> + for [u8] +{ + fn custom_borrow( + &'a mut self, + layout: FieldExpressionRecordLayout, + ) -> FieldExpressionCoreRecordMut<'a> { + let (opcode_buf, input_limbs_buff) = unsafe { self.split_at_mut_unchecked(1) }; + + FieldExpressionCoreRecordMut { + opcode: &mut opcode_buf[0], + input_limbs: &mut input_limbs_buff[..layout.metadata.total_input_limbs], + } + } + + unsafe fn extract_layout(&self) -> FieldExpressionRecordLayout { + panic!("Should get the Layout information from FieldExpressionStep"); + } +} + +impl SizedRecord> for FieldExpressionCoreRecordMut<'_> { + fn size(layout: &FieldExpressionRecordLayout) -> usize { + layout.metadata.total_input_limbs + 1 + } - /// Whether to finalize the trace. True if all-zero rows don't satisfy the constraints (e.g. - /// there is int_add) + fn alignment(_layout: &FieldExpressionRecordLayout) -> usize { + align_of::() + } +} + +impl<'a> FieldExpressionCoreRecordMut<'a> { + // This method is only used in testing + pub fn new_from_execution_data( + buffer: &'a mut [u8], + inputs: &[BigUint], + limbs_per_input: usize, + ) -> Self { + let record_info = FieldExpressionMetadata::<(), ()>::new(inputs.len() * limbs_per_input); + + let record: Self = buffer.custom_borrow(FieldExpressionRecordLayout { + metadata: record_info, + }); + record + } + + #[inline(always)] + pub fn fill_from_execution_data(&mut self, opcode: u8, data: &[u8]) { + // Rust will assert that length of `data` and `self.input_limbs` are the same + // That is `data.len() == num_inputs * limbs_per_input` + *self.opcode = opcode; + self.input_limbs.copy_from_slice(data); + } +} + +// TODO(arayi): use lifetimes and references for fields +pub struct FieldExpressionStep { + adapter: A, + pub expr: FieldExpr, + pub offset: usize, + pub local_opcode_idx: Vec, + pub opcode_flag_idx: Vec, + pub range_checker: SharedVariableRangeCheckerChip, + pub name: String, pub should_finalize: bool, } -impl FieldExpressionCoreChip { +impl FieldExpressionStep { + #[allow(clippy::too_many_arguments)] pub fn new( + adapter: A, expr: FieldExpr, offset: usize, local_opcode_idx: Vec, @@ -194,145 +299,242 @@ impl FieldExpressionCoreChip { name: &str, should_finalize: bool, ) -> Self { - let air = FieldExpressionCoreAir::new(expr, offset, local_opcode_idx, opcode_flag_idx); + let opcode_flag_idx = if opcode_flag_idx.is_empty() && expr.needs_setup() { + // single op chip that needs setup, so there is only one default flag, must be 0. + vec![0] + } else { + // multi ops chip or no-setup chip, use as is. + opcode_flag_idx + }; + assert_eq!(opcode_flag_idx.len(), local_opcode_idx.len() - 1); tracing::info!( - "FieldExpressionCoreChip: opcode={name}, main_width={}", - BaseAir::::width(&air) + "FieldExpressionCoreStep: opcode={name}, main_width={}", + BaseAir::::width(&expr) ); Self { - air, + adapter, + expr, + offset, + local_opcode_idx, + opcode_flag_idx, range_checker, name: name.to_string(), should_finalize, } } + pub fn num_inputs(&self) -> usize { + self.expr.builder.num_input + } - pub fn expr(&self) -> &FieldExpr { - &self.air.expr + pub fn num_vars(&self) -> usize { + self.expr.builder.num_variables + } + + pub fn num_flags(&self) -> usize { + self.expr.builder.num_flags + } + + pub fn output_indices(&self) -> &[usize] { + &self.expr.builder.output_indices + } + pub fn get_record_layout(&self) -> FieldExpressionRecordLayout { + FieldExpressionRecordLayout { + metadata: FieldExpressionMetadata::new( + self.num_inputs() * self.expr.canonical_num_limbs(), + ), + } } } -impl VmCoreChip for FieldExpressionCoreChip +impl TraceStep for FieldExpressionStep where - I: VmAdapterInterface, - I::Reads: Into>, - AdapterRuntimeContext: From>>, + F: PrimeField32, + A: 'static + + AdapterTraceStep>, WriteData: From>>, { - type Record = FieldExpressionRecord; - type Air = FieldExpressionCoreAir; + type RecordLayout = FieldExpressionRecordLayout; + type RecordMut<'a> = (A::RecordMut<'a>, FieldExpressionCoreRecordMut<'a>); - fn execute_instruction( - &self, + fn execute<'buf, RA>( + &mut self, + state: VmStateMut, CTX>, instruction: &Instruction, - _from_pc: u32, - reads: I::Reads, - ) -> Result<(AdapterRuntimeContext, Self::Record)> { - let field_element_limbs = self.air.expr.canonical_num_limbs(); - let limb_bits = self.air.expr.canonical_limb_bits(); - let data: DynArray<_> = reads.into(); - let data = data.0; - assert_eq!(data.len(), self.air.num_inputs() * field_element_limbs); - let data_u32: Vec = data.iter().map(|x| x.as_canonical_u32()).collect(); - - let mut inputs = vec![]; - for i in 0..self.air.num_inputs() { - let start = i * field_element_limbs; - let end = start + field_element_limbs; - let limb_slice = &data_u32[start..end]; - let input = limbs_to_biguint(limb_slice, limb_bits); - inputs.push(input); - } - - let Instruction { opcode, .. } = instruction; - let local_opcode_idx = opcode.local_opcode_idx(self.air.offset); - let mut flags = vec![]; - - // If the chip doesn't need setup, (right now) it must be single op chip and thus no flag is - // needed. Otherwise, there is a flag for each opcode and will be derived by - // is_valid - sum(flags). - if self.expr().needs_setup() { - flags = vec![false; self.air.num_flags()]; - self.air - .opcode_flag_idx - .iter() - .enumerate() - .for_each(|(i, &flag_idx)| { - flags[flag_idx] = local_opcode_idx == self.air.local_opcode_idx[i] - }); - } + arena: &'buf mut RA, + ) -> Result<()> + where + RA: RecordArena<'buf, Self::RecordLayout, Self::RecordMut<'buf>>, + { + let (mut adapter_record, mut core_record) = arena.alloc(self.get_record_layout()); + + A::start(*state.pc, state.memory, &mut adapter_record); + + let data: DynArray<_> = self + .adapter + .read(state.memory, instruction, &mut adapter_record) + .into(); + + core_record.fill_from_execution_data( + instruction.opcode.local_opcode_idx(self.offset) as u8, + &data.0, + ); - let vars = self.air.expr.execute(inputs.clone(), flags.clone()); - assert_eq!(vars.len(), self.air.num_vars()); + let (writes, _, _) = + run_field_expression(self, core_record.input_limbs, *core_record.opcode as usize); - let outputs: Vec = self - .air - .output_indices() - .iter() - .map(|&i| vars[i].clone()) - .collect(); - let writes: Vec = outputs - .iter() - .map(|x| biguint_to_limbs_vec(x.clone(), limb_bits, field_element_limbs)) - .concat() - .into_iter() - .map(|x| F::from_canonical_u32(x)) - .collect(); + self.adapter.write( + state.memory, + instruction, + writes.into(), + &mut adapter_record, + ); - let ctx = AdapterRuntimeContext::<_, DynAdapterInterface<_>>::without_pc(writes); - Ok((ctx.into(), FieldExpressionRecord { inputs, flags })) + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + Ok(()) } fn get_opcode_name(&self, _opcode: usize) -> String { self.name.clone() } +} - fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) { - self.air.expr.generate_subrow( - (self.range_checker.as_ref(), record.inputs, record.flags), - row_slice, - ); - } +impl TraceFiller for FieldExpressionStep +where + F: PrimeField32 + Send + Sync + Clone, + A: 'static + AdapterTraceFiller, +{ + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + // Get the core record from the row slice + let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + + self.adapter.fill_trace_row(mem_helper, adapter_row); + + let record: FieldExpressionCoreRecordMut = + unsafe { get_record_from_slice(&mut core_row, self.get_record_layout::()) }; + + let (_, inputs, flags) = + run_field_expression(self, record.input_limbs, *record.opcode as usize); - fn air(&self) -> &Self::Air { - &self.air + let range_checker = self.range_checker.as_ref(); + self.expr + .generate_subrow((range_checker, inputs, flags), core_row); } - fn finalize(&self, trace: &mut RowMajorMatrix, num_records: usize) { - if !self.should_finalize || num_records == 0 { + fn fill_dummy_trace_row(&self, _mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + if !self.should_finalize { return; } - let core_width = >::width(&self.air); - let adapter_width = trace.width() - core_width; - let dummy_row = self.generate_dummy_trace_row(adapter_width, core_width); - for row in trace.rows_mut().skip(num_records) { - row.copy_from_slice(&dummy_row); - } - } -} - -impl FieldExpressionCoreChip { - // We will be setting is_valid = 0. That forces all flags be 0 (otherwise setup will be -1). - // We generate a dummy row with all flags set to 0, then we set is_valid = 0. - fn generate_dummy_trace_row( - &self, - adapter_width: usize, - core_width: usize, - ) -> Vec { - let record = FieldExpressionRecord { - inputs: vec![BigUint::zero(); self.air.num_inputs()], - flags: vec![false; self.air.num_flags()], - }; - let mut row = vec![F::ZERO; adapter_width + core_width]; - let core_row = &mut row[adapter_width..]; + let inputs: Vec = vec![BigUint::zero(); self.num_inputs()]; + let flags: Vec = vec![false; self.num_flags()]; + let core_row = &mut row_slice[A::WIDTH..]; // We **do not** want this trace row to update the range checker // so we must create a temporary range checker let tmp_range_checker = SharedVariableRangeCheckerChip::new(self.range_checker.bus()); - self.air.expr.generate_subrow( - (tmp_range_checker.as_ref(), record.inputs, record.flags), - core_row, - ); + self.expr + .generate_subrow((tmp_range_checker.as_ref(), inputs, flags), core_row); core_row[0] = F::ZERO; // is_valid = 0 - row } } + +fn run_field_expression( + step: &FieldExpressionStep, + data: &[u8], + local_opcode_idx: usize, +) -> (DynArray, Vec, Vec) { + let field_element_limbs = step.expr.canonical_num_limbs(); + assert_eq!(data.len(), step.num_inputs() * field_element_limbs); + + let mut inputs = Vec::with_capacity(step.num_inputs()); + for i in 0..step.num_inputs() { + let start = i * field_element_limbs; + let end = start + field_element_limbs; + let limb_slice = &data[start..end]; + let input = BigUint::from_bytes_le(limb_slice); + inputs.push(input); + } + + let mut flags = vec![]; + if step.expr.needs_setup() { + flags = vec![false; step.num_flags()]; + + // Find which opcode this is in our local_opcode_idx list + + if let Some(opcode_position) = step + .local_opcode_idx + .iter() + .position(|&idx| idx == local_opcode_idx) + { + // If this is NOT the last opcode (setup), set the corresponding flag + if opcode_position < step.opcode_flag_idx.len() { + let flag_idx = step.opcode_flag_idx[opcode_position]; + flags[flag_idx] = true; + } + // If opcode_position == step.opcode_flag_idx.len(), it's the setup operation + // and all flags should remain false (which they already are) + } + } + + let vars = step.expr.execute(inputs.clone(), flags.clone()); + assert_eq!(vars.len(), step.num_vars()); + + let outputs: Vec = step + .output_indices() + .iter() + .map(|&i| vars[i].clone()) + .collect(); + let writes: DynArray<_> = outputs + .iter() + .map(|x| biguint_to_limbs_vec(x, field_element_limbs)) + .concat() + .into_iter() + .collect::>() + .into(); + + (writes, inputs, flags) +} + +#[inline(always)] +pub fn run_field_expression_precomputed( + expr: &FieldExpr, + flag_idx: usize, + data: &[u8], +) -> DynArray { + let field_element_limbs = expr.canonical_num_limbs(); + assert_eq!(data.len(), expr.num_inputs() * field_element_limbs); + + let mut inputs = Vec::with_capacity(expr.num_inputs()); + for i in 0..expr.num_inputs() { + let start = i * expr.canonical_num_limbs(); + let end = start + expr.canonical_num_limbs(); + let limb_slice = &data[start..end]; + let input = BigUint::from_bytes_le(limb_slice); + inputs.push(input); + } + + let flags = if NEEDS_SETUP { + let mut flags = vec![false; expr.num_flags()]; + if flag_idx < expr.num_flags() { + flags[flag_idx] = true; + } + flags + } else { + vec![] + }; + + let vars = expr.execute(inputs, flags); + assert_eq!(vars.len(), expr.num_vars()); + + let outputs: Vec = expr + .output_indices() + .iter() + .map(|&i| vars[i].clone()) + .collect(); + + outputs + .iter() + .map(|x| biguint_to_limbs_vec(x, field_element_limbs)) + .concat() + .into_iter() + .collect::>() + .into() +} diff --git a/crates/circuits/mod-builder/src/tests.rs b/crates/circuits/mod-builder/src/tests.rs index d217c0c5c2..628043256d 100644 --- a/crates/circuits/mod-builder/src/tests.rs +++ b/crates/circuits/mod-builder/src/tests.rs @@ -11,14 +11,126 @@ use openvm_stark_sdk::{ p3_baby_bear::BabyBear, }; -use crate::{test_utils::*, ExprBuilder, FieldExpr, FieldExprCols, FieldVariable, SymbolicExpr}; +use crate::{ + test_utils::*, utils::biguint_to_limbs_vec, ExprBuilder, FieldExpr, FieldExprCols, + FieldExpressionCoreRecordMut, FieldVariable, SymbolicExpr, +}; const LIMB_BITS: usize = 8; +use std::sync::Arc; + +use openvm_circuit_primitives::var_range::VariableRangeCheckerChip; + +fn create_field_expr_with_setup( + builder: ExprBuilder, +) -> (FieldExpr, Arc, usize) { + let prime = secp256k1_coord_prime(); + let (range_checker, _) = setup(&prime); + let expr = FieldExpr::new(builder, range_checker.bus(), false); + let width = BaseAir::::width(&expr); + (expr, range_checker, width) +} + +fn create_field_expr_with_flags_setup( + builder: ExprBuilder, +) -> (FieldExpr, Arc, usize) { + let prime = secp256k1_coord_prime(); + let (range_checker, _) = setup(&prime); + let expr = FieldExpr::new(builder, range_checker.bus(), true); + let width = BaseAir::::width(&expr); + (expr, range_checker, width) +} + +fn generate_direct_trace( + expr: &FieldExpr, + range_checker: &Arc, + inputs: Vec, + flags: Vec, + width: usize, +) -> Vec { + let mut row = BabyBear::zero_vec(width); + expr.generate_subrow((range_checker, inputs, flags), &mut row); + row +} + +fn generate_recorded_trace( + expr: &FieldExpr, + range_checker: &Arc, + inputs: &[BigUint], + flags: Vec, + width: usize, +) -> Vec { + let mut buffer = vec![0u8; 1024]; + let mut record = FieldExpressionCoreRecordMut::new_from_execution_data( + &mut buffer, + inputs, + expr.canonical_num_limbs(), + ); + let data: Vec = inputs + .iter() + .flat_map(|x| biguint_to_limbs_vec(x, expr.canonical_num_limbs())) + .collect(); + record.fill_from_execution_data(0, &data); + + let reconstructed_inputs: Vec = record + .input_limbs + .chunks(expr.canonical_num_limbs()) + .map(BigUint::from_bytes_le) + .collect(); + + let mut row = BabyBear::zero_vec(width); + expr.generate_subrow((range_checker, reconstructed_inputs, flags), &mut row); + row +} + +fn verify_stark_with_traces( + expr: FieldExpr, + range_checker: Arc, + trace: Vec, + width: usize, +) { + let trace_matrix = RowMajorMatrix::new(trace, width); + let range_trace = range_checker.generate_trace(); + BabyBearBlake3Engine::run_simple_test_no_pis_fast( + any_rap_arc_vec![expr, range_checker.air], + vec![trace_matrix, range_trace], + ) + .expect("Verification failed"); +} + +fn extract_and_verify_result( + expr: &FieldExpr, + trace: &[BabyBear], + expected: &BigUint, + var_index: usize, +) { + let FieldExprCols { vars, .. } = expr.load_vars(trace); + assert!(var_index < vars.len(), "Variable index out of bounds"); + let generated = evaluate_biguint(&vars[var_index], LIMB_BITS); + assert_eq!(generated, *expected); +} + +fn test_trace_equivalence( + expr: &FieldExpr, + range_checker: &Arc, + inputs: Vec, + flags: Vec, + width: usize, +) { + let direct_trace = + generate_direct_trace(expr, range_checker, inputs.clone(), flags.clone(), width); + let recorded_trace = generate_recorded_trace(expr, range_checker, &inputs, flags, width); + assert_eq!( + direct_trace, recorded_trace, + "Direct and recorded traces must be identical for inputs: {:?}", + inputs + ); +} #[test] fn test_add() { let prime = secp256k1_coord_prime(); - let (range_checker, builder) = setup(&prime); + let (_, builder) = setup(&prime); let x1 = ExprBuilder::new_input(builder.clone()); let x2 = ExprBuilder::new_input(builder.clone()); @@ -26,70 +138,45 @@ fn test_add() { x3.save(); let builder = builder.borrow().clone(); - let expr = FieldExpr::new(builder, range_checker.bus(), false); - let width = BaseAir::::width(&expr); + let (expr, range_checker, width) = create_field_expr_with_setup(builder); let x = generate_random_biguint(&prime); let y = generate_random_biguint(&prime); - let expected = (&x + &y) % prime; + let expected = (&x + &y) % ′ let inputs = vec![x, y]; - let mut row = BabyBear::zero_vec(width); - expr.generate_subrow((&range_checker, inputs, vec![]), &mut row); - let FieldExprCols { vars, .. } = expr.load_vars(&row); - assert_eq!(vars.len(), 1); - let generated = evaluate_biguint(&vars[0], LIMB_BITS); - assert_eq!(generated, expected); - - let trace = RowMajorMatrix::new(row, width); - let range_trace = range_checker.generate_trace(); - - BabyBearBlake3Engine::run_simple_test_no_pis_fast( - any_rap_arc_vec![expr, range_checker.air], - vec![trace, range_trace], - ) - .expect("Verification failed"); + let trace = generate_direct_trace(&expr, &range_checker, inputs, vec![], width); + extract_and_verify_result(&expr, &trace, &expected, 0); + verify_stark_with_traces(expr, range_checker, trace, width); } #[test] fn test_div() { let prime = secp256k1_coord_prime(); - let (range_checker, builder) = setup(&prime); + let (_, builder) = setup(&prime); let x1 = ExprBuilder::new_input(builder.clone()); let x2 = ExprBuilder::new_input(builder.clone()); let _x3 = x1 / x2; // auto save on division. let builder = builder.borrow().clone(); - let expr = FieldExpr::new(builder, range_checker.bus(), false); - let width = BaseAir::::width(&expr); + + let (expr, range_checker, width) = create_field_expr_with_setup(builder); let x = generate_random_biguint(&prime); let y = generate_random_biguint(&prime); let y_inv = y.modinv(&prime).unwrap(); - let expected = (&x * &y_inv) % prime; + let expected = (&x * &y_inv) % ′ let inputs = vec![x, y]; - let mut row = BabyBear::zero_vec(width); - expr.generate_subrow((&range_checker, inputs, vec![]), &mut row); - let FieldExprCols { vars, .. } = expr.load_vars(&row); - assert_eq!(vars.len(), 1); - let generated = evaluate_biguint(&vars[0], LIMB_BITS); - assert_eq!(generated, expected); - - let trace = RowMajorMatrix::new(row, width); - let range_trace = range_checker.generate_trace(); - - BabyBearBlake3Engine::run_simple_test_no_pis_fast( - any_rap_arc_vec![expr, range_checker.air], - vec![trace, range_trace], - ) - .expect("Verification failed"); + let trace = generate_direct_trace(&expr, &range_checker, inputs, vec![], width); + extract_and_verify_result(&expr, &trace, &expected, 0); + verify_stark_with_traces(expr, range_checker, trace, width); } #[test] fn test_auto_carry_mul() { let prime = secp256k1_coord_prime(); - let (range_checker, builder) = setup(&prime); + let (_, builder) = setup(&prime); let mut x1 = ExprBuilder::new_input(builder.clone()); let mut x2 = ExprBuilder::new_input(builder.clone()); @@ -101,36 +188,25 @@ fn test_auto_carry_mul() { assert_eq!(x4.expr, SymbolicExpr::Var(1)); let builder = builder.borrow().clone(); + let (expr, range_checker, width) = create_field_expr_with_setup(builder); - let expr = FieldExpr::new(builder, range_checker.bus(), false); - let width = BaseAir::::width(&expr); let x = generate_random_biguint(&prime); let y = generate_random_biguint(&prime); - let expected = (&x * &x * &y) % prime; // x4 = x3 * x1 = (x1 * x2) * x1 + let expected = (&x * &x * &y) % ′ // x4 = x3 * x1 = (x1 * x2) * x1 let inputs = vec![x, y]; - let mut row = BabyBear::zero_vec(width); - expr.generate_subrow((&range_checker, inputs, vec![]), &mut row); - let FieldExprCols { vars, .. } = expr.load_vars(&row); + let trace = generate_direct_trace(&expr, &range_checker, inputs, vec![], width); + let FieldExprCols { vars, .. } = expr.load_vars(&trace); assert_eq!(vars.len(), 2); - let generated = evaluate_biguint(&vars[1], LIMB_BITS); - assert_eq!(generated, expected); - - let trace = RowMajorMatrix::new(row, width); - let range_trace = range_checker.generate_trace(); - - BabyBearBlake3Engine::run_simple_test_no_pis_fast( - any_rap_arc_vec![expr, range_checker.air], - vec![trace, range_trace], - ) - .expect("Verification failed"); + extract_and_verify_result(&expr, &trace, &expected, 1); + verify_stark_with_traces(expr, range_checker, trace, width); } #[test] fn test_auto_carry_intmul() { let prime = secp256k1_coord_prime(); - let (range_checker, builder) = setup(&prime); - let mut x1 = ExprBuilder::new_input(builder.clone()); + let (_, builder) = setup(&prime); + let mut x1: FieldVariable = ExprBuilder::new_input(builder.clone()); let mut x2 = ExprBuilder::new_input(builder.clone()); let mut x3 = &mut x1 * &mut x2; // The int_mul below will overflow: @@ -143,35 +219,24 @@ fn test_auto_carry_intmul() { assert_eq!(x4.expr, SymbolicExpr::Var(1)); let builder = builder.borrow().clone(); + let (expr, range_checker, width) = create_field_expr_with_setup(builder); - let expr = FieldExpr::new(builder, range_checker.bus(), false); - let width = BaseAir::::width(&expr); let x = generate_random_biguint(&prime); let y = generate_random_biguint(&prime); - let expected = (&x * &x * BigUint::from(9u32)) % prime; + let expected = (&x * &x * BigUint::from(9u32)) % ′ let inputs = vec![x, y]; - let mut row = BabyBear::zero_vec(width); - expr.generate_subrow((&range_checker, inputs, vec![]), &mut row); - let FieldExprCols { vars, .. } = expr.load_vars(&row); + let trace = generate_direct_trace(&expr, &range_checker, inputs, vec![], width); + let FieldExprCols { vars, .. } = expr.load_vars(&trace); assert_eq!(vars.len(), 2); - let generated = evaluate_biguint(&vars[1], LIMB_BITS); - assert_eq!(generated, expected); - - let trace = RowMajorMatrix::new(row, width); - let range_trace = range_checker.generate_trace(); - - BabyBearBlake3Engine::run_simple_test_no_pis_fast( - any_rap_arc_vec![expr, range_checker.air], - vec![trace, range_trace], - ) - .expect("Verification failed"); + extract_and_verify_result(&expr, &trace, &expected, 1); + verify_stark_with_traces(expr, range_checker, trace, width); } #[test] fn test_auto_carry_add() { let prime = secp256k1_coord_prime(); - let (range_checker, builder) = setup(&prime); + let (_, builder) = setup(&prime); let mut x1 = ExprBuilder::new_input(builder.clone()); let mut x2 = ExprBuilder::new_input(builder.clone()); @@ -194,36 +259,24 @@ fn test_auto_carry_add() { assert_eq!(x5.expr, SymbolicExpr::Var(1)); let builder = builder.borrow().clone(); - - let expr = FieldExpr::new(builder, range_checker.bus(), false); - let width = BaseAir::::width(&expr); + let (expr, range_checker, width) = create_field_expr_with_setup(builder); let x = generate_random_biguint(&prime); let y = generate_random_biguint(&prime); - let expected = (&x * &x * BigUint::from(10u32)) % prime; + let expected = (&x * &x * BigUint::from(10u32)) % ′ let inputs = vec![x, y]; - let mut row = BabyBear::zero_vec(width); - expr.generate_subrow((&range_checker, inputs, vec![]), &mut row); - let FieldExprCols { vars, .. } = expr.load_vars(&row); + let trace = generate_direct_trace(&expr, &range_checker, inputs, vec![], width); + let FieldExprCols { vars, .. } = expr.load_vars(&trace); assert_eq!(vars.len(), 2); - let generated = evaluate_biguint(&vars[x5_id], LIMB_BITS); - assert_eq!(generated, expected); - - let trace = RowMajorMatrix::new(row, width); - let range_trace = range_checker.generate_trace(); - - BabyBearBlake3Engine::run_simple_test_no_pis_fast( - any_rap_arc_vec![expr, range_checker.air], - vec![trace, range_trace], - ) - .expect("Verification failed"); + extract_and_verify_result(&expr, &trace, &expected, x5_id); + verify_stark_with_traces(expr, range_checker, trace, width); } #[test] fn test_auto_carry_div() { let prime = secp256k1_coord_prime(); - let (range_checker, builder) = setup(&prime); + let (_, builder) = setup(&prime); let mut x1 = ExprBuilder::new_input(builder.clone()); let x2 = ExprBuilder::new_input(builder.clone()); @@ -237,29 +290,16 @@ fn test_auto_carry_div() { let builder = builder.borrow().clone(); assert_eq!(builder.num_variables, 2); // numerator autosaved, and the final division - let expr = FieldExpr::new(builder, range_checker.bus(), false); - let width = BaseAir::::width(&expr); + let (expr, range_checker, width) = create_field_expr_with_setup(builder); let x = generate_random_biguint(&prime); let y = generate_random_biguint(&prime); - // let expected = (&x * &x * BigUint::from(10u32)) % prime; let inputs = vec![x, y]; - let mut row = BabyBear::zero_vec(width); - expr.generate_subrow((&range_checker, inputs, vec![]), &mut row); - let FieldExprCols { vars, .. } = expr.load_vars(&row); + let trace = generate_direct_trace(&expr, &range_checker, inputs, vec![], width); + let FieldExprCols { vars, .. } = expr.load_vars(&trace); assert_eq!(vars.len(), 2); - // let generated = evaluate_biguint(&vars[x5_id], LIMB_BITS); - // assert_eq!(generated, expected); - - let trace = RowMajorMatrix::new(row, width); - let range_trace = range_checker.generate_trace(); - - BabyBearBlake3Engine::run_simple_test_no_pis_fast( - any_rap_arc_vec![expr, range_checker.air], - vec![trace, range_trace], - ) - .expect("Verification failed"); + verify_stark_with_traces(expr, range_checker, trace, width); } fn make_addsub_chip(builder: Rc>) -> ExprBuilder { @@ -283,65 +323,39 @@ fn make_addsub_chip(builder: Rc>) -> ExprBuilder { #[test] fn test_select() { let prime = secp256k1_coord_prime(); - let (range_checker, builder) = setup(&prime); + let (_, builder) = setup(&prime); let builder = make_addsub_chip(builder); - let expr = FieldExpr::new(builder, range_checker.bus(), true); - let width = BaseAir::::width(&expr); + let (expr, range_checker, width) = create_field_expr_with_flags_setup(builder); let x = generate_random_biguint(&prime); let y = generate_random_biguint(&prime); - let expected = (&x + &prime - &y) % prime; + let expected = (&x + &prime - &y) % ′ let inputs = vec![x, y]; - let flags = vec![false, true]; + let flags: Vec = vec![false, true]; - let mut row = BabyBear::zero_vec(width); - expr.generate_subrow((&range_checker, inputs, flags), &mut row); - let FieldExprCols { vars, .. } = expr.load_vars(&row); - assert_eq!(vars.len(), 1); - let generated = evaluate_biguint(&vars[0], LIMB_BITS); - assert_eq!(generated, expected); - - let trace = RowMajorMatrix::new(row, width); - let range_trace = range_checker.generate_trace(); - - BabyBearBlake3Engine::run_simple_test_no_pis_fast( - any_rap_arc_vec![expr, range_checker.air], - vec![trace, range_trace], - ) - .expect("Verification failed"); + let trace = generate_direct_trace(&expr, &range_checker, inputs, flags, width); + extract_and_verify_result(&expr, &trace, &expected, 0); + verify_stark_with_traces(expr, range_checker, trace, width); } #[test] fn test_select2() { let prime = secp256k1_coord_prime(); - let (range_checker, builder) = setup(&prime); + let (_, builder) = setup(&prime); let builder = make_addsub_chip(builder); - let expr = FieldExpr::new(builder, range_checker.bus(), true); - let width = BaseAir::::width(&expr); + let (expr, range_checker, width) = create_field_expr_with_flags_setup(builder); let x = generate_random_biguint(&prime); let y = generate_random_biguint(&prime); - let expected = (&x + &y) % prime; + let expected = (&x + &y) % ′ let inputs = vec![x, y]; - let flags = vec![true, false]; - - let mut row = BabyBear::zero_vec(width); - expr.generate_subrow((&range_checker, inputs, flags), &mut row); - let FieldExprCols { vars, .. } = expr.load_vars(&row); - assert_eq!(vars.len(), 1); - let generated = evaluate_biguint(&vars[0], LIMB_BITS); - assert_eq!(generated, expected); + let flags: Vec = vec![true, false]; - let trace = RowMajorMatrix::new(row, width); - let range_trace = range_checker.generate_trace(); - - BabyBearBlake3Engine::run_simple_test_no_pis_fast( - any_rap_arc_vec![expr, range_checker.air], - vec![trace, range_trace], - ) - .expect("Verification failed"); + let trace = generate_direct_trace(&expr, &range_checker, inputs, flags, width); + extract_and_verify_result(&expr, &trace, &expected, 0); + verify_stark_with_traces(expr, range_checker, trace, width); } fn test_symbolic_limbs(expr: SymbolicExpr, expected_q: usize, expected_carry: usize) { @@ -395,3 +409,299 @@ fn test_symbolic_limbs_mul() { let expected_carry = 64; test_symbolic_limbs(expr, expected_q, expected_carry); } + +#[test] +fn test_recorded_execution_records() { + let prime = secp256k1_coord_prime(); + let (_, builder) = setup(&prime); + + let x1 = ExprBuilder::new_input(builder.clone()); + let x2 = ExprBuilder::new_input(builder.clone()); + let mut x3 = x1 + x2; + x3.save(); + let builder = builder.borrow().clone(); + + let (expr, range_checker, width) = create_field_expr_with_setup(builder); + + let x = generate_random_biguint(&prime); + let y = generate_random_biguint(&prime); + let expected = (&x + &y) % ′ + let inputs = vec![x.clone(), y.clone()]; + let flags: Vec = vec![]; + + // Test record creation and reconstruction + let mut buffer = vec![0u8; 1024]; + let mut record = FieldExpressionCoreRecordMut::new_from_execution_data( + &mut buffer, + &inputs, + expr.canonical_num_limbs(), + ); + let data: Vec = inputs + .iter() + .flat_map(|x| biguint_to_limbs_vec(x, expr.canonical_num_limbs())) + .collect(); + record.fill_from_execution_data(0, &data); + assert_eq!(*record.opcode, 0); + + // Verify input reconstruction preserves data + let reconstructed_inputs: Vec = record + .input_limbs + .chunks(expr.canonical_num_limbs()) + .map(BigUint::from_bytes_le) + .collect(); + assert_eq!(reconstructed_inputs.len(), inputs.len()); + for (original, reconstructed) in inputs.iter().zip(reconstructed_inputs.iter()) { + assert_eq!(original, reconstructed); + } + + // Test standard execution and verification using reconstructed inputs + let trace = generate_direct_trace(&expr, &range_checker, reconstructed_inputs, flags, width); + extract_and_verify_result(&expr, &trace, &expected, 0); + verify_stark_with_traces(expr, range_checker, trace, width); +} + +#[test] +fn test_trace_mathematical_equivalence() { + let prime = secp256k1_coord_prime(); + let (_, builder) = setup(&prime); + + let x1 = ExprBuilder::new_input(builder.clone()); + let x2 = ExprBuilder::new_input(builder.clone()); + let x3 = &mut (x1.clone() * x2.clone()) + &mut (x1.clone().square()); + let mut x4 = x3.clone() / x2.clone(); // This will trigger auto-save + x4.save(); + let builder = builder.borrow().clone(); + + let (expr, range_checker, width) = create_field_expr_with_setup(builder); + + for _ in 0..10 { + let x = generate_random_biguint(&prime); + let y = generate_random_biguint(&prime); + + let expected = { + let temp = (&x * &y + &x * &x) % ′ + let y_inv = y.modinv(&prime).unwrap(); + (temp * y_inv) % &prime + }; + + let inputs = vec![x.clone(), y.clone()]; + let flags: Vec = vec![]; + + // Test direct/recorded equivalence + test_trace_equivalence(&expr, &range_checker, inputs.clone(), flags.clone(), width); + + // Verify the actual computation is correct + let direct_row = generate_direct_trace(&expr, &range_checker, inputs.clone(), flags, width); + let FieldExprCols { vars, .. } = expr.load_vars(&direct_row); + extract_and_verify_result(&expr, &direct_row, &expected, vars.len() - 1); + } +} + +#[test] +fn test_record_arena_allocation_patterns() { + let prime = secp256k1_coord_prime(); + let (_, builder) = setup(&prime); + + let x1 = ExprBuilder::new_input(builder.clone()); + let x2 = ExprBuilder::new_input(builder.clone()); + let mut x3 = x1 + x2; + x3.save(); + let builder = builder.borrow().clone(); + + let (expr, _range_checker, _width) = create_field_expr_with_setup(builder); + + let inputs = vec![ + generate_random_biguint(&prime), + generate_random_biguint(&prime), + ]; + + // Test record creation with various input sizes + let mut buffer = vec![0u8; 1024]; + let mut record = FieldExpressionCoreRecordMut::new_from_execution_data( + &mut buffer, + &inputs, + expr.canonical_num_limbs(), + ); + let data: Vec = inputs + .iter() + .flat_map(|x| biguint_to_limbs_vec(x, expr.canonical_num_limbs())) + .collect(); + record.fill_from_execution_data(0, &data); + assert_eq!(*record.opcode, 0); + + // Test with maximum inputs + let max_inputs = vec![BigUint::one(); 40]; // MAX_INPUT_LIMBS / 4 + let mut max_buffer = vec![0u8; 2048]; + let max_record = + FieldExpressionCoreRecordMut::new_from_execution_data(&mut max_buffer, &max_inputs, 4); + assert_eq!(*max_record.opcode, 0); + + // Test input reconstruction + let reconstructed_inputs: Vec = record + .input_limbs + .chunks(expr.canonical_num_limbs()) + .map(BigUint::from_bytes_le) + .collect(); + assert_eq!(reconstructed_inputs.len(), inputs.len()); + for (original, reconstructed) in inputs.iter().zip(reconstructed_inputs.iter()) { + assert_eq!(original, reconstructed); + } +} + +#[test] +fn test_tracestep_tracefiller_roundtrip() { + let prime = secp256k1_coord_prime(); + let (_, builder) = setup(&prime); + + let x1 = ExprBuilder::new_input(builder.clone()); + let x2 = ExprBuilder::new_input(builder.clone()); + let x3 = x1.clone() * x2.clone(); + let x4 = x3.clone() + x1.clone(); + let mut x5 = x4.clone(); + x5.save(); + let builder_data = builder.borrow().clone(); + + let (expr, _range_checker, _width) = create_field_expr_with_setup(builder_data); + + let inputs = vec![ + generate_random_biguint(&prime), + generate_random_biguint(&prime), + ]; + + let vars_direct = expr.execute(inputs.clone(), vec![]); + + // Test record creation and reconstruction roundtrip + let mut buffer = vec![0u8; 1024]; + let mut record = FieldExpressionCoreRecordMut::new_from_execution_data( + &mut buffer, + &inputs, + expr.canonical_num_limbs(), + ); + let data: Vec = inputs + .iter() + .flat_map(|x| biguint_to_limbs_vec(x, expr.canonical_num_limbs())) + .collect(); + record.fill_from_execution_data(0, &data); + + let reconstructed_inputs: Vec = record + .input_limbs + .chunks(expr.canonical_num_limbs()) + .map(BigUint::from_bytes_le) + .collect(); + let vars_reconstructed = expr.execute(reconstructed_inputs, vec![]); + + // All intermediate variables must be preserved + assert_eq!(vars_direct.len(), vars_reconstructed.len()); + for (direct, reconstructed) in vars_direct.iter().zip(vars_reconstructed.iter()) { + assert_eq!( + direct, reconstructed, + "Variable preservation failed in roundtrip" + ); + } +} + +#[test] +fn test_direct_recorded_with_complex_operations() { + let prime = secp256k1_coord_prime(); + let (_, builder) = setup(&prime); + + let x1 = ExprBuilder::new_input(builder.clone()); + let x2 = ExprBuilder::new_input(builder.clone()); + let x3 = ExprBuilder::new_input(builder.clone()); + + let numerator = x1.clone() * x2.clone() + x3.clone(); + let denominator = x1.clone() + x2.clone(); + let mut result = numerator / denominator; + result.save(); + + let builder_data = builder.borrow().clone(); + let (expr, range_checker, width) = create_field_expr_with_setup(builder_data); + + // Test edge cases with small and large numbers + let test_cases = vec![ + ( + BigUint::from(1u32), + BigUint::from(2u32), + BigUint::from(3u32), + ), + ( + BigUint::from(100u32), + BigUint::from(200u32), + BigUint::from(300u32), + ), + ( + generate_random_biguint(&prime), + generate_random_biguint(&prime), + generate_random_biguint(&prime), + ), + ]; + + for (x, y, z) in test_cases { + let inputs = vec![x.clone(), y.clone(), z.clone()]; + let flags = vec![]; + + // Test direct/recorded equivalence + test_trace_equivalence(&expr, &range_checker, inputs.clone(), flags.clone(), width); + + // Verify mathematical correctness + let expected = { + let num = (&x * &y + &z) % ′ + let den_inv = (&x + &y).modinv(&prime).unwrap(); + (num * den_inv) % &prime + }; + + let direct_row = generate_direct_trace(&expr, &range_checker, inputs, flags, width); + let FieldExprCols { vars, .. } = expr.load_vars(&direct_row); + extract_and_verify_result(&expr, &direct_row, &expected, vars.len() - 1); + } +} + +#[test] +fn test_concurrent_direct_recorded_simulation() { + // Simulate mixed direct/recorded execution to ensure RecordArena abstraction works correctly + let prime = secp256k1_coord_prime(); + let (_, builder) = setup(&prime); + + let x1 = ExprBuilder::new_input(builder.clone()); + let x2 = ExprBuilder::new_input(builder.clone()); + let mut x3 = x1 + x2; + x3.save(); + let builder_data = builder.borrow().clone(); + + let (expr, range_checker, width) = create_field_expr_with_setup(builder_data); + + // Simulate multiple "concurrent" executions with different modes + let execution_scenarios = vec![ + ("direct", true), + ("recorded", false), + ("direct", true), + ("recorded", false), + ]; + + let mut all_traces = Vec::new(); + + for (name, is_direct) in execution_scenarios { + let inputs = vec![ + generate_random_biguint(&prime), + generate_random_biguint(&prime), + ]; + + let trace = if is_direct { + generate_direct_trace(&expr, &range_checker, inputs.clone(), vec![], width) + } else { + generate_recorded_trace(&expr, &range_checker, &inputs, vec![], width) + }; + + all_traces.push((name, inputs, trace)); + } + + // Verify each trace is mathematically valid + for (_, inputs, trace) in &all_traces { + let expected = (&inputs[0] + &inputs[1]) % ′ + extract_and_verify_result(&expr, trace, &expected, 0); + } + + // Verify that direct and recorded with same inputs produce same results + let same_inputs = vec![BigUint::from(123u32), BigUint::from(456u32)]; + test_trace_equivalence(&expr, &range_checker, same_inputs, vec![], width); +} diff --git a/crates/circuits/mod-builder/src/utils.rs b/crates/circuits/mod-builder/src/utils.rs index 7540f0ae2c..2f2561ba87 100644 --- a/crates/circuits/mod-builder/src/utils.rs +++ b/crates/circuits/mod-builder/src/utils.rs @@ -1,27 +1,14 @@ use num_bigint::BigUint; -use num_traits::{FromPrimitive, ToPrimitive, Zero}; - -// little endian. -pub fn limbs_to_biguint(x: &[u32], limb_size: usize) -> BigUint { - let mut result = BigUint::zero(); - let base = BigUint::from_u32(1 << limb_size).unwrap(); - for limb in x.iter().rev() { - result = result * &base + BigUint::from_u32(*limb).unwrap(); - } - result -} // Use this when num_limbs is not a constant. // little endian. -// Warning: This function only returns the last NUM_LIMBS*LIMB_SIZE bits of +// Warning: This function only returns the last NUM_LIMBS bytes of // the input, while the input can have more than that. -pub fn biguint_to_limbs_vec(mut x: BigUint, limb_size: usize, num_limbs: usize) -> Vec { - let mut result = vec![0; num_limbs]; - let base = BigUint::from_u32(1 << limb_size).unwrap(); - for r in result.iter_mut() { - *r = (x.clone() % &base).to_u32().unwrap(); - x /= &base; - } - assert!(x.is_zero()); - result +#[inline(always)] +pub fn biguint_to_limbs_vec(x: &BigUint, num_limbs: usize) -> Vec { + x.to_bytes_le() + .into_iter() + .chain(std::iter::repeat(0u8)) + .take(num_limbs) + .collect() } diff --git a/crates/circuits/primitives/derive/Cargo.toml b/crates/circuits/primitives/derive/Cargo.toml index 06d4c00aed..2e91772fd5 100644 --- a/crates/circuits/primitives/derive/Cargo.toml +++ b/crates/circuits/primitives/derive/Cargo.toml @@ -12,6 +12,13 @@ license.workspace = true proc-macro = true [dependencies] -syn = { version = "2.0", features = ["parsing"] } +syn = { version = "2.0", features = ["parsing", "extra-traits"] } quote = "1.0" -itertools = { workspace = true } +itertools = { workspace = true, default-features = true } +proc-macro2 = "1.0" + +[dev-dependencies] +ndarray.workspace = true + +[package.metadata.cargo-shear] +ignored = ["ndarray"] diff --git a/crates/circuits/primitives/derive/src/cols_ref/README.md b/crates/circuits/primitives/derive/src/cols_ref/README.md new file mode 100644 index 0000000000..82812f7b90 --- /dev/null +++ b/crates/circuits/primitives/derive/src/cols_ref/README.md @@ -0,0 +1,113 @@ +# ColsRef macro + +The `ColsRef` procedural macro is used in constraint generation to create column structs that have dynamic sizes. + +Note: this macro was originally created for use in the SHA-2 VM extension, where we reuse the same constraint generation code for three different circuits (SHA-256, SHA-512, and SHA-384). +See the [SHA-2 VM extension](../../../../../../extensions/sha2/circuit/src/sha2_chip/air.rs) for an example of how to use the `ColsRef` macro to reuse constraint generation code over multiple circuits. + +## Overview + +As an illustrative example, consider the following columns struct: +```rust +struct ExampleCols { + arr: [T; N], + sum: T, +} +``` +Let's say we want to constrain `sum` to be the sum of the elements of `arr`, and `N` can be either 5 or 10. +We can define a trait that stores the config parameters. +```rust +pub trait ExampleConfig { + const N: usize; +} +``` +and then implement it for the two different configs. +```rust +pub struct ExampleConfigImplA; +impl ExampleConfig for ExampleConfigImplA { + const N: usize = 5; +} +pub struct ExampleConfigImplB; +impl ExampleConfig for ExampleConfigImplB { + const N: usize = 10; +} +``` +Then we can use the `ColsRef` macro like this +```rust +#[derive(ColsRef)] +#[config(ExampleConfig)] +struct ExampleCols { + arr: [T; N], + sum: T, +} +``` +which will generate a columns struct that uses references to the fields. +```rust +struct ExampleColsRef<'a, T, const N: usize> { + arr: ndarray::ArrayView1<'a, T>, // an n-dimensional view into the input slice (ArrayView2 for 2D arrays, etc.) + sum: &'a T, +} +``` +The `ColsRef` macro will also generate a `from` method that takes a slice of the correct length and returns an instance of the columns struct. +The `from` method is parameterized by a struct that implements the `ExampleConfig` trait, and it uses the associated constants to determine how to split the input slice into the fields of the columns struct. + +So, the constraint generation code can be written as +```rust +impl Air for ExampleAir { + fn eval(&self, builder: &mut AB) { + let main = builder.main(); + let (local, _) = (main.row_slice(0), main.row_slice(1)); + let local_cols = ExampleColsRef::::from::(&local[..C::N + 1]); + let sum = local_cols.arr.iter().sum(); + builder.assert_eq(local_cols.sum, sum); + } +} +``` +Notes: +- the `arr` and `sum` fields of `ExampleColsRef` are references to the elements of the `local` slice. +- the name, `N`, of the const generic parameter must match the name of the associated constant `N` in the `ExampleConfig` trait. + +The `ColsRef` macro also generates a `ExampleColsRefMut` struct that stores mutable references to the fields, for use in trace generation. + +The `ColsRef` macro supports more than just variable-length array fields. +The field types can also be: +- any type that derives `AlignedBorrow` via `#[derive(AlignedBorrow)]` +- any type that derives `ColsRef` via `#[derive(ColsRef)]` +- (possibly nested) arrays of `T` or (possibly nested) arrays of a type that derives `AlignedBorrow` + +Note that we currently do not support arrays of types that derive `ColsRef`. + +## Specification + +Annotating a struct named `ExampleCols` with `#[derive(ColsRef)]` and `#[config(ExampleConfig)]` produces two structs, `ExampleColsRef` and `ExampleColsRefMut`. +- we assume `ExampleCols` has exactly one generic type parameter, typically named `T`, and any number of const generic parameters. Each const generic parameter must have a name that matches an associated constant in the `ExampleConfig` trait + +The fields of `ExampleColsRef` have the same names as the fields of `ExampleCols`, but their types are transformed as follows: +- type `T` becomes `&T` +- type `[T; LEN]` becomes `&ArrayView1` (see [ndarray](https://docs.rs/ndarray/latest/ndarray/index.html)) where `LEN` is an associated constant in `ExampleConfig` + - the `ExampleColsRef::from` method will correctly infer the length of the array from the config +- fields with names that end in `Cols` are assumed to be a columns struct that derives `ColsRef` and are transformed into the appropriate `ColsRef` type recursively + - one restriction is that any nested `ColsRef` type must have the same config as the outer `ColsRef` type +- fields that are annotated with `#[aligned_borrow]` are assumed to derive `AlignedBorrow` and are borrowed from the input slice. The new type is a reference to the `AlignedBorrow` type + - if a field whose name ends in `Cols` is annotated with `#[aligned_borrow]`, then the aligned borrow takes precedence, and the field is not transformed into an `ArrayView` +- nested arrays of `U` become `&ArrayViewX` where `X` is the number of dimensions in the nested array type + - `U` can be either the generic type `T` or a type that derives `AlignedBorrow`. In the latter case, the field must be annotated with `#[aligned_borrow]` + - the `ArrayViewX` type provides a `X`-dimensional view into the row slice + +The fields of `ExampleColsRefMut` are almost the same as the fields of `ExampleColsRef`, but they are mutable references. +- the `ArrayViewMutX` type is used instead of `ArrayViewX` for the array fields. +- fields that derive `ColsRef` are transformed into the appropriate `ColsRefMut` type recursively. + +Each of the `ExampleColsRef` and `ExampleColsRefMut` types has the following methods implemented: +```rust +// Takes a slice of the correct length and returns an instance of the columns struct. +pub const fn from(slice: &[T]) -> Self; +// Returns the number of cells in the struct +pub const fn width() -> usize; +``` +Note that the `width` method on both structs returns the same value. + +Additionally, the `ExampleColsRef` struct has a `from_mut` method that takes a `ExampleColsRefMut` and returns a `ExampleColsRef`. +This may be useful in trace generation to pass a `ExampleColsRefMut` to a function that expects a `ExampleColsRef`. + +See the [tests](../../tests/test_cols_ref.rs) for concrete examples of how the `ColsRef` macro handles each of the supported field types. \ No newline at end of file diff --git a/crates/circuits/primitives/derive/src/cols_ref/mod.rs b/crates/circuits/primitives/derive/src/cols_ref/mod.rs new file mode 100644 index 0000000000..63289ec5eb --- /dev/null +++ b/crates/circuits/primitives/derive/src/cols_ref/mod.rs @@ -0,0 +1,697 @@ +extern crate proc_macro; + +use itertools::Itertools; +use quote::{format_ident, quote}; +use syn::{parse_quote, DeriveInput}; + +pub fn cols_ref_impl( + derive_input: DeriveInput, + config: proc_macro2::Ident, +) -> proc_macro2::TokenStream { + let DeriveInput { + ident, + generics, + data, + vis, + .. + } = derive_input; + + let generic_types = generics + .params + .iter() + .filter_map(|p| { + if let syn::GenericParam::Type(type_param) = p { + Some(type_param) + } else { + None + } + }) + .collect::>(); + + if generic_types.len() != 1 { + panic!("Struct must have exactly one generic type parameter"); + } + + let generic_type = generic_types[0]; + + let const_generics = generics.const_params().map(|p| &p.ident).collect_vec(); + + match data { + syn::Data::Struct(data_struct) => { + // Process the fields of the struct, transforming the types for use in ColsRef struct + let const_field_infos: Vec = data_struct + .fields + .iter() + .map(|f| get_const_cols_ref_fields(f, generic_type, &const_generics)) + .collect_vec(); + + // The ColsRef struct is named by appending `Ref` to the struct name + let const_cols_ref_name = syn::Ident::new(&format!("{}Ref", ident), ident.span()); + + // the args to the `from` method will be different for the ColsRef and ColsRefMut + // structs + let from_args = quote! { slice: &'a [#generic_type] }; + + // Package all the necessary information to generate the ColsRef struct + let struct_info = StructInfo { + name: const_cols_ref_name, + vis: vis.clone(), + generic_type: generic_type.clone(), + field_infos: const_field_infos, + fields: data_struct.fields.clone(), + from_args, + derive_clone: true, + }; + + // Generate the ColsRef struct + let const_cols_ref_struct = make_struct(struct_info.clone(), &config); + + // Generate the `from_mut` method for the ColsRef struct + let from_mut_impl = make_from_mut(struct_info, &config); + + // Process the fields of the struct, transforming the types for use in ColsRefMut struct + let mut_field_infos: Vec = data_struct + .fields + .iter() + .map(|f| get_mut_cols_ref_fields(f, generic_type, &const_generics)) + .collect_vec(); + + // The ColsRefMut struct is named by appending `RefMut` to the struct name + let mut_cols_ref_name = syn::Ident::new(&format!("{}RefMut", ident), ident.span()); + + // the args to the `from` method will be different for the ColsRef and ColsRefMut + // structs + let from_args = quote! { slice: &'a mut [#generic_type] }; + + // Package all the necessary information to generate the ColsRefMut struct + let struct_info = StructInfo { + name: mut_cols_ref_name, + vis, + generic_type: generic_type.clone(), + field_infos: mut_field_infos, + fields: data_struct.fields, + from_args, + derive_clone: false, + }; + + // Generate the ColsRefMut struct + let mut_cols_ref_struct = make_struct(struct_info, &config); + + quote! { + #const_cols_ref_struct + #from_mut_impl + #mut_cols_ref_struct + } + } + _ => panic!("ColsRef can only be derived for structs"), + } +} + +#[derive(Debug, Clone)] +struct StructInfo { + name: syn::Ident, + vis: syn::Visibility, + generic_type: syn::TypeParam, + field_infos: Vec, + fields: syn::Fields, + from_args: proc_macro2::TokenStream, + derive_clone: bool, +} + +// Generate the ColsRef and ColsRefMut structs, depending on the value of `struct_info` +// This function is meant to reduce code duplication between the code needed to generate the two +// structs Notable differences between the two structs are: +// - the types of the fields +// - ColsRef derives Clone, but ColsRefMut cannot (since it stores mutable references) +// - the `from` method parameter is a reference to a slice for ColsRef and a mutable reference to +// a slice for ColsRefMut +fn make_struct(struct_info: StructInfo, config: &proc_macro2::Ident) -> proc_macro2::TokenStream { + let StructInfo { + name, + vis, + generic_type, + field_infos, + fields, + from_args, + derive_clone, + } = struct_info; + + let field_types = field_infos.iter().map(|f| &f.ty).collect_vec(); + let length_exprs = field_infos.iter().map(|f| &f.length_expr).collect_vec(); + let prepare_subslices = field_infos + .iter() + .map(|f| &f.prepare_subslice) + .collect_vec(); + let initializers = field_infos.iter().map(|f| &f.initializer).collect_vec(); + + let idents = fields.iter().map(|f| &f.ident).collect_vec(); + + let clone_impl = if derive_clone { + quote! { + #[derive(Clone)] + } + } else { + quote! {} + }; + + quote! { + #clone_impl + #[derive(Debug)] + #vis struct #name <'a, #generic_type> { + #( pub #idents: #field_types ),* + } + + impl<'a, #generic_type> #name<'a, #generic_type> { + pub fn from(#from_args) -> Self { + #( #prepare_subslices )* + Self { + #( #idents: #initializers ),* + } + } + + // returns number of cells in the struct (where each cell has type T) + pub const fn width() -> usize { + 0 #( + #length_exprs )* + } + } + } +} + +// Generate the `from_mut` method for the ColsRef struct +fn make_from_mut(struct_info: StructInfo, config: &proc_macro2::Ident) -> proc_macro2::TokenStream { + let StructInfo { + name, + vis: _, + generic_type, + field_infos: _, + fields, + from_args: _, + derive_clone: _, + } = struct_info; + + let from_mut_impl = fields + .iter() + .map(|f| { + let ident = f.ident.clone().unwrap(); + + let derives_aligned_borrow = f + .attrs + .iter() + .any(|attr| attr.path().is_ident("aligned_borrow")); + + let is_array = matches!(f.ty, syn::Type::Array(_)); + + if is_array { + // calling view() on ArrayViewMut returns an ArrayView + quote! { + other.#ident.view() + } + } else if derives_aligned_borrow { + // implicitly converts a mutable reference to an immutable reference, so leave the + // field value unchanged + quote! { + other.#ident + } + } else if is_columns_struct(&f.ty) { + // lifetime 'b is used in from_mut to allow more flexible lifetime of return value + let cols_ref_type = + get_const_cols_ref_type(&f.ty, &generic_type, parse_quote! { 'b }); + // Recursively call `from_mut` on the ColsRef field + quote! { + <#cols_ref_type>::from_mut::(&other.#ident) + } + } else if is_generic_type(&f.ty, &generic_type) { + // implicitly converts a mutable reference to an immutable reference, so leave the + // field value unchanged + quote! { + &other.#ident + } + } else { + panic!("Unsupported field type: {:?}", f.ty); + } + }) + .collect_vec(); + + let field_idents = fields + .iter() + .map(|f| f.ident.clone().unwrap()) + .collect_vec(); + + let mut_struct_ident = format_ident!("{}Mut", name.to_string()); + let mut_struct_type: syn::Type = parse_quote! { + #mut_struct_ident<'a, #generic_type> + }; + + parse_quote! { + // lifetime 'b is used in from_mut to allow more flexible lifetime of return value + impl<'b, #generic_type> #name<'b, #generic_type> { + pub fn from_mut<'a, C: #config>(other: &'b #mut_struct_type) -> Self + { + Self { + #( #field_idents: #from_mut_impl ),* + } + } + } + } +} + +// Information about a field that is used to generate the ColsRef and ColsRefMut structs +// See the `make_struct` function to see how this information is used +#[derive(Debug, Clone)] +struct FieldInfo { + // type for struct definition + ty: syn::Type, + // an expr calculating the length of the field + length_expr: proc_macro2::TokenStream, + // prepare a subslice of the slice to be used in the 'from' method + prepare_subslice: proc_macro2::TokenStream, + // an expr used in the Self initializer in the 'from' method + // may refer to the subslice declared in prepare_subslice + initializer: proc_macro2::TokenStream, +} + +// Prepare the fields for the const ColsRef struct +fn get_const_cols_ref_fields( + f: &syn::Field, + generic_type: &syn::TypeParam, + const_generics: &[&syn::Ident], +) -> FieldInfo { + let length_var = format_ident!("{}_length", f.ident.clone().unwrap()); + let slice_var = format_ident!("{}_slice", f.ident.clone().unwrap()); + + let derives_aligned_borrow = f + .attrs + .iter() + .any(|attr| attr.path().is_ident("aligned_borrow")); + + let is_array = matches!(f.ty, syn::Type::Array(_)); + + if is_array { + let ArrayInfo { dims, elem_type } = get_array_info(&f.ty, const_generics); + debug_assert!( + !dims.is_empty(), + "Array field must have at least one dimension" + ); + + let ndarray_ident: syn::Ident = format_ident!("ArrayView{}", dims.len()); + let ndarray_type: syn::Type = parse_quote! { + ndarray::#ndarray_ident<'a, #elem_type> + }; + + // dimensions of the array in terms of number of cells + let dim_exprs = dims + .iter() + .map(|d| match d { + // need to prepend C:: for const generic array dimensions + Dimension::ConstGeneric(expr) => quote! { C::#expr }, + Dimension::Other(expr) => quote! { #expr }, + }) + .collect_vec(); + + if derives_aligned_borrow { + let length_expr = quote! { + <#elem_type>::width() #(* #dim_exprs)* + }; + + FieldInfo { + ty: parse_quote! { + #ndarray_type + }, + length_expr: length_expr.clone(), + prepare_subslice: quote! { + let (#slice_var, slice) = slice.split_at(#length_expr); + let #slice_var: &[#elem_type] = unsafe { &*(#slice_var as *const [T] as *const [#elem_type]) }; + let #slice_var = ndarray::#ndarray_ident::from_shape( ( #(#dim_exprs),* ) , #slice_var).unwrap(); + }, + initializer: quote! { + #slice_var + }, + } + } else if is_columns_struct(&elem_type) { + panic!("Arrays of columns structs are currently not supported"); + } else if is_generic_type(&elem_type, generic_type) { + let length_expr = quote! { + 1 #(* #dim_exprs)* + }; + FieldInfo { + ty: parse_quote! { + #ndarray_type + }, + length_expr: length_expr.clone(), + prepare_subslice: quote! { + let (#slice_var, slice) = slice.split_at(#length_expr); + let #slice_var = ndarray::#ndarray_ident::from_shape( ( #(#dim_exprs),* ) , #slice_var).unwrap(); + }, + initializer: quote! { + #slice_var + }, + } + } else { + panic!("Unsupported field type: {:?}", f.ty); + } + } else if derives_aligned_borrow { + // treat the field as a struct that derives AlignedBorrow (and doesn't depend on the config) + let f_ty = &f.ty; + FieldInfo { + ty: parse_quote! { + &'a #f_ty + }, + length_expr: quote! { + <#f_ty>::width() + }, + prepare_subslice: quote! { + let #length_var = <#f_ty>::width(); + let (#slice_var, slice) = slice.split_at(#length_var); + }, + initializer: quote! { + { + use core::borrow::Borrow; + #slice_var.borrow() + } + }, + } + } else if is_columns_struct(&f.ty) { + let const_cols_ref_type = get_const_cols_ref_type(&f.ty, generic_type, parse_quote! { 'a }); + FieldInfo { + ty: parse_quote! { + #const_cols_ref_type + }, + length_expr: quote! { + <#const_cols_ref_type>::width::() + }, + prepare_subslice: quote! { + let #length_var = <#const_cols_ref_type>::width::(); + let (#slice_var, slice) = slice.split_at(#length_var); + let #slice_var = <#const_cols_ref_type>::from::(#slice_var); + }, + initializer: quote! { + #slice_var + }, + } + } else if is_generic_type(&f.ty, generic_type) { + FieldInfo { + ty: parse_quote! { + &'a #generic_type + }, + length_expr: quote! { + 1 + }, + prepare_subslice: quote! { + let #length_var = 1; + let (#slice_var, slice) = slice.split_at(#length_var); + }, + initializer: quote! { + &#slice_var[0] + }, + } + } else { + panic!("Unsupported field type: {:?}", f.ty); + } +} + +// Prepare the fields for the mut ColsRef struct +fn get_mut_cols_ref_fields( + f: &syn::Field, + generic_type: &syn::TypeParam, + const_generics: &[&syn::Ident], +) -> FieldInfo { + let length_var = format_ident!("{}_length", f.ident.clone().unwrap()); + let slice_var = format_ident!("{}_slice", f.ident.clone().unwrap()); + + let derives_aligned_borrow = f + .attrs + .iter() + .any(|attr| attr.path().is_ident("aligned_borrow")); + + let is_array = matches!(f.ty, syn::Type::Array(_)); + + if is_array { + let ArrayInfo { dims, elem_type } = get_array_info(&f.ty, const_generics); + debug_assert!( + !dims.is_empty(), + "Array field must have at least one dimension" + ); + + let ndarray_ident: syn::Ident = format_ident!("ArrayViewMut{}", dims.len()); + let ndarray_type: syn::Type = parse_quote! { + ndarray::#ndarray_ident<'a, #elem_type> + }; + + // dimensions of the array in terms of number of cells + let dim_exprs = dims + .iter() + .map(|d| match d { + // need to prepend C:: for const generic array dimensions + Dimension::ConstGeneric(expr) => quote! { C::#expr }, + Dimension::Other(expr) => quote! { #expr }, + }) + .collect_vec(); + + if derives_aligned_borrow { + let length_expr = quote! { + <#elem_type>::width() #(* #dim_exprs)* + }; + + FieldInfo { + ty: parse_quote! { + #ndarray_type + }, + length_expr: length_expr.clone(), + prepare_subslice: quote! { + let (#slice_var, slice) = slice.split_at_mut (#length_expr); + let #slice_var: &mut [#elem_type] = unsafe { &mut *(#slice_var as *mut [T] as *mut [#elem_type]) }; + let #slice_var = ndarray::#ndarray_ident::from_shape( ( #(#dim_exprs),* ) , #slice_var).unwrap(); + }, + initializer: quote! { + #slice_var + }, + } + } else if is_columns_struct(&elem_type) { + panic!("Arrays of columns structs are currently not supported"); + } else if is_generic_type(&elem_type, generic_type) { + let length_expr = quote! { + 1 #(* #dim_exprs)* + }; + FieldInfo { + ty: parse_quote! { + #ndarray_type + }, + length_expr: length_expr.clone(), + prepare_subslice: quote! { + let (#slice_var, slice) = slice.split_at_mut(#length_expr); + let #slice_var = ndarray::#ndarray_ident::from_shape( ( #(#dim_exprs),* ) , #slice_var).unwrap(); + }, + initializer: quote! { + #slice_var + }, + } + } else { + panic!("Unsupported field type: {:?}", f.ty); + } + } else if derives_aligned_borrow { + // treat the field as a struct that derives AlignedBorrow (and doesn't depend on the config) + let f_ty = &f.ty; + FieldInfo { + ty: parse_quote! { + &'a mut #f_ty + }, + length_expr: quote! { + <#f_ty>::width() + }, + prepare_subslice: quote! { + let #length_var = <#f_ty>::width(); + let (#slice_var, slice) = slice.split_at_mut(#length_var); + }, + initializer: quote! { + { + use core::borrow::BorrowMut; + #slice_var.borrow_mut() + } + }, + } + } else if is_columns_struct(&f.ty) { + let mut_cols_ref_type = get_mut_cols_ref_type(&f.ty, generic_type); + FieldInfo { + ty: parse_quote! { + #mut_cols_ref_type + }, + length_expr: quote! { + <#mut_cols_ref_type>::width::() + }, + prepare_subslice: quote! { + let #length_var = <#mut_cols_ref_type>::width::(); + let (#slice_var, slice) = slice.split_at_mut(#length_var); + let #slice_var = <#mut_cols_ref_type>::from::(#slice_var); + }, + initializer: quote! { + #slice_var + }, + } + } else if is_generic_type(&f.ty, generic_type) { + FieldInfo { + ty: parse_quote! { + &'a mut #generic_type + }, + length_expr: quote! { + 1 + }, + prepare_subslice: quote! { + let #length_var = 1; + let (#slice_var, slice) = slice.split_at_mut(#length_var); + }, + initializer: quote! { + &mut #slice_var[0] + }, + } + } else { + panic!("Unsupported field type: {:?}", f.ty); + } +} + +// Helper functions + +fn is_columns_struct(ty: &syn::Type) -> bool { + if let syn::Type::Path(type_path) = ty { + type_path + .path + .segments + .iter() + .last() + .map(|s| s.ident.to_string().ends_with("Cols")) + .unwrap_or(false) + } else { + false + } +} + +// If 'ty' is a struct that derives ColsRef, return the ColsRef struct type +// Otherwise, return None +fn get_const_cols_ref_type( + ty: &syn::Type, + generic_type: &syn::TypeParam, + lifetime: syn::Lifetime, +) -> syn::TypePath { + if !is_columns_struct(ty) { + panic!("Expected a columns struct, got {:?}", ty); + } + + if let syn::Type::Path(type_path) = ty { + let s = type_path.path.segments.iter().last().unwrap(); + if s.ident.to_string().ends_with("Cols") { + let const_cols_ref_ident = format_ident!("{}Ref", s.ident); + let const_cols_ref_type = parse_quote! { + #const_cols_ref_ident<#lifetime, #generic_type> + }; + const_cols_ref_type + } else { + panic!("is_columns_struct returned true for type {:?} but the last segment is not a columns struct", ty); + } + } else { + panic!( + "is_columns_struct returned true but the type {:?} is not a path", + ty + ); + } +} + +// If 'ty' is a struct that derives ColsRef, return the ColsRefMut struct type +// Otherwise, return None +fn get_mut_cols_ref_type(ty: &syn::Type, generic_type: &syn::TypeParam) -> syn::TypePath { + if !is_columns_struct(ty) { + panic!("Expected a columns struct, got {:?}", ty); + } + + if let syn::Type::Path(type_path) = ty { + let s = type_path.path.segments.iter().last().unwrap(); + if s.ident.to_string().ends_with("Cols") { + let mut_cols_ref_ident = format_ident!("{}RefMut", s.ident); + let mut_cols_ref_type = parse_quote! { + #mut_cols_ref_ident<'a, #generic_type> + }; + mut_cols_ref_type + } else { + panic!("is_columns_struct returned true for type {:?} but the last segment is not a columns struct", ty); + } + } else { + panic!( + "is_columns_struct returned true but the type {:?} is not a path", + ty + ); + } +} + +fn is_generic_type(ty: &syn::Type, generic_type: &syn::TypeParam) -> bool { + if let syn::Type::Path(type_path) = ty { + if type_path.path.segments.len() == 1 { + type_path + .path + .segments + .iter() + .last() + .map(|s| s.ident == generic_type.ident) + .unwrap_or(false) + } else { + false + } + } else { + false + } +} + +// Type of array dimension +enum Dimension { + ConstGeneric(syn::Expr), + Other(syn::Expr), +} + +// Describes a nested array +struct ArrayInfo { + dims: Vec, + elem_type: syn::Type, +} + +fn get_array_info(ty: &syn::Type, const_generics: &[&syn::Ident]) -> ArrayInfo { + let dims = get_dims(ty, const_generics); + let elem_type = get_elem_type(ty); + ArrayInfo { dims, elem_type } +} + +fn get_elem_type(ty: &syn::Type) -> syn::Type { + match ty { + syn::Type::Array(array) => get_elem_type(array.elem.as_ref()), + syn::Type::Path(_) => ty.clone(), + _ => panic!("Unsupported type: {:?}", ty), + } +} + +// Get a vector of the dimensions of the array +// Each dimension is either a constant generic or a literal integer value +fn get_dims(ty: &syn::Type, const_generics: &[&syn::Ident]) -> Vec { + get_dims_impl(ty, const_generics) + .into_iter() + .rev() + .collect() +} + +fn get_dims_impl(ty: &syn::Type, const_generics: &[&syn::Ident]) -> Vec { + match ty { + syn::Type::Array(array) => { + let mut dims = get_dims_impl(array.elem.as_ref(), const_generics); + match &array.len { + syn::Expr::Path(syn::ExprPath { path, .. }) => { + let len_ident = path.get_ident(); + if len_ident.is_some() && const_generics.contains(&len_ident.unwrap()) { + dims.push(Dimension::ConstGeneric(array.len.clone())); + } else { + dims.push(Dimension::Other(array.len.clone())); + } + } + syn::Expr::Lit(expr_lit) => dims.push(Dimension::Other(expr_lit.clone().into())), + _ => panic!("Unsupported array length type"), + } + dims + } + syn::Type::Path(_) => Vec::new(), + _ => panic!("Unsupported field type"), + } +} diff --git a/crates/circuits/primitives/derive/src/lib.rs b/crates/circuits/primitives/derive/src/lib.rs index 47ff1e220a..2f5dab0a4c 100644 --- a/crates/circuits/primitives/derive/src/lib.rs +++ b/crates/circuits/primitives/derive/src/lib.rs @@ -7,6 +7,9 @@ use proc_macro::TokenStream; use quote::quote; use syn::{parse_macro_input, Data, DeriveInput, Fields, GenericParam, LitStr, Meta}; +mod cols_ref; +use cols_ref::cols_ref_impl; + #[proc_macro_derive(AlignedBorrow)] pub fn aligned_borrow_derive(input: TokenStream) -> TokenStream { let ast = parse_macro_input!(input as DeriveInput); @@ -73,6 +76,49 @@ pub fn aligned_borrow_derive(input: TokenStream) -> TokenStream { TokenStream::from(methods) } +/// `S` is the type the derive macro is being called on +/// Implements Borrow and BorrowMut for [u8] +/// [u8] has to have (checked via `debug_assert!`s) +/// - at least size_of(S) length +/// - at least align_of(S) alignment +#[proc_macro_derive(AlignedBytesBorrow)] +pub fn aligned_bytes_borrow_derive(input: TokenStream) -> TokenStream { + let ast = parse_macro_input!(input as DeriveInput); + let name = &ast.ident; + + // Get impl generics, type generics, where clause + // Note, need to add the new type generic to the `impl_generics` + let (impl_generics, type_generics, where_clause) = ast.generics.split_for_impl(); + + let methods = quote! { + impl #impl_generics core::borrow::Borrow<#name #type_generics> for [u8] + where + #where_clause + { + fn borrow(&self) -> &#name #type_generics { + use core::mem::{align_of, size_of_val}; + debug_assert!(size_of_val(self) >= core::mem::size_of::<#name #type_generics>()); + debug_assert_eq!(self.as_ptr() as usize % align_of::<#name #type_generics>(), 0); + unsafe { &*(self.as_ptr() as *const #name #type_generics) } + } + } + + impl #impl_generics core::borrow::BorrowMut<#name #type_generics> for [u8] + where + #where_clause + { + fn borrow_mut(&mut self) -> &mut #name #type_generics { + use core::mem::{align_of, size_of_val}; + debug_assert!(size_of_val(self) >= core::mem::size_of::<#name #type_generics>()); + debug_assert_eq!(self.as_ptr() as usize % align_of::<#name #type_generics>(), 0); + unsafe { &mut *(self.as_mut_ptr() as *mut #name #type_generics) } + } + } + }; + + TokenStream::from(methods) +} + #[proc_macro_derive(Chip, attributes(chip))] pub fn chip_derive(input: TokenStream) -> TokenStream { // Parse the attributes from the struct or enum @@ -400,3 +446,25 @@ pub fn bytes_stateful_derive(input: TokenStream) -> TokenStream { _ => unimplemented!(), } } + +#[proc_macro_derive(ColsRef, attributes(aligned_borrow, config))] +pub fn cols_ref_derive(input: TokenStream) -> TokenStream { + let derive_input: DeriveInput = parse_macro_input!(input as DeriveInput); + + let config = derive_input + .attrs + .iter() + .find(|attr| attr.path().is_ident("config")); + if config.is_none() { + return syn::Error::new(derive_input.ident.span(), "Config attribute is required") + .to_compile_error() + .into(); + } + let config: proc_macro2::Ident = config + .unwrap() + .parse_args() + .expect("Failed to parse config"); + + let res = cols_ref_impl(derive_input, config); + res.into() +} diff --git a/crates/circuits/primitives/derive/tests/example.rs b/crates/circuits/primitives/derive/tests/example.rs new file mode 100644 index 0000000000..58bac9e26c --- /dev/null +++ b/crates/circuits/primitives/derive/tests/example.rs @@ -0,0 +1,87 @@ +use openvm_circuit_primitives_derive::ColsRef; + +pub trait ExampleConfig { + const N: usize; +} +pub struct ExampleConfigImplA; +impl ExampleConfig for ExampleConfigImplA { + const N: usize = 5; +} +pub struct ExampleConfigImplB; +impl ExampleConfig for ExampleConfigImplB { + const N: usize = 10; +} + +#[allow(dead_code)] +#[derive(ColsRef)] +#[config(ExampleConfig)] +struct ExampleCols { + arr: [T; N], + sum: T, +} + +#[test] +fn example() { + let input = [1, 2, 3, 4, 5, 15]; + let test: ExampleColsRef = ExampleColsRef::from::(&input); + println!("{}, {}", test.arr, test.sum); +} + +/* + * For reference, this is what the ColsRef macro expands to. + * The `cargo expand` tool is helpful for understanding how the ColsRef macro works. + * See https://github.com/dtolnay/cargo-expand + +#[derive(Debug, Clone)] +struct ExampleColsRef<'a, T> { + pub arr: ndarray::ArrayView1<'a, T>, + pub sum: &'a T, +} + +impl<'a, T> ExampleColsRef<'a, T> { + pub fn from(slice: &'a [T]) -> Self { + let (arr_slice, slice) = slice.split_at(1 * C::N); + let arr_slice = ndarray::ArrayView1::from_shape((C::N), arr_slice).unwrap(); + let sum_length = 1; + let (sum_slice, slice) = slice.split_at(sum_length); + Self { + arr: arr_slice, + sum: &sum_slice[0], + } + } + pub const fn width() -> usize { + 0 + 1 * C::N + 1 + } +} + +impl<'b, T> ExampleColsRef<'b, T> { + pub fn from_mut<'a, C: ExampleConfig>(other: &'b ExampleColsRefMut<'a, T>) -> Self { + Self { + arr: other.arr.view(), + sum: &other.sum, + } + } +} + +#[derive(Debug)] +struct ExampleColsRefMut<'a, T> { + pub arr: ndarray::ArrayViewMut1<'a, T>, + pub sum: &'a mut T, +} + +impl<'a, T> ExampleColsRefMut<'a, T> { + pub fn from(slice: &'a mut [T]) -> Self { + let (arr_slice, slice) = slice.split_at_mut(1 * C::N); + let arr_slice = ndarray::ArrayViewMut1::from_shape((C::N), arr_slice).unwrap(); + let sum_length = 1; + let (sum_slice, slice) = slice.split_at_mut(sum_length); + Self { + arr: arr_slice, + sum: &mut sum_slice[0], + } + } + pub const fn width() -> usize { + 0 + 1 * C::N + 1 + } +} +*/ diff --git a/crates/circuits/primitives/derive/tests/test_cols_ref.rs b/crates/circuits/primitives/derive/tests/test_cols_ref.rs new file mode 100644 index 0000000000..6bad0c4f9f --- /dev/null +++ b/crates/circuits/primitives/derive/tests/test_cols_ref.rs @@ -0,0 +1,299 @@ +use openvm_circuit_primitives_derive::{AlignedBorrow, ColsRef}; + +pub trait TestConfig { + const N: usize; + const M: usize; +} +pub struct TestConfigImpl; +impl TestConfig for TestConfigImpl { + const N: usize = 5; + const M: usize = 2; +} + +#[allow(dead_code)] // TestCols isn't actually used in the code. silence clippy warning +#[derive(ColsRef)] +#[config(TestConfig)] +struct TestCols { + single_field_element: T, + array_of_t: [T; N], + nested_array_of_t: [[T; N]; N], + cols_struct: TestSubCols, + #[aligned_borrow] + array_of_aligned_borrow: [TestAlignedBorrow; N], + #[aligned_borrow] + nested_array_of_aligned_borrow: [[TestAlignedBorrow; N]; N], +} + +#[allow(dead_code)] // TestSubCols isn't actually used in the code. silence clippy warning +#[derive(ColsRef, Debug)] +#[config(TestConfig)] +struct TestSubCols { + // TestSubCols can have fields of any type that TestCols can have + a: T, + b: [T; M], + #[aligned_borrow] + c: TestAlignedBorrow, +} + +#[derive(AlignedBorrow, Debug)] +struct TestAlignedBorrow { + a: T, + b: [T; 5], +} + +#[test] +fn test_cols_ref() { + assert_eq!( + TestColsRef::::width::(), + TestColsRefMut::::width::() + ); + const WIDTH: usize = TestColsRef::::width::(); + let mut input = vec![0; WIDTH]; + let mut cols: TestColsRefMut = TestColsRefMut::from::(&mut input); + + *cols.single_field_element = 1; + cols.array_of_t[0] = 2; + cols.nested_array_of_t[[0, 0]] = 3; + *cols.cols_struct.a = 4; + cols.cols_struct.b[0] = 5; + cols.cols_struct.c.a = 6; + cols.cols_struct.c.b[0] = 7; + cols.array_of_aligned_borrow[0].a = 8; + cols.array_of_aligned_borrow[0].b[0] = 9; + cols.nested_array_of_aligned_borrow[[0, 0]].a = 10; + cols.nested_array_of_aligned_borrow[[0, 0]].b[0] = 11; + + let cols: TestColsRef = TestColsRef::from::(&input); + println!("{:?}", cols); + assert_eq!(*cols.single_field_element, 1); + assert_eq!(cols.array_of_t[0], 2); + assert_eq!(cols.nested_array_of_t[[0, 0]], 3); + assert_eq!(*cols.cols_struct.a, 4); + assert_eq!(cols.cols_struct.b[0], 5); + assert_eq!(cols.cols_struct.c.a, 6); + assert_eq!(cols.cols_struct.c.b[0], 7); + assert_eq!(cols.array_of_aligned_borrow[0].a, 8); + assert_eq!(cols.array_of_aligned_borrow[0].b[0], 9); + assert_eq!(cols.nested_array_of_aligned_borrow[[0, 0]].a, 10); + assert_eq!(cols.nested_array_of_aligned_borrow[[0, 0]].b[0], 11); +} + +/* + * For reference, this is what the ColsRef macro expands to. + * The `cargo expand` tool is helpful for understanding how the ColsRef macro works. + * See https://github.com/dtolnay/cargo-expand + +#[derive(Debug, Clone)] +struct TestColsRef<'a, T> { + pub single_field_element: &'a T, + pub array_of_t: ndarray::ArrayView1<'a, T>, + pub nested_array_of_t: ndarray::ArrayView2<'a, T>, + pub cols_struct: TestSubColsRef<'a, T>, + pub array_of_aligned_borrow: ndarray::ArrayView1<'a, TestAlignedBorrow>, + pub nested_array_of_aligned_borrow: ndarray::ArrayView2<'a, TestAlignedBorrow>, +} + +impl<'a, T> TestColsRef<'a, T> { + pub fn from(slice: &'a [T]) -> Self { + let single_field_element_length = 1; + let (single_field_element_slice, slice) = slice + .split_at(single_field_element_length); + let (array_of_t_slice, slice) = slice.split_at(1 * C::N); + let array_of_t_slice = ndarray::ArrayView1::from_shape((C::N), array_of_t_slice) + .unwrap(); + let (nested_array_of_t_slice, slice) = slice.split_at(1 * C::N * C::N); + let nested_array_of_t_slice = ndarray::ArrayView2::from_shape( + (C::N, C::N), + nested_array_of_t_slice, + ) + .unwrap(); + let cols_struct_length = >::width::(); + let (cols_struct_slice, slice) = slice.split_at(cols_struct_length); + let cols_struct_slice = >::from::(cols_struct_slice); + let (array_of_aligned_borrow_slice, slice) = slice + .split_at(>::width() * C::N); + let array_of_aligned_borrow_slice: &[TestAlignedBorrow] = unsafe { + &*(array_of_aligned_borrow_slice as *const [T] + as *const [TestAlignedBorrow]) + }; + let array_of_aligned_borrow_slice = ndarray::ArrayView1::from_shape( + (C::N), + array_of_aligned_borrow_slice, + ) + .unwrap(); + let (nested_array_of_aligned_borrow_slice, slice) = slice + .split_at(>::width() * C::N * C::N); + let nested_array_of_aligned_borrow_slice: &[TestAlignedBorrow] = unsafe { + &*(nested_array_of_aligned_borrow_slice as *const [T] + as *const [TestAlignedBorrow]) + }; + let nested_array_of_aligned_borrow_slice = ndarray::ArrayView2::from_shape( + (C::N, C::N), + nested_array_of_aligned_borrow_slice, + ) + .unwrap(); + Self { + single_field_element: &single_field_element_slice[0], + array_of_t: array_of_t_slice, + nested_array_of_t: nested_array_of_t_slice, + cols_struct: cols_struct_slice, + array_of_aligned_borrow: array_of_aligned_borrow_slice, + nested_array_of_aligned_borrow: nested_array_of_aligned_borrow_slice, + } + } + pub const fn width() -> usize { + 0 + 1 + 1 * C::N + 1 * C::N * C::N + >::width::() + + >::width() * C::N + + >::width() * C::N * C::N + } +} + +impl<'b, T> TestColsRef<'b, T> { + pub fn from_mut<'a, C: TestConfig>(other: &'b TestColsRefMut<'a, T>) -> Self { + Self { + single_field_element: &other.single_field_element, + array_of_t: other.array_of_t.view(), + nested_array_of_t: other.nested_array_of_t.view(), + cols_struct: >::from_mut::(&other.cols_struct), + array_of_aligned_borrow: other.array_of_aligned_borrow.view(), + nested_array_of_aligned_borrow: other.nested_array_of_aligned_borrow.view(), + } + } +} + +#[derive(Debug)] +struct TestColsRefMut<'a, T> { + pub single_field_element: &'a mut T, + pub array_of_t: ndarray::ArrayViewMut1<'a, T>, + pub nested_array_of_t: ndarray::ArrayViewMut2<'a, T>, + pub cols_struct: TestSubColsRefMut<'a, T>, + pub array_of_aligned_borrow: ndarray::ArrayViewMut1<'a, TestAlignedBorrow>, + pub nested_array_of_aligned_borrow: ndarray::ArrayViewMut2<'a, TestAlignedBorrow>, +} + +impl<'a, T> TestColsRefMut<'a, T> { + pub fn from(slice: &'a mut [T]) -> Self { + let single_field_element_length = 1; + let (single_field_element_slice, slice) = slice + .split_at_mut(single_field_element_length); + let (array_of_t_slice, slice) = slice.split_at_mut(1 * C::N); + let array_of_t_slice = ndarray::ArrayViewMut1::from_shape( + (C::N), + array_of_t_slice, + ) + .unwrap(); + let (nested_array_of_t_slice, slice) = slice.split_at_mut(1 * C::N * C::N); + let nested_array_of_t_slice = ndarray::ArrayViewMut2::from_shape( + (C::N, C::N), + nested_array_of_t_slice, + ) + .unwrap(); + let cols_struct_length = >::width::(); + let (cols_struct_slice, slice) = slice.split_at_mut(cols_struct_length); + let cols_struct_slice = >::from::(cols_struct_slice); + let (array_of_aligned_borrow_slice, slice) = slice + .split_at_mut(>::width() * C::N); + let array_of_aligned_borrow_slice: &mut [TestAlignedBorrow] = unsafe { + &mut *(array_of_aligned_borrow_slice as *mut [T] + as *mut [TestAlignedBorrow]) + }; + let array_of_aligned_borrow_slice = ndarray::ArrayViewMut1::from_shape( + (C::N), + array_of_aligned_borrow_slice, + ) + .unwrap(); + let (nested_array_of_aligned_borrow_slice, slice) = slice + .split_at_mut(>::width() * C::N * C::N); + let nested_array_of_aligned_borrow_slice: &mut [TestAlignedBorrow] = unsafe { + &mut *(nested_array_of_aligned_borrow_slice as *mut [T] + as *mut [TestAlignedBorrow]) + }; + let nested_array_of_aligned_borrow_slice = ndarray::ArrayViewMut2::from_shape( + (C::N, C::N), + nested_array_of_aligned_borrow_slice, + ) + .unwrap(); + Self { + single_field_element: &mut single_field_element_slice[0], + array_of_t: array_of_t_slice, + nested_array_of_t: nested_array_of_t_slice, + cols_struct: cols_struct_slice, + array_of_aligned_borrow: array_of_aligned_borrow_slice, + nested_array_of_aligned_borrow: nested_array_of_aligned_borrow_slice, + } + } + pub const fn width() -> usize { + 0 + 1 + 1 * C::N + 1 * C::N * C::N + >::width::() + + >::width() * C::N + + >::width() * C::N * C::N + } +} + +#[derive(Debug, Clone)] +struct TestSubColsRef<'a, T> { + pub a: &'a T, + pub b: ndarray::ArrayView1<'a, T>, + pub c: &'a TestAlignedBorrow, +} + +impl<'a, T> TestSubColsRef<'a, T> { + pub fn from(slice: &'a [T]) -> Self { + let a_length = 1; + let (a_slice, slice) = slice.split_at(a_length); + let (b_slice, slice) = slice.split_at(1 * C::M); + let b_slice = ndarray::ArrayView1::from_shape((C::M), b_slice).unwrap(); + let c_length = >::width(); + let (c_slice, slice) = slice.split_at(c_length); + Self { + a: &a_slice[0], + b: b_slice, + c: { + use core::borrow::Borrow; + c_slice.borrow() + }, + } + } + pub const fn width() -> usize { + 0 + 1 + 1 * C::M + >::width() + } +} + +impl<'b, T> TestSubColsRef<'b, T> { + pub fn from_mut<'a, C: TestConfig>(other: &'b TestSubColsRefMut<'a, T>) -> Self { + Self { + a: &other.a, + b: other.b.view(), + c: other.c, + } + } +} + +#[derive(Debug)] +struct TestSubColsRefMut<'a, T> { + pub a: &'a mut T, + pub b: ndarray::ArrayViewMut1<'a, T>, + pub c: &'a mut TestAlignedBorrow, +} + +impl<'a, T> TestSubColsRefMut<'a, T> { + pub fn from(slice: &'a mut [T]) -> Self { + let a_length = 1; + let (a_slice, slice) = slice.split_at_mut(a_length); + let (b_slice, slice) = slice.split_at_mut(1 * C::M); + let b_slice = ndarray::ArrayViewMut1::from_shape((C::M), b_slice).unwrap(); + let c_length = >::width(); + let (c_slice, slice) = slice.split_at_mut(c_length); + Self { + a: &mut a_slice[0], + b: b_slice, + c: { + use core::borrow::BorrowMut; + c_slice.borrow_mut() + }, + } + } + pub const fn width() -> usize { + 0 + 1 + 1 * C::M + >::width() + } +} +*/ diff --git a/crates/circuits/sha256-air/Cargo.toml b/crates/circuits/sha2-air/Cargo.toml similarity index 77% rename from crates/circuits/sha256-air/Cargo.toml rename to crates/circuits/sha2-air/Cargo.toml index c376a1ffdd..9758a10e6e 100644 --- a/crates/circuits/sha256-air/Cargo.toml +++ b/crates/circuits/sha2-air/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "openvm-sha256-air" +name = "openvm-sha2-air" version.workspace = true authors.workspace = true edition.workspace = true @@ -7,8 +7,11 @@ edition.workspace = true [dependencies] openvm-circuit-primitives = { workspace = true } openvm-stark-backend = { workspace = true } +openvm-circuit-primitives-derive = { workspace = true } sha2 = { version = "0.10", features = ["compress"] } rand.workspace = true +ndarray.workspace = true +num_enum = { workspace = true } [dev-dependencies] openvm-stark-sdk = { workspace = true } diff --git a/crates/circuits/sha2-air/src/air.rs b/crates/circuits/sha2-air/src/air.rs new file mode 100644 index 0000000000..9f110480fd --- /dev/null +++ b/crates/circuits/sha2-air/src/air.rs @@ -0,0 +1,694 @@ +use std::{cmp::max, iter::once, marker::PhantomData}; + +use ndarray::s; +use openvm_circuit_primitives::{ + bitwise_op_lookup::BitwiseOperationLookupBus, + encoder::Encoder, + utils::{not, select}, + SubAir, +}; +use openvm_stark_backend::{ + interaction::{BusIndex, InteractionBuilder, PermutationCheckBus}, + p3_air::{AirBuilder, BaseAir}, + p3_field::{Field, FieldAlgebra}, + p3_matrix::Matrix, +}; + +use super::{ + big_sig0_field, big_sig1_field, ch_field, compose, maj_field, small_sig0_field, + small_sig1_field, +}; +use crate::{ + constraint_word_addition, word_into_u16_limbs, Sha2Config, ShaDigestColsRef, ShaRoundColsRef, +}; + +/// Expects the message to be padded to a multiple of C::BLOCK_WORDS * C::WORD_BITS bits +#[derive(Clone, Debug)] +pub struct Sha2Air { + pub bitwise_lookup_bus: BitwiseOperationLookupBus, + pub row_idx_encoder: Encoder, + /// Internal bus for self-interactions in this AIR. + bus: PermutationCheckBus, + _phantom: PhantomData, +} + +impl Sha2Air { + pub fn new(bitwise_lookup_bus: BitwiseOperationLookupBus, self_bus_idx: BusIndex) -> Self { + Self { + bitwise_lookup_bus, + row_idx_encoder: Encoder::new(C::ROWS_PER_BLOCK + 1, 2, false), /* + 1 for dummy + * (padding) rows */ + bus: PermutationCheckBus::new(self_bus_idx), + _phantom: PhantomData, + } + } +} + +impl BaseAir for Sha2Air { + fn width(&self) -> usize { + max(C::ROUND_WIDTH, C::DIGEST_WIDTH) + } +} + +impl SubAir for Sha2Air { + /// The start column for the sub-air to use + type AirContext<'a> + = usize + where + Self: 'a, + AB: 'a, + ::Var: 'a, + ::Expr: 'a; + + fn eval<'a>(&'a self, builder: &'a mut AB, start_col: Self::AirContext<'a>) + where + ::Var: 'a, + ::Expr: 'a, + { + self.eval_row(builder, start_col); + self.eval_transitions(builder, start_col); + } +} + +impl Sha2Air { + /// Implements the single row constraints (i.e. imposes constraints only on local) + /// Implements some sanity constraints on the row index, flags, and work variables + fn eval_row(&self, builder: &mut AB, start_col: usize) { + let main = builder.main(); + let local = main.row_slice(0); + + // Doesn't matter which column struct we use here as we are only interested in the common + // columns + let local_cols: ShaDigestColsRef = + ShaDigestColsRef::from::(&local[start_col..start_col + C::DIGEST_WIDTH]); + let flags = &local_cols.flags; + builder.assert_bool(*flags.is_round_row); + builder.assert_bool(*flags.is_first_4_rows); + builder.assert_bool(*flags.is_digest_row); + builder.assert_bool(*flags.is_round_row + *flags.is_digest_row); + builder.assert_bool(*flags.is_last_block); + + self.row_idx_encoder + .eval(builder, local_cols.flags.row_idx.to_slice().unwrap()); + builder.assert_one(self.row_idx_encoder.contains_flag_range::( + local_cols.flags.row_idx.to_slice().unwrap(), + 0..=C::ROWS_PER_BLOCK, + )); + builder.assert_eq( + self.row_idx_encoder + .contains_flag_range::(local_cols.flags.row_idx.to_slice().unwrap(), 0..=3), + *flags.is_first_4_rows, + ); + builder.assert_eq( + self.row_idx_encoder.contains_flag_range::( + local_cols.flags.row_idx.to_slice().unwrap(), + 0..=C::ROUND_ROWS - 1, + ), + *flags.is_round_row, + ); + builder.assert_eq( + self.row_idx_encoder.contains_flag::( + local_cols.flags.row_idx.to_slice().unwrap(), + &[C::ROUND_ROWS], + ), + *flags.is_digest_row, + ); + // If padding row we want the row_idx to be C::ROWS_PER_BLOCK + builder.assert_eq( + self.row_idx_encoder.contains_flag::( + local_cols.flags.row_idx.to_slice().unwrap(), + &[C::ROWS_PER_BLOCK], + ), + flags.is_padding_row(), + ); + + // Constrain a, e, being composed of bits: we make sure a and e are always in the same place + // in the trace matrix Note: this has to be true for every row, even padding rows + for i in 0..C::ROUNDS_PER_ROW { + for j in 0..C::WORD_BITS { + builder.assert_bool(local_cols.hash.a[[i, j]]); + builder.assert_bool(local_cols.hash.e[[i, j]]); + } + } + } + + /// Implements constraints for a digest row that ensure proper state transitions between blocks + /// This validates that: + /// The work variables are correctly initialized for the next message block + /// For the last message block, the initial state matches SHA_H constants + fn eval_digest_row( + &self, + builder: &mut AB, + local: ShaRoundColsRef, + next: ShaDigestColsRef, + ) { + // Check that if this is the last row of a message or an inpadding row, the hash should be + // the [SHA_H] + for i in 0..C::ROUNDS_PER_ROW { + let a = next.hash.a.row(i).mapv(|x| x.into()).to_vec(); + let e = next.hash.e.row(i).mapv(|x| x.into()).to_vec(); + + for j in 0..C::WORD_U16S { + let a_limb = compose::(&a[j * 16..(j + 1) * 16], 1); + let e_limb = compose::(&e[j * 16..(j + 1) * 16], 1); + + // If it is a padding row or the last row of a message, the `hash` should be the + // [SHA_H] + builder + .when( + next.flags.is_padding_row() + + *next.flags.is_last_block * *next.flags.is_digest_row, + ) + .assert_eq( + a_limb, + AB::Expr::from_canonical_u32( + word_into_u16_limbs::(C::get_h()[C::ROUNDS_PER_ROW - i - 1])[j], + ), + ); + + builder + .when( + next.flags.is_padding_row() + + *next.flags.is_last_block * *next.flags.is_digest_row, + ) + .assert_eq( + e_limb, + AB::Expr::from_canonical_u32( + word_into_u16_limbs::(C::get_h()[C::ROUNDS_PER_ROW - i + 3])[j], + ), + ); + } + } + + // Check if last row of a non-last block, the `hash` should be equal to the final hash of + // the current block + for i in 0..C::ROUNDS_PER_ROW { + let prev_a = next.hash.a.row(i).mapv(|x| x.into()).to_vec(); + let prev_e = next.hash.e.row(i).mapv(|x| x.into()).to_vec(); + let cur_a = next + .final_hash + .row(C::ROUNDS_PER_ROW - i - 1) + .mapv(|x| x.into()); + + let cur_e = next + .final_hash + .row(C::ROUNDS_PER_ROW - i + 3) + .mapv(|x| x.into()); + for j in 0..C::WORD_U8S { + let prev_a_limb = compose::(&prev_a[j * 8..(j + 1) * 8], 1); + let prev_e_limb = compose::(&prev_e[j * 8..(j + 1) * 8], 1); + + builder + .when(not(*next.flags.is_last_block) * *next.flags.is_digest_row) + .assert_eq(prev_a_limb, cur_a[j].clone()); + + builder + .when(not(*next.flags.is_last_block) * *next.flags.is_digest_row) + .assert_eq(prev_e_limb, cur_e[j].clone()); + } + } + + // Assert that the previous hash + work vars == final hash. + // That is, `next.prev_hash[i] + local.work_vars[i] == next.final_hash[i]` + // where addition is done modulo 2^32 + for i in 0..C::HASH_WORDS { + let mut carry = AB::Expr::ZERO; + for j in 0..C::WORD_U16S { + let work_var_limb = if i < C::ROUNDS_PER_ROW { + compose::( + local + .work_vars + .a + .slice(s![C::ROUNDS_PER_ROW - 1 - i, j * 16..(j + 1) * 16]) + .as_slice() + .unwrap(), + 1, + ) + } else { + compose::( + local + .work_vars + .e + .slice(s![C::ROUNDS_PER_ROW + 3 - i, j * 16..(j + 1) * 16]) + .as_slice() + .unwrap(), + 1, + ) + }; + let final_hash_limb = compose::( + next.final_hash + .slice(s![i, j * 2..(j + 1) * 2]) + .as_slice() + .unwrap(), + 8, + ); + + carry = AB::Expr::from(AB::F::from_canonical_u32(1 << 16).inverse()) + * (next.prev_hash[[i, j]] + work_var_limb + carry - final_hash_limb); + builder + .when(*next.flags.is_digest_row) + .assert_bool(carry.clone()); + } + // constrain the final hash limbs two at a time since we can do two checks per + // interaction + for chunk in next.final_hash.row(i).as_slice().unwrap().chunks(2) { + self.bitwise_lookup_bus + .send_range(chunk[0], chunk[1]) + .eval(builder, *next.flags.is_digest_row); + } + } + } + + fn eval_transitions(&self, builder: &mut AB, start_col: usize) { + let main = builder.main(); + let local = main.row_slice(0); + let next = main.row_slice(1); + + // Doesn't matter what column structs we use here + let local_cols: ShaRoundColsRef = + ShaRoundColsRef::from::(&local[start_col..start_col + C::ROUND_WIDTH]); + let next_cols: ShaRoundColsRef = + ShaRoundColsRef::from::(&next[start_col..start_col + C::ROUND_WIDTH]); + + let local_is_padding_row = local_cols.flags.is_padding_row(); + // Note that there will always be a padding row in the trace since the unpadded height is a + // multiple of 17 (SHA-256) or 21 (SHA-512, SHA-384). So the next row is padding iff the + // current block is the last block in the trace. + let next_is_padding_row = next_cols.flags.is_padding_row(); + + // We check that the very last block has `is_last_block` set to true, which guarantees that + // there is at least one complete message. If other digest rows have `is_last_block` set to + // true, then the trace will be interpreted as containing multiple messages. + builder + .when(next_is_padding_row.clone()) + .when(*next_cols.flags.is_digest_row) + .assert_one(*next_cols.flags.is_last_block); + // If we are in a round row, the next row cannot be a padding row + builder + .when(*local_cols.flags.is_round_row) + .assert_zero(next_is_padding_row.clone()); + // The first row must be a round row + builder + .when_first_row() + .assert_one(*local_cols.flags.is_round_row); + // If we are in a padding row, the next row must also be a padding row + builder + .when_transition() + .when(local_is_padding_row.clone()) + .assert_one(next_is_padding_row.clone()); + // If we are in a digest row, the next row cannot be a digest row + builder + .when(*local_cols.flags.is_digest_row) + .assert_zero(*next_cols.flags.is_digest_row); + // Constrain how much the row index changes by + // round->round: 1 + // round->digest: 1 + // digest->round: -C::ROUND_ROWS + // digest->padding: 1 + // padding->padding: 0 + // Other transitions are not allowed by the above constraints + let delta = *local_cols.flags.is_round_row * AB::Expr::ONE + + *local_cols.flags.is_digest_row + * *next_cols.flags.is_round_row + * AB::Expr::from_canonical_usize(C::ROUND_ROWS) + * AB::Expr::NEG_ONE + + *local_cols.flags.is_digest_row * next_is_padding_row.clone() * AB::Expr::ONE; + + let local_row_idx = self.row_idx_encoder.flag_with_val::( + local_cols.flags.row_idx.to_slice().unwrap(), + &(0..=C::ROWS_PER_BLOCK).map(|i| (i, i)).collect::>(), + ); + let next_row_idx = self.row_idx_encoder.flag_with_val::( + next_cols.flags.row_idx.to_slice().unwrap(), + &(0..=C::ROWS_PER_BLOCK).map(|i| (i, i)).collect::>(), + ); + + builder + .when_transition() + .assert_eq(local_row_idx.clone() + delta, next_row_idx.clone()); + builder.when_first_row().assert_zero(local_row_idx); + + // Constrain the global block index + // We set the global block index to 0 for padding rows + // Starting with 1 so it is not the same as the padding rows + + // Global block index is 1 on first row + builder + .when_first_row() + .assert_one(*local_cols.flags.global_block_idx); + + // Global block index is constant on all rows in a block + builder.when(*local_cols.flags.is_round_row).assert_eq( + *local_cols.flags.global_block_idx, + *next_cols.flags.global_block_idx, + ); + // Global block index increases by 1 between blocks + builder + .when_transition() + .when(*local_cols.flags.is_digest_row) + .when(*next_cols.flags.is_round_row) + .assert_eq( + *local_cols.flags.global_block_idx + AB::Expr::ONE, + *next_cols.flags.global_block_idx, + ); + // Global block index is 0 on padding rows + builder + .when(local_is_padding_row.clone()) + .assert_zero(*local_cols.flags.global_block_idx); + + // Constrain the local block index + // We set the local block index to 0 for padding rows + + // Local block index is constant on all rows in a block + // and its value on padding rows is equal to its value on the first block + builder + .when(not(*local_cols.flags.is_digest_row)) + .assert_eq( + *local_cols.flags.local_block_idx, + *next_cols.flags.local_block_idx, + ); + // Local block index increases by 1 between blocks in the same message + builder + .when(*local_cols.flags.is_digest_row) + .when(not(*local_cols.flags.is_last_block)) + .assert_eq( + *local_cols.flags.local_block_idx + AB::Expr::ONE, + *next_cols.flags.local_block_idx, + ); + // Local block index is 0 on padding rows + // Combined with the above, this means that the local block index is 0 in the first block + builder + .when(*local_cols.flags.is_digest_row) + .when(*local_cols.flags.is_last_block) + .assert_zero(*next_cols.flags.local_block_idx); + + self.eval_message_schedule(builder, local_cols.clone(), next_cols.clone()); + self.eval_work_vars(builder, local_cols.clone(), next_cols); + let next_cols: ShaDigestColsRef = + ShaDigestColsRef::from::(&next[start_col..start_col + C::DIGEST_WIDTH]); + self.eval_digest_row(builder, local_cols, next_cols); + let local_cols: ShaDigestColsRef = + ShaDigestColsRef::from::(&local[start_col..start_col + C::DIGEST_WIDTH]); + self.eval_prev_hash(builder, local_cols, next_is_padding_row); + } + + /// Constrains that the next block's `prev_hash` is equal to the current block's `hash` + /// Note: the constraining is done by interactions with the chip itself on every digest row + fn eval_prev_hash( + &self, + builder: &mut AB, + local: ShaDigestColsRef, + is_last_block_of_trace: AB::Expr, /* note this indicates the last block of the trace, + * not the last block of the message */ + ) { + // Constrain that next block's `prev_hash` is equal to the current block's `hash` + let composed_hash = (0..C::HASH_WORDS) + .map(|i| { + let hash_bits = if i < C::ROUNDS_PER_ROW { + local + .hash + .a + .row(C::ROUNDS_PER_ROW - 1 - i) + .mapv(|x| x.into()) + .to_vec() + } else { + local + .hash + .e + .row(C::ROUNDS_PER_ROW + 3 - i) + .mapv(|x| x.into()) + .to_vec() + }; + (0..C::WORD_U16S) + .map(|j| compose::(&hash_bits[j * 16..(j + 1) * 16], 1)) + .collect::>() + }) + .collect::>(); + // Need to handle the case if this is the very last block of the trace matrix + let next_global_block_idx = select( + is_last_block_of_trace, + AB::Expr::ONE, + *local.flags.global_block_idx + AB::Expr::ONE, + ); + // The following interactions constrain certain values from block to block + self.bus.send( + builder, + composed_hash + .into_iter() + .flatten() + .chain(once(next_global_block_idx)), + *local.flags.is_digest_row, + ); + + self.bus.receive( + builder, + local + .prev_hash + .flatten() + .mapv(|x| x.into()) + .into_iter() + .chain(once((*local.flags.global_block_idx).into())), + *local.flags.is_digest_row, + ); + } + + /// Constrain the message schedule additions for `next` row + /// Note: For every addition we need to constrain the following for each of [WORD_U16S] limbs + /// sig_1(w_{t-2})[i] + w_{t-7}[i] + sig_0(w_{t-15})[i] + w_{t-16}[i] + carry_w[t][i-1] - + /// carry_w[t][i] * 2^16 - w_t[i] == 0 Refer to [https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.180-4.pdf] + fn eval_message_schedule<'a, AB: InteractionBuilder>( + &self, + builder: &mut AB, + local: ShaRoundColsRef<'a, AB::Var>, + next: ShaRoundColsRef<'a, AB::Var>, + ) { + // This `w` array contains 8 message schedule words - w_{idx}, ..., w_{idx+7} for some idx + let w = ndarray::concatenate( + ndarray::Axis(0), + &[local.message_schedule.w, next.message_schedule.w], + ) + .unwrap(); + + // Constrain `w_3` for `next` row + for i in 0..C::ROUNDS_PER_ROW - 1 { + // here we constrain the w_3 of the i_th word of the next row + // w_3 of next is w[i+4-3] = w[i+1] + let w_3 = w.row(i + 1).mapv(|x| x.into()).to_vec(); + let expected_w_3 = next.schedule_helper.w_3.row(i); + for j in 0..C::WORD_U16S { + let w_3_limb = compose::(&w_3[j * 16..(j + 1) * 16], 1); + builder + .when(*local.flags.is_round_row) + .assert_eq(w_3_limb, expected_w_3[j].into()); + } + } + + // Constrain intermed for `next` row + // We will only constrain intermed_12 for rows [3, C::ROUND_ROWS - 2], and let it + // unconstrained for other rows Other rows should put the needed value in + // intermed_12 to make the below summation constraint hold + let is_row_intermed_12 = self.row_idx_encoder.contains_flag_range::( + next.flags.row_idx.to_slice().unwrap(), + 3..=C::ROUND_ROWS - 2, + ); + // We will only constrain intermed_8 for rows [2, C::ROUND_ROWS - 3], and let it + // unconstrained for other rows + let is_row_intermed_8 = self.row_idx_encoder.contains_flag_range::( + next.flags.row_idx.to_slice().unwrap(), + 2..=C::ROUND_ROWS - 3, + ); + for i in 0..C::ROUNDS_PER_ROW { + // w_idx + let w_idx = w.row(i).mapv(|x| x.into()).to_vec(); + // sig_0(w_{idx+1}) + let sig_w = small_sig0_field::(w.row(i + 1).as_slice().unwrap()); + for j in 0..C::WORD_U16S { + let w_idx_limb = compose::(&w_idx[j * 16..(j + 1) * 16], 1); + let sig_w_limb = compose::(&sig_w[j * 16..(j + 1) * 16], 1); + + // We would like to constrain this only on rows 0..16, but we can't do a conditional + // check because the degree is already 3. So we must fill in + // `intermed_4` with dummy values on rows 0 and 16 to ensure the constraint holds on + // these rows. + builder.when_transition().assert_eq( + next.schedule_helper.intermed_4[[i, j]], + w_idx_limb + sig_w_limb, + ); + + builder.when(is_row_intermed_8.clone()).assert_eq( + next.schedule_helper.intermed_8[[i, j]], + local.schedule_helper.intermed_4[[i, j]], + ); + + builder.when(is_row_intermed_12.clone()).assert_eq( + next.schedule_helper.intermed_12[[i, j]], + local.schedule_helper.intermed_8[[i, j]], + ); + } + } + + // Constrain the message schedule additions for `next` row + for i in 0..C::ROUNDS_PER_ROW { + // Note, here by w_{t} we mean the i_th word of the `next` row + // w_{t-7} + let w_7 = if i < 3 { + local.schedule_helper.w_3.row(i).mapv(|x| x.into()).to_vec() + } else { + let w_3 = w.row(i - 3).mapv(|x| x.into()).to_vec(); + (0..C::WORD_U16S) + .map(|j| compose::(&w_3[j * 16..(j + 1) * 16], 1)) + .collect::>() + }; + // sig_0(w_{t-15}) + w_{t-16} + let intermed_16 = local.schedule_helper.intermed_12.row(i).mapv(|x| x.into()); + + let carries = (0..C::WORD_U16S) + .map(|j| { + next.message_schedule.carry_or_buffer[[i, j * 2]] + + AB::Expr::TWO * next.message_schedule.carry_or_buffer[[i, j * 2 + 1]] + }) + .collect::>(); + + // Constrain `W_{idx} = sig_1(W_{idx-2}) + W_{idx-7} + sig_0(W_{idx-15}) + W_{idx-16}` + // We would like to constrain this only on rows 4..C::ROUND_ROWS, but we can't do a + // conditional check because the degree of sum is already 3 So we must fill + // in `intermed_12` with dummy values on rows 0..3 and C::ROUND_ROWS-1 and C::ROUND_ROWS + // to ensure the constraint holds on rows 0..4 and C::ROUND_ROWS. Note that + // the dummy value goes in the previous row to make the current row's constraint hold. + constraint_word_addition::<_, C>( + // Note: here we can't do a conditional check because the degree of sum is already + // 3 + &mut builder.when_transition(), + &[&small_sig1_field::( + w.row(i + 2).as_slice().unwrap(), + )], + &[&w_7, intermed_16.as_slice().unwrap()], + w.row(i + 4).as_slice().unwrap(), + &carries, + ); + + for j in 0..C::WORD_U16S { + // When on rows 4..C::ROUND_ROWS message schedule carries should be 0 or 1 + let is_row_4_or_more = *next.flags.is_round_row - *next.flags.is_first_4_rows; + builder + .when(is_row_4_or_more.clone()) + .assert_bool(next.message_schedule.carry_or_buffer[[i, j * 2]]); + builder + .when(is_row_4_or_more) + .assert_bool(next.message_schedule.carry_or_buffer[[i, j * 2 + 1]]); + } + // Constrain w being composed of bits + for j in 0..C::WORD_BITS { + builder + .when(*next.flags.is_round_row) + .assert_bool(next.message_schedule.w[[i, j]]); + } + } + } + + /// Constrain the work vars on `next` row according to the sha documentation + /// Refer to [https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.180-4.pdf] + fn eval_work_vars<'a, AB: InteractionBuilder>( + &self, + builder: &mut AB, + local: ShaRoundColsRef<'a, AB::Var>, + next: ShaRoundColsRef<'a, AB::Var>, + ) { + let a = + ndarray::concatenate(ndarray::Axis(0), &[local.work_vars.a, next.work_vars.a]).unwrap(); + let e = + ndarray::concatenate(ndarray::Axis(0), &[local.work_vars.e, next.work_vars.e]).unwrap(); + + for i in 0..C::ROUNDS_PER_ROW { + for j in 0..C::WORD_U16S { + // Although we need carry_a <= 6 and carry_e <= 5, constraining carry_a, carry_e in + // [0, 2^8) is enough to prevent overflow and ensure the soundness + // of the addition we want to check + self.bitwise_lookup_bus + .send_range( + local.work_vars.carry_a[[i, j]], + local.work_vars.carry_e[[i, j]], + ) + .eval(builder, *local.flags.is_round_row); + } + + let w_limbs = (0..C::WORD_U16S) + .map(|j| { + compose::( + next.message_schedule + .w + .slice(s![i, j * 16..(j + 1) * 16]) + .as_slice() + .unwrap(), + 1, + ) * *next.flags.is_round_row + }) + .collect::>(); + + let k_limbs = (0..C::WORD_U16S) + .map(|j| { + self.row_idx_encoder.flag_with_val::( + next.flags.row_idx.to_slice().unwrap(), + &(0..C::ROUND_ROWS) + .map(|rw_idx| { + ( + rw_idx, + word_into_u16_limbs::( + C::get_k()[rw_idx * C::ROUNDS_PER_ROW + i], + )[j] as usize, + ) + }) + .collect::>(), + ) + }) + .collect::>(); + + // Constrain `a = h + sig_1(e) + ch(e, f, g) + K + W + sig_0(a) + Maj(a, b, c)` + // We have to enforce this constraint on all rows since the degree of the constraint is + // already 3. So, we must fill in `carry_a` with dummy values on digest rows + // to ensure the constraint holds. + constraint_word_addition::<_, C>( + builder, + &[ + e.row(i).mapv(|x| x.into()).as_slice().unwrap(), // previous `h` + &big_sig1_field::(e.row(i + 3).as_slice().unwrap()), /* sig_1 of previous `e` */ + &ch_field::( + e.row(i + 3).as_slice().unwrap(), + e.row(i + 2).as_slice().unwrap(), + e.row(i + 1).as_slice().unwrap(), + ), /* Ch of previous `e`, `f`, `g` */ + &big_sig0_field::(a.row(i + 3).as_slice().unwrap()), /* sig_0 of previous `a` */ + &maj_field::( + a.row(i + 3).as_slice().unwrap(), + a.row(i + 2).as_slice().unwrap(), + a.row(i + 1).as_slice().unwrap(), + ), /* Maj of previous a, b, c */ + ], + &[&w_limbs, &k_limbs], // K and W + a.row(i + 4).as_slice().unwrap(), // new `a` + next.work_vars.carry_a.row(i).as_slice().unwrap(), // carries of addition + ); + + // Constrain `e = d + h + sig_1(e) + ch(e, f, g) + K + W` + // We have to enforce this constraint on all rows since the degree of the constraint is + // already 3. So, we must fill in `carry_e` with dummy values on digest rows + // to ensure the constraint holds. + constraint_word_addition::<_, C>( + builder, + &[ + a.row(i).mapv(|x| x.into()).as_slice().unwrap(), // previous `d` + e.row(i).mapv(|x| x.into()).as_slice().unwrap(), // previous `h` + &big_sig1_field::(e.row(i + 3).as_slice().unwrap()), /* sig_1 of previous `e` */ + &ch_field::( + e.row(i + 3).as_slice().unwrap(), + e.row(i + 2).as_slice().unwrap(), + e.row(i + 1).as_slice().unwrap(), + ), /* Ch of previous `e`, `f`, `g` */ + ], + &[&w_limbs, &k_limbs], // K and W + e.row(i + 4).as_slice().unwrap(), // new `e` + next.work_vars.carry_e.row(i).as_slice().unwrap(), // carries of addition + ); + } + } +} diff --git a/crates/circuits/sha2-air/src/columns.rs b/crates/circuits/sha2-air/src/columns.rs new file mode 100644 index 0000000000..da1e334e97 --- /dev/null +++ b/crates/circuits/sha2-air/src/columns.rs @@ -0,0 +1,187 @@ +//! WARNING: the order of fields in the structs is important, do not change it + +use openvm_circuit_primitives::utils::not; +use openvm_circuit_primitives_derive::ColsRef; +use openvm_stark_backend::p3_field::FieldAlgebra; + +use crate::Sha2Config; + +/// In each SHA block: +/// - First C::ROUND_ROWS rows use ShaRoundCols +/// - Final row uses ShaDigestCols +/// +/// Note that for soundness, we require that there is always a padding row after the last digest row +/// in the trace. Right now, this is true because the unpadded height is a multiple of 17 (SHA-256) +/// or 21 (SHA-512), and thus not a power of 2. +/// +/// ShaRoundCols and ShaDigestCols share the same first 3 fields: +/// - flags +/// - work_vars/hash (same type, different name) +/// - schedule_helper +/// +/// This design allows for: +/// 1. Common constraints to work on either struct type by accessing these shared fields +/// 2. Specific constraints to use the appropriate struct, with flags helping to do conditional +/// constraints +/// +/// Note that the `ShaWorkVarsCols` field is used for different purposes in the two structs. +#[repr(C)] +#[derive(Clone, Copy, Debug, ColsRef)] +#[config(Sha2Config)] +pub struct ShaRoundCols< + T, + const WORD_BITS: usize, + const WORD_U8S: usize, + const WORD_U16S: usize, + const ROUNDS_PER_ROW: usize, + const ROUNDS_PER_ROW_MINUS_ONE: usize, + const ROW_VAR_CNT: usize, +> { + pub flags: Sha2FlagsCols, + pub work_vars: ShaWorkVarsCols, + pub schedule_helper: + Sha2MessageHelperCols, + pub message_schedule: ShaMessageScheduleCols, +} + +#[repr(C)] +#[derive(Clone, Copy, Debug, ColsRef)] +#[config(Sha2Config)] +pub struct ShaDigestCols< + T, + const WORD_BITS: usize, + const WORD_U8S: usize, + const WORD_U16S: usize, + const HASH_WORDS: usize, + const ROUNDS_PER_ROW: usize, + const ROUNDS_PER_ROW_MINUS_ONE: usize, + const ROW_VAR_CNT: usize, +> { + pub flags: Sha2FlagsCols, + /// Will serve as previous hash values for the next block + pub hash: ShaWorkVarsCols, + pub schedule_helper: + Sha2MessageHelperCols, + /// The actual final hash values of the given block + /// Note: the above `hash` will be equal to `final_hash` unless we are on the last block + pub final_hash: [[T; WORD_U8S]; HASH_WORDS], + /// The final hash of the previous block + /// Note: will be constrained using interactions with the chip itself + pub prev_hash: [[T; WORD_U16S]; HASH_WORDS], +} + +#[repr(C)] +#[derive(Clone, Copy, Debug, ColsRef)] +#[config(Sha2Config)] +pub struct ShaMessageScheduleCols< + T, + const WORD_BITS: usize, + const ROUNDS_PER_ROW: usize, + const WORD_U8S: usize, +> { + /// The message schedule words as bits + /// The first 16 words will be the message data + pub w: [[T; WORD_BITS]; ROUNDS_PER_ROW], + /// Will be message schedule carries for rows 4..C::ROUND_ROWS and a buffer for rows 0..4 to be + /// used freely by wrapper chips Note: carries are 2 bit numbers represented using 2 cells + /// as individual bits + pub carry_or_buffer: [[T; WORD_U8S]; ROUNDS_PER_ROW], +} + +#[repr(C)] +#[derive(Clone, Copy, Debug, ColsRef)] +#[config(Sha2Config)] +pub struct ShaWorkVarsCols< + T, + const WORD_BITS: usize, + const ROUNDS_PER_ROW: usize, + const WORD_U16S: usize, +> { + /// `a` and `e` after each iteration as 32-bits + pub a: [[T; WORD_BITS]; ROUNDS_PER_ROW], + pub e: [[T; WORD_BITS]; ROUNDS_PER_ROW], + /// The carry's used for addition during each iteration when computing `a` and `e` + pub carry_a: [[T; WORD_U16S]; ROUNDS_PER_ROW], + pub carry_e: [[T; WORD_U16S]; ROUNDS_PER_ROW], +} + +/// These are the columns that are used to help with the message schedule additions +/// Note: these need to be correctly assigned for every row even on padding rows +#[repr(C)] +#[derive(Clone, Copy, Debug, ColsRef)] +#[config(Sha2Config)] +pub struct Sha2MessageHelperCols< + T, + const WORD_U16S: usize, + const ROUNDS_PER_ROW: usize, + const ROUNDS_PER_ROW_MINUS_ONE: usize, +> { + /// The following are used to move data forward to constrain the message schedule additions + /// The value of `w` from 3 rounds ago + pub w_3: [[T; WORD_U16S]; ROUNDS_PER_ROW_MINUS_ONE], + /// Here intermediate(i) = w_i + sig_0(w_{i+1}) + /// Intermed_t represents the intermediate t rounds ago + /// This is needed to constrain the message schedule, since we can only constrain on two rows + /// at a time + pub intermed_4: [[T; WORD_U16S]; ROUNDS_PER_ROW], + pub intermed_8: [[T; WORD_U16S]; ROUNDS_PER_ROW], + pub intermed_12: [[T; WORD_U16S]; ROUNDS_PER_ROW], +} + +#[repr(C)] +#[derive(Clone, Copy, Debug, ColsRef)] +#[config(Sha2Config)] +pub struct Sha2FlagsCols { + pub is_round_row: T, + /// A flag that indicates if the current row is among the first 4 rows of a block (the message + /// rows) + pub is_first_4_rows: T, + pub is_digest_row: T, + pub is_last_block: T, + /// We will encode the row index [0..17) using 5 cells + pub row_idx: [T; ROW_VAR_CNT], + /// The global index of the current block + pub global_block_idx: T, + /// Will store the index of the current block in the current message starting from 0 + pub local_block_idx: T, +} + +impl, const ROW_VAR_CNT: usize> + Sha2FlagsCols +{ + // This refers to the padding rows that are added to the air to make the trace length a power of + // 2. Not to be confused with the padding added to messages as part of the SHA hash + // function. + pub fn is_not_padding_row(&self) -> O { + self.is_round_row + self.is_digest_row + } + + // This refers to the padding rows that are added to the air to make the trace length a power of + // 2. Not to be confused with the padding added to messages as part of the SHA hash + // function. + pub fn is_padding_row(&self) -> O + where + O: FieldAlgebra, + { + not(self.is_not_padding_row()) + } +} + +impl> Sha2FlagsColsRef<'_, T> { + // This refers to the padding rows that are added to the air to make the trace length a power of + // 2. Not to be confused with the padding added to messages as part of the SHA hash + // function. + pub fn is_not_padding_row(&self) -> O { + *self.is_round_row + *self.is_digest_row + } + + // This refers to the padding rows that are added to the air to make the trace length a power of + // 2. Not to be confused with the padding added to messages as part of the SHA hash + // function. + pub fn is_padding_row(&self) -> O + where + O: FieldAlgebra, + { + not(self.is_not_padding_row()) + } +} diff --git a/crates/circuits/sha2-air/src/config.rs b/crates/circuits/sha2-air/src/config.rs new file mode 100644 index 0000000000..e6e6b54202 --- /dev/null +++ b/crates/circuits/sha2-air/src/config.rs @@ -0,0 +1,388 @@ +use std::ops::{BitAnd, BitOr, BitXor, Not, Shl, Shr}; + +use crate::{ShaDigestColsRef, ShaRoundColsRef}; + +#[repr(u32)] +#[derive(num_enum::TryFromPrimitive, num_enum::IntoPrimitive)] +pub enum Sha2Variant { + Sha256, + Sha512, + Sha384, +} + +pub trait Sha2Config: Send + Sync + Clone { + type Word: 'static + + Shr + + Shl + + BitAnd + + Not + + BitXor + + BitOr + + RotateRight + + WrappingAdd + + PartialEq + + From + + TryInto + + Copy + + Send + + Sync; + // Differentiate between the SHA-2 variants + const VARIANT: Sha2Variant; + /// Number of bits in a SHA word + const WORD_BITS: usize; + /// Number of 16-bit limbs in a SHA word + const WORD_U16S: usize = Self::WORD_BITS / 16; + /// Number of 8-bit limbs in a SHA word + const WORD_U8S: usize = Self::WORD_BITS / 8; + /// Number of words in a SHA block + const BLOCK_WORDS: usize; + /// Number of cells in a SHA block + const BLOCK_U8S: usize = Self::BLOCK_WORDS * Self::WORD_U8S; + /// Number of bits in a SHA block + const BLOCK_BITS: usize = Self::BLOCK_WORDS * Self::WORD_BITS; + /// Number of rows per block + const ROWS_PER_BLOCK: usize; + /// Number of rounds per row. Must divide Self::ROUNDS_PER_BLOCK + const ROUNDS_PER_ROW: usize; + /// Number of rows used for the sha rounds + const ROUND_ROWS: usize = Self::ROUNDS_PER_BLOCK / Self::ROUNDS_PER_ROW; + /// Number of rows used for the message + const MESSAGE_ROWS: usize = Self::BLOCK_WORDS / Self::ROUNDS_PER_ROW; + /// Number of rounds per row minus one (needed for one of the column structs) + const ROUNDS_PER_ROW_MINUS_ONE: usize = Self::ROUNDS_PER_ROW - 1; + /// Number of rounds per block. Must be a multiple of Self::ROUNDS_PER_ROW + const ROUNDS_PER_BLOCK: usize; + /// Number of words in a SHA hash + const HASH_WORDS: usize; + /// Number of vars needed to encode the row index with [Encoder] + const ROW_VAR_CNT: usize; + /// Width of the ShaRoundCols + const ROUND_WIDTH: usize = ShaRoundColsRef::::width::(); + /// Width of the ShaDigestCols + const DIGEST_WIDTH: usize = ShaDigestColsRef::::width::(); + /// Width of the ShaCols + const WIDTH: usize = if Self::ROUND_WIDTH > Self::DIGEST_WIDTH { + Self::ROUND_WIDTH + } else { + Self::DIGEST_WIDTH + }; + /// Number of cells used in each message row to store the message + const CELLS_PER_ROW: usize = Self::ROUNDS_PER_ROW * Self::WORD_U8S; + + /// To optimize the trace generation of invalid rows, we precompute those values. + // these should be appropriately sized for the config + fn get_invalid_carry_a(round_num: usize) -> &'static [u32]; + fn get_invalid_carry_e(round_num: usize) -> &'static [u32]; + + /// We also store the SHA constants K and H + fn get_k() -> &'static [Self::Word]; + fn get_h() -> &'static [Self::Word]; +} + +#[derive(Clone)] +pub struct Sha256Config; + +impl Sha2Config for Sha256Config { + // ==== Do not change these constants! ==== + const VARIANT: Sha2Variant = Sha2Variant::Sha256; + type Word = u32; + /// Number of bits in a SHA256 word + const WORD_BITS: usize = 32; + /// Number of words in a SHA256 block + const BLOCK_WORDS: usize = 16; + /// Number of rows per block + const ROWS_PER_BLOCK: usize = 17; + /// Number of rounds per row + const ROUNDS_PER_ROW: usize = 4; + /// Number of rounds per block + const ROUNDS_PER_BLOCK: usize = 64; + /// Number of words in a SHA256 hash + const HASH_WORDS: usize = 8; + /// Number of vars needed to encode the row index with [Encoder] + const ROW_VAR_CNT: usize = 5; + + fn get_invalid_carry_a(round_num: usize) -> &'static [u32] { + &SHA256_INVALID_CARRY_A[round_num] + } + fn get_invalid_carry_e(round_num: usize) -> &'static [u32] { + &SHA256_INVALID_CARRY_E[round_num] + } + fn get_k() -> &'static [u32] { + &SHA256_K + } + fn get_h() -> &'static [u32] { + &SHA256_H + } +} + +pub const SHA256_INVALID_CARRY_A: [[u32; Sha256Config::WORD_U16S]; Sha256Config::ROUNDS_PER_ROW] = [ + [1230919683, 1162494304], + [266373122, 1282901987], + [1519718403, 1008990871], + [923381762, 330807052], +]; +pub const SHA256_INVALID_CARRY_E: [[u32; Sha256Config::WORD_U16S]; Sha256Config::ROUNDS_PER_ROW] = [ + [204933122, 1994683449], + [443873282, 1544639095], + [719953922, 1888246508], + [194580482, 1075725211], +]; + +/// SHA256 constant K's +pub const SHA256_K: [u32; 64] = [ + 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5, + 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174, + 0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da, + 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967, + 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, + 0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070, + 0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3, + 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2, +]; +/// SHA256 initial hash values +pub const SHA256_H: [u32; 8] = [ + 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19, +]; + +#[derive(Clone)] +pub struct Sha512Config; + +impl Sha2Config for Sha512Config { + // ==== Do not change these constants! ==== + const VARIANT: Sha2Variant = Sha2Variant::Sha512; + type Word = u64; + /// Number of bits in a SHA512 word + const WORD_BITS: usize = 64; + /// Number of words in a SHA512 block + const BLOCK_WORDS: usize = 16; + /// Number of rows per block + const ROWS_PER_BLOCK: usize = 21; + /// Number of rounds per row + const ROUNDS_PER_ROW: usize = 4; + /// Number of rounds per block + const ROUNDS_PER_BLOCK: usize = 80; + /// Number of words in a SHA512 hash + const HASH_WORDS: usize = 8; + /// Number of vars needed to encode the row index with [Encoder] + const ROW_VAR_CNT: usize = 6; + + fn get_invalid_carry_a(round_num: usize) -> &'static [u32] { + &SHA512_INVALID_CARRY_A[round_num] + } + fn get_invalid_carry_e(round_num: usize) -> &'static [u32] { + &SHA512_INVALID_CARRY_E[round_num] + } + fn get_k() -> &'static [u64] { + &SHA512_K + } + fn get_h() -> &'static [u64] { + &SHA512_H + } +} + +pub(crate) const SHA512_INVALID_CARRY_A: [[u32; Sha512Config::WORD_U16S]; + Sha512Config::ROUNDS_PER_ROW] = [ + [55971842, 827997017, 993005918, 512731953], + [227512322, 1697529235, 1936430385, 940122990], + [1939875843, 1173318562, 826201586, 1513494849], + [891955202, 1732283693, 1736658755, 223514501], +]; + +pub(crate) const SHA512_INVALID_CARRY_E: [[u32; Sha512Config::WORD_U16S]; + Sha512Config::ROUNDS_PER_ROW] = [ + [1384427522, 1509509767, 153131516, 102514978], + [1527552003, 1041677071, 837289497, 843522538], + [775188482, 1620184630, 744892564, 892058728], + [1801267202, 1393118048, 1846108940, 830635531], +]; + +/// SHA512 constant K's +pub const SHA512_K: [u64; 80] = [ + 0x428a2f98d728ae22, + 0x7137449123ef65cd, + 0xb5c0fbcfec4d3b2f, + 0xe9b5dba58189dbbc, + 0x3956c25bf348b538, + 0x59f111f1b605d019, + 0x923f82a4af194f9b, + 0xab1c5ed5da6d8118, + 0xd807aa98a3030242, + 0x12835b0145706fbe, + 0x243185be4ee4b28c, + 0x550c7dc3d5ffb4e2, + 0x72be5d74f27b896f, + 0x80deb1fe3b1696b1, + 0x9bdc06a725c71235, + 0xc19bf174cf692694, + 0xe49b69c19ef14ad2, + 0xefbe4786384f25e3, + 0x0fc19dc68b8cd5b5, + 0x240ca1cc77ac9c65, + 0x2de92c6f592b0275, + 0x4a7484aa6ea6e483, + 0x5cb0a9dcbd41fbd4, + 0x76f988da831153b5, + 0x983e5152ee66dfab, + 0xa831c66d2db43210, + 0xb00327c898fb213f, + 0xbf597fc7beef0ee4, + 0xc6e00bf33da88fc2, + 0xd5a79147930aa725, + 0x06ca6351e003826f, + 0x142929670a0e6e70, + 0x27b70a8546d22ffc, + 0x2e1b21385c26c926, + 0x4d2c6dfc5ac42aed, + 0x53380d139d95b3df, + 0x650a73548baf63de, + 0x766a0abb3c77b2a8, + 0x81c2c92e47edaee6, + 0x92722c851482353b, + 0xa2bfe8a14cf10364, + 0xa81a664bbc423001, + 0xc24b8b70d0f89791, + 0xc76c51a30654be30, + 0xd192e819d6ef5218, + 0xd69906245565a910, + 0xf40e35855771202a, + 0x106aa07032bbd1b8, + 0x19a4c116b8d2d0c8, + 0x1e376c085141ab53, + 0x2748774cdf8eeb99, + 0x34b0bcb5e19b48a8, + 0x391c0cb3c5c95a63, + 0x4ed8aa4ae3418acb, + 0x5b9cca4f7763e373, + 0x682e6ff3d6b2b8a3, + 0x748f82ee5defb2fc, + 0x78a5636f43172f60, + 0x84c87814a1f0ab72, + 0x8cc702081a6439ec, + 0x90befffa23631e28, + 0xa4506cebde82bde9, + 0xbef9a3f7b2c67915, + 0xc67178f2e372532b, + 0xca273eceea26619c, + 0xd186b8c721c0c207, + 0xeada7dd6cde0eb1e, + 0xf57d4f7fee6ed178, + 0x06f067aa72176fba, + 0x0a637dc5a2c898a6, + 0x113f9804bef90dae, + 0x1b710b35131c471b, + 0x28db77f523047d84, + 0x32caab7b40c72493, + 0x3c9ebe0a15c9bebc, + 0x431d67c49c100d4c, + 0x4cc5d4becb3e42b6, + 0x597f299cfc657e2a, + 0x5fcb6fab3ad6faec, + 0x6c44198c4a475817, +]; +/// SHA512 initial hash values +pub const SHA512_H: [u64; 8] = [ + 0x6a09e667f3bcc908, + 0xbb67ae8584caa73b, + 0x3c6ef372fe94f82b, + 0xa54ff53a5f1d36f1, + 0x510e527fade682d1, + 0x9b05688c2b3e6c1f, + 0x1f83d9abfb41bd6b, + 0x5be0cd19137e2179, +]; + +#[derive(Clone)] +pub struct Sha384Config; + +impl Sha2Config for Sha384Config { + // ==== Do not change these constants! ==== + const VARIANT: Sha2Variant = Sha2Variant::Sha384; + type Word = u64; + /// Number of bits in a SHA384 word + const WORD_BITS: usize = 64; + /// Number of words in a SHA384 block + const BLOCK_WORDS: usize = 16; + /// Number of rows per block + const ROWS_PER_BLOCK: usize = 21; + /// Number of rounds per row + const ROUNDS_PER_ROW: usize = 4; + /// Number of rounds per block + const ROUNDS_PER_BLOCK: usize = 80; + /// Number of words in a SHA384 hash + const HASH_WORDS: usize = 8; + /// Number of vars needed to encode the row index with [Encoder] + const ROW_VAR_CNT: usize = 6; + + fn get_invalid_carry_a(round_num: usize) -> &'static [u32] { + &SHA384_INVALID_CARRY_A[round_num] + } + fn get_invalid_carry_e(round_num: usize) -> &'static [u32] { + &SHA384_INVALID_CARRY_E[round_num] + } + fn get_k() -> &'static [u64] { + &SHA384_K + } + fn get_h() -> &'static [u64] { + &SHA384_H + } +} + +pub(crate) const SHA384_INVALID_CARRY_A: [[u32; Sha384Config::WORD_U16S]; + Sha384Config::ROUNDS_PER_ROW] = [ + [1571481603, 1428841901, 1050676523, 793575075], + [1233315842, 1822329223, 112923808, 1874228927], + [1245603842, 927240770, 1579759431, 70557227], + [195532801, 594312107, 1429379950, 220407092], +]; + +pub(crate) const SHA384_INVALID_CARRY_E: [[u32; Sha384Config::WORD_U16S]; + Sha384Config::ROUNDS_PER_ROW] = [ + [1067980802, 1508061099, 1418826213, 1232569491], + [1453086722, 1702524575, 152427899, 238512408], + [1623674882, 701393097, 1002035664, 4776891], + [1888911362, 184963225, 1151849224, 1034237098], +]; + +/// SHA384 constant K's +pub const SHA384_K: [u64; 80] = SHA512_K; + +/// SHA384 initial hash values +pub const SHA384_H: [u64; 8] = [ + 0xcbbb9d5dc1059ed8, + 0x629a292a367cd507, + 0x9159015a3070dd17, + 0x152fecd8f70e5939, + 0x67332667ffc00b31, + 0x8eb44a8768581511, + 0xdb0c2e0d64f98fa7, + 0x47b5481dbefa4fa4, +]; + +// Needed to avoid compile errors in utils.rs +// not sure why this doesn't inf loop +pub trait RotateRight { + fn rotate_right(self, n: u32) -> Self; +} +impl RotateRight for u32 { + fn rotate_right(self, n: u32) -> Self { + self.rotate_right(n) + } +} +impl RotateRight for u64 { + fn rotate_right(self, n: u32) -> Self { + self.rotate_right(n) + } +} +pub trait WrappingAdd { + fn wrapping_add(self, n: Self) -> Self; +} +impl WrappingAdd for u32 { + fn wrapping_add(self, n: u32) -> Self { + self.wrapping_add(n) + } +} +impl WrappingAdd for u64 { + fn wrapping_add(self, n: u64) -> Self { + self.wrapping_add(n) + } +} diff --git a/crates/circuits/sha256-air/src/lib.rs b/crates/circuits/sha2-air/src/lib.rs similarity index 65% rename from crates/circuits/sha256-air/src/lib.rs rename to crates/circuits/sha2-air/src/lib.rs index 48bdaee5f9..7c7d095938 100644 --- a/crates/circuits/sha256-air/src/lib.rs +++ b/crates/circuits/sha2-air/src/lib.rs @@ -1,13 +1,15 @@ -//! Implementation of the SHA256 compression function without padding +//! Implementation of the SHA256/SHA512 compression function without padding //! This this AIR doesn't constrain any of the message padding mod air; mod columns; +mod config; mod trace; mod utils; pub use air::*; pub use columns::*; +pub use config::*; pub use trace::*; pub use utils::*; diff --git a/crates/circuits/sha2-air/src/tests.rs b/crates/circuits/sha2-air/src/tests.rs new file mode 100644 index 0000000000..f376b0b246 --- /dev/null +++ b/crates/circuits/sha2-air/src/tests.rs @@ -0,0 +1,187 @@ +use std::{cmp::max, sync::Arc}; + +use openvm_circuit::arch::{ + instructions::riscv::RV32_CELL_BITS, + testing::{VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}, +}; +use openvm_circuit_primitives::{ + bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, + SubAir, +}; +use openvm_stark_backend::{ + config::{StarkGenericConfig, Val}, + interaction::{BusIndex, InteractionBuilder}, + p3_air::{Air, BaseAir}, + p3_field::{Field, FieldAlgebra, PrimeField32}, + p3_matrix::{dense::RowMajorMatrix, Matrix}, + prover::types::AirProofInput, + rap::{get_air_name, BaseAirWithPublicValues, PartitionedBaseAir}, + utils::disable_debug_builder, + verifier::VerificationError, + AirRef, Chip, ChipUsageGetter, +}; +use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; +use rand::Rng; + +use crate::{ + Sha256Config, Sha2Air, Sha2Config, Sha2StepHelper, Sha384Config, Sha512Config, + ShaDigestColsRefMut, +}; + +// A wrapper AIR purely for testing purposes +#[derive(Clone, Debug)] +pub struct Sha2TestAir { + pub sub_air: Sha2Air, +} + +impl BaseAirWithPublicValues for Sha2TestAir {} +impl PartitionedBaseAir for Sha2TestAir {} +impl BaseAir for Sha2TestAir { + fn width(&self) -> usize { + as BaseAir>::width(&self.sub_air) + } +} + +impl Air for Sha2TestAir { + fn eval(&self, builder: &mut AB) { + self.sub_air.eval(builder, 0); + } +} + +// A wrapper Chip purely for testing purposes +pub struct Sha2TestChip { + pub air: Sha2TestAir, + pub step: Sha2StepHelper, + pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<8>, + pub records: Vec<(Vec, bool)>, // length of inner vec is C::BLOCK_U8S +} + +impl Chip for Sha2TestChip +where + Val: PrimeField32, +{ + fn air(&self) -> AirRef { + Arc::new(self.air.clone()) + } + + fn generate_air_proof_input(self) -> AirProofInput { + let trace = crate::generate_trace::, C>( + &self.step, + self.bitwise_lookup_chip.clone(), + as BaseAir>>::width(&self.air.sub_air), + self.records, + ); + AirProofInput::simple_no_pis(trace) + } +} + +impl ChipUsageGetter for Sha2TestChip { + fn air_name(&self) -> String { + get_air_name(&self.air) + } + fn current_trace_height(&self) -> usize { + self.records.len() * C::ROWS_PER_BLOCK + } + + fn trace_width(&self) -> usize { + max(C::ROUND_WIDTH, C::DIGEST_WIDTH) + } +} + +const SELF_BUS_IDX: BusIndex = 28; +type F = BabyBear; + +fn create_chip_with_rand_records( +) -> (Sha2TestChip, SharedBitwiseOperationLookupChip<8>) { + let mut rng = create_seeded_rng(); + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); + let len = rng.gen_range(1..100); + let random_records: Vec<_> = (0..len) + .map(|i| { + ( + (0..C::BLOCK_U8S) + .map(|_| rng.gen::()) + .collect::>(), + rng.gen::() || i == len - 1, + ) + }) + .collect(); + let chip = Sha2TestChip { + air: Sha2TestAir { + sub_air: Sha2Air::::new(bitwise_bus, SELF_BUS_IDX), + }, + step: Sha2StepHelper::::new(), + bitwise_lookup_chip: bitwise_chip.clone(), + records: random_records, + }; + + (chip, bitwise_chip) +} + +fn rand_sha2_test() { + let tester = VmChipTestBuilder::default(); + let (chip, bitwise_chip) = create_chip_with_rand_records::(); + let tester = tester.build().load(chip).load(bitwise_chip).finalize(); + tester.simple_test().expect("Verification failed"); +} + +#[test] +fn rand_sha256_test() { + rand_sha2_test::(); +} + +#[test] +fn rand_sha512_test() { + rand_sha2_test::(); +} + +#[test] +fn rand_sha384_test() { + rand_sha2_test::(); +} + +fn negative_sha2_test_bad_final_hash() { + let tester = VmChipTestBuilder::default(); + let (chip, bitwise_chip) = create_chip_with_rand_records::(); + + // Set the final_hash to all zeros + let modify_trace = |trace: &mut RowMajorMatrix| { + trace.row_chunks_exact_mut(1).for_each(|row| { + let mut row_slice = row.row_slice(0).to_vec(); + let mut cols: ShaDigestColsRefMut = + ShaDigestColsRefMut::from::(&mut row_slice[..C::DIGEST_WIDTH]); + if cols.flags.is_last_block.is_one() && cols.flags.is_digest_row.is_one() { + for i in 0..C::HASH_WORDS { + for j in 0..C::WORD_U8S { + cols.final_hash[[i, j]] = F::ZERO; + } + } + row.values.copy_from_slice(&row_slice); + } + }); + }; + + disable_debug_builder(); + let tester = tester + .build() + .load_and_prank_trace(chip, modify_trace) + .load(bitwise_chip) + .finalize(); + tester.simple_test_with_expected_error(VerificationError::OodEvaluationMismatch); +} + +#[test] +fn negative_sha256_test_bad_final_hash() { + negative_sha2_test_bad_final_hash::(); +} + +#[test] +fn negative_sha512_test_bad_final_hash() { + negative_sha2_test_bad_final_hash::(); +} + +#[test] +fn negative_sha384_test_bad_final_hash() { + negative_sha2_test_bad_final_hash::(); +} diff --git a/crates/circuits/sha2-air/src/trace.rs b/crates/circuits/sha2-air/src/trace.rs new file mode 100644 index 0000000000..d2c8e8f8d8 --- /dev/null +++ b/crates/circuits/sha2-air/src/trace.rs @@ -0,0 +1,864 @@ +use std::{marker::PhantomData, ops::Range}; + +use openvm_circuit_primitives::{ + bitwise_op_lookup::SharedBitwiseOperationLookupChip, encoder::Encoder, + utils::next_power_of_two_or_zero, +}; +use openvm_stark_backend::{ + p3_field::PrimeField32, p3_matrix::dense::RowMajorMatrix, p3_maybe_rayon::prelude::*, +}; +use sha2::{compress256, compress512, digest::generic_array::GenericArray}; + +use super::{ + big_sig0_field, big_sig1_field, ch_field, compose, get_flag_pt_array, maj_field, + small_sig0_field, small_sig1_field, ShaRoundColsRefMut, +}; +use crate::{ + big_sig0, big_sig1, ch, le_limbs_into_word, maj, small_sig0, small_sig1, word_into_bits, + word_into_u16_limbs, word_into_u8_limbs, Sha2Config, Sha2Variant, ShaDigestColsRefMut, + ShaRoundColsRef, WrappingAdd, +}; + +/// A helper struct for the SHA256 trace generation. +/// Also, separates the inner AIR from the trace generation. +pub struct Sha2StepHelper { + pub row_idx_encoder: Encoder, + _phantom: PhantomData, +} + +impl Default for Sha2StepHelper { + fn default() -> Self { + Self::new() + } +} + +/// The trace generation of SHA should be done in two passes. +/// The first pass should do `get_block_trace` for every block and generate the invalid rows through +/// `get_default_row` The second pass should go through all the blocks and call +/// `generate_missing_cells` +impl Sha2StepHelper { + pub fn new() -> Self { + Self { + // +1 for dummy (padding) rows + row_idx_encoder: Encoder::new(C::ROWS_PER_BLOCK + 1, 2, false), + _phantom: PhantomData, + } + } + + /// This function takes the input_message (padding not handled), the previous hash, + /// and returns the new hash after processing the block input + pub fn get_block_hash(prev_hash: &[C::Word], input: Vec) -> Vec { + debug_assert!(prev_hash.len() == C::HASH_WORDS); + debug_assert!(input.len() == C::BLOCK_U8S); + let mut new_hash: [C::Word; 8] = prev_hash.try_into().unwrap(); + match C::VARIANT { + Sha2Variant::Sha256 => { + let input_array = [*GenericArray::::from_slice( + &input, + )]; + let hash_ptr: &mut [u32; 8] = unsafe { std::mem::transmute(&mut new_hash) }; + compress256(hash_ptr, &input_array); + } + Sha2Variant::Sha512 | Sha2Variant::Sha384 => { + let hash_ptr: &mut [u64; 8] = unsafe { std::mem::transmute(&mut new_hash) }; + let input_array = [*GenericArray::::from_slice( + &input, + )]; + compress512(hash_ptr, &input_array); + } + } + new_hash.to_vec() + } + + /// This function takes a C::BLOCK_BITS-bit chunk of the input message (padding not handled), + /// the previous hash, a flag indicating if it's the last block, the global block index, the + /// local block index, and the buffer values that will be put in rows 0..4. + /// Will populate the given `trace` with the trace of the block, where the width of the trace is + /// `trace_width` and the starting column for the `Sha2Air` is `trace_start_col`. + /// **Note**: this function only generates some of the required trace. Another pass is required, + /// refer to [`Self::generate_missing_cells`] for details. + #[allow(clippy::too_many_arguments)] + pub fn generate_block_trace( + &self, + trace: &mut [F], + trace_width: usize, + trace_start_col: usize, + input: &[C::Word], + bitwise_lookup_chip: SharedBitwiseOperationLookupChip<8>, + prev_hash: &[C::Word], + is_last_block: bool, + global_block_idx: u32, + local_block_idx: u32, + ) { + #[cfg(debug_assertions)] + { + assert!(input.len() == C::BLOCK_WORDS); + assert!(prev_hash.len() == C::HASH_WORDS); + assert!(trace.len() == trace_width * C::ROWS_PER_BLOCK); + assert!(trace_start_col + C::WIDTH <= trace_width); + if local_block_idx == 0 { + assert!(*prev_hash == *C::get_h()); + } + } + let get_range = |start: usize, len: usize| -> Range { start..start + len }; + let mut message_schedule = vec![C::Word::from(0); C::ROUNDS_PER_BLOCK]; + message_schedule[..input.len()].copy_from_slice(input); + let mut work_vars = prev_hash.to_vec(); + for (i, row) in trace.chunks_exact_mut(trace_width).enumerate() { + // do the rounds + if i < C::ROUND_ROWS { + let mut cols: ShaRoundColsRefMut = ShaRoundColsRefMut::from::( + &mut row[get_range(trace_start_col, C::ROUND_WIDTH)], + ); + *cols.flags.is_round_row = F::ONE; + *cols.flags.is_first_4_rows = if i < 4 { F::ONE } else { F::ZERO }; + *cols.flags.is_digest_row = F::ZERO; + *cols.flags.is_last_block = F::from_bool(is_last_block); + cols.flags + .row_idx + .iter_mut() + .zip( + get_flag_pt_array(&self.row_idx_encoder, i) + .into_iter() + .map(F::from_canonical_u32), + ) + .for_each(|(x, y)| *x = y); + + *cols.flags.global_block_idx = F::from_canonical_u32(global_block_idx); + *cols.flags.local_block_idx = F::from_canonical_u32(local_block_idx); + + // W_idx = M_idx + if i < C::MESSAGE_ROWS { + for j in 0..C::ROUNDS_PER_ROW { + cols.message_schedule + .w + .row_mut(j) + .iter_mut() + .zip( + word_into_bits::(input[i * C::ROUNDS_PER_ROW + j]) + .into_iter() + .map(F::from_canonical_u32), + ) + .for_each(|(x, y)| *x = y); + } + } + // W_idx = SIG1(W_{idx-2}) + W_{idx-7} + SIG0(W_{idx-15}) + W_{idx-16} + else { + for j in 0..C::ROUNDS_PER_ROW { + let idx = i * C::ROUNDS_PER_ROW + j; + let nums: [C::Word; 4] = [ + small_sig1::(message_schedule[idx - 2]), + message_schedule[idx - 7], + small_sig0::(message_schedule[idx - 15]), + message_schedule[idx - 16], + ]; + let w: C::Word = nums + .iter() + .fold(C::Word::from(0), |acc, &num| acc.wrapping_add(num)); + cols.message_schedule + .w + .row_mut(j) + .iter_mut() + .zip( + word_into_bits::(w) + .into_iter() + .map(F::from_canonical_u32), + ) + .for_each(|(x, y)| *x = y); + + let nums_limbs = nums + .iter() + .map(|x| word_into_u16_limbs::(*x)) + .collect::>(); + let w_limbs = word_into_u16_limbs::(w); + + // fill in the carrys + for k in 0..C::WORD_U16S { + let mut sum = nums_limbs.iter().fold(0, |acc, num| acc + num[k]); + if k > 0 { + sum += (cols.message_schedule.carry_or_buffer[[j, k * 2 - 2]] + + F::TWO + * cols.message_schedule.carry_or_buffer[[j, k * 2 - 1]]) + .as_canonical_u32(); + } + let carry = (sum - w_limbs[k]) >> 16; + cols.message_schedule.carry_or_buffer[[j, k * 2]] = + F::from_canonical_u32(carry & 1); + cols.message_schedule.carry_or_buffer[[j, k * 2 + 1]] = + F::from_canonical_u32(carry >> 1); + } + // update the message schedule + message_schedule[idx] = w; + } + } + // fill in the work variables + for j in 0..C::ROUNDS_PER_ROW { + // t1 = h + SIG1(e) + ch(e, f, g) + K_idx + W_idx + let t1 = [ + work_vars[7], + big_sig1::(work_vars[4]), + ch::(work_vars[4], work_vars[5], work_vars[6]), + C::get_k()[i * C::ROUNDS_PER_ROW + j], + le_limbs_into_word::( + cols.message_schedule + .w + .row(j) + .map(|f| f.as_canonical_u32()) + .as_slice() + .unwrap(), + ), + ]; + let t1_sum: C::Word = t1 + .iter() + .fold(C::Word::from(0), |acc, &num| acc.wrapping_add(num)); + + // t2 = SIG0(a) + maj(a, b, c) + let t2 = [ + big_sig0::(work_vars[0]), + maj::(work_vars[0], work_vars[1], work_vars[2]), + ]; + + let t2_sum: C::Word = t2 + .iter() + .fold(C::Word::from(0), |acc, &num| acc.wrapping_add(num)); + + // e = d + t1 + let e = work_vars[3].wrapping_add(t1_sum); + cols.work_vars + .e + .row_mut(j) + .iter_mut() + .zip( + word_into_bits::(e) + .into_iter() + .map(F::from_canonical_u32), + ) + .for_each(|(x, y)| *x = y); + let e_limbs = word_into_u16_limbs::(e); + // a = t1 + t2 + let a = t1_sum.wrapping_add(t2_sum); + cols.work_vars + .a + .row_mut(j) + .iter_mut() + .zip( + word_into_bits::(a) + .into_iter() + .map(F::from_canonical_u32), + ) + .for_each(|(x, y)| *x = y); + let a_limbs = word_into_u16_limbs::(a); + // fill in the carrys + for k in 0..C::WORD_U16S { + let t1_limb = t1 + .iter() + .fold(0, |acc, &num| acc + word_into_u16_limbs::(num)[k]); + let t2_limb = t2 + .iter() + .fold(0, |acc, &num| acc + word_into_u16_limbs::(num)[k]); + + let mut e_limb = t1_limb + word_into_u16_limbs::(work_vars[3])[k]; + let mut a_limb = t1_limb + t2_limb; + if k > 0 { + a_limb += cols.work_vars.carry_a[[j, k - 1]].as_canonical_u32(); + e_limb += cols.work_vars.carry_e[[j, k - 1]].as_canonical_u32(); + } + let carry_a = (a_limb - a_limbs[k]) >> 16; + let carry_e = (e_limb - e_limbs[k]) >> 16; + cols.work_vars.carry_a[[j, k]] = F::from_canonical_u32(carry_a); + cols.work_vars.carry_e[[j, k]] = F::from_canonical_u32(carry_e); + bitwise_lookup_chip.request_range(carry_a, carry_e); + } + + // update working variables + work_vars[7] = work_vars[6]; + work_vars[6] = work_vars[5]; + work_vars[5] = work_vars[4]; + work_vars[4] = e; + work_vars[3] = work_vars[2]; + work_vars[2] = work_vars[1]; + work_vars[1] = work_vars[0]; + work_vars[0] = a; + } + + // filling w_3 and intermed_4 here and the rest later + if i > 0 { + for j in 0..C::ROUNDS_PER_ROW { + let idx = i * C::ROUNDS_PER_ROW + j; + let w_4 = word_into_u16_limbs::(message_schedule[idx - 4]); + let sig_0_w_3 = + word_into_u16_limbs::(small_sig0::(message_schedule[idx - 3])); + cols.schedule_helper + .intermed_4 + .row_mut(j) + .iter_mut() + .zip( + (0..C::WORD_U16S) + .map(|k| F::from_canonical_u32(w_4[k] + sig_0_w_3[k])) + .collect::>(), + ) + .for_each(|(x, y)| *x = y); + if j < C::ROUNDS_PER_ROW - 1 { + let w_3 = message_schedule[idx - 3]; + cols.schedule_helper + .w_3 + .row_mut(j) + .iter_mut() + .zip( + word_into_u16_limbs::(w_3) + .into_iter() + .map(F::from_canonical_u32), + ) + .for_each(|(x, y)| *x = y); + } + } + } + } + // generate the digest row + else { + let mut cols: ShaDigestColsRefMut = ShaDigestColsRefMut::from::( + &mut row[get_range(trace_start_col, C::DIGEST_WIDTH)], + ); + for j in 0..C::ROUNDS_PER_ROW - 1 { + let w_3 = message_schedule[i * C::ROUNDS_PER_ROW + j - 3]; + cols.schedule_helper + .w_3 + .row_mut(j) + .iter_mut() + .zip( + word_into_u16_limbs::(w_3) + .into_iter() + .map(F::from_canonical_u32) + .collect::>(), + ) + .for_each(|(x, y)| *x = y); + } + *cols.flags.is_round_row = F::ZERO; + *cols.flags.is_first_4_rows = F::ZERO; + *cols.flags.is_digest_row = F::ONE; + *cols.flags.is_last_block = F::from_bool(is_last_block); + cols.flags + .row_idx + .iter_mut() + .zip( + get_flag_pt_array(&self.row_idx_encoder, C::ROUND_ROWS) + .into_iter() + .map(F::from_canonical_u32), + ) + .for_each(|(x, y)| *x = y); + + *cols.flags.global_block_idx = F::from_canonical_u32(global_block_idx); + + *cols.flags.local_block_idx = F::from_canonical_u32(local_block_idx); + let final_hash: Vec = (0..C::HASH_WORDS) + .map(|i| work_vars[i].wrapping_add(prev_hash[i])) + .collect(); + let final_hash_limbs: Vec> = final_hash + .iter() + .map(|word| word_into_u8_limbs::(*word)) + .collect(); + // need to ensure final hash limbs are bytes, in order for + // prev_hash[i] + work_vars[i] == final_hash[i] + // to be constrained correctly + for word in final_hash_limbs.iter() { + for chunk in word.chunks(2) { + bitwise_lookup_chip.request_range(chunk[0], chunk[1]); + } + } + cols.final_hash + .iter_mut() + .zip((0..C::HASH_WORDS).flat_map(|i| { + word_into_u8_limbs::(final_hash[i]) + .into_iter() + .map(F::from_canonical_u32) + .collect::>() + })) + .for_each(|(x, y)| *x = y); + cols.prev_hash + .iter_mut() + .zip(prev_hash.iter().flat_map(|f| { + word_into_u16_limbs::(*f) + .into_iter() + .map(F::from_canonical_u32) + .collect::>() + })) + .for_each(|(x, y)| *x = y); + + let hash = if is_last_block { + C::get_h() + .iter() + .map(|x| word_into_bits::(*x)) + .collect::>() + } else { + cols.final_hash + .rows_mut() + .into_iter() + .map(|f| { + le_limbs_into_word::( + f.map(|x| x.as_canonical_u32()).as_slice().unwrap(), + ) + }) + .map(word_into_bits::) + .collect() + } + .into_iter() + .map(|x| x.into_iter().map(F::from_canonical_u32)) + .collect::>(); + + for i in 0..C::ROUNDS_PER_ROW { + cols.hash + .a + .row_mut(i) + .iter_mut() + .zip(hash[C::ROUNDS_PER_ROW - i - 1].clone()) + .for_each(|(x, y)| *x = y); + cols.hash + .e + .row_mut(i) + .iter_mut() + .zip(hash[C::ROUNDS_PER_ROW - i + 3].clone()) + .for_each(|(x, y)| *x = y); + } + } + } + + for i in 0..C::ROWS_PER_BLOCK - 1 { + let rows = &mut trace[i * trace_width..(i + 2) * trace_width]; + let (local, next) = rows.split_at_mut(trace_width); + let mut local_cols: ShaRoundColsRefMut = ShaRoundColsRefMut::from::( + &mut local[get_range(trace_start_col, C::ROUND_WIDTH)], + ); + let mut next_cols: ShaRoundColsRefMut = ShaRoundColsRefMut::from::( + &mut next[get_range(trace_start_col, C::ROUND_WIDTH)], + ); + if i > 0 { + for j in 0..C::ROUNDS_PER_ROW { + next_cols + .schedule_helper + .intermed_8 + .row_mut(j) + .assign(&local_cols.schedule_helper.intermed_4.row(j)); + if (2..C::ROWS_PER_BLOCK - 3).contains(&i) { + next_cols + .schedule_helper + .intermed_12 + .row_mut(j) + .assign(&local_cols.schedule_helper.intermed_8.row(j)); + } + } + } + if i == C::ROWS_PER_BLOCK - 2 { + // `next` is a digest row. + // Fill in `carry_a` and `carry_e` with dummy values so the constraints on `a` and + // `e` hold. + let const_local_cols = ShaRoundColsRef::::from_mut::(&local_cols); + Self::generate_carry_ae(const_local_cols.clone(), &mut next_cols); + // Fill in row 16's `intermed_4` with dummy values so the message schedule + // constraints holds on that row + Self::generate_intermed_4(const_local_cols, &mut next_cols); + } + if i <= 2 { + // i is in 0..3. + // Fill in `local.intermed_12` with dummy values so the message schedule constraints + // hold on rows 1..4. + Self::generate_intermed_12( + &mut local_cols, + ShaRoundColsRef::::from_mut::(&next_cols), + ); + } + } + } + + /// This function will fill in the cells that we couldn't do during the first pass. + /// This function should be called only after `generate_block_trace` was called for all blocks + /// And [`Self::generate_default_row`] is called for all invalid rows + /// Will populate the missing values of `trace`, where the width of the trace is `trace_width` + /// and the starting column for the `ShaAir` is `trace_start_col`. + /// Note: `trace` needs to be the rows 1..C::ROWS_PER_BLOCK of a block and the first row of the + /// next block + pub fn generate_missing_cells( + &self, + trace: &mut [F], + trace_width: usize, + trace_start_col: usize, + ) { + let rows = &mut trace[(C::ROUND_ROWS - 2) * trace_width..(C::ROUND_ROWS + 1) * trace_width]; + let (last_round_row, rows) = rows.split_at_mut(trace_width); + let (digest_row, next_block_first_row) = rows.split_at_mut(trace_width); + let mut cols_last_round_row: ShaRoundColsRefMut = ShaRoundColsRefMut::from::( + &mut last_round_row[trace_start_col..trace_start_col + C::ROUND_WIDTH], + ); + let mut cols_digest_row: ShaRoundColsRefMut = ShaRoundColsRefMut::from::( + &mut digest_row[trace_start_col..trace_start_col + C::ROUND_WIDTH], + ); + let mut cols_next_block_first_row: ShaRoundColsRefMut = ShaRoundColsRefMut::from::( + &mut next_block_first_row[trace_start_col..trace_start_col + C::ROUND_WIDTH], + ); + // Fill in the last round row's `intermed_12` with dummy values so the message schedule + // constraints holds on row 16 + Self::generate_intermed_12( + &mut cols_last_round_row, + ShaRoundColsRef::from_mut::(&cols_digest_row), + ); + // Fill in the digest row's `intermed_12` with dummy values so the message schedule + // constraints holds on the next block's row 0 + Self::generate_intermed_12( + &mut cols_digest_row, + ShaRoundColsRef::from_mut::(&cols_next_block_first_row), + ); + // Fill in the next block's first row's `intermed_4` with dummy values so the message + // schedule constraints holds on that row + Self::generate_intermed_4( + ShaRoundColsRef::from_mut::(&cols_digest_row), + &mut cols_next_block_first_row, + ); + } + + /// Fills the `cols` as a padding row + /// Note: we still need to correctly fill in the hash values, carries and intermeds + pub fn generate_default_row(&self, mut cols: ShaRoundColsRefMut) { + *cols.flags.is_round_row = F::ZERO; + *cols.flags.is_first_4_rows = F::ZERO; + *cols.flags.is_digest_row = F::ZERO; + + *cols.flags.is_last_block = F::ZERO; + *cols.flags.global_block_idx = F::ZERO; + cols.flags + .row_idx + .iter_mut() + .zip( + get_flag_pt_array(&self.row_idx_encoder, C::ROWS_PER_BLOCK) + .into_iter() + .map(F::from_canonical_u32), + ) + .for_each(|(x, y)| *x = y); + *cols.flags.local_block_idx = F::ZERO; + + cols.message_schedule + .w + .iter_mut() + .for_each(|x| *x = F::ZERO); + cols.message_schedule + .carry_or_buffer + .iter_mut() + .for_each(|x| *x = F::ZERO); + + let hash = C::get_h() + .iter() + .map(|x| word_into_bits::(*x)) + .map(|x| x.into_iter().map(F::from_canonical_u32).collect::>()) + .collect::>(); + + for i in 0..C::ROUNDS_PER_ROW { + cols.work_vars + .a + .row_mut(i) + .iter_mut() + .zip(hash[C::ROUNDS_PER_ROW - i - 1].clone()) + .for_each(|(x, y)| *x = y); + cols.work_vars + .e + .row_mut(i) + .iter_mut() + .zip(hash[C::ROUNDS_PER_ROW - i + 3].clone()) + .for_each(|(x, y)| *x = y); + } + + cols.work_vars + .carry_a + .iter_mut() + .zip((0..C::ROUNDS_PER_ROW).flat_map(|i| { + (0..C::WORD_U16S) + .map(|j| F::from_canonical_u32(C::get_invalid_carry_a(i)[j])) + .collect::>() + })) + .for_each(|(x, y)| *x = y); + cols.work_vars + .carry_e + .iter_mut() + .zip((0..C::ROUNDS_PER_ROW).flat_map(|i| { + (0..C::WORD_U16S) + .map(|j| F::from_canonical_u32(C::get_invalid_carry_e(i)[j])) + .collect::>() + })) + .for_each(|(x, y)| *x = y); + } + + /// The following functions do the calculations in native field since they will be called on + /// padding rows which can overflow and we need to make sure it matches the AIR constraints + /// Puts the correct carries in the `next_row`, the resulting carries can be out of bounds + pub fn generate_carry_ae( + local_cols: ShaRoundColsRef, + next_cols: &mut ShaRoundColsRefMut, + ) { + let a = [ + local_cols + .work_vars + .a + .rows() + .into_iter() + .collect::>(), + next_cols.work_vars.a.rows().into_iter().collect::>(), + ] + .concat(); + let e = [ + local_cols + .work_vars + .e + .rows() + .into_iter() + .collect::>(), + next_cols.work_vars.e.rows().into_iter().collect::>(), + ] + .concat(); + for i in 0..C::ROUNDS_PER_ROW { + let cur_a = a[i + 4]; + let sig_a = big_sig0_field::(a[i + 3].as_slice().unwrap()); + let maj_abc = maj_field::( + a[i + 3].as_slice().unwrap(), + a[i + 2].as_slice().unwrap(), + a[i + 1].as_slice().unwrap(), + ); + let d = a[i]; + let cur_e = e[i + 4]; + let sig_e = big_sig1_field::(e[i + 3].as_slice().unwrap()); + let ch_efg = ch_field::( + e[i + 3].as_slice().unwrap(), + e[i + 2].as_slice().unwrap(), + e[i + 1].as_slice().unwrap(), + ); + let h = e[i]; + + let t1 = [h.to_vec(), sig_e, ch_efg.to_vec()]; + let t2 = [sig_a, maj_abc]; + for j in 0..C::WORD_U16S { + let t1_limb_sum = t1.iter().fold(F::ZERO, |acc, x| { + acc + compose::(&x[j * 16..(j + 1) * 16], 1) + }); + let t2_limb_sum = t2.iter().fold(F::ZERO, |acc, x| { + acc + compose::(&x[j * 16..(j + 1) * 16], 1) + }); + let d_limb = compose::(&d.as_slice().unwrap()[j * 16..(j + 1) * 16], 1); + let cur_a_limb = compose::(&cur_a.as_slice().unwrap()[j * 16..(j + 1) * 16], 1); + let cur_e_limb = compose::(&cur_e.as_slice().unwrap()[j * 16..(j + 1) * 16], 1); + let sum = d_limb + + t1_limb_sum + + if j == 0 { + F::ZERO + } else { + next_cols.work_vars.carry_e[[i, j - 1]] + } + - cur_e_limb; + let carry_e = sum * (F::from_canonical_u32(1 << 16).inverse()); + + let sum = t1_limb_sum + + t2_limb_sum + + if j == 0 { + F::ZERO + } else { + next_cols.work_vars.carry_a[[i, j - 1]] + } + - cur_a_limb; + let carry_a = sum * (F::from_canonical_u32(1 << 16).inverse()); + next_cols.work_vars.carry_e[[i, j]] = carry_e; + next_cols.work_vars.carry_a[[i, j]] = carry_a; + } + } + } + + /// Puts the correct intermed_4 in the `next_row` + fn generate_intermed_4( + local_cols: ShaRoundColsRef, + next_cols: &mut ShaRoundColsRefMut, + ) { + let w = [ + local_cols + .message_schedule + .w + .rows() + .into_iter() + .collect::>(), + next_cols + .message_schedule + .w + .rows() + .into_iter() + .collect::>(), + ] + .concat(); + let w_limbs: Vec> = w + .iter() + .map(|x| { + (0..C::WORD_U16S) + .map(|i| compose::(&x.as_slice().unwrap()[i * 16..(i + 1) * 16], 1)) + .collect::>() + }) + .collect(); + for i in 0..C::ROUNDS_PER_ROW { + let sig_w = small_sig0_field::(w[i + 1].as_slice().unwrap()); + let sig_w_limbs: Vec = (0..C::WORD_U16S) + .map(|j| compose::(&sig_w[j * 16..(j + 1) * 16], 1)) + .collect(); + for (j, sig_w_limb) in sig_w_limbs.iter().enumerate() { + next_cols.schedule_helper.intermed_4[[i, j]] = w_limbs[i][j] + *sig_w_limb; + } + } + } + + /// Puts the needed intermed_12 in the `local_row` + fn generate_intermed_12( + local_cols: &mut ShaRoundColsRefMut, + next_cols: ShaRoundColsRef, + ) { + let w = [ + local_cols + .message_schedule + .w + .rows() + .into_iter() + .collect::>(), + next_cols + .message_schedule + .w + .rows() + .into_iter() + .collect::>(), + ] + .concat(); + let w_limbs: Vec> = w + .iter() + .map(|x| { + (0..C::WORD_U16S) + .map(|i| compose::(&x.as_slice().unwrap()[i * 16..(i + 1) * 16], 1)) + .collect::>() + }) + .collect(); + for i in 0..C::ROUNDS_PER_ROW { + // sig_1(w_{t-2}) + let sig_w_2: Vec = (0..C::WORD_U16S) + .map(|j| { + compose::( + &small_sig1_field::(w[i + 2].as_slice().unwrap()) + [j * 16..(j + 1) * 16], + 1, + ) + }) + .collect(); + // w_{t-7} + let w_7 = if i < 3 { + local_cols.schedule_helper.w_3.row(i).to_slice().unwrap() + } else { + w_limbs[i - 3].as_slice() + }; + // w_t + let w_cur = w_limbs[i + 4].as_slice(); + for j in 0..C::WORD_U16S { + let carry = next_cols.message_schedule.carry_or_buffer[[i, j * 2]] + + F::TWO * next_cols.message_schedule.carry_or_buffer[[i, j * 2 + 1]]; + let sum = sig_w_2[j] + w_7[j] - carry * F::from_canonical_u32(1 << 16) - w_cur[j] + + if j > 0 { + next_cols.message_schedule.carry_or_buffer[[i, j * 2 - 2]] + + F::from_canonical_u32(2) + * next_cols.message_schedule.carry_or_buffer[[i, j * 2 - 1]] + } else { + F::ZERO + }; + local_cols.schedule_helper.intermed_12[[i, j]] = -sum; + } + } + } +} + +/// `records` consists of pairs of `(input_block, is_last_block)`. +pub fn generate_trace( + step: &Sha2StepHelper, + bitwise_lookup_chip: SharedBitwiseOperationLookupChip<8>, + width: usize, + records: Vec<(Vec, bool)>, +) -> RowMajorMatrix { + for (input, _) in &records { + debug_assert!(input.len() == C::BLOCK_U8S); + } + + let non_padded_height = records.len() * C::ROWS_PER_BLOCK; + let height = next_power_of_two_or_zero(non_padded_height); + let mut values = F::zero_vec(height * width); + + struct BlockContext { + prev_hash: Vec, // len is C::HASH_WORDS + local_block_idx: u32, + global_block_idx: u32, + input: Vec, // len is C::BLOCK_U8S + is_last_block: bool, + } + let mut block_ctx: Vec> = Vec::with_capacity(records.len()); + let mut prev_hash = C::get_h().to_vec(); + let mut local_block_idx = 0; + let mut global_block_idx = 1; + for (input, is_last_block) in records { + block_ctx.push(BlockContext { + prev_hash: prev_hash.clone(), + local_block_idx, + global_block_idx, + input: input.clone(), + is_last_block, + }); + global_block_idx += 1; + if is_last_block { + local_block_idx = 0; + prev_hash = C::get_h().to_vec(); + } else { + local_block_idx += 1; + prev_hash = Sha2StepHelper::::get_block_hash(&prev_hash, input); + } + } + // first pass + values + .par_chunks_exact_mut(width * C::ROWS_PER_BLOCK) + .zip(block_ctx) + .for_each(|(block, ctx)| { + let BlockContext { + prev_hash, + local_block_idx, + global_block_idx, + input, + is_last_block, + } = ctx; + let input_words = (0..C::BLOCK_WORDS) + .map(|i| { + le_limbs_into_word::( + &(0..C::WORD_U8S) + .map(|j| input[(i + 1) * C::WORD_U8S - j - 1] as u32) + .collect::>(), + ) + }) + .collect::>(); + step.generate_block_trace( + block, + width, + 0, + &input_words, + bitwise_lookup_chip.clone(), + &prev_hash, + is_last_block, + global_block_idx, + local_block_idx, + ); + }); + // second pass: padding rows + values[width * non_padded_height..] + .par_chunks_mut(width) + .for_each(|row| { + let cols: ShaRoundColsRefMut = ShaRoundColsRefMut::from::(row); + step.generate_default_row(cols); + }); + + // second pass: non-padding rows + values[width..] + .par_chunks_mut(width * C::ROWS_PER_BLOCK) + .take(non_padded_height / C::ROWS_PER_BLOCK) + .for_each(|chunk| { + step.generate_missing_cells(chunk, width, 0); + }); + RowMajorMatrix::new(values, width) +} diff --git a/crates/circuits/sha2-air/src/utils.rs b/crates/circuits/sha2-air/src/utils.rs new file mode 100644 index 0000000000..35d4446318 --- /dev/null +++ b/crates/circuits/sha2-air/src/utils.rs @@ -0,0 +1,289 @@ +pub use openvm_circuit_primitives::utils::compose; +use openvm_circuit_primitives::{ + encoder::Encoder, + utils::{not, select}, +}; +use openvm_stark_backend::{p3_air::AirBuilder, p3_field::FieldAlgebra}; +use rand::{rngs::StdRng, Rng}; + +use crate::{RotateRight, Sha2Config}; + +/// Convert a word into a list of 8-bit limbs in little endian +pub fn word_into_u8_limbs(num: impl Into) -> Vec { + word_into_limbs::(num.into(), C::WORD_U8S) +} + +/// Convert a word into a list of 16-bit limbs in little endian +pub fn word_into_u16_limbs(num: impl Into) -> Vec { + word_into_limbs::(num.into(), C::WORD_U16S) +} + +/// Convert a word into a list of 1-bit limbs in little endian +pub fn word_into_bits(num: impl Into) -> Vec { + word_into_limbs::(num.into(), C::WORD_BITS) +} + +/// Convert a word into a list of limbs in little endian +pub fn word_into_limbs(num: C::Word, num_limbs: usize) -> Vec { + let limb_bits = std::mem::size_of::() * 8 / num_limbs; + (0..num_limbs) + .map(|i| { + let shifted = num >> (limb_bits * i); + let mask: C::Word = ((1u32 << limb_bits) - 1).into(); + let masked = shifted & mask; + masked.try_into().unwrap() + }) + .collect() +} + +/// Convert a u32 into a list of 1-bit limbs in little endian +pub fn u32_into_bits(num: u32) -> Vec { + let limb_bits = 32 / C::WORD_BITS; + (0..C::WORD_BITS) + .map(|i| (num >> (limb_bits * i)) & ((1 << limb_bits) - 1)) + .collect() +} + +/// Convert a list of limbs in little endian into a Word +pub fn le_limbs_into_word(limbs: &[u32]) -> C::Word { + let mut limbs = limbs.to_vec(); + limbs.reverse(); + be_limbs_into_word::(&limbs) +} + +/// Convert a list of limbs in big endian into a Word +pub fn be_limbs_into_word(limbs: &[u32]) -> C::Word { + let limb_bits = C::WORD_BITS / limbs.len(); + limbs.iter().fold(C::Word::from(0), |acc, &limb| { + (acc << limb_bits) | limb.into() + }) +} + +/// Convert a list of limbs in little endian into a u32 +pub fn limbs_into_u32(limbs: &[u32]) -> u32 { + let limb_bits = 32 / limbs.len(); + limbs + .iter() + .rev() + .fold(0, |acc, &limb| (acc << limb_bits) | limb) +} + +/// Rotates `bits` right by `n` bits, assumes `bits` is in little-endian +#[inline] +pub(crate) fn rotr(bits: &[impl Into + Clone], n: usize) -> Vec { + (0..bits.len()) + .map(|i| bits[(i + n) % bits.len()].clone().into()) + .collect() +} + +/// Shifts `bits` right by `n` bits, assumes `bits` is in little-endian +#[inline] +pub(crate) fn shr(bits: &[impl Into + Clone], n: usize) -> Vec { + (0..bits.len()) + .map(|i| { + if i + n < bits.len() { + bits[i + n].clone().into() + } else { + F::ZERO + } + }) + .collect() +} + +/// Computes x ^ y ^ z, where x, y, z are assumed to be boolean +#[inline] +pub(crate) fn xor_bit( + x: impl Into, + y: impl Into, + z: impl Into, +) -> F { + let (x, y, z) = (x.into(), y.into(), z.into()); + (x.clone() * y.clone() * z.clone()) + + (x.clone() * not::(y.clone()) * not::(z.clone())) + + (not::(x.clone()) * y.clone() * not::(z.clone())) + + (not::(x) * not::(y) * z) +} + +/// Computes x ^ y ^ z, where x, y, z are [C::WORD_BITS] bit numbers +#[inline] +pub(crate) fn xor( + x: &[impl Into + Clone], + y: &[impl Into + Clone], + z: &[impl Into + Clone], +) -> Vec { + (0..x.len()) + .map(|i| xor_bit(x[i].clone(), y[i].clone(), z[i].clone())) + .collect() +} + +/// Choose function from the SHA spec +#[inline] +pub fn ch(x: C::Word, y: C::Word, z: C::Word) -> C::Word { + (x & y) ^ ((!x) & z) +} + +/// Computes Ch(x,y,z), where x, y, z are [C::WORD_BITS] bit numbers +#[inline] +pub(crate) fn ch_field( + x: &[impl Into + Clone], + y: &[impl Into + Clone], + z: &[impl Into + Clone], +) -> Vec { + (0..x.len()) + .map(|i| select(x[i].clone(), y[i].clone(), z[i].clone())) + .collect() +} + +/// Majority function from the SHA spec +pub fn maj(x: C::Word, y: C::Word, z: C::Word) -> C::Word { + (x & y) ^ (x & z) ^ (y & z) +} + +/// Computes Maj(x,y,z), where x, y, z are [C::WORD_BITS] bit numbers +#[inline] +pub(crate) fn maj_field( + x: &[impl Into + Clone], + y: &[impl Into + Clone], + z: &[impl Into + Clone], +) -> Vec { + (0..x.len()) + .map(|i| { + let (x, y, z) = ( + x[i].clone().into(), + y[i].clone().into(), + z[i].clone().into(), + ); + x.clone() * y.clone() + x.clone() * z.clone() + y.clone() * z.clone() + - F::TWO * x * y * z + }) + .collect() +} + +/// Big sigma_0 function from the SHA spec +pub fn big_sig0(x: C::Word) -> C::Word { + if C::WORD_BITS == 32 { + x.rotate_right(2) ^ x.rotate_right(13) ^ x.rotate_right(22) + } else { + x.rotate_right(28) ^ x.rotate_right(34) ^ x.rotate_right(39) + } +} + +/// Computes BigSigma0(x), where x is a [C::WORD_BITS] bit number in little-endian +#[inline] +pub(crate) fn big_sig0_field( + x: &[impl Into + Clone], +) -> Vec { + if C::WORD_BITS == 32 { + xor(&rotr::(x, 2), &rotr::(x, 13), &rotr::(x, 22)) + } else { + xor(&rotr::(x, 28), &rotr::(x, 34), &rotr::(x, 39)) + } +} + +/// Big sigma_1 function from the SHA spec +pub fn big_sig1(x: C::Word) -> C::Word { + if C::WORD_BITS == 32 { + x.rotate_right(6) ^ x.rotate_right(11) ^ x.rotate_right(25) + } else { + x.rotate_right(14) ^ x.rotate_right(18) ^ x.rotate_right(41) + } +} + +/// Computes BigSigma1(x), where x is a [C::WORD_BITS] bit number in little-endian +#[inline] +pub(crate) fn big_sig1_field( + x: &[impl Into + Clone], +) -> Vec { + if C::WORD_BITS == 32 { + xor(&rotr::(x, 6), &rotr::(x, 11), &rotr::(x, 25)) + } else { + xor(&rotr::(x, 14), &rotr::(x, 18), &rotr::(x, 41)) + } +} + +/// Small sigma_0 function from the SHA spec +pub fn small_sig0(x: C::Word) -> C::Word { + if C::WORD_BITS == 32 { + x.rotate_right(7) ^ x.rotate_right(18) ^ (x >> 3) + } else { + x.rotate_right(1) ^ x.rotate_right(8) ^ (x >> 7) + } +} + +/// Computes SmallSigma0(x), where x is a [C::WORD_BITS] bit number in little-endian +#[inline] +pub(crate) fn small_sig0_field( + x: &[impl Into + Clone], +) -> Vec { + if C::WORD_BITS == 32 { + xor(&rotr::(x, 7), &rotr::(x, 18), &shr::(x, 3)) + } else { + xor(&rotr::(x, 1), &rotr::(x, 8), &shr::(x, 7)) + } +} + +/// Small sigma_1 function from the SHA spec +pub fn small_sig1(x: C::Word) -> C::Word { + if C::WORD_BITS == 32 { + x.rotate_right(17) ^ x.rotate_right(19) ^ (x >> 10) + } else { + x.rotate_right(19) ^ x.rotate_right(61) ^ (x >> 6) + } +} + +/// Computes SmallSigma1(x), where x is a [C::WORD_BITS] bit number in little-endian +#[inline] +pub(crate) fn small_sig1_field( + x: &[impl Into + Clone], +) -> Vec { + if C::WORD_BITS == 32 { + xor(&rotr::(x, 17), &rotr::(x, 19), &shr::(x, 10)) + } else { + xor(&rotr::(x, 19), &rotr::(x, 61), &shr::(x, 6)) + } +} + +/// Generate a random message of a given length +pub fn get_random_message(rng: &mut StdRng, len: usize) -> Vec { + let mut random_message: Vec = vec![0u8; len]; + rng.fill(&mut random_message[..]); + random_message +} + +/// Wrapper of `get_flag_pt` to get the flag pointer as an array +pub fn get_flag_pt_array(encoder: &Encoder, flag_idx: usize) -> Vec { + encoder.get_flag_pt(flag_idx) +} + +/// Constrain the addition of [C::WORD_BITS] bit words in 16-bit limbs +/// It takes in the terms some in bits some in 16-bit limbs, +/// the expected sum in bits and the carries +pub fn constraint_word_addition( + builder: &mut AB, + terms_bits: &[&[impl Into + Clone]], + terms_limb: &[&[impl Into + Clone]], + expected_sum: &[impl Into + Clone], + carries: &[impl Into + Clone], +) { + debug_assert!(terms_bits.iter().all(|x| x.len() == C::WORD_BITS)); + debug_assert!(terms_limb.iter().all(|x| x.len() == C::WORD_U16S)); + assert_eq!(expected_sum.len(), C::WORD_BITS); + assert_eq!(carries.len(), C::WORD_U16S); + + for i in 0..C::WORD_U16S { + let mut limb_sum = if i == 0 { + AB::Expr::ZERO + } else { + carries[i - 1].clone().into() + }; + for term in terms_bits { + limb_sum += compose::(&term[i * 16..(i + 1) * 16], 1); + } + for term in terms_limb { + limb_sum += term[i].clone().into(); + } + let expected_sum_limb = compose::(&expected_sum[i * 16..(i + 1) * 16], 1) + + carries[i].clone().into() * AB::Expr::from_canonical_u32(1 << 16); + builder.assert_eq(limb_sum, expected_sum_limb); + } +} diff --git a/crates/circuits/sha256-air/src/air.rs b/crates/circuits/sha256-air/src/air.rs deleted file mode 100644 index 96578984d0..0000000000 --- a/crates/circuits/sha256-air/src/air.rs +++ /dev/null @@ -1,613 +0,0 @@ -use std::{array, borrow::Borrow, cmp::max, iter::once}; - -use openvm_circuit_primitives::{ - bitwise_op_lookup::BitwiseOperationLookupBus, - encoder::Encoder, - utils::{not, select}, - SubAir, -}; -use openvm_stark_backend::{ - interaction::{BusIndex, InteractionBuilder, PermutationCheckBus}, - p3_air::{AirBuilder, BaseAir}, - p3_field::{Field, FieldAlgebra}, - p3_matrix::Matrix, -}; - -use super::{ - big_sig0_field, big_sig1_field, ch_field, compose, maj_field, small_sig0_field, - small_sig1_field, u32_into_limbs, Sha256DigestCols, Sha256RoundCols, SHA256_DIGEST_WIDTH, - SHA256_H, SHA256_HASH_WORDS, SHA256_K, SHA256_ROUNDS_PER_ROW, SHA256_ROUND_WIDTH, - SHA256_WORD_BITS, SHA256_WORD_U16S, SHA256_WORD_U8S, -}; -use crate::constraint_word_addition; - -/// Expects the message to be padded to a multiple of 512 bits -#[derive(Clone, Debug)] -pub struct Sha256Air { - pub bitwise_lookup_bus: BitwiseOperationLookupBus, - pub row_idx_encoder: Encoder, - /// Internal bus for self-interactions in this AIR. - bus: PermutationCheckBus, -} - -impl Sha256Air { - pub fn new(bitwise_lookup_bus: BitwiseOperationLookupBus, self_bus_idx: BusIndex) -> Self { - Self { - bitwise_lookup_bus, - row_idx_encoder: Encoder::new(18, 2, false), - bus: PermutationCheckBus::new(self_bus_idx), - } - } -} - -impl BaseAir for Sha256Air { - fn width(&self) -> usize { - max( - Sha256RoundCols::::width(), - Sha256DigestCols::::width(), - ) - } -} - -impl SubAir for Sha256Air { - /// The start column for the sub-air to use - type AirContext<'a> - = usize - where - Self: 'a, - AB: 'a, - ::Var: 'a, - ::Expr: 'a; - - fn eval<'a>(&'a self, builder: &'a mut AB, start_col: Self::AirContext<'a>) - where - ::Var: 'a, - ::Expr: 'a, - { - self.eval_row(builder, start_col); - self.eval_transitions(builder, start_col); - } -} - -impl Sha256Air { - /// Implements the single row constraints (i.e. imposes constraints only on local) - /// Implements some sanity constraints on the row index, flags, and work variables - fn eval_row(&self, builder: &mut AB, start_col: usize) { - let main = builder.main(); - let local = main.row_slice(0); - - // Doesn't matter which column struct we use here as we are only interested in the common - // columns - let local_cols: &Sha256DigestCols = - local[start_col..start_col + SHA256_DIGEST_WIDTH].borrow(); - let flags = &local_cols.flags; - builder.assert_bool(flags.is_round_row); - builder.assert_bool(flags.is_first_4_rows); - builder.assert_bool(flags.is_digest_row); - builder.assert_bool(flags.is_round_row + flags.is_digest_row); - builder.assert_bool(flags.is_last_block); - - self.row_idx_encoder - .eval(builder, &local_cols.flags.row_idx); - builder.assert_one( - self.row_idx_encoder - .contains_flag_range::(&local_cols.flags.row_idx, 0..=17), - ); - builder.assert_eq( - self.row_idx_encoder - .contains_flag_range::(&local_cols.flags.row_idx, 0..=3), - flags.is_first_4_rows, - ); - builder.assert_eq( - self.row_idx_encoder - .contains_flag_range::(&local_cols.flags.row_idx, 0..=15), - flags.is_round_row, - ); - builder.assert_eq( - self.row_idx_encoder - .contains_flag::(&local_cols.flags.row_idx, &[16]), - flags.is_digest_row, - ); - // If padding row we want the row_idx to be 17 - builder.assert_eq( - self.row_idx_encoder - .contains_flag::(&local_cols.flags.row_idx, &[17]), - flags.is_padding_row(), - ); - - // Constrain a, e, being composed of bits: we make sure a and e are always in the same place - // in the trace matrix Note: this has to be true for every row, even padding rows - for i in 0..SHA256_ROUNDS_PER_ROW { - for j in 0..SHA256_WORD_BITS { - builder.assert_bool(local_cols.hash.a[i][j]); - builder.assert_bool(local_cols.hash.e[i][j]); - } - } - } - - /// Implements constraints for a digest row that ensure proper state transitions between blocks - /// This validates that: - /// The work variables are correctly initialized for the next message block - /// For the last message block, the initial state matches SHA256_H constants - fn eval_digest_row( - &self, - builder: &mut AB, - local: &Sha256RoundCols, - next: &Sha256DigestCols, - ) { - // Check that if this is the last row of a message or an inpadding row, the hash should be - // the [SHA256_H] - for i in 0..SHA256_ROUNDS_PER_ROW { - let a = next.hash.a[i].map(|x| x.into()); - let e = next.hash.e[i].map(|x| x.into()); - for j in 0..SHA256_WORD_U16S { - let a_limb = compose::(&a[j * 16..(j + 1) * 16], 1); - let e_limb = compose::(&e[j * 16..(j + 1) * 16], 1); - - // If it is a padding row or the last row of a message, the `hash` should be the - // [SHA256_H] - builder - .when( - next.flags.is_padding_row() - + next.flags.is_last_block * next.flags.is_digest_row, - ) - .assert_eq( - a_limb, - AB::Expr::from_canonical_u32( - u32_into_limbs::<2>(SHA256_H[SHA256_ROUNDS_PER_ROW - i - 1])[j], - ), - ); - - builder - .when( - next.flags.is_padding_row() - + next.flags.is_last_block * next.flags.is_digest_row, - ) - .assert_eq( - e_limb, - AB::Expr::from_canonical_u32( - u32_into_limbs::<2>(SHA256_H[SHA256_ROUNDS_PER_ROW - i + 3])[j], - ), - ); - } - } - - // Check if last row of a non-last block, the `hash` should be equal to the final hash of - // the current block - for i in 0..SHA256_ROUNDS_PER_ROW { - let prev_a = next.hash.a[i].map(|x| x.into()); - let prev_e = next.hash.e[i].map(|x| x.into()); - let cur_a = next.final_hash[SHA256_ROUNDS_PER_ROW - i - 1].map(|x| x.into()); - - let cur_e = next.final_hash[SHA256_ROUNDS_PER_ROW - i + 3].map(|x| x.into()); - for j in 0..SHA256_WORD_U8S { - let prev_a_limb = compose::(&prev_a[j * 8..(j + 1) * 8], 1); - let prev_e_limb = compose::(&prev_e[j * 8..(j + 1) * 8], 1); - - builder - .when(not(next.flags.is_last_block) * next.flags.is_digest_row) - .assert_eq(prev_a_limb, cur_a[j].clone()); - - builder - .when(not(next.flags.is_last_block) * next.flags.is_digest_row) - .assert_eq(prev_e_limb, cur_e[j].clone()); - } - } - - // Assert that the previous hash + work vars == final hash. - // That is, `next.prev_hash[i] + local.work_vars[i] == next.final_hash[i]` - // where addition is done modulo 2^32 - for i in 0..SHA256_HASH_WORDS { - let mut carry = AB::Expr::ZERO; - for j in 0..SHA256_WORD_U16S { - let work_var_limb = if i < SHA256_ROUNDS_PER_ROW { - compose::( - &local.work_vars.a[SHA256_ROUNDS_PER_ROW - 1 - i][j * 16..(j + 1) * 16], - 1, - ) - } else { - compose::( - &local.work_vars.e[SHA256_ROUNDS_PER_ROW + 3 - i][j * 16..(j + 1) * 16], - 1, - ) - }; - let final_hash_limb = - compose::(&next.final_hash[i][j * 2..(j + 1) * 2], 8); - - carry = AB::Expr::from(AB::F::from_canonical_u32(1 << 16).inverse()) - * (next.prev_hash[i][j] + work_var_limb + carry - final_hash_limb); - builder - .when(next.flags.is_digest_row) - .assert_bool(carry.clone()); - } - // constrain the final hash limbs two at a time since we can do two checks per - // interaction - for chunk in next.final_hash[i].chunks(2) { - self.bitwise_lookup_bus - .send_range(chunk[0], chunk[1]) - .eval(builder, next.flags.is_digest_row); - } - } - } - - fn eval_transitions(&self, builder: &mut AB, start_col: usize) { - let main = builder.main(); - let local = main.row_slice(0); - let next = main.row_slice(1); - - // Doesn't matter what column structs we use here - let local_cols: &Sha256RoundCols = - local[start_col..start_col + SHA256_ROUND_WIDTH].borrow(); - let next_cols: &Sha256RoundCols = - next[start_col..start_col + SHA256_ROUND_WIDTH].borrow(); - - let local_is_padding_row = local_cols.flags.is_padding_row(); - // Note that there will always be a padding row in the trace since the unpadded height is a - // multiple of 17. So the next row is padding iff the current block is the last - // block in the trace. - let next_is_padding_row = next_cols.flags.is_padding_row(); - - // We check that the very last block has `is_last_block` set to true, which guarantees that - // there is at least one complete message. If other digest rows have `is_last_block` set to - // true, then the trace will be interpreted as containing multiple messages. - builder - .when(next_is_padding_row.clone()) - .when(local_cols.flags.is_digest_row) - .assert_one(local_cols.flags.is_last_block); - // If we are in a round row, the next row cannot be a padding row - builder - .when(local_cols.flags.is_round_row) - .assert_zero(next_is_padding_row.clone()); - // The first row must be a round row - builder - .when_first_row() - .assert_one(local_cols.flags.is_round_row); - // If we are in a padding row, the next row must also be a padding row - builder - .when_transition() - .when(local_is_padding_row.clone()) - .assert_one(next_is_padding_row.clone()); - // If we are in a digest row, the next row cannot be a digest row - builder - .when(local_cols.flags.is_digest_row) - .assert_zero(next_cols.flags.is_digest_row); - // Constrain how much the row index changes by - // round->round: 1 - // round->digest: 1 - // digest->round: -16 - // digest->padding: 1 - // padding->padding: 0 - // Other transitions are not allowed by the above constraints - let delta = local_cols.flags.is_round_row * AB::Expr::ONE - + local_cols.flags.is_digest_row - * next_cols.flags.is_round_row - * AB::Expr::from_canonical_u32(16) - * AB::Expr::NEG_ONE - + local_cols.flags.is_digest_row * next_is_padding_row.clone() * AB::Expr::ONE; - - let local_row_idx = self.row_idx_encoder.flag_with_val::( - &local_cols.flags.row_idx, - &(0..18).map(|i| (i, i)).collect::>(), - ); - let next_row_idx = self.row_idx_encoder.flag_with_val::( - &next_cols.flags.row_idx, - &(0..18).map(|i| (i, i)).collect::>(), - ); - - builder - .when_transition() - .assert_eq(local_row_idx.clone() + delta, next_row_idx.clone()); - builder.when_first_row().assert_zero(local_row_idx); - - // Constrain the global block index - // We set the global block index to 0 for padding rows - // Starting with 1 so it is not the same as the padding rows - - // Global block index is 1 on first row - builder - .when_first_row() - .assert_one(local_cols.flags.global_block_idx); - - // Global block index is constant on all rows in a block - builder.when(local_cols.flags.is_round_row).assert_eq( - local_cols.flags.global_block_idx, - next_cols.flags.global_block_idx, - ); - // Global block index increases by 1 between blocks - builder - .when_transition() - .when(local_cols.flags.is_digest_row) - .when(next_cols.flags.is_round_row) - .assert_eq( - local_cols.flags.global_block_idx + AB::Expr::ONE, - next_cols.flags.global_block_idx, - ); - // Global block index is 0 on padding rows - builder - .when(local_is_padding_row.clone()) - .assert_zero(local_cols.flags.global_block_idx); - - // Constrain the local block index - // We set the local block index to 0 for padding rows - - // Local block index is constant on all rows in a block - // and its value on padding rows is equal to its value on the first block - builder.when(not(local_cols.flags.is_digest_row)).assert_eq( - local_cols.flags.local_block_idx, - next_cols.flags.local_block_idx, - ); - // Local block index increases by 1 between blocks in the same message - builder - .when(local_cols.flags.is_digest_row) - .when(not(local_cols.flags.is_last_block)) - .assert_eq( - local_cols.flags.local_block_idx + AB::Expr::ONE, - next_cols.flags.local_block_idx, - ); - // Local block index is 0 on padding rows - // Combined with the above, this means that the local block index is 0 in the first block - builder - .when(local_cols.flags.is_digest_row) - .when(local_cols.flags.is_last_block) - .assert_zero(next_cols.flags.local_block_idx); - - self.eval_message_schedule::(builder, local_cols, next_cols); - self.eval_work_vars::(builder, local_cols, next_cols); - let next_cols: &Sha256DigestCols = - next[start_col..start_col + SHA256_DIGEST_WIDTH].borrow(); - self.eval_digest_row(builder, local_cols, next_cols); - let local_cols: &Sha256DigestCols = - local[start_col..start_col + SHA256_DIGEST_WIDTH].borrow(); - self.eval_prev_hash::(builder, local_cols, next_is_padding_row); - } - - /// Constrains that the next block's `prev_hash` is equal to the current block's `hash` - /// Note: the constraining is done by interactions with the chip itself on every digest row - fn eval_prev_hash( - &self, - builder: &mut AB, - local: &Sha256DigestCols, - is_last_block_of_trace: AB::Expr, /* note this indicates the last block of the trace, - * not the last block of the message */ - ) { - // Constrain that next block's `prev_hash` is equal to the current block's `hash` - let composed_hash: [[::Expr; SHA256_WORD_U16S]; SHA256_HASH_WORDS] = - array::from_fn(|i| { - let hash_bits = if i < SHA256_ROUNDS_PER_ROW { - local.hash.a[SHA256_ROUNDS_PER_ROW - 1 - i].map(|x| x.into()) - } else { - local.hash.e[SHA256_ROUNDS_PER_ROW + 3 - i].map(|x| x.into()) - }; - array::from_fn(|j| compose::(&hash_bits[j * 16..(j + 1) * 16], 1)) - }); - // Need to handle the case if this is the very last block of the trace matrix - let next_global_block_idx = select( - is_last_block_of_trace, - AB::Expr::ONE, - local.flags.global_block_idx + AB::Expr::ONE, - ); - // The following interactions constrain certain values from block to block - self.bus.send( - builder, - composed_hash - .into_iter() - .flatten() - .chain(once(next_global_block_idx)), - local.flags.is_digest_row, - ); - - self.bus.receive( - builder, - local - .prev_hash - .into_iter() - .flatten() - .map(|x| x.into()) - .chain(once(local.flags.global_block_idx.into())), - local.flags.is_digest_row, - ); - } - - /// Constrain the message schedule additions for `next` row - /// Note: For every addition we need to constrain the following for each of [SHA256_WORD_U16S] - /// limbs sig_1(w_{t-2})[i] + w_{t-7}[i] + sig_0(w_{t-15})[i] + w_{t-16}[i] + - /// carry_w[t][i-1] - carry_w[t][i] * 2^16 - w_t[i] == 0 Refer to [https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.180-4.pdf] - fn eval_message_schedule( - &self, - builder: &mut AB, - local: &Sha256RoundCols, - next: &Sha256RoundCols, - ) { - // This `w` array contains 8 message schedule words - w_{idx}, ..., w_{idx+7} for some idx - let w = [local.message_schedule.w, next.message_schedule.w].concat(); - - // Constrain `w_3` for `next` row - for i in 0..SHA256_ROUNDS_PER_ROW - 1 { - // here we constrain the w_3 of the i_th word of the next row - // w_3 of next is w[i+4-3] = w[i+1] - let w_3 = w[i + 1].map(|x| x.into()); - let expected_w_3 = next.schedule_helper.w_3[i]; - for j in 0..SHA256_WORD_U16S { - let w_3_limb = compose::(&w_3[j * 16..(j + 1) * 16], 1); - builder - .when(local.flags.is_round_row) - .assert_eq(w_3_limb, expected_w_3[j].into()); - } - } - - // Constrain intermed for `next` row - // We will only constrain intermed_12 for rows [3, 14], and let it be unconstrained for - // other rows Other rows should put the needed value in intermed_12 to make the - // below summation constraint hold - let is_row_3_14 = self - .row_idx_encoder - .contains_flag_range::(&next.flags.row_idx, 3..=14); - // We will only constrain intermed_8 for rows [2, 13], and let it unconstrained for other - // rows - let is_row_2_13 = self - .row_idx_encoder - .contains_flag_range::(&next.flags.row_idx, 2..=13); - for i in 0..SHA256_ROUNDS_PER_ROW { - // w_idx - let w_idx = w[i].map(|x| x.into()); - // sig_0(w_{idx+1}) - let sig_w = small_sig0_field::(&w[i + 1]); - for j in 0..SHA256_WORD_U16S { - let w_idx_limb = compose::(&w_idx[j * 16..(j + 1) * 16], 1); - let sig_w_limb = compose::(&sig_w[j * 16..(j + 1) * 16], 1); - - // We would like to constrain this only on rows 0..16, but we can't do a conditional - // check because the degree is already 3. So we must fill in - // `intermed_4` with dummy values on rows 0 and 16 to ensure the constraint holds on - // these rows. - builder.when_transition().assert_eq( - next.schedule_helper.intermed_4[i][j], - w_idx_limb + sig_w_limb, - ); - - builder.when(is_row_2_13.clone()).assert_eq( - next.schedule_helper.intermed_8[i][j], - local.schedule_helper.intermed_4[i][j], - ); - - builder.when(is_row_3_14.clone()).assert_eq( - next.schedule_helper.intermed_12[i][j], - local.schedule_helper.intermed_8[i][j], - ); - } - } - - // Constrain the message schedule additions for `next` row - for i in 0..SHA256_ROUNDS_PER_ROW { - // Note, here by w_{t} we mean the i_th word of the `next` row - // w_{t-7} - let w_7 = if i < 3 { - local.schedule_helper.w_3[i].map(|x| x.into()) - } else { - let w_3 = w[i - 3].map(|x| x.into()); - array::from_fn(|j| compose::(&w_3[j * 16..(j + 1) * 16], 1)) - }; - // sig_0(w_{t-15}) + w_{t-16} - let intermed_16 = local.schedule_helper.intermed_12[i].map(|x| x.into()); - - let carries = array::from_fn(|j| { - next.message_schedule.carry_or_buffer[i][j * 2] - + AB::Expr::TWO * next.message_schedule.carry_or_buffer[i][j * 2 + 1] - }); - - // Constrain `W_{idx} = sig_1(W_{idx-2}) + W_{idx-7} + sig_0(W_{idx-15}) + W_{idx-16}` - // We would like to constrain this only on rows 4..16, but we can't do a conditional - // check because the degree of sum is already 3 So we must fill in - // `intermed_12` with dummy values on rows 0..3 and 15 and 16 to ensure the constraint - // holds on rows 0..4 and 16. Note that the dummy value goes in the previous - // row to make the current row's constraint hold. - constraint_word_addition( - // Note: here we can't do a conditional check because the degree of sum is already - // 3 - &mut builder.when_transition(), - &[&small_sig1_field::(&w[i + 2])], - &[&w_7, &intermed_16], - &w[i + 4], - &carries, - ); - - for j in 0..SHA256_WORD_U16S { - // When on rows 4..16 message schedule carries should be 0 or 1 - let is_row_4_15 = next.flags.is_round_row - next.flags.is_first_4_rows; - builder - .when(is_row_4_15.clone()) - .assert_bool(next.message_schedule.carry_or_buffer[i][j * 2]); - builder - .when(is_row_4_15) - .assert_bool(next.message_schedule.carry_or_buffer[i][j * 2 + 1]); - } - // Constrain w being composed of bits - for j in 0..SHA256_WORD_BITS { - builder - .when(next.flags.is_round_row) - .assert_bool(next.message_schedule.w[i][j]); - } - } - } - - /// Constrain the work vars on `next` row according to the sha256 documentation - /// Refer to [https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.180-4.pdf] - fn eval_work_vars( - &self, - builder: &mut AB, - local: &Sha256RoundCols, - next: &Sha256RoundCols, - ) { - let a = [local.work_vars.a, next.work_vars.a].concat(); - let e = [local.work_vars.e, next.work_vars.e].concat(); - for i in 0..SHA256_ROUNDS_PER_ROW { - for j in 0..SHA256_WORD_U16S { - // Although we need carry_a <= 6 and carry_e <= 5, constraining carry_a, carry_e in - // [0, 2^8) is enough to prevent overflow and ensure the soundness - // of the addition we want to check - self.bitwise_lookup_bus - .send_range(local.work_vars.carry_a[i][j], local.work_vars.carry_e[i][j]) - .eval(builder, local.flags.is_round_row); - } - - let w_limbs = array::from_fn(|j| { - compose::(&next.message_schedule.w[i][j * 16..(j + 1) * 16], 1) - * next.flags.is_round_row - }); - let k_limbs = array::from_fn(|j| { - self.row_idx_encoder.flag_with_val::( - &next.flags.row_idx, - &(0..16) - .map(|rw_idx| { - ( - rw_idx, - u32_into_limbs::( - SHA256_K[rw_idx * SHA256_ROUNDS_PER_ROW + i], - )[j] as usize, - ) - }) - .collect::>(), - ) - }); - - // Constrain `a = h + sig_1(e) + ch(e, f, g) + K + W + sig_0(a) + Maj(a, b, c)` - // We have to enforce this constraint on all rows since the degree of the constraint is - // already 3. So, we must fill in `carry_a` with dummy values on digest rows - // to ensure the constraint holds. - constraint_word_addition( - builder, - &[ - &e[i].map(|x| x.into()), // previous `h` - &big_sig1_field::(&e[i + 3]), // sig_1 of previous `e` - &ch_field::(&e[i + 3], &e[i + 2], &e[i + 1]), /* Ch of previous - * `e`, `f`, `g` */ - &big_sig0_field::(&a[i + 3]), // sig_0 of previous `a` - &maj_field::(&a[i + 3], &a[i + 2], &a[i + 1]), /* Maj of previous - * a, b, c */ - ], - &[&w_limbs, &k_limbs], // K and W - &a[i + 4], // new `a` - &next.work_vars.carry_a[i], // carries of addition - ); - - // Constrain `e = d + h + sig_1(e) + ch(e, f, g) + K + W` - // We have to enforce this constraint on all rows since the degree of the constraint is - // already 3. So, we must fill in `carry_e` with dummy values on digest rows - // to ensure the constraint holds. - constraint_word_addition( - builder, - &[ - &a[i].map(|x| x.into()), // previous `d` - &e[i].map(|x| x.into()), // previous `h` - &big_sig1_field::(&e[i + 3]), /* sig_1 of previous - * `e` */ - &ch_field::(&e[i + 3], &e[i + 2], &e[i + 1]), /* Ch of previous - * `e`, `f`, `g` */ - ], - &[&w_limbs, &k_limbs], // K and W - &e[i + 4], // new `e` - &next.work_vars.carry_e[i], // carries of addition - ); - } - } -} diff --git a/crates/circuits/sha256-air/src/columns.rs b/crates/circuits/sha256-air/src/columns.rs deleted file mode 100644 index 1c735394c3..0000000000 --- a/crates/circuits/sha256-air/src/columns.rs +++ /dev/null @@ -1,140 +0,0 @@ -//! WARNING: the order of fields in the structs is important, do not change it - -use openvm_circuit_primitives::{utils::not, AlignedBorrow}; -use openvm_stark_backend::p3_field::FieldAlgebra; - -use super::{ - SHA256_HASH_WORDS, SHA256_ROUNDS_PER_ROW, SHA256_ROW_VAR_CNT, SHA256_WORD_BITS, - SHA256_WORD_U16S, SHA256_WORD_U8S, -}; - -/// In each SHA256 block: -/// - First 16 rows use Sha256RoundCols -/// - Final row uses Sha256DigestCols -/// -/// Note that for soundness, we require that there is always a padding row after the last digest row -/// in the trace. Right now, this is true because the unpadded height is a multiple of 17, and thus -/// not a power of 2. -/// -/// Sha256RoundCols and Sha256DigestCols share the same first 3 fields: -/// - flags -/// - work_vars/hash (same type, different name) -/// - schedule_helper -/// -/// This design allows for: -/// 1. Common constraints to work on either struct type by accessing these shared fields -/// 2. Specific constraints to use the appropriate struct, with flags helping to do conditional -/// constraints -/// -/// Note that the `Sha256WorkVarsCols` field it is used for different purposes in the two structs. -#[repr(C)] -#[derive(Clone, Copy, Debug, AlignedBorrow)] -pub struct Sha256RoundCols { - pub flags: Sha256FlagsCols, - /// Stores the current state of the working variables - pub work_vars: Sha256WorkVarsCols, - pub schedule_helper: Sha256MessageHelperCols, - pub message_schedule: Sha256MessageScheduleCols, -} - -#[repr(C)] -#[derive(Clone, Copy, Debug, AlignedBorrow)] -pub struct Sha256DigestCols { - pub flags: Sha256FlagsCols, - /// Will serve as previous hash values for the next block. - /// - on non-last blocks, this is the final hash of the current block - /// - on last blocks, this is the initial state constants, SHA256_H. - /// The work variables constraints are applied on all rows, so `carry_a` and `carry_e` - /// must be filled in with dummy values to ensure these constraints hold. - pub hash: Sha256WorkVarsCols, - pub schedule_helper: Sha256MessageHelperCols, - /// The actual final hash values of the given block - /// Note: the above `hash` will be equal to `final_hash` unless we are on the last block - pub final_hash: [[T; SHA256_WORD_U8S]; SHA256_HASH_WORDS], - /// The final hash of the previous block - /// Note: will be constrained using interactions with the chip itself - pub prev_hash: [[T; SHA256_WORD_U16S]; SHA256_HASH_WORDS], -} - -#[repr(C)] -#[derive(Clone, Copy, Debug, AlignedBorrow)] -pub struct Sha256MessageScheduleCols { - /// The message schedule words as 32-bit integers - /// The first 16 words will be the message data - pub w: [[T; SHA256_WORD_BITS]; SHA256_ROUNDS_PER_ROW], - /// Will be message schedule carries for rows 4..16 and a buffer for rows 0..4 to be used - /// freely by wrapper chips Note: carries are 2 bit numbers represented using 2 cells as - /// individual bits - pub carry_or_buffer: [[T; SHA256_WORD_U8S]; SHA256_ROUNDS_PER_ROW], -} - -#[repr(C)] -#[derive(Clone, Copy, Debug, AlignedBorrow)] -pub struct Sha256WorkVarsCols { - /// `a` and `e` after each iteration as 32-bits - pub a: [[T; SHA256_WORD_BITS]; SHA256_ROUNDS_PER_ROW], - pub e: [[T; SHA256_WORD_BITS]; SHA256_ROUNDS_PER_ROW], - /// The carry's used for addition during each iteration when computing `a` and `e` - pub carry_a: [[T; SHA256_WORD_U16S]; SHA256_ROUNDS_PER_ROW], - pub carry_e: [[T; SHA256_WORD_U16S]; SHA256_ROUNDS_PER_ROW], -} - -/// These are the columns that are used to help with the message schedule additions -/// Note: these need to be correctly assigned for every row even on padding rows -#[repr(C)] -#[derive(Clone, Copy, Debug, AlignedBorrow)] -pub struct Sha256MessageHelperCols { - /// The following are used to move data forward to constrain the message schedule additions - /// The value of `w` (message schedule word) from 3 rounds ago - /// In general, `w_i` means `w` from `i` rounds ago - pub w_3: [[T; SHA256_WORD_U16S]; SHA256_ROUNDS_PER_ROW - 1], - /// Here intermediate(i) = w_i + sig_0(w_{i+1}) - /// Intermed_t represents the intermediate t rounds ago - /// This is needed to constrain the message schedule, since we can only constrain on two rows - /// at a time - pub intermed_4: [[T; SHA256_WORD_U16S]; SHA256_ROUNDS_PER_ROW], - pub intermed_8: [[T; SHA256_WORD_U16S]; SHA256_ROUNDS_PER_ROW], - pub intermed_12: [[T; SHA256_WORD_U16S]; SHA256_ROUNDS_PER_ROW], -} - -#[repr(C)] -#[derive(Clone, Copy, Debug, AlignedBorrow)] -pub struct Sha256FlagsCols { - /// A flag that indicates if the current row is among the first 16 rows of a block. - pub is_round_row: T, - /// A flag that indicates if the current row is among the first 4 rows of a block. - pub is_first_4_rows: T, - /// A flag that indicates if the current row is the last (17th) row of a block. - pub is_digest_row: T, - // A flag that indicates if the current row is the last block of the message. - // This flag is only used in digest rows. - pub is_last_block: T, - /// We will encode the row index [0..17) using 5 cells - pub row_idx: [T; SHA256_ROW_VAR_CNT], - /// The index of the current block in the trace starting at 1. - /// Set to 0 on padding rows. - pub global_block_idx: T, - /// The index of the current block in the current message starting at 0. - /// Resets after every message. - /// Set to 0 on padding rows. - pub local_block_idx: T, -} - -impl> Sha256FlagsCols { - // This refers to the padding rows that are added to the air to make the trace length a power of - // 2. Not to be confused with the padding added to messages as part of the SHA hash - // function. - pub fn is_not_padding_row(&self) -> O { - self.is_round_row + self.is_digest_row - } - - // This refers to the padding rows that are added to the air to make the trace length a power of - // 2. Not to be confused with the padding added to messages as part of the SHA hash - // function. - pub fn is_padding_row(&self) -> O - where - O: FieldAlgebra, - { - not(self.is_not_padding_row()) - } -} diff --git a/crates/circuits/sha256-air/src/tests.rs b/crates/circuits/sha256-air/src/tests.rs deleted file mode 100644 index 903b7b0695..0000000000 --- a/crates/circuits/sha256-air/src/tests.rs +++ /dev/null @@ -1,233 +0,0 @@ -use std::{array, borrow::BorrowMut, cmp::max, sync::Arc}; - -use openvm_circuit::arch::{ - instructions::riscv::RV32_CELL_BITS, - testing::{VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}, -}; -use openvm_circuit_primitives::{ - bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, - SubAir, -}; -use openvm_stark_backend::{ - config::{StarkGenericConfig, Val}, - interaction::{BusIndex, InteractionBuilder}, - p3_air::{Air, BaseAir}, - p3_field::{Field, FieldAlgebra, PrimeField32}, - p3_maybe_rayon::prelude::{IndexedParallelIterator, ParallelIterator, ParallelSliceMut}, - prover::types::AirProofInput, - rap::{get_air_name, BaseAirWithPublicValues, PartitionedBaseAir}, - AirRef, Chip, ChipUsageGetter, -}; -use openvm_stark_sdk::utils::create_seeded_rng; -use rand::Rng; - -use crate::{ - compose, small_sig0_field, Sha256Air, Sha256RoundCols, SHA256_BLOCK_U8S, SHA256_DIGEST_WIDTH, - SHA256_HASH_WORDS, SHA256_ROUNDS_PER_ROW, SHA256_ROUND_WIDTH, SHA256_ROWS_PER_BLOCK, - SHA256_WORD_U16S, SHA256_WORD_U8S, -}; - -// A wrapper AIR purely for testing purposes -#[derive(Clone, Debug)] -pub struct Sha256TestAir { - pub sub_air: Sha256Air, -} - -impl BaseAirWithPublicValues for Sha256TestAir {} -impl PartitionedBaseAir for Sha256TestAir {} -impl BaseAir for Sha256TestAir { - fn width(&self) -> usize { - >::width(&self.sub_air) - } -} - -impl Air for Sha256TestAir { - fn eval(&self, builder: &mut AB) { - self.sub_air.eval(builder, 0); - } -} - -// A wrapper Chip purely for testing purposes -pub struct Sha256TestChip { - pub air: Sha256TestAir, - pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<8>, - pub records: Vec<([u8; SHA256_BLOCK_U8S], bool)>, -} - -impl Chip for Sha256TestChip -where - Val: PrimeField32, -{ - fn air(&self) -> AirRef { - Arc::new(self.air.clone()) - } - - fn generate_air_proof_input(self) -> AirProofInput { - let trace = crate::generate_trace::>( - &self.air.sub_air, - self.bitwise_lookup_chip.clone(), - self.records, - ); - AirProofInput::simple_no_pis(trace) - } -} - -impl ChipUsageGetter for Sha256TestChip { - fn air_name(&self) -> String { - get_air_name(&self.air) - } - fn current_trace_height(&self) -> usize { - self.records.len() * SHA256_ROWS_PER_BLOCK - } - - fn trace_width(&self) -> usize { - max(SHA256_ROUND_WIDTH, SHA256_DIGEST_WIDTH) - } -} - -const SELF_BUS_IDX: BusIndex = 28; -#[test] -fn rand_sha256_test() { - let mut rng = create_seeded_rng(); - let tester = VmChipTestBuilder::default(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let len = rng.gen_range(1..100); - let random_records: Vec<_> = (0..len) - .map(|i| { - ( - array::from_fn(|_| rng.gen::()), - rng.gen::() || i == len - 1, - ) - }) - .collect(); - let chip = Sha256TestChip { - air: Sha256TestAir { - sub_air: Sha256Air::new(bitwise_bus, SELF_BUS_IDX), - }, - bitwise_lookup_chip: bitwise_chip.clone(), - records: random_records, - }; - - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); - tester.simple_test().expect("Verification failed"); -} - -// A wrapper Chip to test that the final_hash is properly constrained. -// This chip implements a malicious trace gen that violates the final_hash constraints. -pub struct Sha256TestBadFinalHashChip { - pub air: Sha256TestAir, - pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<8>, - pub records: Vec<([u8; SHA256_BLOCK_U8S], bool)>, -} - -impl Chip for Sha256TestBadFinalHashChip -where - Val: PrimeField32, -{ - fn air(&self) -> AirRef { - Arc::new(self.air.clone()) - } - - fn generate_air_proof_input(self) -> AirProofInput { - let mut trace = crate::generate_trace::>( - &self.air.sub_air, - self.bitwise_lookup_chip.clone(), - self.records.clone(), - ); - - // Set the final_hash in the digest row of the last block of each hash to zero. - // That is, every hash that this chip does will result in a final_hash of zero. - for (i, row) in self.records.iter().enumerate() { - if row.1 { - let last_digest_row_idx = (i + 1) * SHA256_ROWS_PER_BLOCK - 1; - let last_digest_row: &mut crate::Sha256DigestCols> = - trace.row_mut(last_digest_row_idx)[..SHA256_DIGEST_WIDTH].borrow_mut(); - // Set the final_hash to all zeros - for i in 0..SHA256_HASH_WORDS { - for j in 0..SHA256_WORD_U8S { - last_digest_row.final_hash[i][j] = Val::::ZERO; - } - } - - let (last_round_row, last_digest_row) = - trace.row_pair_mut(last_digest_row_idx - 1, last_digest_row_idx); - let last_round_row: &mut crate::Sha256RoundCols> = - last_round_row.borrow_mut(); - let last_digest_row: &mut crate::Sha256RoundCols> = - last_digest_row.borrow_mut(); - // fix the intermed_4 for the digest row - generate_intermed_4(last_round_row, last_digest_row); - } - } - - let non_padded_height = self.records.len() * SHA256_ROWS_PER_BLOCK; - let width = >>::width(&self.air.sub_air); - // recalculate the missing cells (second pass of generate_trace) - trace.values[width..] - .par_chunks_mut(width * SHA256_ROWS_PER_BLOCK) - .take(non_padded_height / SHA256_ROWS_PER_BLOCK) - .for_each(|chunk| { - self.air.sub_air.generate_missing_cells(chunk, width, 0); - }); - - AirProofInput::simple_no_pis(trace) - } -} - -// Copy of private method in Sha256Air used for testing -/// Puts the correct intermed_4 in the `next_row` -fn generate_intermed_4( - local_cols: &Sha256RoundCols, - next_cols: &mut Sha256RoundCols, -) { - let w = [local_cols.message_schedule.w, next_cols.message_schedule.w].concat(); - let w_limbs: Vec<[F; SHA256_WORD_U16S]> = w - .iter() - .map(|x| array::from_fn(|i| compose::(&x[i * 16..(i + 1) * 16], 1))) - .collect(); - for i in 0..SHA256_ROUNDS_PER_ROW { - let sig_w = small_sig0_field::(&w[i + 1]); - let sig_w_limbs: [F; SHA256_WORD_U16S] = - array::from_fn(|j| compose::(&sig_w[j * 16..(j + 1) * 16], 1)); - for (j, sig_w_limb) in sig_w_limbs.iter().enumerate() { - next_cols.schedule_helper.intermed_4[i][j] = w_limbs[i][j] + *sig_w_limb; - } - } -} - -impl ChipUsageGetter for Sha256TestBadFinalHashChip { - fn air_name(&self) -> String { - get_air_name(&self.air) - } - fn current_trace_height(&self) -> usize { - self.records.len() * SHA256_ROWS_PER_BLOCK - } - - fn trace_width(&self) -> usize { - max(SHA256_ROUND_WIDTH, SHA256_DIGEST_WIDTH) - } -} - -#[test] -#[should_panic] -fn test_sha256_final_hash_constraints() { - let mut rng = create_seeded_rng(); - let tester = VmChipTestBuilder::default(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let len = rng.gen_range(1..100); - let random_records: Vec<_> = (0..len) - .map(|_| (array::from_fn(|_| rng.gen::()), true)) - .collect(); - let chip = Sha256TestBadFinalHashChip { - air: Sha256TestAir { - sub_air: Sha256Air::new(bitwise_bus, SELF_BUS_IDX), - }, - bitwise_lookup_chip: bitwise_chip.clone(), - records: random_records, - }; - - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); - tester.simple_test().expect("Verification failed"); -} diff --git a/crates/circuits/sha256-air/src/trace.rs b/crates/circuits/sha256-air/src/trace.rs deleted file mode 100644 index eaf9174f50..0000000000 --- a/crates/circuits/sha256-air/src/trace.rs +++ /dev/null @@ -1,573 +0,0 @@ -use std::{array, borrow::BorrowMut, ops::Range}; - -use openvm_circuit_primitives::{ - bitwise_op_lookup::SharedBitwiseOperationLookupChip, utils::next_power_of_two_or_zero, -}; -use openvm_stark_backend::{ - p3_air::BaseAir, p3_field::PrimeField32, p3_matrix::dense::RowMajorMatrix, - p3_maybe_rayon::prelude::*, -}; -use sha2::{compress256, digest::generic_array::GenericArray}; - -use super::{ - air::Sha256Air, big_sig0_field, big_sig1_field, ch_field, columns::Sha256RoundCols, compose, - get_flag_pt_array, maj_field, small_sig0_field, small_sig1_field, SHA256_BLOCK_WORDS, - SHA256_DIGEST_WIDTH, SHA256_HASH_WORDS, SHA256_ROUND_WIDTH, -}; -use crate::{ - big_sig0, big_sig1, ch, columns::Sha256DigestCols, limbs_into_u32, maj, small_sig0, small_sig1, - u32_into_limbs, SHA256_BLOCK_U8S, SHA256_BUFFER_SIZE, SHA256_H, SHA256_INVALID_CARRY_A, - SHA256_INVALID_CARRY_E, SHA256_K, SHA256_ROUNDS_PER_ROW, SHA256_ROWS_PER_BLOCK, - SHA256_WORD_BITS, SHA256_WORD_U16S, SHA256_WORD_U8S, -}; - -/// The trace generation of SHA256 should be done in two passes. -/// The first pass should do `get_block_trace` for every block and generate the invalid rows through -/// `get_default_row` The second pass should go through all the blocks and call -/// `generate_missing_cells` -impl Sha256Air { - /// This function takes the input_message (padding not handled), the previous hash, - /// and returns the new hash after processing the block input - pub fn get_block_hash( - prev_hash: &[u32; SHA256_HASH_WORDS], - input: [u8; SHA256_BLOCK_U8S], - ) -> [u32; SHA256_HASH_WORDS] { - let mut new_hash = *prev_hash; - let input_array = [GenericArray::from(input)]; - compress256(&mut new_hash, &input_array); - new_hash - } - - /// This function takes a 512-bit chunk of the input message (padding not handled), the previous - /// hash, a flag indicating if it's the last block, the global block index, the local block - /// index, and the buffer values that will be put in rows 0..4. - /// Will populate the given `trace` with the trace of the block, where the width of the trace is - /// `trace_width` and the starting column for the `Sha256Air` is `trace_start_col`. - /// **Note**: this function only generates some of the required trace. Another pass is required, - /// refer to [`Self::generate_missing_cells`] for details. - #[allow(clippy::too_many_arguments)] - pub fn generate_block_trace( - &self, - trace: &mut [F], - trace_width: usize, - trace_start_col: usize, - input: &[u32; SHA256_BLOCK_WORDS], - bitwise_lookup_chip: SharedBitwiseOperationLookupChip<8>, - prev_hash: &[u32; SHA256_HASH_WORDS], - is_last_block: bool, - global_block_idx: u32, - local_block_idx: u32, - buffer_vals: &[[F; SHA256_BUFFER_SIZE]; 4], - ) { - #[cfg(debug_assertions)] - { - assert!(trace.len() == trace_width * SHA256_ROWS_PER_BLOCK); - assert!(trace_start_col + super::SHA256_WIDTH <= trace_width); - assert!(self.bitwise_lookup_bus == bitwise_lookup_chip.bus()); - if local_block_idx == 0 { - assert!(*prev_hash == SHA256_H); - } - } - let get_range = |start: usize, len: usize| -> Range { start..start + len }; - let mut message_schedule = [0u32; 64]; - message_schedule[..input.len()].copy_from_slice(input); - let mut work_vars = *prev_hash; - for (i, row) in trace.chunks_exact_mut(trace_width).enumerate() { - // doing the 64 rounds in 16 rows - if i < 16 { - let cols: &mut Sha256RoundCols = - row[get_range(trace_start_col, SHA256_ROUND_WIDTH)].borrow_mut(); - cols.flags.is_round_row = F::ONE; - cols.flags.is_first_4_rows = if i < 4 { F::ONE } else { F::ZERO }; - cols.flags.is_digest_row = F::ZERO; - cols.flags.is_last_block = F::from_bool(is_last_block); - cols.flags.row_idx = - get_flag_pt_array(&self.row_idx_encoder, i).map(F::from_canonical_u32); - cols.flags.global_block_idx = F::from_canonical_u32(global_block_idx); - cols.flags.local_block_idx = F::from_canonical_u32(local_block_idx); - - // W_idx = M_idx - if i < SHA256_ROWS_PER_BLOCK / SHA256_ROUNDS_PER_ROW { - for j in 0..SHA256_ROUNDS_PER_ROW { - cols.message_schedule.w[j] = u32_into_limbs::( - input[i * SHA256_ROUNDS_PER_ROW + j], - ) - .map(F::from_canonical_u32); - cols.message_schedule.carry_or_buffer[j] = - array::from_fn(|k| buffer_vals[i][j * SHA256_WORD_U16S * 2 + k]); - } - } - // W_idx = SIG1(W_{idx-2}) + W_{idx-7} + SIG0(W_{idx-15}) + W_{idx-16} - else { - for j in 0..SHA256_ROUNDS_PER_ROW { - let idx = i * SHA256_ROUNDS_PER_ROW + j; - let nums: [u32; 4] = [ - small_sig1(message_schedule[idx - 2]), - message_schedule[idx - 7], - small_sig0(message_schedule[idx - 15]), - message_schedule[idx - 16], - ]; - let w: u32 = nums.iter().fold(0, |acc, &num| acc.wrapping_add(num)); - cols.message_schedule.w[j] = - u32_into_limbs::(w).map(F::from_canonical_u32); - - let nums_limbs = nums - .iter() - .map(|x| u32_into_limbs::(*x)) - .collect::>(); - let w_limbs = u32_into_limbs::(w); - - // fill in the carrys - for k in 0..SHA256_WORD_U16S { - let mut sum = nums_limbs.iter().fold(0, |acc, num| acc + num[k]); - if k > 0 { - sum += (cols.message_schedule.carry_or_buffer[j][k * 2 - 2] - + F::TWO * cols.message_schedule.carry_or_buffer[j][k * 2 - 1]) - .as_canonical_u32(); - } - let carry = (sum - w_limbs[k]) >> 16; - cols.message_schedule.carry_or_buffer[j][k * 2] = - F::from_canonical_u32(carry & 1); - cols.message_schedule.carry_or_buffer[j][k * 2 + 1] = - F::from_canonical_u32(carry >> 1); - } - // update the message schedule - message_schedule[idx] = w; - } - } - // fill in the work variables - for j in 0..SHA256_ROUNDS_PER_ROW { - // t1 = h + SIG1(e) + ch(e, f, g) + K_idx + W_idx - let t1 = [ - work_vars[7], - big_sig1(work_vars[4]), - ch(work_vars[4], work_vars[5], work_vars[6]), - SHA256_K[i * SHA256_ROUNDS_PER_ROW + j], - limbs_into_u32(cols.message_schedule.w[j].map(|f| f.as_canonical_u32())), - ]; - let t1_sum: u32 = t1.iter().fold(0, |acc, &num| acc.wrapping_add(num)); - - // t2 = SIG0(a) + maj(a, b, c) - let t2 = [ - big_sig0(work_vars[0]), - maj(work_vars[0], work_vars[1], work_vars[2]), - ]; - - let t2_sum: u32 = t2.iter().fold(0, |acc, &num| acc.wrapping_add(num)); - - // e = d + t1 - let e = work_vars[3].wrapping_add(t1_sum); - cols.work_vars.e[j] = - u32_into_limbs::(e).map(F::from_canonical_u32); - let e_limbs = u32_into_limbs::(e); - // a = t1 + t2 - let a = t1_sum.wrapping_add(t2_sum); - cols.work_vars.a[j] = - u32_into_limbs::(a).map(F::from_canonical_u32); - let a_limbs = u32_into_limbs::(a); - // fill in the carrys - for k in 0..SHA256_WORD_U16S { - let t1_limb = t1.iter().fold(0, |acc, &num| { - acc + u32_into_limbs::(num)[k] - }); - let t2_limb = t2.iter().fold(0, |acc, &num| { - acc + u32_into_limbs::(num)[k] - }); - - let mut e_limb = - t1_limb + u32_into_limbs::(work_vars[3])[k]; - let mut a_limb = t1_limb + t2_limb; - if k > 0 { - a_limb += cols.work_vars.carry_a[j][k - 1].as_canonical_u32(); - e_limb += cols.work_vars.carry_e[j][k - 1].as_canonical_u32(); - } - let carry_a = (a_limb - a_limbs[k]) >> 16; - let carry_e = (e_limb - e_limbs[k]) >> 16; - cols.work_vars.carry_a[j][k] = F::from_canonical_u32(carry_a); - cols.work_vars.carry_e[j][k] = F::from_canonical_u32(carry_e); - bitwise_lookup_chip.request_range(carry_a, carry_e); - } - - // update working variables - work_vars[7] = work_vars[6]; - work_vars[6] = work_vars[5]; - work_vars[5] = work_vars[4]; - work_vars[4] = e; - work_vars[3] = work_vars[2]; - work_vars[2] = work_vars[1]; - work_vars[1] = work_vars[0]; - work_vars[0] = a; - } - - // filling w_3 and intermed_4 here and the rest later - if i > 0 { - for j in 0..SHA256_ROUNDS_PER_ROW { - let idx = i * SHA256_ROUNDS_PER_ROW + j; - let w_4 = u32_into_limbs::(message_schedule[idx - 4]); - let sig_0_w_3 = u32_into_limbs::(small_sig0( - message_schedule[idx - 3], - )); - cols.schedule_helper.intermed_4[j] = - array::from_fn(|k| F::from_canonical_u32(w_4[k] + sig_0_w_3[k])); - if j < SHA256_ROUNDS_PER_ROW - 1 { - let w_3 = message_schedule[idx - 3]; - cols.schedule_helper.w_3[j] = - u32_into_limbs::(w_3).map(F::from_canonical_u32); - } - } - } - } - // generate the digest row - else { - let cols: &mut Sha256DigestCols = - row[get_range(trace_start_col, SHA256_DIGEST_WIDTH)].borrow_mut(); - for j in 0..SHA256_ROUNDS_PER_ROW - 1 { - let w_3 = message_schedule[i * SHA256_ROUNDS_PER_ROW + j - 3]; - cols.schedule_helper.w_3[j] = - u32_into_limbs::(w_3).map(F::from_canonical_u32); - } - cols.flags.is_round_row = F::ZERO; - cols.flags.is_first_4_rows = F::ZERO; - cols.flags.is_digest_row = F::ONE; - cols.flags.is_last_block = F::from_bool(is_last_block); - cols.flags.row_idx = - get_flag_pt_array(&self.row_idx_encoder, 16).map(F::from_canonical_u32); - cols.flags.global_block_idx = F::from_canonical_u32(global_block_idx); - - cols.flags.local_block_idx = F::from_canonical_u32(local_block_idx); - let final_hash: [u32; SHA256_HASH_WORDS] = - array::from_fn(|i| work_vars[i].wrapping_add(prev_hash[i])); - let final_hash_limbs: [[u32; SHA256_WORD_U8S]; SHA256_HASH_WORDS] = - array::from_fn(|i| u32_into_limbs::(final_hash[i])); - // need to ensure final hash limbs are bytes, in order for - // prev_hash[i] + work_vars[i] == final_hash[i] - // to be constrained correctly - for word in final_hash_limbs.iter() { - for chunk in word.chunks(2) { - bitwise_lookup_chip.request_range(chunk[0], chunk[1]); - } - } - cols.final_hash = array::from_fn(|i| { - array::from_fn(|j| F::from_canonical_u32(final_hash_limbs[i][j])) - }); - cols.prev_hash = prev_hash - .map(|f| u32_into_limbs::(f).map(F::from_canonical_u32)); - let hash = if is_last_block { - SHA256_H.map(u32_into_limbs::) - } else { - cols.final_hash - .map(|f| limbs_into_u32(f.map(|x| x.as_canonical_u32()))) - .map(u32_into_limbs::) - } - .map(|x| x.map(F::from_canonical_u32)); - - for i in 0..SHA256_ROUNDS_PER_ROW { - cols.hash.a[i] = hash[SHA256_ROUNDS_PER_ROW - i - 1]; - cols.hash.e[i] = hash[SHA256_ROUNDS_PER_ROW - i + 3]; - } - } - } - - for i in 0..SHA256_ROWS_PER_BLOCK - 1 { - let rows = &mut trace[i * trace_width..(i + 2) * trace_width]; - let (local, next) = rows.split_at_mut(trace_width); - let local_cols: &mut Sha256RoundCols = - local[get_range(trace_start_col, SHA256_ROUND_WIDTH)].borrow_mut(); - let next_cols: &mut Sha256RoundCols = - next[get_range(trace_start_col, SHA256_ROUND_WIDTH)].borrow_mut(); - if i > 0 { - for j in 0..SHA256_ROUNDS_PER_ROW { - next_cols.schedule_helper.intermed_8[j] = - local_cols.schedule_helper.intermed_4[j]; - if (2..SHA256_ROWS_PER_BLOCK - 3).contains(&i) { - next_cols.schedule_helper.intermed_12[j] = - local_cols.schedule_helper.intermed_8[j]; - } - } - } - if i == SHA256_ROWS_PER_BLOCK - 2 { - // `next` is a digest row. - // Fill in `carry_a` and `carry_e` with dummy values so the constraints on `a` and - // `e` hold. - Self::generate_carry_ae(local_cols, next_cols); - // Fill in row 16's `intermed_4` with dummy values so the message schedule - // constraints holds on that row - Self::generate_intermed_4(local_cols, next_cols); - } - if i <= 2 { - // i is in 0..3. - // Fill in `local.intermed_12` with dummy values so the message schedule constraints - // hold on rows 1..4. - Self::generate_intermed_12(local_cols, next_cols); - } - } - } - - /// This function will fill in the cells that we couldn't do during the first pass. - /// This function should be called only after `generate_block_trace` was called for all blocks - /// And [`Self::generate_default_row`] is called for all invalid rows - /// Will populate the missing values of `trace`, where the width of the trace is `trace_width` - /// and the starting column for the `Sha256Air` is `trace_start_col`. - /// Note: `trace` needs to be the rows 1..17 of a block and the first row of the next block - pub fn generate_missing_cells( - &self, - trace: &mut [F], - trace_width: usize, - trace_start_col: usize, - ) { - // Here row_17 = next blocks row 0 - let rows_15_17 = &mut trace[14 * trace_width..17 * trace_width]; - let (row_15, row_16_17) = rows_15_17.split_at_mut(trace_width); - let (row_16, row_17) = row_16_17.split_at_mut(trace_width); - let cols_15: &mut Sha256RoundCols = - row_15[trace_start_col..trace_start_col + SHA256_ROUND_WIDTH].borrow_mut(); - let cols_16: &mut Sha256RoundCols = - row_16[trace_start_col..trace_start_col + SHA256_ROUND_WIDTH].borrow_mut(); - let cols_17: &mut Sha256RoundCols = - row_17[trace_start_col..trace_start_col + SHA256_ROUND_WIDTH].borrow_mut(); - // Fill in row 15's `intermed_12` with dummy values so the message schedule constraints - // holds on row 16 - Self::generate_intermed_12(cols_15, cols_16); - // Fill in row 16's `intermed_12` with dummy values so the message schedule constraints - // holds on the next block's row 0 - Self::generate_intermed_12(cols_16, cols_17); - // Fill in row 0's `intermed_4` with dummy values so the message schedule constraints holds - // on that row - Self::generate_intermed_4(cols_16, cols_17); - } - - /// Fills the `cols` as a padding row - /// Note: we still need to correctly fill in the hash values, carries and intermeds - pub fn generate_default_row(self: &Sha256Air, cols: &mut Sha256RoundCols) { - cols.flags.is_round_row = F::ZERO; - cols.flags.is_first_4_rows = F::ZERO; - cols.flags.is_digest_row = F::ZERO; - - cols.flags.is_last_block = F::ZERO; - cols.flags.global_block_idx = F::ZERO; - cols.flags.row_idx = - get_flag_pt_array(&self.row_idx_encoder, 17).map(F::from_canonical_u32); - cols.flags.local_block_idx = F::ZERO; - - cols.message_schedule.w = [[F::ZERO; SHA256_WORD_BITS]; SHA256_ROUNDS_PER_ROW]; - cols.message_schedule.carry_or_buffer = - [[F::ZERO; SHA256_WORD_U16S * 2]; SHA256_ROUNDS_PER_ROW]; - - let hash = SHA256_H - .map(u32_into_limbs::) - .map(|x| x.map(F::from_canonical_u32)); - - for i in 0..SHA256_ROUNDS_PER_ROW { - cols.work_vars.a[i] = hash[SHA256_ROUNDS_PER_ROW - i - 1]; - cols.work_vars.e[i] = hash[SHA256_ROUNDS_PER_ROW - i + 3]; - } - - cols.work_vars.carry_a = array::from_fn(|i| { - array::from_fn(|j| F::from_canonical_u32(SHA256_INVALID_CARRY_A[i][j])) - }); - cols.work_vars.carry_e = array::from_fn(|i| { - array::from_fn(|j| F::from_canonical_u32(SHA256_INVALID_CARRY_E[i][j])) - }); - } - - /// The following functions do the calculations in native field since they will be called on - /// padding rows which can overflow and we need to make sure it matches the AIR constraints - /// Puts the correct carrys in the `next_row`, the resulting carrys can be out of bound - fn generate_carry_ae( - local_cols: &Sha256RoundCols, - next_cols: &mut Sha256RoundCols, - ) { - let a = [local_cols.work_vars.a, next_cols.work_vars.a].concat(); - let e = [local_cols.work_vars.e, next_cols.work_vars.e].concat(); - for i in 0..SHA256_ROUNDS_PER_ROW { - let cur_a = a[i + 4]; - let sig_a = big_sig0_field::(&a[i + 3]); - let maj_abc = maj_field::(&a[i + 3], &a[i + 2], &a[i + 1]); - let d = a[i]; - let cur_e = e[i + 4]; - let sig_e = big_sig1_field::(&e[i + 3]); - let ch_efg = ch_field::(&e[i + 3], &e[i + 2], &e[i + 1]); - let h = e[i]; - - let t1 = [h, sig_e, ch_efg]; - let t2 = [sig_a, maj_abc]; - for j in 0..SHA256_WORD_U16S { - let t1_limb_sum = t1.iter().fold(F::ZERO, |acc, x| { - acc + compose::(&x[j * 16..(j + 1) * 16], 1) - }); - let t2_limb_sum = t2.iter().fold(F::ZERO, |acc, x| { - acc + compose::(&x[j * 16..(j + 1) * 16], 1) - }); - let d_limb = compose::(&d[j * 16..(j + 1) * 16], 1); - let cur_a_limb = compose::(&cur_a[j * 16..(j + 1) * 16], 1); - let cur_e_limb = compose::(&cur_e[j * 16..(j + 1) * 16], 1); - let sum = d_limb - + t1_limb_sum - + if j == 0 { - F::ZERO - } else { - next_cols.work_vars.carry_e[i][j - 1] - } - - cur_e_limb; - let carry_e = sum * (F::from_canonical_u32(1 << 16).inverse()); - - let sum = t1_limb_sum - + t2_limb_sum - + if j == 0 { - F::ZERO - } else { - next_cols.work_vars.carry_a[i][j - 1] - } - - cur_a_limb; - let carry_a = sum * (F::from_canonical_u32(1 << 16).inverse()); - next_cols.work_vars.carry_e[i][j] = carry_e; - next_cols.work_vars.carry_a[i][j] = carry_a; - } - } - } - - /// Puts the correct intermed_4 in the `next_row` - fn generate_intermed_4( - local_cols: &Sha256RoundCols, - next_cols: &mut Sha256RoundCols, - ) { - let w = [local_cols.message_schedule.w, next_cols.message_schedule.w].concat(); - let w_limbs: Vec<[F; SHA256_WORD_U16S]> = w - .iter() - .map(|x| array::from_fn(|i| compose::(&x[i * 16..(i + 1) * 16], 1))) - .collect(); - for i in 0..SHA256_ROUNDS_PER_ROW { - let sig_w = small_sig0_field::(&w[i + 1]); - let sig_w_limbs: [F; SHA256_WORD_U16S] = - array::from_fn(|j| compose::(&sig_w[j * 16..(j + 1) * 16], 1)); - for (j, sig_w_limb) in sig_w_limbs.iter().enumerate() { - next_cols.schedule_helper.intermed_4[i][j] = w_limbs[i][j] + *sig_w_limb; - } - } - } - - /// Puts the needed intermed_12 in the `local_row` - fn generate_intermed_12( - local_cols: &mut Sha256RoundCols, - next_cols: &Sha256RoundCols, - ) { - let w = [local_cols.message_schedule.w, next_cols.message_schedule.w].concat(); - let w_limbs: Vec<[F; SHA256_WORD_U16S]> = w - .iter() - .map(|x| array::from_fn(|i| compose::(&x[i * 16..(i + 1) * 16], 1))) - .collect(); - for i in 0..SHA256_ROUNDS_PER_ROW { - // sig_1(w_{t-2}) - let sig_w_2: [F; SHA256_WORD_U16S] = array::from_fn(|j| { - compose::(&small_sig1_field::(&w[i + 2])[j * 16..(j + 1) * 16], 1) - }); - // w_{t-7} - let w_7 = if i < 3 { - local_cols.schedule_helper.w_3[i] - } else { - w_limbs[i - 3] - }; - // w_t - let w_cur = w_limbs[i + 4]; - for j in 0..SHA256_WORD_U16S { - let carry = next_cols.message_schedule.carry_or_buffer[i][j * 2] - + F::TWO * next_cols.message_schedule.carry_or_buffer[i][j * 2 + 1]; - let sum = sig_w_2[j] + w_7[j] - carry * F::from_canonical_u32(1 << 16) - w_cur[j] - + if j > 0 { - next_cols.message_schedule.carry_or_buffer[i][j * 2 - 2] - + F::from_canonical_u32(2) - * next_cols.message_schedule.carry_or_buffer[i][j * 2 - 1] - } else { - F::ZERO - }; - local_cols.schedule_helper.intermed_12[i][j] = -sum; - } - } - } -} - -/// `records` consists of pairs of `(input_block, is_last_block)`. -pub fn generate_trace( - sub_air: &Sha256Air, - bitwise_lookup_chip: SharedBitwiseOperationLookupChip<8>, - records: Vec<([u8; SHA256_BLOCK_U8S], bool)>, -) -> RowMajorMatrix { - let non_padded_height = records.len() * SHA256_ROWS_PER_BLOCK; - let height = next_power_of_two_or_zero(non_padded_height); - let width = >::width(sub_air); - let mut values = F::zero_vec(height * width); - - struct BlockContext { - prev_hash: [u32; 8], - local_block_idx: u32, - global_block_idx: u32, - input: [u8; SHA256_BLOCK_U8S], - is_last_block: bool, - } - let mut block_ctx: Vec = Vec::with_capacity(records.len()); - let mut prev_hash = SHA256_H; - let mut local_block_idx = 0; - let mut global_block_idx = 1; - for (input, is_last_block) in records { - block_ctx.push(BlockContext { - prev_hash, - local_block_idx, - global_block_idx, - input, - is_last_block, - }); - global_block_idx += 1; - if is_last_block { - local_block_idx = 0; - prev_hash = SHA256_H; - } else { - local_block_idx += 1; - prev_hash = Sha256Air::get_block_hash(&prev_hash, input); - } - } - // first pass - values - .par_chunks_exact_mut(width * SHA256_ROWS_PER_BLOCK) - .zip(block_ctx) - .for_each(|(block, ctx)| { - let BlockContext { - prev_hash, - local_block_idx, - global_block_idx, - input, - is_last_block, - } = ctx; - let input_words = array::from_fn(|i| { - limbs_into_u32::(array::from_fn(|j| { - input[(i + 1) * SHA256_WORD_U8S - j - 1] as u32 - })) - }); - sub_air.generate_block_trace( - block, - width, - 0, - &input_words, - bitwise_lookup_chip.clone(), - &prev_hash, - is_last_block, - global_block_idx, - local_block_idx, - &[[F::ZERO; 16]; 4], - ); - }); - // second pass: padding rows - values[width * non_padded_height..] - .par_chunks_mut(width) - .for_each(|row| { - let cols: &mut Sha256RoundCols = row.borrow_mut(); - sub_air.generate_default_row(cols); - }); - // second pass: non-padding rows - values[width..] - .par_chunks_mut(width * SHA256_ROWS_PER_BLOCK) - .take(non_padded_height / SHA256_ROWS_PER_BLOCK) - .for_each(|chunk| { - sub_air.generate_missing_cells(chunk, width, 0); - }); - RowMajorMatrix::new(values, width) -} diff --git a/crates/circuits/sha256-air/src/utils.rs b/crates/circuits/sha256-air/src/utils.rs deleted file mode 100644 index abf8b6e7f2..0000000000 --- a/crates/circuits/sha256-air/src/utils.rs +++ /dev/null @@ -1,268 +0,0 @@ -use std::array; - -pub use openvm_circuit_primitives::utils::compose; -use openvm_circuit_primitives::{ - encoder::Encoder, - utils::{not, select}, -}; -use openvm_stark_backend::{p3_air::AirBuilder, p3_field::FieldAlgebra}; -use rand::{rngs::StdRng, Rng}; - -use super::{Sha256DigestCols, Sha256RoundCols}; - -// ==== Do not change these constants! ==== -/// Number of bits in a SHA256 word -pub const SHA256_WORD_BITS: usize = 32; -/// Number of 16-bit limbs in a SHA256 word -pub const SHA256_WORD_U16S: usize = SHA256_WORD_BITS / 16; -/// Number of 8-bit limbs in a SHA256 word -pub const SHA256_WORD_U8S: usize = SHA256_WORD_BITS / 8; -/// Number of words in a SHA256 block -pub const SHA256_BLOCK_WORDS: usize = 16; -/// Number of cells in a SHA256 block -pub const SHA256_BLOCK_U8S: usize = SHA256_BLOCK_WORDS * SHA256_WORD_U8S; -/// Number of bits in a SHA256 block -pub const SHA256_BLOCK_BITS: usize = SHA256_BLOCK_WORDS * SHA256_WORD_BITS; -/// Number of rows per block -pub const SHA256_ROWS_PER_BLOCK: usize = 17; -/// Number of rounds per row -pub const SHA256_ROUNDS_PER_ROW: usize = 4; -/// Number of words in a SHA256 hash -pub const SHA256_HASH_WORDS: usize = 8; -/// Number of vars needed to encode the row index with [Encoder] -pub const SHA256_ROW_VAR_CNT: usize = 5; -/// Width of the Sha256RoundCols -pub const SHA256_ROUND_WIDTH: usize = Sha256RoundCols::::width(); -/// Width of the Sha256DigestCols -pub const SHA256_DIGEST_WIDTH: usize = Sha256DigestCols::::width(); -/// Size of the buffer of the first 4 rows of a block (each row's size) -pub const SHA256_BUFFER_SIZE: usize = SHA256_ROUNDS_PER_ROW * SHA256_WORD_U16S * 2; -/// Width of the Sha256Cols -pub const SHA256_WIDTH: usize = if SHA256_ROUND_WIDTH > SHA256_DIGEST_WIDTH { - SHA256_ROUND_WIDTH -} else { - SHA256_DIGEST_WIDTH -}; -/// We can notice that `carry_a`'s and `carry_e`'s are always the same on invalid rows -/// To optimize the trace generation of invalid rows, we have those values precomputed here -pub(crate) const SHA256_INVALID_CARRY_A: [[u32; SHA256_WORD_U16S]; SHA256_ROUNDS_PER_ROW] = [ - [1230919683, 1162494304], - [266373122, 1282901987], - [1519718403, 1008990871], - [923381762, 330807052], -]; -pub(crate) const SHA256_INVALID_CARRY_E: [[u32; SHA256_WORD_U16S]; SHA256_ROUNDS_PER_ROW] = [ - [204933122, 1994683449], - [443873282, 1544639095], - [719953922, 1888246508], - [194580482, 1075725211], -]; -/// SHA256 constant K's -pub const SHA256_K: [u32; 64] = [ - 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5, - 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174, - 0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da, - 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967, - 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, - 0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070, - 0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3, - 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2, -]; - -/// SHA256 initial hash values -pub const SHA256_H: [u32; 8] = [ - 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19, -]; - -/// Convert a u32 into a list of limbs in little endian -pub fn u32_into_limbs(num: u32) -> [u32; NUM_LIMBS] { - let limb_bits = 32 / NUM_LIMBS; - array::from_fn(|i| (num >> (limb_bits * i)) & ((1 << limb_bits) - 1)) -} - -/// Convert a list of limbs in little endian into a u32 -pub fn limbs_into_u32(limbs: [u32; NUM_LIMBS]) -> u32 { - let limb_bits = 32 / NUM_LIMBS; - limbs - .iter() - .rev() - .fold(0, |acc, &limb| (acc << limb_bits) | limb) -} - -/// Rotates `bits` right by `n` bits, assumes `bits` is in little-endian -#[inline] -pub(crate) fn rotr( - bits: &[impl Into + Clone; SHA256_WORD_BITS], - n: usize, -) -> [F; SHA256_WORD_BITS] { - array::from_fn(|i| bits[(i + n) % SHA256_WORD_BITS].clone().into()) -} - -/// Shifts `bits` right by `n` bits, assumes `bits` is in little-endian -#[inline] -pub(crate) fn shr( - bits: &[impl Into + Clone; SHA256_WORD_BITS], - n: usize, -) -> [F; SHA256_WORD_BITS] { - array::from_fn(|i| { - if i + n < SHA256_WORD_BITS { - bits[i + n].clone().into() - } else { - F::ZERO - } - }) -} - -/// Computes x ^ y ^ z, where x, y, z are assumed to be boolean -#[inline] -pub(crate) fn xor_bit( - x: impl Into, - y: impl Into, - z: impl Into, -) -> F { - let (x, y, z) = (x.into(), y.into(), z.into()); - (x.clone() * y.clone() * z.clone()) - + (x.clone() * not::(y.clone()) * not::(z.clone())) - + (not::(x.clone()) * y.clone() * not::(z.clone())) - + (not::(x) * not::(y) * z) -} - -/// Computes x ^ y ^ z, where x, y, z are [SHA256_WORD_BITS] bit numbers -#[inline] -pub(crate) fn xor( - x: &[impl Into + Clone; SHA256_WORD_BITS], - y: &[impl Into + Clone; SHA256_WORD_BITS], - z: &[impl Into + Clone; SHA256_WORD_BITS], -) -> [F; SHA256_WORD_BITS] { - array::from_fn(|i| xor_bit(x[i].clone(), y[i].clone(), z[i].clone())) -} - -/// Choose function from SHA256 -#[inline] -pub fn ch(x: u32, y: u32, z: u32) -> u32 { - (x & y) ^ ((!x) & z) -} - -/// Computes Ch(x,y,z), where x, y, z are [SHA256_WORD_BITS] bit numbers -#[inline] -pub(crate) fn ch_field( - x: &[impl Into + Clone; SHA256_WORD_BITS], - y: &[impl Into + Clone; SHA256_WORD_BITS], - z: &[impl Into + Clone; SHA256_WORD_BITS], -) -> [F; SHA256_WORD_BITS] { - array::from_fn(|i| select(x[i].clone(), y[i].clone(), z[i].clone())) -} - -/// Majority function from SHA256 -pub fn maj(x: u32, y: u32, z: u32) -> u32 { - (x & y) ^ (x & z) ^ (y & z) -} - -/// Computes Maj(x,y,z), where x, y, z are [SHA256_WORD_BITS] bit numbers -#[inline] -pub(crate) fn maj_field( - x: &[impl Into + Clone; SHA256_WORD_BITS], - y: &[impl Into + Clone; SHA256_WORD_BITS], - z: &[impl Into + Clone; SHA256_WORD_BITS], -) -> [F; SHA256_WORD_BITS] { - array::from_fn(|i| { - let (x, y, z) = ( - x[i].clone().into(), - y[i].clone().into(), - z[i].clone().into(), - ); - x.clone() * y.clone() + x.clone() * z.clone() + y.clone() * z.clone() - F::TWO * x * y * z - }) -} - -/// Big sigma_0 function from SHA256 -pub fn big_sig0(x: u32) -> u32 { - x.rotate_right(2) ^ x.rotate_right(13) ^ x.rotate_right(22) -} - -/// Computes BigSigma0(x), where x is a [SHA256_WORD_BITS] bit number in little-endian -#[inline] -pub(crate) fn big_sig0_field( - x: &[impl Into + Clone; SHA256_WORD_BITS], -) -> [F; SHA256_WORD_BITS] { - xor(&rotr::(x, 2), &rotr::(x, 13), &rotr::(x, 22)) -} - -/// Big sigma_1 function from SHA256 -pub fn big_sig1(x: u32) -> u32 { - x.rotate_right(6) ^ x.rotate_right(11) ^ x.rotate_right(25) -} - -/// Computes BigSigma1(x), where x is a [SHA256_WORD_BITS] bit number in little-endian -#[inline] -pub(crate) fn big_sig1_field( - x: &[impl Into + Clone; SHA256_WORD_BITS], -) -> [F; SHA256_WORD_BITS] { - xor(&rotr::(x, 6), &rotr::(x, 11), &rotr::(x, 25)) -} - -/// Small sigma_0 function from SHA256 -pub fn small_sig0(x: u32) -> u32 { - x.rotate_right(7) ^ x.rotate_right(18) ^ (x >> 3) -} - -/// Computes SmallSigma0(x), where x is a [SHA256_WORD_BITS] bit number in little-endian -#[inline] -pub(crate) fn small_sig0_field( - x: &[impl Into + Clone; SHA256_WORD_BITS], -) -> [F; SHA256_WORD_BITS] { - xor(&rotr::(x, 7), &rotr::(x, 18), &shr::(x, 3)) -} - -/// Small sigma_1 function from SHA256 -pub fn small_sig1(x: u32) -> u32 { - x.rotate_right(17) ^ x.rotate_right(19) ^ (x >> 10) -} - -/// Computes SmallSigma1(x), where x is a [SHA256_WORD_BITS] bit number in little-endian -#[inline] -pub(crate) fn small_sig1_field( - x: &[impl Into + Clone; SHA256_WORD_BITS], -) -> [F; SHA256_WORD_BITS] { - xor(&rotr::(x, 17), &rotr::(x, 19), &shr::(x, 10)) -} - -/// Generate a random message of a given length -pub fn get_random_message(rng: &mut StdRng, len: usize) -> Vec { - let mut random_message: Vec = vec![0u8; len]; - rng.fill(&mut random_message[..]); - random_message -} - -/// Wrapper of `get_flag_pt` to get the flag pointer as an array -pub fn get_flag_pt_array(encoder: &Encoder, flag_idx: usize) -> [u32; N] { - encoder.get_flag_pt(flag_idx).try_into().unwrap() -} - -/// Constrain the addition of [SHA256_WORD_BITS] bit words in 16-bit limbs -/// It takes in the terms some in bits some in 16-bit limbs, -/// the expected sum in bits and the carries -pub fn constraint_word_addition( - builder: &mut AB, - terms_bits: &[&[impl Into + Clone; SHA256_WORD_BITS]], - terms_limb: &[&[impl Into + Clone; SHA256_WORD_U16S]], - expected_sum: &[impl Into + Clone; SHA256_WORD_BITS], - carries: &[impl Into + Clone; SHA256_WORD_U16S], -) { - for i in 0..SHA256_WORD_U16S { - let mut limb_sum = if i == 0 { - AB::Expr::ZERO - } else { - carries[i - 1].clone().into() - }; - for term in terms_bits { - limb_sum += compose::(&term[i * 16..(i + 1) * 16], 1); - } - for term in terms_limb { - limb_sum += term[i].clone().into(); - } - let expected_sum_limb = compose::(&expected_sum[i * 16..(i + 1) * 16], 1) - + carries[i].clone().into() * AB::Expr::from_canonical_u32(1 << 16); - builder.assert_eq(limb_sum, expected_sum_limb); - } -} diff --git a/crates/cli/Cargo.toml b/crates/cli/Cargo.toml index f105c588a3..6667cb52ae 100644 --- a/crates/cli/Cargo.toml +++ b/crates/cli/Cargo.toml @@ -55,3 +55,4 @@ mimalloc = ["openvm-sdk/mimalloc"] jemalloc = ["openvm-sdk/jemalloc"] jemalloc-prof = ["openvm-sdk/jemalloc-prof"] nightly-features = ["openvm-sdk/nightly-features"] +ci = [] diff --git a/crates/cli/tests/app_e2e.rs b/crates/cli/tests/app_e2e.rs index 779a28d72a..3fbf679774 100644 --- a/crates/cli/tests/app_e2e.rs +++ b/crates/cli/tests/app_e2e.rs @@ -3,16 +3,27 @@ use std::{ fs::{self, read_to_string}, path::Path, process::Command, + sync::OnceLock, }; use eyre::Result; use itertools::Itertools; use tempfile::tempdir; +fn install_cli() { + static FORCE_INSTALL: OnceLock = OnceLock::new(); + FORCE_INSTALL.get_or_init(|| { + if !matches!(env::var("SKIP_INSTALL"), Ok(x) if !x.is_empty()) { + run_cmd("cargo", &["install", "--path", ".", "--force"]).unwrap(); + } + true + }); +} + #[test] fn test_cli_app_e2e() -> Result<()> { let temp_dir = tempdir()?; - run_cmd("cargo", &["install", "--path", ".", "--force"])?; + install_cli(); let exe_path = "tests/programs/fibonacci/target/openvm/release/openvm-cli-example-test.vmexe"; let temp_pk = temp_dir.path().join("app.pk"); let temp_vk = temp_dir.path().join("app.vk"); @@ -87,7 +98,7 @@ fn test_cli_app_e2e() -> Result<()> { #[test] fn test_cli_app_e2e_simplified() -> Result<()> { - run_cmd("cargo", &["install", "--path", ".", "--force"])?; + install_cli(); run_cmd( "cargo", &[ @@ -128,7 +139,7 @@ fn test_cli_init_build() -> Result<()> { let temp_path = temp_dir.path(); let config_path = temp_path.join("openvm.toml"); let manifest_path = temp_path.join("Cargo.toml"); - run_cmd("cargo", &["install", "--path", ".", "--force"])?; + install_cli(); // Cargo will not respect patches if run within a workspace run_cmd( diff --git a/crates/continuations/src/verifier/leaf/mod.rs b/crates/continuations/src/verifier/leaf/mod.rs index 969733ba41..7ab08cdb0b 100644 --- a/crates/continuations/src/verifier/leaf/mod.rs +++ b/crates/continuations/src/verifier/leaf/mod.rs @@ -1,6 +1,6 @@ use openvm_circuit::{ - arch::{instructions::program::Program, SystemConfig}, - system::memory::tree::public_values::PUBLIC_VALUES_ADDRESS_SPACE_OFFSET, + arch::{instructions::program::Program, SystemConfig, ADDR_SPACE_OFFSET}, + system::memory::merkle::public_values::PUBLIC_VALUES_ADDRESS_SPACE_OFFSET, }; use openvm_native_compiler::{conversion::CompilerOptions, prelude::*}; use openvm_native_recursion::{ @@ -113,7 +113,7 @@ impl LeafVmVerifierConfig { builder: &mut Builder, ) -> ([Felt; DIGEST_SIZE], [Felt; DIGEST_SIZE]) { let memory_dimensions = self.app_system_config.memory_config.memory_dimensions(); - let pv_as = PUBLIC_VALUES_ADDRESS_SPACE_OFFSET + memory_dimensions.as_offset; + let pv_as = PUBLIC_VALUES_ADDRESS_SPACE_OFFSET + ADDR_SPACE_OFFSET; let pv_start_idx = memory_dimensions.label_to_index((pv_as, 0)); let pv_height = log2_strict_usize(self.app_system_config.num_public_values / DIGEST_SIZE); let proof_len = memory_dimensions.overall_height() - pv_height; diff --git a/crates/continuations/src/verifier/leaf/types.rs b/crates/continuations/src/verifier/leaf/types.rs index 16aca7a169..d47b36f248 100644 --- a/crates/continuations/src/verifier/leaf/types.rs +++ b/crates/continuations/src/verifier/leaf/types.rs @@ -1,6 +1,6 @@ use derivative::Derivative; use openvm_circuit::{ - arch::ContinuationVmProof, system::memory::tree::public_values::UserPublicValuesProof, + arch::ContinuationVmProof, system::memory::merkle::public_values::UserPublicValuesProof, }; use openvm_native_compiler::ir::DIGEST_SIZE; use openvm_stark_sdk::{ diff --git a/crates/prof/src/aggregate.rs b/crates/prof/src/aggregate.rs index 047d16b30a..a712917a6a 100644 --- a/crates/prof/src/aggregate.rs +++ b/crates/prof/src/aggregate.rs @@ -165,11 +165,14 @@ impl AggregateMetrics { let mut total_par_proof_time = MdTableCell::new(0.0, Some(0.0)); for (group_name, metrics) in &self.by_group { let stats = metrics.get(PROOF_TIME_LABEL); - let execute_stats = metrics.get(EXECUTE_TIME_LABEL); + let execute_metered_stats = metrics.get(EXECUTE_METERED_TIME_LABEL); + let execute_e1_stats = metrics.get(EXECUTE_E1_TIME_LABEL); if stats.is_none() { continue; } - let stats = stats.unwrap(); + let stats = stats.unwrap_or_else(|| { + panic!("Missing proof time statistics for group '{}'", group_name) + }); let mut sum = stats.sum; let mut max = stats.max; // convert ms to s @@ -184,26 +187,61 @@ impl AggregateMetrics { if !group_name.contains("keygen") { // Proving time in keygen group is dummy and not part of total. total_proof_time.val += sum.val; - *total_proof_time.diff.as_mut().unwrap() += sum.diff.unwrap_or(0.0); + *total_proof_time + .diff + .as_mut() + .expect("total_proof_time.diff should be initialized") += + sum.diff.unwrap_or(0.0); total_par_proof_time.val += max.val; - *total_par_proof_time.diff.as_mut().unwrap() += max.diff.unwrap_or(0.0); + *total_par_proof_time + .diff + .as_mut() + .expect("total_par_proof_time.diff should be initialized") += + max.diff.unwrap_or(0.0); - // Account for the fact that execution is serial - // Add total execution time for the app proofs, and subtract the max segment - // execution time + // Account for the serial execute_metered and execute_e1 for app outside of segments if group_name != "leaf" && group_name != "root" && group_name != "halo2_outer" && group_name != "halo2_wrapper" && !group_name.starts_with("internal") { - let execute_stats = execute_stats.unwrap(); - total_par_proof_time.val += - (execute_stats.sum.val - execute_stats.max.val) / 1000.0; - *total_par_proof_time.diff.as_mut().unwrap() += - (execute_stats.sum.diff.unwrap_or(0.0) - - execute_stats.max.diff.unwrap_or(0.0)) - / 1000.0; + if let Some(execute_metered_stats) = execute_metered_stats { + // For metered metrics without segment labels, we just use the value + // directly Count is 1, so avg = sum = max = min = + // value + total_proof_time.val += execute_metered_stats.avg.val / 1000.0; + total_par_proof_time.val += execute_metered_stats.avg.val / 1000.0; + if let Some(diff) = execute_metered_stats.avg.diff { + *total_proof_time + .diff + .as_mut() + .expect("total_proof_time.diff should be initialized") += + diff / 1000.0; + *total_par_proof_time + .diff + .as_mut() + .expect("total_par_proof_time.diff should be initialized") += + diff / 1000.0; + } + } + + if let Some(execute_e1_stats) = execute_e1_stats { + total_proof_time.val += execute_e1_stats.avg.val / 1000.0; + total_par_proof_time.val += execute_e1_stats.avg.val / 1000.0; + if let Some(diff) = execute_e1_stats.avg.diff { + *total_proof_time + .diff + .as_mut() + .expect("total_proof_time.diff should be initialized") += + diff / 1000.0; + *total_par_proof_time + .diff + .as_mut() + .expect("total_par_proof_time.diff should be initialized") += + diff / 1000.0; + } + } } } } @@ -239,7 +277,13 @@ impl AggregateMetrics { .into_iter() .map(|group_name| { let key = group_name.clone(); - let value = self.by_group.get(group_name).unwrap().clone(); + let value = self + .by_group + .get(group_name) + .unwrap_or_else(|| { + panic!("Group '{}' should exist in by_group map", group_name) + }) + .clone(); (key, value) }) .collect() @@ -252,6 +296,7 @@ impl AggregateMetrics { .map(|(group_name, metrics)| { let metrics = metrics .iter() + .filter(|(_, stats)| stats.avg.val.is_finite() && stats.sum.val.is_finite()) .flat_map(|(metric_name, stats)| { [ (format!("{metric_name}::sum"), stats.sum.into()), @@ -295,11 +340,37 @@ impl AggregateMetrics { for metric_name in names { let summary = summaries.get(metric_name); if let Some(summary) = summary { - writeln!( - writer, - "| `{:<20}` | {:<10} | {:<10} | {:<10} | {:<10} |", - metric_name, summary.avg, summary.sum, summary.max, summary.min, - )?; + // Special handling for execute_metered metrics (not aggregated across segments + // in the app proof case) + if metric_name == EXECUTE_METERED_TIME_LABEL + && group_name != "leaf" + && group_name != "root" + && group_name != "halo2_outer" + && group_name != "halo2_wrapper" + && !group_name.starts_with("internal") + { + writeln!( + writer, + "| `{:<20}` | {:<10} | {:<10} | {:<10} | {:<10} |", + metric_name, summary.avg, "-", "-", "-", + )?; + } else if metric_name == EXECUTE_E1_INSN_MI_S_LABEL + || metric_name == EXECUTE_E3_INSN_MI_S_LABEL + || metric_name == EXECUTE_METERED_INSN_MI_S_LABEL + { + // skip sum because it is misleading + writeln!( + writer, + "| `{:<20}` | {:<10} | {:<10} | {:<10} | {:<10} |", + metric_name, summary.avg, "-", summary.max, summary.min, + )?; + } else { + writeln!( + writer, + "| `{:<20}` | {:<10} | {:<10} | {:<10} | {:<10} |", + metric_name, summary.avg, summary.sum, summary.max, summary.min, + )?; + } } } writeln!(writer)?; @@ -317,11 +388,16 @@ impl AggregateMetrics { writeln!(writer, "|:---|---:|---:|")?; let mut rows = Vec::new(); for (group_name, summaries) in self.to_vec() { + if group_name.contains("keygen") { + continue; + } let stats = summaries.get(PROOF_TIME_LABEL); if stats.is_none() { continue; } - let stats = stats.unwrap(); + let stats = stats.unwrap_or_else(|| { + panic!("Missing proof time statistics for group '{}'", group_name) + }); let mut sum = stats.sum; let mut max = stats.max; // convert ms to s @@ -352,7 +428,12 @@ impl AggregateMetrics { self.by_group .keys() .find(|k| group_weight(k) == 0) - .unwrap_or_else(|| self.by_group.keys().next().unwrap()) + .unwrap_or_else(|| { + self.by_group + .keys() + .next() + .expect("by_group should contain at least one group") + }) .clone() } } @@ -383,16 +464,32 @@ impl BenchmarkOutput { pub const PROOF_TIME_LABEL: &str = "total_proof_time_ms"; pub const CELLS_USED_LABEL: &str = "main_cells_used"; pub const CYCLES_LABEL: &str = "total_cycles"; -pub const EXECUTE_TIME_LABEL: &str = "execute_time_ms"; +pub const EXECUTE_E1_TIME_LABEL: &str = "execute_e1_time_ms"; +pub const EXECUTE_E1_INSN_MI_S_LABEL: &str = "execute_e1_insn_mi/s"; +pub const EXECUTE_METERED_TIME_LABEL: &str = "execute_metered_time_ms"; +pub const EXECUTE_METERED_INSN_MI_S_LABEL: &str = "execute_metered_insn_mi/s"; +pub const EXECUTE_E3_TIME_LABEL: &str = "execute_e3_time_ms"; +pub const EXECUTE_E3_INSN_MI_S_LABEL: &str = "execute_e3_insn_mi/s"; pub const TRACE_GEN_TIME_LABEL: &str = "trace_gen_time_ms"; +pub const MEM_FIN_TIME_LABEL: &str = "memory_finalize_time_ms"; +pub const BOUNDARY_FIN_TIME_LABEL: &str = "boundary_finalize_time_ms"; +pub const MERKLE_FIN_TIME_LABEL: &str = "merkle_finalize_time_ms"; pub const PROVE_EXCL_TRACE_TIME_LABEL: &str = "stark_prove_excluding_trace_time_ms"; pub const VM_METRIC_NAMES: &[&str] = &[ PROOF_TIME_LABEL, CELLS_USED_LABEL, CYCLES_LABEL, - EXECUTE_TIME_LABEL, + EXECUTE_E1_TIME_LABEL, + EXECUTE_E1_INSN_MI_S_LABEL, + EXECUTE_METERED_TIME_LABEL, + EXECUTE_METERED_INSN_MI_S_LABEL, + EXECUTE_E3_TIME_LABEL, + EXECUTE_E3_INSN_MI_S_LABEL, TRACE_GEN_TIME_LABEL, + MEM_FIN_TIME_LABEL, + BOUNDARY_FIN_TIME_LABEL, + MERKLE_FIN_TIME_LABEL, PROVE_EXCL_TRACE_TIME_LABEL, "main_trace_commit_time_ms", "generate_perm_trace_time_ms", diff --git a/crates/prof/src/lib.rs b/crates/prof/src/lib.rs index 58440a8e02..4da92d36ee 100644 --- a/crates/prof/src/lib.rs +++ b/crates/prof/src/lib.rs @@ -1,12 +1,13 @@ use std::{collections::HashMap, fs::File, path::Path}; -use aggregate::{ - EXECUTE_TIME_LABEL, PROOF_TIME_LABEL, PROVE_EXCL_TRACE_TIME_LABEL, TRACE_GEN_TIME_LABEL, -}; +use aggregate::{PROOF_TIME_LABEL, PROVE_EXCL_TRACE_TIME_LABEL, TRACE_GEN_TIME_LABEL}; use eyre::Result; use memmap2::Mmap; -use crate::types::{Labels, Metric, MetricDb, MetricsFile}; +use crate::{ + aggregate::{EXECUTE_E3_TIME_LABEL, EXECUTE_METERED_TIME_LABEL}, + types::{Labels, Metric, MetricDb, MetricsFile}, +}; pub mod aggregate; pub mod summary; @@ -45,13 +46,17 @@ impl MetricDb { pub fn apply_aggregations(&mut self) { for metrics in self.flat_dict.values_mut() { let get = |key: &str| metrics.iter().find(|m| m.name == key).map(|m| m.value); - let execute_time = get(EXECUTE_TIME_LABEL); + let execute_metered_time = get(EXECUTE_METERED_TIME_LABEL); + let execute_e3_time = get(EXECUTE_E3_TIME_LABEL); let trace_gen_time = get(TRACE_GEN_TIME_LABEL); let prove_excl_trace_time = get(PROVE_EXCL_TRACE_TIME_LABEL); - if let (Some(execute_time), Some(trace_gen_time), Some(prove_excl_trace_time)) = - (execute_time, trace_gen_time, prove_excl_trace_time) + if let (Some(execute_e3_time), Some(trace_gen_time), Some(prove_excl_trace_time)) = + (execute_e3_time, trace_gen_time, prove_excl_trace_time) { - let total_time = execute_time + trace_gen_time + prove_excl_trace_time; + let total_time = execute_metered_time.unwrap_or(0.0) + + execute_e3_time + + trace_gen_time + + prove_excl_trace_time; metrics.push(Metric::new(PROOF_TIME_LABEL.to_string(), total_time)); } } @@ -90,7 +95,12 @@ impl MetricDb { let label_values: Vec = label_keys .iter() - .map(|key| label_dict.get(key).unwrap().clone()) + .map(|key| { + label_dict + .get(key) + .unwrap_or_else(|| panic!("Label key '{}' should exist in label_dict", key)) + .clone() + }) .collect(); // Add to dict_by_label_types diff --git a/crates/prof/src/main.rs b/crates/prof/src/main.rs index 31ddb2b359..1474153a9f 100644 --- a/crates/prof/src/main.rs +++ b/crates/prof/src/main.rs @@ -84,8 +84,9 @@ fn main() -> Result<()> { // If this is a new benchmark, prev_path will not exist if let Ok(prev_db) = MetricDb::new(&prev_path) { let prev_grouped = GroupedMetrics::new(&prev_db, "group")?; - prev_aggregated = Some(prev_grouped.aggregate()); - aggregated.set_diff(prev_aggregated.as_ref().unwrap()); + let prev_grouped_aggregated = prev_grouped.aggregate(); + aggregated.set_diff(&prev_grouped_aggregated); + prev_aggregated = Some(prev_grouped_aggregated); } } if name.is_empty() { diff --git a/crates/prof/src/summary.rs b/crates/prof/src/summary.rs index 9501b03e05..f8c014532a 100644 --- a/crates/prof/src/summary.rs +++ b/crates/prof/src/summary.rs @@ -4,7 +4,11 @@ use eyre::Result; use itertools::Itertools; use crate::{ - aggregate::{AggregateMetrics, CELLS_USED_LABEL, CYCLES_LABEL, PROOF_TIME_LABEL}, + aggregate::{ + AggregateMetrics, CELLS_USED_LABEL, CYCLES_LABEL, EXECUTE_E3_TIME_LABEL, + EXECUTE_METERED_TIME_LABEL, PROOF_TIME_LABEL, PROVE_EXCL_TRACE_TIME_LABEL, + TRACE_GEN_TIME_LABEL, + }, types::MdTableCell, }; @@ -52,8 +56,14 @@ impl GithubSummary { .zip_eq(md_paths.iter()) .zip_eq(names) .map(|(((aggregated, prev_aggregated), md_path), name)| { - let md_filename = md_path.file_name().unwrap().to_str().unwrap(); - let mut row = aggregated.get_summary_row(md_filename).unwrap(); + let md_filename = md_path + .file_name() + .expect("Path should have a filename") + .to_str() + .expect("Filename should be valid UTF-8"); + let mut row = aggregated.get_summary_row(md_filename).unwrap_or_else(|| { + panic!("Failed to get summary row for file '{}'", md_filename) + }); if let Some(prev_aggregated) = prev_aggregated { // md_filename doesn't matter if let Some(prev_row) = prev_aggregated.get_summary_row(md_filename) { @@ -152,8 +162,56 @@ impl AggregateMetrics { pub fn get_single_summary(&self, name: &str) -> Option { let stats = self.by_group.get(name)?; // Any group must have proof_time, but may not have cells_used or cycles (e.g., halo2) - let proof_time_ms = stats.get(PROOF_TIME_LABEL)?.sum; - let par_proof_time_ms = stats.get(PROOF_TIME_LABEL)?.max; + let proof_time_ms = if let Some(proof_stats) = stats.get(PROOF_TIME_LABEL) { + proof_stats.sum + } else { + // Note: execute_metered is outside any segment scope, so it should have sum = max = avg + let execute_metered = stats + .get(EXECUTE_METERED_TIME_LABEL) + .map(|s| s.sum.val) + .unwrap_or(0.0); + let execute_e3 = stats + .get(EXECUTE_E3_TIME_LABEL) + .map(|s| s.sum.val) + .unwrap_or(0.0); + // If total_proof_time_ms is not available, compute it from components + let trace_gen = stats + .get(TRACE_GEN_TIME_LABEL) + .map(|s| s.sum.val) + .unwrap_or(0.0); + let stark_prove = stats + .get(PROVE_EXCL_TRACE_TIME_LABEL) + .map(|s| s.sum.val) + .unwrap_or(0.0); + println!( + "{} {} {} {}", + execute_metered, execute_e3, trace_gen, stark_prove + ); + MdTableCell::new(execute_metered + execute_e3 + trace_gen + stark_prove, None) + }; + println!("{}", self.total_proof_time.val); + let par_proof_time_ms = if let Some(proof_stats) = stats.get(PROOF_TIME_LABEL) { + proof_stats.max + } else { + // Use the same computation for max + let execute_metered = stats + .get(EXECUTE_METERED_TIME_LABEL) + .map(|s| s.max.val) + .unwrap_or(0.0); + let execute_e3 = stats + .get(EXECUTE_E3_TIME_LABEL) + .map(|s| s.max.val) + .unwrap_or(0.0); + let trace_gen = stats + .get(TRACE_GEN_TIME_LABEL) + .map(|s| s.max.val) + .unwrap_or(0.0); + let stark_prove = stats + .get(PROVE_EXCL_TRACE_TIME_LABEL) + .map(|s| s.max.val) + .unwrap_or(0.0); + MdTableCell::new(execute_metered + execute_e3 + trace_gen + stark_prove, None) + }; let cells_used = stats .get(CELLS_USED_LABEL) .map(|s| s.sum) diff --git a/crates/sdk/Cargo.toml b/crates/sdk/Cargo.toml index 6a868a3beb..cf5b46d930 100644 --- a/crates/sdk/Cargo.toml +++ b/crates/sdk/Cargo.toml @@ -18,8 +18,8 @@ openvm-ecc-circuit = { workspace = true } openvm-ecc-transpiler = { workspace = true } openvm-keccak256-circuit = { workspace = true } openvm-keccak256-transpiler = { workspace = true } -openvm-sha256-circuit = { workspace = true } -openvm-sha256-transpiler = { workspace = true } +openvm-sha2-circuit = { workspace = true } +openvm-sha2-transpiler = { workspace = true } openvm-pairing-circuit = { workspace = true } openvm-pairing-transpiler = { workspace = true } openvm-native-circuit = { workspace = true } @@ -51,6 +51,7 @@ clap = { workspace = true, features = ["derive"] } serde_with = { workspace = true, features = ["hex"] } serde_json.workspace = true thiserror.workspace = true +rand.workspace = true snark-verifier = { workspace = true, optional = true } snark-verifier-sdk = { workspace = true, optional = true } tempfile.workspace = true @@ -93,3 +94,6 @@ nightly-features = ["openvm-circuit/nightly-features"] name = "sdk_evm" path = "examples/sdk_evm.rs" required-features = ["evm-verify"] + +[package.metadata.cargo-shear] +ignored = ["rand"] diff --git a/crates/sdk/guest/fib/src/main.rs b/crates/sdk/guest/fib/src/main.rs index bc6d94cda8..6bbc6ca425 100644 --- a/crates/sdk/guest/fib/src/main.rs +++ b/crates/sdk/guest/fib/src/main.rs @@ -4,7 +4,7 @@ openvm::entry!(main); pub fn main() { - let n = core::hint::black_box(1 << 3); + let n = core::hint::black_box(1 << 8); let mut a: u32 = 0; let mut b: u32 = 1; for _ in 1..n { diff --git a/crates/sdk/src/codec.rs b/crates/sdk/src/codec.rs index 9d0ab48a93..c75268bbe3 100644 --- a/crates/sdk/src/codec.rs +++ b/crates/sdk/src/codec.rs @@ -1,7 +1,7 @@ use std::io::{self, Cursor, Read, Result, Write}; use openvm_circuit::{ - arch::ContinuationVmProof, system::memory::tree::public_values::UserPublicValuesProof, + arch::ContinuationVmProof, system::memory::merkle::public_values::UserPublicValuesProof, }; use openvm_continuations::verifier::{ internal::types::VmStarkProof, root::types::RootVmVerifierInput, diff --git a/crates/sdk/src/config/global.rs b/crates/sdk/src/config/global.rs index faf8182246..dc3fc05965 100644 --- a/crates/sdk/src/config/global.rs +++ b/crates/sdk/src/config/global.rs @@ -9,11 +9,11 @@ use openvm_bigint_circuit::{Int256, Int256Executor, Int256Periphery}; use openvm_bigint_transpiler::Int256TranspilerExtension; use openvm_circuit::{ arch::{ - InitFileGenerator, SystemConfig, SystemExecutor, SystemPeriphery, VmChipComplex, VmConfig, - VmInventoryError, + instructions::NATIVE_AS, InitFileGenerator, SystemConfig, SystemExecutor, SystemPeriphery, + VmChipComplex, VmConfig, VmInventoryError, }, circuit_derive::{Chip, ChipUsageGetter}, - derive::{AnyEnum, InstructionExecutor}, + derive::{AnyEnum, InsExecutorE1, InsExecutorE2, InstructionExecutor}, }; use openvm_ecc_circuit::{ WeierstrassExtension, WeierstrassExtensionExecutor, WeierstrassExtensionPeriphery, @@ -37,8 +37,8 @@ use openvm_rv32im_circuit::{ use openvm_rv32im_transpiler::{ Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, }; -use openvm_sha256_circuit::{Sha256, Sha256Executor, Sha256Periphery}; -use openvm_sha256_transpiler::Sha256TranspilerExtension; +use openvm_sha2_circuit::{Sha2, Sha2Executor, Sha2Periphery}; +use openvm_sha2_transpiler::Sha2TranspilerExtension; use openvm_stark_backend::p3_field::PrimeField32; use openvm_transpiler::transpiler::Transpiler; use serde::{Deserialize, Serialize}; @@ -46,14 +46,14 @@ use serde::{Deserialize, Serialize}; use crate::F; #[derive(Builder, Clone, Debug, Serialize, Deserialize)] +#[serde(from = "SdkVmConfigWithDefaultDeser")] pub struct SdkVmConfig { - #[serde(default)] pub system: SdkSystemConfig, pub rv32i: Option, pub io: Option, pub keccak: Option, - pub sha256: Option, + pub sha2: Option, pub native: Option, pub castf: Option, @@ -65,7 +65,9 @@ pub struct SdkVmConfig { pub ecc: Option, } -#[derive(ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum)] +#[derive( + ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum, InsExecutorE1, InsExecutorE2, +)] pub enum SdkVmConfigExecutor { #[any_enum] System(SystemExecutor), @@ -76,7 +78,7 @@ pub enum SdkVmConfigExecutor { #[any_enum] Keccak(Keccak256Executor), #[any_enum] - Sha256(Sha256Executor), + Sha2(Sha2Executor), #[any_enum] Native(NativeExecutor), #[any_enum] @@ -106,7 +108,7 @@ pub enum SdkVmConfigPeriphery { #[any_enum] Keccak(Keccak256Periphery), #[any_enum] - Sha256(Sha256Periphery), + Sha2(Sha2Periphery), #[any_enum] Native(NativePeriphery), #[any_enum] @@ -137,8 +139,8 @@ impl SdkVmConfig { if self.keccak.is_some() { transpiler = transpiler.with_extension(Keccak256TranspilerExtension); } - if self.sha256.is_some() { - transpiler = transpiler.with_extension(Sha256TranspilerExtension); + if self.sha2.is_some() { + transpiler = transpiler.with_extension(Sha2TranspilerExtension); } if self.native.is_some() { transpiler = transpiler.with_extension(LongFormTranspilerExtension); @@ -191,8 +193,8 @@ impl VmConfig for SdkVmConfig { if self.keccak.is_some() { complex = complex.extend(&Keccak256)?; } - if self.sha256.is_some() { - complex = complex.extend(&Sha256)?; + if self.sha2.is_some() { + complex = complex.extend(&Sha2)?; } if self.native.is_some() { complex = complex.extend(&Native)?; @@ -318,8 +320,8 @@ impl From for UnitStruct { } } -impl From for UnitStruct { - fn from(_: Sha256) -> Self { +impl From for UnitStruct { + fn from(_: Sha2) -> Self { UnitStruct {} } } @@ -335,3 +337,49 @@ impl From for UnitStruct { UnitStruct {} } } + +#[derive(Deserialize)] +struct SdkVmConfigWithDefaultDeser { + #[serde(default)] + pub system: SdkSystemConfig, + + pub rv32i: Option, + pub io: Option, + pub keccak: Option, + pub sha2: Option, + pub native: Option, + pub castf: Option, + + pub rv32m: Option, + pub bigint: Option, + pub modular: Option, + pub fp2: Option, + pub pairing: Option, + pub ecc: Option, +} + +impl From for SdkVmConfig { + fn from(config: SdkVmConfigWithDefaultDeser) -> Self { + let mut system = config.system; + if config.native.is_none() && config.castf.is_none() { + // There should be no need to write to native address space if Native extension and + // CastF extension are not enabled. + system.config.memory_config.addr_space_sizes[NATIVE_AS as usize] = 0; + } + Self { + system, + rv32i: config.rv32i, + io: config.io, + keccak: config.keccak, + sha2: config.sha2, + native: config.native, + castf: config.castf, + rv32m: config.rv32m, + bigint: config.bigint, + modular: config.modular, + fp2: config.fp2, + pairing: config.pairing, + ecc: config.ecc, + } + } +} diff --git a/crates/sdk/src/config/mod.rs b/crates/sdk/src/config/mod.rs index 3a231f180d..9079f7efe7 100644 --- a/crates/sdk/src/config/mod.rs +++ b/crates/sdk/src/config/mod.rs @@ -33,7 +33,7 @@ pub struct AppConfig { pub compiler_options: CompilerOptions, } -#[derive(Clone, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct AggConfig { /// STARK aggregation config pub agg_stark_config: AggStarkConfig, @@ -55,7 +55,7 @@ pub struct AggStarkConfig { pub root_max_constraint_degree: usize, } -#[derive(Clone, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct Halo2Config { /// Log degree for the outer recursion verifier circuit. pub verifier_k: usize, diff --git a/crates/sdk/src/keygen/dummy.rs b/crates/sdk/src/keygen/dummy.rs index 3fe2bcd300..edd687a512 100644 --- a/crates/sdk/src/keygen/dummy.rs +++ b/crates/sdk/src/keygen/dummy.rs @@ -6,8 +6,8 @@ use openvm_circuit::{ exe::VmExe, instruction::Instruction, program::Program, LocalOpcode, SystemOpcode::TERMINATE, }, - ContinuationVmProof, SingleSegmentVmExecutor, VirtualMachine, VmComplexTraceHeights, - VmConfig, VmExecutor, + ContinuationVmProof, InsExecutorE1, SingleSegmentVmExecutor, VirtualMachine, + VmComplexTraceHeights, VmConfig, VmExecutor, }, system::program::trace::VmCommittedExe, utils::next_power_of_two_or_zero, @@ -21,6 +21,7 @@ use openvm_native_circuit::NativeConfig; use openvm_native_compiler::ir::DIGEST_SIZE; use openvm_native_recursion::hints::Hintable; use openvm_rv32im_circuit::Rv32ImConfig; +use openvm_stark_backend::config::Val; use openvm_stark_sdk::{ config::{ baby_bear_poseidon2::BabyBearPoseidon2Engine, @@ -48,6 +49,7 @@ pub(super) fn compute_root_proof_heights( root_vm_config: NativeConfig, root_exe: VmExe, dummy_internal_proof: &Proof, + interactions: &[usize], ) -> (Vec, VmComplexTraceHeights) { let num_user_public_values = root_vm_config.system.num_public_values - 2 * DIGEST_SIZE; let root_input = RootVmVerifierInput { @@ -55,8 +57,11 @@ pub(super) fn compute_root_proof_heights( public_values: vec![F::ZERO; num_user_public_values], }; let vm = SingleSegmentVmExecutor::new(root_vm_config); + let max_trace_heights = vm + .execute_metered(root_exe.clone(), root_input.write(), interactions) + .unwrap(); let res = vm - .execute_and_compute_heights(root_exe, root_input.write()) + .execute_and_compute_heights(root_exe, root_input.write(), &max_trace_heights) .unwrap(); let air_heights: Vec<_> = res .air_heights @@ -104,7 +109,7 @@ pub fn dummy_leaf_proof>( overridden_heights: Option, ) -> Proof where - VC::Executor: Chip, + VC::Executor: Chip + InsExecutorE1>, VC::Periphery: Chip, { let app_proof = dummy_app_proof_impl(app_vm_pk.clone(), overridden_heights); @@ -168,7 +173,7 @@ fn dummy_app_proof_impl>( overridden_heights: Option, ) -> ContinuationVmProof where - VC::Executor: Chip, + VC::Executor: Chip + InsExecutorE1>, VC::Periphery: Chip, { let fri_params = app_vm_pk.fri_params; @@ -179,12 +184,16 @@ where } else { // We first execute once to get the trace heights from dummy_exe, then pad to powers of 2 // (forcing trace height 0 to 1) + let vm_vk = app_vm_pk.vm_pk.get_vk(); let executor = VmExecutor::new(app_vm_pk.vm_config.clone()); + let segments = executor + .execute_metered(dummy_exe.exe.clone(), vec![], &vm_vk.num_interactions()) + .unwrap(); + assert_eq!(segments.len(), 1, "dummy exe should have only 1 segment"); let mut results = executor - .execute_segments(dummy_exe.exe.clone(), vec![]) + .execute_segments(dummy_exe.exe.clone(), vec![], &segments) .unwrap(); // ASSUMPTION: the dummy exe has only 1 segment - assert_eq!(results.len(), 1, "dummy exe should have only 1 segment"); let mut result = results.pop().unwrap(); result.chip_complex.finalize_memory(); let mut vm_heights = result.chip_complex.get_internal_trace_heights(); @@ -192,12 +201,8 @@ where vm_heights }; // For the dummy proof, we must override the trace heights. - let app_prover = - VmLocalProver::::new_with_overridden_trace_heights( - app_vm_pk, - dummy_exe, - Some(overridden_heights), - ); + let app_prover = VmLocalProver::::new(app_vm_pk, dummy_exe) + .with_overridden_continuation_trace_heights(overridden_heights); ContinuationVmProver::prove(&app_prover, vec![]) } diff --git a/crates/sdk/src/keygen/mod.rs b/crates/sdk/src/keygen/mod.rs index 0806cc6f3d..930b7553ac 100644 --- a/crates/sdk/src/keygen/mod.rs +++ b/crates/sdk/src/keygen/mod.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use derivative::Derivative; use dummy::{compute_root_proof_heights, dummy_internal_proof_riscv_app_vm}; use openvm_circuit::{ - arch::{VirtualMachine, VmComplexTraceHeights, VmConfig}, + arch::{InsExecutorE1, VirtualMachine, VmComplexTraceHeights, VmConfig}, system::{memory::dimensions::MemoryDimensions, program::trace::VmCommittedExe}, }; use openvm_continuations::verifier::{ @@ -99,7 +99,7 @@ pub struct Halo2ProvingKey { impl> AppProvingKey where - VC::Executor: Chip, + VC::Executor: Chip + InsExecutorE1>, VC::Periphery: Chip, { pub fn keygen(config: AppConfig) -> Self { @@ -342,10 +342,12 @@ impl AggStarkProvingKey { let mut vm_pk = vm.keygen(); assert!(vm_pk.max_constraint_degree <= config.root_fri_params.max_constraint_degree()); + let vm_vk = vm_pk.get_vk(); let (air_heights, vm_heights) = compute_root_proof_heights( root_vm_config.clone(), root_committed_exe.exe.clone(), &internal_proof, + &vm_vk.num_interactions(), ); let root_air_perm = AirIdPermutation::compute(&air_heights); root_air_perm.permute(&mut vm_pk.per_air); diff --git a/crates/sdk/src/lib.rs b/crates/sdk/src/lib.rs index c2c874d3f1..2a9857fa84 100644 --- a/crates/sdk/src/lib.rs +++ b/crates/sdk/src/lib.rs @@ -13,12 +13,12 @@ use openvm_circuit::{ arch::{ hasher::{poseidon2::vm_poseidon2_hasher, Hasher}, instructions::exe::VmExe, - verify_segments, ContinuationVmProof, ExecutionError, InitFileGenerator, + verify_segments, ContinuationVmProof, ExecutionError, InitFileGenerator, InsExecutorE1, VerifiedExecutionPayload, VmConfig, VmExecutor, CONNECTOR_AIR_ID, PROGRAM_AIR_ID, PROGRAM_CACHED_TRACE_INDEX, PUBLIC_VALUES_AIR_ID, }, system::{ - memory::{tree::public_values::extract_public_values, CHUNK}, + memory::{merkle::public_values::extract_public_values, CHUNK}, program::trace::{compute_exe_commit, VmCommittedExe}, }, }; @@ -35,7 +35,7 @@ use openvm_continuations::verifier::{ pub use openvm_continuations::{RootSC, C, F, SC}; #[cfg(feature = "evm-prove")] use openvm_native_recursion::halo2::utils::Halo2ParamsReader; -use openvm_stark_backend::proof::Proof; +use openvm_stark_backend::{config::Val, proof::Proof}; use openvm_stark_sdk::{ config::{baby_bear_poseidon2::BabyBearPoseidon2Engine, FriParameters}, engine::StarkFriEngine, @@ -175,16 +175,13 @@ impl> GenericSdk { inputs: StdIn, ) -> Result, ExecutionError> where - VC::Executor: Chip, + VC::Executor: Chip + InsExecutorE1>, VC::Periphery: Chip, { let vm = VmExecutor::new(vm_config); - let final_memory = vm.execute(exe, inputs)?; - let public_values = extract_public_values( - &vm.config.system().memory_config.memory_dimensions(), - vm.config.system().num_public_values, - final_memory.as_ref().unwrap(), - ); + let final_memory = vm.execute_e1(exe, inputs, None)?.memory; + let public_values = + extract_public_values(vm.config.system().num_public_values, &final_memory); Ok(public_values) } @@ -199,7 +196,7 @@ impl> GenericSdk { pub fn app_keygen>(&self, config: AppConfig) -> Result> where - VC::Executor: Chip, + VC::Executor: Chip + InsExecutorE1>, VC::Periphery: Chip, { let app_pk = AppProvingKey::keygen(config); @@ -213,7 +210,7 @@ impl> GenericSdk { inputs: StdIn, ) -> Result> where - VC::Executor: Chip, + VC::Executor: Chip + InsExecutorE1>, VC::Periphery: Chip, { let app_prover = AppProver::::new(app_pk.app_vm_pk.clone(), app_committed_exe); @@ -303,7 +300,7 @@ impl> GenericSdk { inputs: StdIn, ) -> Result> where - VC::Executor: Chip, + VC::Executor: Chip + InsExecutorE1>, VC::Periphery: Chip, { let stark_prover = @@ -320,7 +317,7 @@ impl> GenericSdk { inputs: StdIn, ) -> Result> where - VC::Executor: Chip, + VC::Executor: Chip + InsExecutorE1, VC::Periphery: Chip, { let stark_prover = @@ -439,7 +436,7 @@ impl> GenericSdk { inputs: StdIn, ) -> Result where - VC::Executor: Chip, + VC::Executor: Chip + InsExecutorE1>, VC::Periphery: Chip, { let e2e_prover = diff --git a/crates/sdk/src/prover/agg.rs b/crates/sdk/src/prover/agg.rs index d3c5fd29c1..822564f19a 100644 --- a/crates/sdk/src/prover/agg.rs +++ b/crates/sdk/src/prover/agg.rs @@ -6,7 +6,7 @@ use openvm_continuations::verifier::{ leaf::types::LeafVmVerifierInput, root::types::RootVmVerifierInput, }; -use openvm_native_circuit::NativeConfig; +use openvm_native_circuit::{NativeConfig, NATIVE_MAX_TRACE_HEIGHTS}; use openvm_native_compiler::ir::DIGEST_SIZE; use openvm_native_recursion::hints::Hintable; use openvm_stark_sdk::{engine::StarkFriEngine, openvm_stark_backend::proof::Proof}; @@ -45,14 +45,16 @@ impl> AggStarkProver { tree_config: AggregationTreeConfig, ) -> Self { let leaf_prover = - VmLocalProver::::new(agg_stark_pk.leaf_vm_pk, leaf_committed_exe); + VmLocalProver::::new(agg_stark_pk.leaf_vm_pk, leaf_committed_exe) + .with_overridden_single_segment_trace_heights(NATIVE_MAX_TRACE_HEIGHTS.to_vec()); let leaf_controller = LeafProvingController { num_children: tree_config.num_children_leaf, }; let internal_prover = VmLocalProver::::new( agg_stark_pk.internal_vm_pk, agg_stark_pk.internal_committed_exe, - ); + ) + .with_overridden_single_segment_trace_heights(NATIVE_MAX_TRACE_HEIGHTS.to_vec()); let root_prover = RootVerifierLocalProver::new(agg_stark_pk.root_verifier_pk); Self { leaf_prover, @@ -222,10 +224,11 @@ pub fn wrap_e2e_stark_proof>( } = e2e_stark_proof; let mut wrapper_layers = 0; loop { - let actual_air_heights = root_prover.execute_for_air_heights(RootVmVerifierInput { + let input = RootVmVerifierInput { proofs: vec![proof.clone()], public_values: user_public_values.clone(), - }); + }; + let actual_air_heights = root_prover.execute_for_air_heights(input); // Root verifier can handle the internal proof. We can stop here. if heights_le( &actual_air_heights, diff --git a/crates/sdk/src/prover/app.rs b/crates/sdk/src/prover/app.rs index 095351677e..fa91b47ab5 100644 --- a/crates/sdk/src/prover/app.rs +++ b/crates/sdk/src/prover/app.rs @@ -1,8 +1,8 @@ use std::sync::Arc; use getset::Getters; -use openvm_circuit::arch::{ContinuationVmProof, VmConfig}; -use openvm_stark_backend::{proof::Proof, Chip}; +use openvm_circuit::arch::{ContinuationVmProof, InsExecutorE1, VmConfig}; +use openvm_stark_backend::{config::Val, proof::Proof, Chip}; use openvm_stark_sdk::engine::StarkFriEngine; use tracing::info_span; @@ -45,7 +45,7 @@ impl> AppProver { pub fn generate_app_proof(&self, input: StdIn) -> ContinuationVmProof where VC: VmConfig, - VC::Executor: Chip, + VC::Executor: Chip + InsExecutorE1>, VC::Periphery: Chip, { assert!( @@ -70,7 +70,7 @@ impl> AppProver { pub fn generate_app_proof_without_continuations(&self, input: StdIn) -> Proof where VC: VmConfig, - VC::Executor: Chip, + VC::Executor: Chip + InsExecutorE1>, VC::Periphery: Chip, { assert!( diff --git a/crates/sdk/src/prover/mod.rs b/crates/sdk/src/prover/mod.rs index 67ccfe1eb8..94b6fa3f5c 100644 --- a/crates/sdk/src/prover/mod.rs +++ b/crates/sdk/src/prover/mod.rs @@ -19,8 +19,9 @@ pub use stark::*; mod evm { use std::sync::Arc; - use openvm_circuit::arch::VmConfig; + use openvm_circuit::arch::{InsExecutorE1, VmConfig}; use openvm_native_recursion::halo2::utils::Halo2ParamsReader; + use openvm_stark_backend::config::Val; use openvm_stark_sdk::{engine::StarkFriEngine, openvm_stark_backend::Chip}; use super::{Halo2Prover, StarkProver}; @@ -68,7 +69,7 @@ mod evm { pub fn generate_proof_for_evm(&self, input: StdIn) -> EvmProof where VC: VmConfig, - VC::Executor: Chip, + VC::Executor: Chip + InsExecutorE1>, VC::Periphery: Chip, { let root_proof = self.stark_prover.generate_proof_for_outer_recursion(input); diff --git a/crates/sdk/src/prover/root.rs b/crates/sdk/src/prover/root.rs index 6e69aa0f13..0015c6fe99 100644 --- a/crates/sdk/src/prover/root.rs +++ b/crates/sdk/src/prover/root.rs @@ -1,4 +1,5 @@ use async_trait::async_trait; +use itertools::Itertools; use openvm_circuit::arch::{SingleSegmentVmExecutor, Streams}; use openvm_continuations::verifier::root::types::RootVmVerifierInput; use openvm_native_circuit::NativeConfig; @@ -31,11 +32,21 @@ impl RootVerifierLocalProver { } } pub fn execute_for_air_heights(&self, input: RootVmVerifierInput) -> Vec { + let vm_vk = self.root_verifier_pk.vm_pk.vm_pk.get_vk(); + let max_trace_heights = self + .executor_for_heights + .execute_metered( + self.root_verifier_pk.root_committed_exe.exe.clone(), + input.write(), + &vm_vk.num_interactions(), + ) + .unwrap(); let result = self .executor_for_heights .execute_and_compute_heights( self.root_verifier_pk.root_committed_exe.exe.clone(), input.write(), + &max_trace_heights, ) .unwrap(); result.air_heights @@ -54,8 +65,18 @@ impl SingleSegmentVmProver for RootVerifierLocalProver { let input = input.into(); let mut vm = SingleSegmentVmExecutor::new(self.vm_config().clone()); vm.set_override_trace_heights(self.root_verifier_pk.vm_heights.clone()); + let trace_heights = self + .root_verifier_pk + .air_heights + .iter() + .map(|&height| height as u32) + .collect_vec(); let mut proof_input = vm - .execute_and_generate(self.root_verifier_pk.root_committed_exe.clone(), input) + .execute_and_generate( + self.root_verifier_pk.root_committed_exe.clone(), + input, + &trace_heights, + ) .unwrap(); assert_eq!( proof_input.per_air.len(), diff --git a/crates/sdk/src/prover/stark.rs b/crates/sdk/src/prover/stark.rs index fdec583f0f..adabe8391d 100644 --- a/crates/sdk/src/prover/stark.rs +++ b/crates/sdk/src/prover/stark.rs @@ -1,10 +1,10 @@ use std::sync::Arc; -use openvm_circuit::arch::VmConfig; +use openvm_circuit::arch::{InsExecutorE1, VmConfig}; use openvm_continuations::verifier::{ internal::types::VmStarkProof, root::types::RootVmVerifierInput, }; -use openvm_stark_backend::{proof::Proof, Chip}; +use openvm_stark_backend::{config::Val, proof::Proof, Chip}; use openvm_stark_sdk::engine::StarkFriEngine; use crate::{ @@ -54,7 +54,7 @@ impl> StarkProver { pub fn generate_proof_for_outer_recursion(&self, input: StdIn) -> Proof where VC: VmConfig, - VC::Executor: Chip, + VC::Executor: Chip + InsExecutorE1>, VC::Periphery: Chip, { let app_proof = self.app_prover.generate_app_proof(input); @@ -64,7 +64,7 @@ impl> StarkProver { pub fn generate_root_verifier_input(&self, input: StdIn) -> RootVmVerifierInput where VC: VmConfig, - VC::Executor: Chip, + VC::Executor: Chip + InsExecutorE1>, VC::Periphery: Chip, { let app_proof = self.app_prover.generate_app_proof(input); @@ -74,7 +74,7 @@ impl> StarkProver { pub fn generate_e2e_stark_proof(&self, input: StdIn) -> VmStarkProof where VC: VmConfig, - VC::Executor: Chip, + VC::Executor: Chip + InsExecutorE1, VC::Periphery: Chip, { let app_proof = self.app_prover.generate_app_proof(input); diff --git a/crates/sdk/src/prover/vm/local.rs b/crates/sdk/src/prover/vm/local.rs index b56c6a1ad3..4f82607f63 100644 --- a/crates/sdk/src/prover/vm/local.rs +++ b/crates/sdk/src/prover/vm/local.rs @@ -1,12 +1,14 @@ -use std::{marker::PhantomData, mem, sync::Arc}; +use std::{marker::PhantomData, sync::Arc}; use async_trait::async_trait; use openvm_circuit::{ arch::{ - hasher::poseidon2::vm_poseidon2_hasher, GenerationError, SingleSegmentVmExecutor, Streams, - VirtualMachine, VmComplexTraceHeights, VmConfig, + hasher::poseidon2::vm_poseidon2_hasher, GenerationError, InsExecutorE1, + SingleSegmentVmExecutor, Streams, VirtualMachine, VmComplexTraceHeights, VmConfig, + }, + system::{ + memory::merkle::public_values::UserPublicValuesProof, program::trace::VmCommittedExe, }, - system::{memory::tree::public_values::UserPublicValuesProof, program::trace::VmCommittedExe}, }; use openvm_stark_backend::{ config::{StarkGenericConfig, Val}, @@ -25,7 +27,8 @@ use crate::prover::vm::{ pub struct VmLocalProver> { pub pk: Arc>, pub committed_exe: Arc>, - overridden_heights: Option, + continuation_heights: Option, + single_segment_heights: Option>, _marker: PhantomData, } @@ -34,26 +37,26 @@ impl> VmLocalProver Self { pk, committed_exe, - overridden_heights: None, + continuation_heights: None, + single_segment_heights: None, _marker: PhantomData, } } - pub fn new_with_overridden_trace_heights( - pk: Arc>, - committed_exe: Arc>, - overridden_heights: Option, + pub fn with_overridden_continuation_trace_heights( + mut self, + overridden_heights: VmComplexTraceHeights, ) -> Self { - Self { - pk, - committed_exe, - overridden_heights, - _marker: PhantomData, - } + self.continuation_heights = Some(overridden_heights); + self } - pub fn set_override_trace_heights(&mut self, overridden_heights: VmComplexTraceHeights) { - self.overridden_heights = Some(overridden_heights); + pub fn with_overridden_single_segment_trace_heights( + mut self, + overridden_heights: Vec, + ) -> Self { + self.single_segment_heights = Some(overridden_heights); + self } pub fn vm_config(&self) -> &VC { @@ -65,13 +68,11 @@ impl> VmLocalProver } } -const MAX_SEGMENTATION_RETRIES: usize = 4; - impl>, E: StarkFriEngine> ContinuationVmProver for VmLocalProver where Val: PrimeField32, - VC::Executor: Chip, + VC::Executor: Chip + InsExecutorE1>, VC::Periphery: Chip, { fn prove(&self, input: impl Into>>) -> ContinuationVmProof { @@ -81,52 +82,50 @@ where let mut vm = VirtualMachine::new_with_overridden_trace_heights( e, self.pk.vm_config.clone(), - self.overridden_heights.clone(), + self.continuation_heights.clone(), ); vm.set_trace_height_constraints(trace_height_constraints.clone()); - let mut final_memory = None; let VmCommittedExe { exe, committed_program, } = self.committed_exe.as_ref(); let input = input.into(); - // This loop should typically iterate exactly once. Only in exceptional cases will the - // segmentation produce an invalid segment and we will have to retry. - let mut retries = 0; - let per_segment = loop { - match vm.executor.execute_and_then( + let vm_vk = self.pk.vm_pk.get_vk(); + let segments = vm + .executor + .execute_metered(exe.clone(), input.clone(), &vm_vk.num_interactions()) + .expect("execute_metered failed"); + + let mut final_memory = None; + let per_segment = vm + .executor + .execute_and_then( exe.clone(), - input.clone(), - |seg_idx, mut seg| { - final_memory = mem::take(&mut seg.final_memory); + input, + &segments, + |seg_idx, seg| { + final_memory = Some( + seg.chip_complex + .memory_controller() + .memory + .data + .memory + .clone(), + ); let proof_input = info_span!("trace_gen", segment = seg_idx) .in_scope(|| seg.generate_proof_input(Some(committed_program.clone())))?; - info_span!("prove_segment", segment = seg_idx) - .in_scope(|| Ok(vm.engine.prove(&self.pk.vm_pk, proof_input))) + info_span!("prove_segment", segment = seg_idx).in_scope(|| { + let proof = vm.engine.prove(&self.pk.vm_pk, proof_input); + vm.engine + .verify(&self.pk.vm_pk.get_vk(), &proof) + .expect("verification failed"); + Ok(proof) + }) }, GenerationError::Execution, - ) { - Ok(per_segment) => break per_segment, - Err(GenerationError::Execution(err)) => panic!("execution error: {err}"), - Err(GenerationError::TraceHeightsLimitExceeded) => { - if retries >= MAX_SEGMENTATION_RETRIES { - panic!( - "trace heights limit exceeded after {MAX_SEGMENTATION_RETRIES} retries" - ); - } - retries += 1; - tracing::info!( - "trace heights limit exceeded; retrying execution (attempt {retries})" - ); - let sys_config = vm.executor.config.system_mut(); - let new_seg_strat = sys_config.segmentation_strategy.stricter_strategy(); - sys_config.set_segmentation_strategy(new_seg_strat); - // continue - } - }; - }; - + ) + .expect("execute_and_then failed"); let user_public_values = UserPublicValuesProof::compute( self.pk.vm_config.system().memory_config.memory_dimensions(), self.pk.vm_config.system().num_public_values, @@ -146,7 +145,7 @@ impl>, E: StarkFriEngine> where VmLocalProver: Send + Sync, Val: PrimeField32, - VC::Executor: Chip, + VC::Executor: Chip + InsExecutorE1>, VC::Periphery: Chip, { async fn prove( @@ -161,7 +160,7 @@ impl>, E: StarkFriEngine> Singl for VmLocalProver where Val: PrimeField32, - VC::Executor: Chip, + VC::Executor: Chip + InsExecutorE1>, VC::Periphery: Chip, { fn prove(&self, input: impl Into>>) -> Proof { @@ -173,9 +172,24 @@ where executor.set_trace_height_constraints(self.pk.vm_pk.trace_height_constraints.clone()); executor }; + + let vm_vk = self.pk.vm_pk.get_vk(); + let input = input.into(); + let max_trace_heights = if let Some(overridden_heights) = &self.single_segment_heights { + overridden_heights + } else { + &executor + .execute_metered( + self.committed_exe.exe.clone(), + input.clone(), + &vm_vk.num_interactions(), + ) + .expect("execute_metered failed") + }; let proof_input = executor - .execute_and_generate(self.committed_exe.clone(), input) + .execute_and_generate(self.committed_exe.clone(), input, max_trace_heights) .unwrap(); + let vm = VirtualMachine::new(e, executor.config); vm.prove_single(&self.pk.vm_pk, proof_input) } @@ -187,7 +201,7 @@ impl>, E: StarkFriEngine> where VmLocalProver: Send + Sync, Val: PrimeField32, - VC::Executor: Chip, + VC::Executor: Chip + InsExecutorE1>, VC::Periphery: Chip, { async fn prove(&self, input: impl Into>> + Send + Sync) -> Proof { diff --git a/crates/sdk/tests/integration_test.rs b/crates/sdk/tests/integration_test.rs index 9248fc5445..124a81cf1a 100644 --- a/crates/sdk/tests/integration_test.rs +++ b/crates/sdk/tests/integration_test.rs @@ -5,17 +5,19 @@ use openvm_build::GuestOptions; use openvm_circuit::{ arch::{ hasher::poseidon2::vm_poseidon2_hasher, ContinuationVmProof, ExecutionError, - GenerationError, SingleSegmentVmExecutor, SystemConfig, VmConfig, VmExecutor, + SingleSegmentVmExecutor, VirtualMachine, }, - system::{memory::tree::public_values::UserPublicValuesProof, program::trace::VmCommittedExe}, + system::{ + memory::merkle::public_values::UserPublicValuesProof, program::trace::VmCommittedExe, + }, + utils::test_system_config_with_continuations, }; use openvm_continuations::verifier::{ common::types::VmVerifierPvs, leaf::types::{LeafVmVerifierInput, UserPublicValuesRootProof}, }; -use openvm_native_circuit::{Native, NativeConfig}; +use openvm_native_circuit::NativeConfig; use openvm_native_compiler::{conversion::CompilerOptions, prelude::*}; -use openvm_native_recursion::types::InnerConfig; use openvm_rv32im_transpiler::{ Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, }; @@ -25,14 +27,13 @@ use openvm_sdk::{ keygen::AppProvingKey, Sdk, StdIn, }; -use openvm_stark_backend::{keygen::types::LinearConstraint, p3_matrix::Matrix}; use openvm_stark_sdk::{ config::{ - baby_bear_poseidon2::{BabyBearPoseidon2Config, BabyBearPoseidon2Engine}, - setup_tracing, FriParameters, + baby_bear_poseidon2::{default_engine, BabyBearPoseidon2Config, BabyBearPoseidon2Engine}, + FriParameters, }, engine::{StarkEngine, StarkFriEngine}, - openvm_stark_backend::{p3_field::FieldAlgebra, Chip}, + openvm_stark_backend::p3_field::FieldAlgebra, p3_baby_bear::BabyBear, }; use openvm_transpiler::transpiler::Transpiler; @@ -65,7 +66,6 @@ use { }; type SC = BabyBearPoseidon2Config; -type C = InnerConfig; type F = BabyBear; const NUM_PUB_VALUES: usize = 16; @@ -91,18 +91,27 @@ fn verify_evm_halo2_proof_with_fallback( Ok(gas_cost) } -fn run_leaf_verifier>( - leaf_vm: &SingleSegmentVmExecutor, +fn run_leaf_verifier( + leaf_vm_config: &NativeConfig, leaf_committed_exe: Arc>, verifier_input: LeafVmVerifierInput, -) -> Result, ExecutionError> -where - VC::Executor: Chip, - VC::Periphery: Chip, -{ - let exe_result = leaf_vm.execute_and_compute_heights( +) -> Result, ExecutionError> { + let leaf_vm = VirtualMachine::new(default_engine(), leaf_vm_config.clone()); + let leaf_vm_pk = leaf_vm.keygen(); + let leaf_vm_vk = leaf_vm_pk.get_vk(); + + let executor = SingleSegmentVmExecutor::new(leaf_vm.config().clone()); + + let max_trace_heights = executor.execute_metered( + leaf_committed_exe.exe.clone(), + verifier_input.write_to_stream(), + &leaf_vm_vk.num_interactions(), + )?; + + let exe_result = executor.execute_and_compute_heights( leaf_committed_exe.exe.clone(), verifier_input.write_to_stream(), + &max_trace_heights, )?; let runtime_pvs: Vec<_> = exe_result .public_values @@ -113,25 +122,21 @@ where } fn app_committed_exe_for_test(app_log_blowup: usize) -> Arc> { - let program = { - let n = 200; - let mut builder = Builder::::default(); - let a: Felt = builder.eval(F::ZERO); - let b: Felt = builder.eval(F::ONE); - let c: Felt = builder.uninit(); - builder.range(0, n).for_each(|_, builder| { - builder.assign(&c, a + b); - builder.assign(&a, b); - builder.assign(&b, c); - }); - builder.halt(); - builder.compile_isa() - }; - Sdk::new() - .commit_app_exe( - FriParameters::new_for_testing(app_log_blowup), - program.into(), + let sdk = Sdk::new(); + let mut pkg_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).to_path_buf(); + pkg_dir.push("guest/fib"); + let vm_config = app_vm_config_for_test(); + let elf = sdk + .build( + Default::default(), + &vm_config, + pkg_dir, + &Default::default(), + None, ) + .unwrap(); + let exe = sdk.transpile(elf, vm_config.transpiler()).unwrap(); + sdk.commit_app_exe(FriParameters::new_for_testing(app_log_blowup), exe) .unwrap() } @@ -162,16 +167,22 @@ fn agg_stark_config_for_test() -> AggStarkConfig { } } -fn small_test_app_config(app_log_blowup: usize) -> AppConfig { +fn app_vm_config_for_test() -> SdkVmConfig { + let config = test_system_config_with_continuations() + .with_max_segment_len(200) + .with_public_values(NUM_PUB_VALUES); + SdkVmConfig::builder() + .system(SdkSystemConfig { config }) + .rv32i(Default::default()) + .rv32m(Default::default()) + .io(Default::default()) + .build() +} + +fn small_test_app_config(app_log_blowup: usize) -> AppConfig { AppConfig { app_fri_params: FriParameters::new_for_testing(app_log_blowup).into(), - app_vm_config: NativeConfig::new( - SystemConfig::default() - .with_max_segment_len(200) - .with_continuations() - .with_public_values(NUM_PUB_VALUES), - Native, - ), + app_vm_config: app_vm_config_for_test(), leaf_fri_params: FriParameters::new_for_testing(LEAF_LOG_BLOWUP).into(), compiler_options: CompilerOptions { enable_cycle_tracker: true, @@ -182,23 +193,37 @@ fn small_test_app_config(app_log_blowup: usize) -> AppConfig { #[test] fn test_public_values_and_leaf_verification() { - let app_log_blowup = 3; + let app_log_blowup = 1; let app_config = small_test_app_config(app_log_blowup); let app_pk = AppProvingKey::keygen(app_config); let app_committed_exe = app_committed_exe_for_test(app_log_blowup); + let pc_start = app_committed_exe.exe.pc_start; let agg_stark_config = agg_stark_config_for_test(); let leaf_vm_config = agg_stark_config.leaf_vm_config(); - let leaf_vm = SingleSegmentVmExecutor::new(leaf_vm_config); let leaf_committed_exe = app_pk.leaf_committed_exe.clone(); let app_engine = BabyBearPoseidon2Engine::new(app_pk.app_vm_pk.fri_params); - let app_vm = VmExecutor::new(app_pk.app_vm_pk.vm_config.clone()); + let app_vm = VirtualMachine::new(app_engine, app_pk.app_vm_pk.vm_config.clone()); + + let app_vm_pk = app_vm.keygen(); + let app_vm_vk = app_vm_pk.get_vk(); + let segments = app_vm + .executor + .execute_metered( + app_committed_exe.exe.clone(), + vec![], + &app_vm_vk.num_interactions(), + ) + .unwrap(); + let app_vm_result = app_vm - .execute_and_generate_with_cached_program(app_committed_exe.clone(), vec![]) + .executor + .execute_and_generate_with_cached_program(app_committed_exe.clone(), vec![], &segments) .unwrap(); assert!(app_vm_result.per_segment.len() > 2); + let app_engine = BabyBearPoseidon2Engine::new(app_pk.app_vm_pk.fri_params); let mut app_vm_seg_proofs: Vec<_> = app_vm_result .per_segment .into_iter() @@ -211,7 +236,7 @@ fn test_public_values_and_leaf_verification() { // Verify all segments except the last one. let (first_seg_final_pc, first_seg_final_mem_root) = { let runtime_pvs = run_leaf_verifier( - &leaf_vm, + &leaf_vm_config, leaf_committed_exe.clone(), LeafVmVerifierInput { proofs: app_vm_seg_proofs.clone(), @@ -224,7 +249,10 @@ fn test_public_values_and_leaf_verification() { assert_eq!(leaf_vm_pvs.app_commit, expected_app_commit); assert_eq!(leaf_vm_pvs.connector.is_terminate, F::ZERO); - assert_eq!(leaf_vm_pvs.connector.initial_pc, F::ZERO); + assert_eq!( + leaf_vm_pvs.connector.initial_pc, + F::from_canonical_u32(pc_start) + ); ( leaf_vm_pvs.connector.final_pc, leaf_vm_pvs.memory.final_root, @@ -232,7 +260,12 @@ fn test_public_values_and_leaf_verification() { }; let pv_proof = UserPublicValuesProof::compute( - app_vm.config.system.memory_config.memory_dimensions(), + app_vm + .config() + .system + .config + .memory_config + .memory_dimensions(), NUM_PUB_VALUES, &vm_poseidon2_hasher(), app_vm_result.final_memory.as_ref().unwrap(), @@ -242,7 +275,7 @@ fn test_public_values_and_leaf_verification() { // Verify the last segment with the correct public values root proof. { let runtime_pvs = run_leaf_verifier( - &leaf_vm, + &leaf_vm_config, leaf_committed_exe.clone(), LeafVmVerifierInput { proofs: vec![app_last_proof.clone()], @@ -268,7 +301,7 @@ fn test_public_values_and_leaf_verification() { let mut wrong_pv_root_proof = pv_root_proof.clone(); wrong_pv_root_proof.public_values_commit[0] += F::ONE; let execution_result = run_leaf_verifier( - &leaf_vm, + &leaf_vm_config, leaf_committed_exe.clone(), LeafVmVerifierInput { proofs: vec![app_last_proof.clone()], @@ -287,7 +320,7 @@ fn test_public_values_and_leaf_verification() { let mut wrong_pv_root_proof = pv_root_proof.clone(); wrong_pv_root_proof.sibling_hashes[0][0] += F::ONE; let execution_result = run_leaf_verifier( - &leaf_vm, + &leaf_vm_config, leaf_committed_exe.clone(), LeafVmVerifierInput { proofs: vec![app_last_proof.clone()], @@ -304,6 +337,7 @@ fn test_public_values_and_leaf_verification() { #[cfg(feature = "evm-verify")] #[test] +#[ignore = "slow"] fn test_static_verifier_custom_pv_handler() { // Define custom public values handler and implement StaticVerifierPvHandler trait on it pub struct CustomPvHandler { @@ -403,33 +437,8 @@ fn test_static_verifier_custom_pv_handler() { #[cfg(feature = "evm-verify")] #[test] fn test_e2e_proof_generation_and_verification_with_pvs() { - let mut pkg_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).to_path_buf(); - pkg_dir.push("guest/fib"); - - let vm_config = SdkVmConfig::builder() - .system(SdkSystemConfig { - config: SystemConfig::default() - .with_max_segment_len(200) - .with_continuations() - .with_public_values(NUM_PUB_VALUES), - }) - .rv32i(Default::default()) - .rv32m(Default::default()) - .io(Default::default()) - .native(Default::default()) - .build(); - + let vm_config = app_vm_config_for_test(); let sdk = Sdk::new(); - let elf = sdk - .build( - Default::default(), - &vm_config, - pkg_dir, - &Default::default(), - None, - ) - .unwrap(); - let exe = sdk.transpile(elf, vm_config.transpiler()).unwrap(); let app_log_blowup = 1; let app_fri_params = FriParameters::new_for_testing(app_log_blowup); @@ -438,10 +447,7 @@ fn test_e2e_proof_generation_and_verification_with_pvs() { AppConfig::new_with_leaf_fri_params(app_fri_params, vm_config, leaf_fri_params); app_config.compiler_options.enable_cycle_tracker = true; - let app_committed_exe = sdk - .commit_app_exe(app_fri_params, exe) - .expect("failed to commit exe"); - + let app_committed_exe = app_committed_exe_for_test(app_log_blowup); let app_pk = sdk.app_keygen(app_config).unwrap(); let params_reader = CacheHalo2ParamsReader::new_with_default_params_dir(); @@ -475,25 +481,11 @@ fn test_e2e_proof_generation_and_verification_with_pvs() { #[test] fn test_sdk_guest_build_and_transpile() { let sdk = Sdk::new(); - let guest_opts = GuestOptions::default() - // .with_features(vec!["zkvm"]) - // .with_options(vec!["--release"]); - ; + let guest_opts = GuestOptions::default(); let mut pkg_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).to_path_buf(); pkg_dir.push("guest/fib"); - let vm_config = SdkVmConfig::builder() - .system(SdkSystemConfig { - config: SystemConfig::default() - .with_max_segment_len(200) - .with_continuations() - .with_public_values(NUM_PUB_VALUES), - }) - .rv32i(Default::default()) - .rv32m(Default::default()) - .io(Default::default()) - .native(Default::default()) - .build(); + let vm_config = app_vm_config_for_test(); let one = sdk .build( @@ -526,33 +518,11 @@ fn test_sdk_guest_build_and_transpile() { fn test_inner_proof_codec_roundtrip() -> eyre::Result<()> { // generate a proof let sdk = Sdk::new(); - let mut pkg_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).to_path_buf(); - pkg_dir.push("guest/fib"); - - let vm_config = SdkVmConfig::builder() - .system(SdkSystemConfig { - config: SystemConfig::default() - .with_max_segment_len(200) - .with_continuations() - .with_public_values(NUM_PUB_VALUES), - }) - .rv32i(Default::default()) - .rv32m(Default::default()) - .io(Default::default()) - .native(Default::default()) - .build(); - let elf = sdk.build( - Default::default(), - &vm_config, - pkg_dir, - &Default::default(), - None, - )?; + let vm_config = app_vm_config_for_test(); assert!(vm_config.system.config.continuation_enabled); - let exe = sdk.transpile(elf, vm_config.transpiler())?; let fri_params = FriParameters::standard_fast(); let app_config = AppConfig::new(fri_params, vm_config); - let committed_exe = sdk.commit_app_exe(fri_params, exe)?; + let committed_exe = app_committed_exe_for_test(fri_params.log_blowup); let app_pk = Arc::new(sdk.app_keygen(app_config)?); let app_proof = sdk.generate_app_proof(app_pk.clone(), committed_exe, StdIn::default())?; let mut app_proof_bytes = Vec::new(); @@ -567,59 +537,3 @@ fn test_inner_proof_codec_roundtrip() -> eyre::Result<()> { sdk.verify_app_proof(&app_pk.get_app_vk(), &decoded_app_proof)?; Ok(()) } - -#[test] -fn test_segmentation_retry() { - setup_tracing(); - let app_log_blowup = 3; - let app_config = small_test_app_config(app_log_blowup); - let app_pk = AppProvingKey::keygen(app_config); - let app_committed_exe = app_committed_exe_for_test(app_log_blowup); - - let app_vm = VmExecutor::new(app_pk.app_vm_pk.vm_config.clone()); - let app_vm_result = app_vm - .execute_and_generate_with_cached_program(app_committed_exe.clone(), vec![]) - .unwrap(); - assert!(app_vm_result.per_segment.len() > 2); - - let total_height: usize = app_vm_result.per_segment[0] - .per_air - .iter() - .map(|(_, input)| { - let main = input.raw.common_main.as_ref(); - main.map(|mat| mat.height()).unwrap_or(0) - }) - .sum(); - - // Re-run with a threshold that will be violated. - let mut app_vm = VmExecutor::new(app_pk.app_vm_pk.vm_config.clone()); - let num_airs = app_pk.app_vm_pk.vm_pk.per_air.len(); - app_vm.set_trace_height_constraints(vec![LinearConstraint { - coefficients: vec![1; num_airs], - threshold: total_height as u32 - 1, - }]); - let app_vm_result = - app_vm.execute_and_generate_with_cached_program(app_committed_exe.clone(), vec![]); - assert!(matches!( - app_vm_result, - Err(GenerationError::TraceHeightsLimitExceeded) - )); - - // Try lowering segmentation threshold. - let config = VmConfig::::system_mut(&mut app_vm.config); - config.set_segmentation_strategy(config.segmentation_strategy.stricter_strategy()); - let app_vm_result = app_vm - .execute_and_generate_with_cached_program(app_committed_exe.clone(), vec![]) - .unwrap(); - - // New max height should indeed by smaller. - let new_total_height: usize = app_vm_result.per_segment[0] - .per_air - .iter() - .map(|(_, input)| { - let main = input.raw.common_main.as_ref(); - main.map(|mat| mat.height()).unwrap_or(0) - }) - .sum(); - assert!(new_total_height < total_height); -} diff --git a/crates/toolchain/instructions/src/exe.rs b/crates/toolchain/instructions/src/exe.rs index fb84ec7da5..9db5f242ac 100644 --- a/crates/toolchain/instructions/src/exe.rs +++ b/crates/toolchain/instructions/src/exe.rs @@ -5,8 +5,9 @@ use serde::{Deserialize, Serialize}; use crate::program::Program; -/// Memory image is a map from (address space, address) to word. -pub type MemoryImage = BTreeMap<(u32, u32), F>; +// TODO[jpw]: delete this +/// Memory image is a map from (address space, address * size_of) to u8. +pub type SparseMemoryImage = BTreeMap<(u32, u32), u8>; /// Stores the starting address, end address, and name of a set of function. pub type FnBounds = BTreeMap; @@ -22,7 +23,7 @@ pub struct VmExe { /// Start address of pc. pub pc_start: u32, /// Initial memory image. - pub init_memory: MemoryImage, + pub init_memory: SparseMemoryImage, /// Starting + ending bounds for each function. pub fn_bounds: FnBounds, } @@ -40,7 +41,7 @@ impl VmExe { self.pc_start = pc_start; self } - pub fn with_init_memory(mut self, init_memory: MemoryImage) -> Self { + pub fn with_init_memory(mut self, init_memory: SparseMemoryImage) -> Self { self.init_memory = init_memory; self } diff --git a/crates/toolchain/instructions/src/lib.rs b/crates/toolchain/instructions/src/lib.rs index c251e77d0d..9515180d15 100644 --- a/crates/toolchain/instructions/src/lib.rs +++ b/crates/toolchain/instructions/src/lib.rs @@ -18,6 +18,8 @@ pub mod utils; pub use phantom::*; +pub const NATIVE_AS: u32 = 4; + pub trait LocalOpcode { const CLASS_OFFSET: usize; /// Convert from the discriminant of the enum to the typed enum variant. @@ -25,8 +27,11 @@ pub trait LocalOpcode { fn from_usize(value: usize) -> Self; fn local_usize(&self) -> usize; + fn global_opcode_usize(&self) -> usize { + self.local_usize() + Self::CLASS_OFFSET + } fn global_opcode(&self) -> VmOpcode { - VmOpcode::from_usize(self.local_usize() + Self::CLASS_OFFSET) + VmOpcode::from_usize(self.global_opcode_usize()) } } diff --git a/crates/toolchain/instructions/src/riscv.rs b/crates/toolchain/instructions/src/riscv.rs index b2998c4539..720b323d52 100644 --- a/crates/toolchain/instructions/src/riscv.rs +++ b/crates/toolchain/instructions/src/riscv.rs @@ -5,3 +5,5 @@ pub const RV32_CELL_BITS: usize = 8; pub const RV32_IMM_AS: u32 = 0; pub const RV32_REGISTER_AS: u32 = 1; pub const RV32_MEMORY_AS: u32 = 2; + +pub const RV32_NUM_REGISTERS: usize = 32; diff --git a/crates/toolchain/platform/src/alloc.rs b/crates/toolchain/platform/src/alloc.rs new file mode 100644 index 0000000000..0af25a3671 --- /dev/null +++ b/crates/toolchain/platform/src/alloc.rs @@ -0,0 +1,62 @@ +extern crate alloc; + +use alloc::alloc::{alloc, dealloc, handle_alloc_error, Layout}; +use core::ptr::NonNull; + +/// Bytes allocated according to the given Layout +pub struct AlignedBuf { + pub ptr: *mut u8, + pub layout: Layout, +} + +impl AlignedBuf { + /// Allocate a new buffer whose start address is aligned to `align` bytes. + /// *NOTE* if `len` is zero then a creates new `NonNull` that is dangling and 16-byte aligned. + pub fn uninit(len: usize, align: usize) -> Self { + let layout = Layout::from_size_align(len, align).unwrap(); + if layout.size() == 0 { + return Self { + ptr: NonNull::::dangling().as_ptr() as *mut u8, + layout, + }; + } + // SAFETY: `len` is nonzero + let ptr = unsafe { alloc(layout) }; + if ptr.is_null() { + handle_alloc_error(layout); + } + AlignedBuf { ptr, layout } + } + + /// Allocate a new buffer whose start address is aligned to `align` bytes + /// and copy the given data into it. + /// + /// # Safety + /// - `bytes` must not be null + /// - `len` should not be zero + /// + /// See [alloc]. In particular `data` should not be empty. + pub unsafe fn new(bytes: *const u8, len: usize, align: usize) -> Self { + let buf = Self::uninit(len, align); + // SAFETY: + // - src and dst are not null + // - src and dst are allocated for size + // - no alignment requirements on u8 + // - non-overlapping since ptr is newly allocated + unsafe { + core::ptr::copy_nonoverlapping(bytes, buf.ptr, len); + } + + buf + } +} + +impl Drop for AlignedBuf { + fn drop(&mut self) { + if self.layout.size() != 0 { + unsafe { + dealloc(self.ptr, self.layout); + } + } + } +} diff --git a/crates/toolchain/platform/src/lib.rs b/crates/toolchain/platform/src/lib.rs index 1ace328a66..2a0beedef1 100644 --- a/crates/toolchain/platform/src/lib.rs +++ b/crates/toolchain/platform/src/lib.rs @@ -4,12 +4,15 @@ #![deny(rustdoc::broken_intra_doc_links)] #![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))] -#[cfg(all(feature = "rust-runtime", target_os = "zkvm"))] +#[cfg(target_os = "zkvm")] pub use openvm_custom_insn::{custom_insn_i, custom_insn_r}; +#[cfg(target_os = "zkvm")] +pub mod alloc; #[cfg(all(feature = "rust-runtime", target_os = "zkvm"))] pub mod heap; #[cfg(all(feature = "export-libm", target_os = "zkvm"))] mod libm_extern; + pub mod memory; pub mod print; #[cfg(feature = "rust-runtime")] @@ -19,9 +22,6 @@ pub mod rust_rt; /// 4 bytes (i.e. 32 bits) as the zkVM is an implementation of the rv32im ISA. pub const WORD_SIZE: usize = core::mem::size_of::(); -/// Size of a zkVM memory page. -pub const PAGE_SIZE: usize = 1024; - /// Standard IO file descriptors for use with sys_read and sys_write. pub mod fileno { pub const STDIN: u32 = 0; diff --git a/crates/toolchain/tests/Cargo.toml b/crates/toolchain/tests/Cargo.toml index 9f3e3caa82..c2349b893f 100644 --- a/crates/toolchain/tests/Cargo.toml +++ b/crates/toolchain/tests/Cargo.toml @@ -8,11 +8,16 @@ homepage.workspace = true repository.workspace = true [dependencies] +openvm-build.workspace = true +openvm-circuit.workspace = true +openvm-transpiler.workspace = true +eyre.workspace = true +tempfile.workspace = true + +[dev-dependencies] openvm-stark-backend.workspace = true openvm-stark-sdk.workspace = true openvm-circuit = { workspace = true, features = ["test-utils"] } -openvm-transpiler.workspace = true -openvm-build.workspace = true openvm-algebra-transpiler.workspace = true openvm-bigint-circuit.workspace = true openvm-rv32im-circuit.workspace = true @@ -21,10 +26,8 @@ openvm-algebra-circuit.workspace = true openvm-ecc-circuit = { workspace = true } openvm-instructions = { workspace = true } openvm-platform = { workspace = true } - -eyre.workspace = true test-case.workspace = true -tempfile.workspace = true +rand = { workspace = true } serde = { workspace = true, features = ["alloc"] } derive_more = { workspace = true, features = ["from"] } @@ -36,4 +39,4 @@ default = ["parallel"] parallel = ["openvm-circuit/parallel"] [package.metadata.cargo-shear] -ignored = ["derive_more", "openvm-stark-backend"] +ignored = ["derive_more", "openvm-stark-backend", "rand"] diff --git a/crates/toolchain/tests/src/utils.rs b/crates/toolchain/tests/src/utils.rs deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/crates/toolchain/tests/tests/riscv_test_vectors.rs b/crates/toolchain/tests/tests/riscv_test_vectors.rs index 9516b0cd7b..919a92f97b 100644 --- a/crates/toolchain/tests/tests/riscv_test_vectors.rs +++ b/crates/toolchain/tests/tests/riscv_test_vectors.rs @@ -2,7 +2,7 @@ use std::{fs::read_dir, path::PathBuf}; use eyre::Result; use openvm_circuit::{ - arch::{instructions::exe::VmExe, VmExecutor}, + arch::{execution_mode::e1::E1Ctx, instructions::exe::VmExe, interpreter::InterpretedInstance}, utils::air_test, }; use openvm_rv32im_circuit::Rv32ImConfig; @@ -39,9 +39,10 @@ fn test_rv32im_riscv_vector_runtime() -> Result<()> { .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension), )?; - let executor = VmExecutor::::new(config.clone()); - let res = executor.execute(exe, vec![])?; - Ok(res) + let interpreter = InterpretedInstance::new(config.clone(), exe); + let state = interpreter.execute(E1Ctx::new(None), vec![])?; + state.exit_code?; + Ok(()) }); match result { diff --git a/crates/toolchain/tests/tests/transpiler_tests.rs b/crates/toolchain/tests/tests/transpiler_tests.rs index bf07eccc42..42b62be437 100644 --- a/crates/toolchain/tests/tests/transpiler_tests.rs +++ b/crates/toolchain/tests/tests/transpiler_tests.rs @@ -12,7 +12,10 @@ use openvm_algebra_circuit::{ use openvm_algebra_transpiler::{Fp2TranspilerExtension, ModularTranspilerExtension}; use openvm_bigint_circuit::{Int256, Int256Executor, Int256Periphery}; use openvm_circuit::{ - arch::{InitFileGenerator, SystemConfig, VmExecutor}, + arch::{ + execution_mode::e1::E1Ctx, interpreter::InterpretedInstance, InitFileGenerator, + SystemConfig, + }, derive::VmConfig, utils::air_test, }; @@ -80,8 +83,8 @@ fn test_rv32im_runtime(elf_path: &str) -> Result<()> { .with_extension(Rv32IoTranspilerExtension), )?; let config = Rv32ImConfig::default(); - let executor = VmExecutor::::new(config); - executor.execute(exe, vec![])?; + let interpreter = InterpretedInstance::new(config, exe); + interpreter.execute(E1Ctx::new(None), vec![])?; Ok(()) } @@ -143,8 +146,8 @@ fn test_intrinsic_runtime(elf_path: &str) -> Result<()> { .with_extension(ModularTranspilerExtension) .with_extension(Fp2TranspilerExtension), )?; - let executor = VmExecutor::::new(config); - executor.execute(openvm_exe, vec![])?; + let interpreter = InterpretedInstance::new(config, openvm_exe); + interpreter.execute(E1Ctx::new(None), vec![])?; Ok(()) } diff --git a/crates/toolchain/transpiler/src/util.rs b/crates/toolchain/transpiler/src/util.rs index d9135de153..c5711653ff 100644 --- a/crates/toolchain/transpiler/src/util.rs +++ b/crates/toolchain/transpiler/src/util.rs @@ -1,7 +1,7 @@ use std::collections::BTreeMap; use openvm_instructions::{ - exe::MemoryImage, + exe::SparseMemoryImage, instruction::Instruction, riscv::{RV32_MEMORY_AS, RV32_REGISTER_NUM_LIMBS}, utils::isize_to_field, @@ -165,17 +165,14 @@ pub fn nop() -> Instruction { } } -/// Converts our memory image (u32 -> [u8; 4]) into Vm memory image ((as, address) -> word) -pub fn elf_memory_image_to_openvm_memory_image( +/// Converts our memory image (u32 -> [u8; 4]) into Vm memory image ((as=2, address) -> byte) +pub fn elf_memory_image_to_openvm_memory_image( memory_image: BTreeMap, -) -> MemoryImage { - let mut result = MemoryImage::new(); +) -> SparseMemoryImage { + let mut result = SparseMemoryImage::new(); for (addr, word) in memory_image { for (i, byte) in word.to_le_bytes().into_iter().enumerate() { - result.insert( - (RV32_MEMORY_AS, addr + i as u32), - F::from_canonical_u8(byte), - ); + result.insert((RV32_MEMORY_AS, addr + i as u32), byte); } } result diff --git a/crates/vm/Cargo.toml b/crates/vm/Cargo.toml index 80e6794b48..22a6bda84e 100644 --- a/crates/vm/Cargo.toml +++ b/crates/vm/Cargo.toml @@ -35,24 +35,38 @@ eyre.workspace = true derivative.workspace = true static_assertions.workspace = true getset.workspace = true +dashmap.workspace = true + +[target.'cfg(any(unix, windows))'.dependencies] +memmap2.workspace = true [dev-dependencies] test-log.workspace = true openvm-circuit = { workspace = true, features = ["test-utils"] } openvm-stark-sdk.workspace = true -openvm-native-circuit.workspace = true +openvm-native-circuit = { workspace = true, features = ["test-utils"] } openvm-native-compiler.workspace = true openvm-rv32im-transpiler.workspace = true [features] default = ["parallel", "jemalloc"] -parallel = ["openvm-stark-backend/parallel"] -test-utils = ["dep:openvm-stark-sdk"] -bench-metrics = ["dep:metrics", "openvm-stark-backend/bench-metrics"] +parallel = [ + "openvm-stark-backend/parallel", + "dashmap/rayon", + "openvm-stark-sdk?/parallel", +] +test-utils = ["openvm-stark-sdk"] +bench-metrics = [ + "dep:metrics", + "openvm-stark-backend/bench-metrics", + "openvm-stark-sdk?/bench-metrics", +] function-span = ["bench-metrics"] +# use basic memory instead of mmap: +basic-memory = [] # performance features: mimalloc = ["openvm-stark-backend/mimalloc"] jemalloc = ["openvm-stark-backend/jemalloc"] jemalloc-prof = ["openvm-stark-backend/jemalloc-prof"] -nightly-features = ["openvm-stark-sdk/nightly-features"] +nightly-features = ["openvm-stark-sdk?/nightly-features"] diff --git a/crates/vm/derive/src/lib.rs b/crates/vm/derive/src/lib.rs index 37dca6e4ed..0b0718802e 100644 --- a/crates/vm/derive/src/lib.rs +++ b/crates/vm/derive/src/lib.rs @@ -11,8 +11,23 @@ pub fn instruction_executor_derive(input: TokenStream) -> TokenStream { let ast: syn::DeriveInput = syn::parse(input).unwrap(); let name = &ast.ident; - let generics = &ast.generics; - let (impl_generics, ty_generics, _) = generics.split_for_impl(); + let (_, ty_generics, _) = ast.generics.split_for_impl(); + let mut generics = ast.generics.clone(); + + // Check if first generic is 'F' + let needs_f = match generics.params.first() { + Some(GenericParam::Type(type_param)) => type_param.ident != "F", + Some(_) => true, // First param is lifetime or const, so we need F + None => true, // No generics at all, so we need F + }; + if needs_f { + // Create new F generic parameter + let f_param: GenericParam = syn::parse_quote!(F); + + // Insert at the beginning + generics.params.insert(0, f_param); + } + let (impl_generics, _, _) = generics.split_for_impl(); match &ast.data { Data::Struct(inner) => { @@ -38,10 +53,12 @@ pub fn instruction_executor_derive(input: TokenStream) -> TokenStream { fn execute( &mut self, memory: &mut ::openvm_circuit::system::memory::MemoryController, + streams: &mut ::openvm_circuit::arch::Streams, + rng: &mut ::rand::rngs::StdRng, instruction: &::openvm_circuit::arch::instructions::instruction::Instruction, from_state: ::openvm_circuit::arch::ExecutionState, ) -> ::openvm_circuit::arch::Result<::openvm_circuit::arch::ExecutionState> { - self.0.execute(memory, instruction, from_state) + self.0.execute(memory, streams, rng, instruction, from_state) } fn get_opcode_name(&self, opcode: usize) -> String { @@ -79,7 +96,7 @@ pub fn instruction_executor_derive(input: TokenStream) -> TokenStream { multiunzip(variants.iter().map(|(variant_name, field)| { let field_ty = &field.ty; let execute_arm = quote! { - #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::InstructionExecutor<#first_ty_generic>>::execute(x, memory, instruction, from_state) + #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::InstructionExecutor<#first_ty_generic>>::execute(x, memory, streams, rng, instruction, from_state) }; let get_opcode_name_arm = quote! { #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::InstructionExecutor<#first_ty_generic>>::get_opcode_name(x, opcode) @@ -92,6 +109,8 @@ pub fn instruction_executor_derive(input: TokenStream) -> TokenStream { fn execute( &mut self, memory: &mut ::openvm_circuit::system::memory::MemoryController<#first_ty_generic>, + streams: &mut ::openvm_circuit::arch::Streams, + rng: &mut ::rand::rngs::StdRng, instruction: &::openvm_circuit::arch::instructions::instruction::Instruction<#first_ty_generic>, from_state: ::openvm_circuit::arch::ExecutionState, ) -> ::openvm_circuit::arch::Result<::openvm_circuit::arch::ExecutionState> { @@ -113,6 +132,416 @@ pub fn instruction_executor_derive(input: TokenStream) -> TokenStream { } } +#[proc_macro_derive(TraceStep)] +pub fn trace_step_derive(input: TokenStream) -> TokenStream { + let ast: syn::DeriveInput = syn::parse(input).unwrap(); + + let name = &ast.ident; + let (_, ty_generics, _) = ast.generics.split_for_impl(); + let mut generics = ast.generics.clone(); + + // Check if first generic is 'F' + let needs_f = match generics.params.first() { + Some(GenericParam::Type(type_param)) => type_param.ident != "F", + Some(_) => true, // First param is lifetime or const, so we need F + None => true, // No generics at all, so we need F + }; + if needs_f { + // Create new F generic parameter + let f_param: GenericParam = + syn::parse_quote!(F: ::openvm_stark_backend::p3_field::PrimeField32); + + // Insert at the beginning + generics.params.insert(0, f_param); + } + let need_ctx = if generics.params.len() >= 2 { + match &generics.params[2] { + GenericParam::Type(type_param) => type_param.ident != "CTX", + _ => true, + } + } else { + true + }; + if need_ctx { + // Create new F generic parameter + let ctx_param: GenericParam = syn::parse_quote!(CTX); + + // Insert at the beginning + generics.params.insert(0, ctx_param); + } + let (impl_generics, _, _) = generics.split_for_impl(); + + match &ast.data { + Data::Struct(inner) => { + // Check if the struct has only one unnamed field + let inner_ty = match &inner.fields { + Fields::Unnamed(fields) => { + if fields.unnamed.len() != 1 { + panic!("Only one unnamed field is supported"); + } + fields.unnamed.first().unwrap().ty.clone() + } + _ => panic!("Only unnamed fields are supported"), + }; + quote! { + impl #impl_generics ::openvm_circuit::arch::TraceStep for #name #ty_generics { + type RecordLayout = <#inner_ty as ::openvm_circuit::arch::TraceStep>::RecordLayout; + type RecordMut<'a> = <#inner_ty as ::openvm_circuit::arch::TraceStep>::RecordMut<'a>; + + fn execute<'buf, RA>( + &mut self, + state: ::openvm_circuit::arch::execution::VmStateMut, CTX>, + instruction: &::openvm_instructions::instruction::Instruction, + arena: &'buf mut RA, + ) -> ::openvm_circuit::arch::Result<()> + where + RA: ::openvm_circuit::arch::RecordArena<'buf, Self::RecordLayout, Self::RecordMut<'buf>>, + { + self.0.execute(state, instruction, arena) + } + + fn get_opcode_name(&self, opcode: usize) -> String { + ::openvm_circuit::arch::TraceStep::::get_opcode_name(&self.0, opcode) + } + } + } + .into() + } + _ => unimplemented!(), + } +} + +#[proc_macro_derive(TraceFiller)] +pub fn trace_filler_derive(input: TokenStream) -> TokenStream { + let ast: syn::DeriveInput = syn::parse(input).unwrap(); + + let name = &ast.ident; + let (_, ty_generics, _) = ast.generics.split_for_impl(); + let mut generics = ast.generics.clone(); + + // Check if first generic is 'F' + let needs_f = match generics.params.first() { + Some(GenericParam::Type(type_param)) => type_param.ident != "F", + Some(_) => true, // First param is lifetime or const, so we need F + None => true, // No generics at all, so we need F + }; + if needs_f { + // Create new F generic parameter + let f_param: GenericParam = + syn::parse_quote!(F: ::openvm_stark_backend::p3_field::PrimeField32); + + // Insert at the beginning + generics.params.insert(0, f_param); + } + let need_ctx = if generics.params.len() >= 2 { + match &generics.params[2] { + GenericParam::Type(type_param) => type_param.ident != "CTX", + _ => true, + } + } else { + true + }; + if need_ctx { + // Create new F generic parameter + let ctx_param: GenericParam = syn::parse_quote!(CTX); + + // Insert at the beginning + generics.params.insert(0, ctx_param); + } + let (impl_generics, _, _) = generics.split_for_impl(); + + match &ast.data { + Data::Struct(inner) => { + // Check if the struct has only one unnamed field + match &inner.fields { + Fields::Unnamed(fields) => { + if fields.unnamed.len() != 1 { + panic!("Only one unnamed field is supported"); + } + fields.unnamed.first().unwrap().ty.clone() + } + _ => panic!("Only unnamed fields are supported"), + }; + quote! { + impl #impl_generics ::openvm_circuit::arch::TraceFiller for #name #ty_generics { + fn fill_trace( + &self, + mem_helper: &::openvm_circuit::system::memory::MemoryAuxColsFactory, + trace: &mut ::openvm_stark_backend::p3_matrix::dense::RowMajorMatrix, + rows_used: usize, + ) where + Self: Send + Sync, + F: Send + Sync + Clone, + { + ::openvm_circuit::arch::TraceFiller::::fill_trace(&self.0, mem_helper, trace, rows_used); + } + + fn fill_trace_row(&self, mem_helper: &::openvm_circuit::system::memory::MemoryAuxColsFactory, row_slice: &mut [F]) { + ::openvm_circuit::arch::TraceFiller::::fill_trace_row(&self.0, mem_helper, row_slice); + } + + fn fill_dummy_trace_row(&self, mem_helper: &::openvm_circuit::system::memory::MemoryAuxColsFactory, row_slice: &mut [F]) { + ::openvm_circuit::arch::TraceFiller::::fill_dummy_trace_row(&self.0, mem_helper, row_slice); + } + } + } + .into() + } + _ => unimplemented!(), + } +} + +#[proc_macro_derive(InsExecutorE1)] +pub fn ins_executor_e1_executor_derive(input: TokenStream) -> TokenStream { + let ast: syn::DeriveInput = syn::parse(input).unwrap(); + + let name = &ast.ident; + let generics = &ast.generics; + let (impl_generics, ty_generics, _) = generics.split_for_impl(); + + match &ast.data { + Data::Struct(inner) => { + // Check if the struct has only one unnamed field + let inner_ty = match &inner.fields { + Fields::Unnamed(fields) => { + if fields.unnamed.len() != 1 { + panic!("Only one unnamed field is supported"); + } + fields.unnamed.first().unwrap().ty.clone() + } + _ => panic!("Only unnamed fields are supported"), + }; + // Use full path ::openvm_circuit... so it can be used either within or outside the vm + // crate. Assume F is already generic of the field. + let mut new_generics = generics.clone(); + let where_clause = new_generics.make_where_clause(); + where_clause + .predicates + .push(syn::parse_quote! { #inner_ty: ::openvm_circuit::arch::InsExecutorE1 }); + quote! { + impl #impl_generics ::openvm_circuit::arch::InsExecutorE1 for #name #ty_generics #where_clause { + #[inline(always)] + fn pre_compute_size(&self) -> usize { + self.0.pre_compute_size() + } + #[inline(always)] + fn pre_compute_e1( + &self, + pc: u32, + inst: &::openvm_circuit::arch::instructions::instruction::Instruction, + data: &mut [u8], + ) -> ::openvm_circuit::arch::execution::Result<::openvm_circuit::arch::ExecuteFunc> + where + Ctx: ::openvm_circuit::arch::execution_mode::E1ExecutionCtx, { + self.0.pre_compute_e1(pc, inst, data) + } + + fn set_trace_height(&mut self, height: usize) { + self.0.set_trace_buffer_height(height); + } + } + } + .into() + } + Data::Enum(e) => { + let variants = e + .variants + .iter() + .map(|variant| { + let variant_name = &variant.ident; + + let mut fields = variant.fields.iter(); + let field = fields.next().unwrap(); + assert!(fields.next().is_none(), "Only one field is supported"); + (variant_name, field) + }) + .collect::>(); + let first_ty_generic = ast + .generics + .params + .first() + .and_then(|param| match param { + GenericParam::Type(type_param) => Some(&type_param.ident), + _ => None, + }) + .expect("First generic must be type for Field"); + // Use full path ::openvm_circuit... so it can be used either within or outside the vm + // crate. Assume F is already generic of the field. + let pre_compute_size_arms = variants.iter().map(|(variant_name, field)| { + let field_ty = &field.ty; + quote! { + #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::InsExecutorE1<#first_ty_generic>>::pre_compute_size(x) + } + }).collect::>(); + let pre_compute_e1_arms = variants.iter().map(|(variant_name, field)| { + let field_ty = &field.ty; + quote! { + #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::InsExecutorE1<#first_ty_generic>>::pre_compute_e1(x, pc, instruction, data) + } + }).collect::>(); + let set_trace_height_arms = variants.iter().map(|(variant_name, field)| { + let field_ty = &field.ty; + quote! { + #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::InsExecutorE1<#first_ty_generic>>::set_trace_height(x, height) + } + }).collect::>(); + + quote! { + impl #impl_generics ::openvm_circuit::arch::InsExecutorE1<#first_ty_generic> for #name #ty_generics { + #[inline(always)] + fn pre_compute_size(&self) -> usize { + match self { + #(#pre_compute_size_arms,)* + } + } + + #[inline(always)] + fn pre_compute_e1( + &self, + pc: u32, + instruction: &::openvm_circuit::arch::instructions::instruction::Instruction, + data: &mut [u8], + ) -> ::openvm_circuit::arch::execution::Result<::openvm_circuit::arch::ExecuteFunc> + where + Ctx: ::openvm_circuit::arch::execution_mode::E1ExecutionCtx, { + match self { + #(#pre_compute_e1_arms,)* + } + } + + fn set_trace_height( + &mut self, + height: usize, + ) { + match self { + #(#set_trace_height_arms,)* + } + } + } + } + .into() + } + Data::Union(_) => unimplemented!("Unions are not supported"), + } +} + +#[proc_macro_derive(InsExecutorE2)] +pub fn ins_executor_e2_executor_derive(input: TokenStream) -> TokenStream { + let ast: syn::DeriveInput = syn::parse(input).unwrap(); + + let name = &ast.ident; + let generics = &ast.generics; + let (impl_generics, ty_generics, _) = generics.split_for_impl(); + + match &ast.data { + Data::Struct(inner) => { + // Check if the struct has only one unnamed field + let inner_ty = match &inner.fields { + Fields::Unnamed(fields) => { + if fields.unnamed.len() != 1 { + panic!("Only one unnamed field is supported"); + } + fields.unnamed.first().unwrap().ty.clone() + } + _ => panic!("Only unnamed fields are supported"), + }; + // Use full path ::openvm_circuit... so it can be used either within or outside the vm + // crate. Assume F is already generic of the field. + let mut new_generics = generics.clone(); + let where_clause = new_generics.make_where_clause(); + where_clause + .predicates + .push(syn::parse_quote! { #inner_ty: ::openvm_circuit::arch::InsExecutorE2 }); + quote! { + impl #impl_generics ::openvm_circuit::arch::InsExecutorE2 for #name #ty_generics #where_clause { + #[inline(always)] + fn e2_pre_compute_size(&self) -> usize { + self.0.e2_pre_compute_size() + } + #[inline(always)] + fn pre_compute_e2( + &self, + chip_idx: usize, + pc: u32, + inst: &::openvm_circuit::arch::instructions::instruction::Instruction, + data: &mut [u8], + ) -> ::openvm_circuit::arch::execution::Result<::openvm_circuit::arch::ExecuteFunc> + where + Ctx: ::openvm_circuit::arch::execution_mode::E2ExecutionCtx, { + self.0.pre_compute_e2(chip_idx, pc, inst, data) + } + } + } + .into() + } + Data::Enum(e) => { + let variants = e + .variants + .iter() + .map(|variant| { + let variant_name = &variant.ident; + + let mut fields = variant.fields.iter(); + let field = fields.next().unwrap(); + assert!(fields.next().is_none(), "Only one field is supported"); + (variant_name, field) + }) + .collect::>(); + let first_ty_generic = ast + .generics + .params + .first() + .and_then(|param| match param { + GenericParam::Type(type_param) => Some(&type_param.ident), + _ => None, + }) + .expect("First generic must be type for Field"); + // Use full path ::openvm_circuit... so it can be used either within or outside the vm + // crate. Assume F is already generic of the field. + let pre_compute_size_arms = variants.iter().map(|(variant_name, field)| { + let field_ty = &field.ty; + quote! { + #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::InsExecutorE2<#first_ty_generic>>::e2_pre_compute_size(x) + } + }).collect::>(); + let pre_compute_e2_arms = variants.iter().map(|(variant_name, field)| { + let field_ty = &field.ty; + quote! { + #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::InsExecutorE2<#first_ty_generic>>::pre_compute_e2(x, chip_idx, pc, instruction, data) + } + }).collect::>(); + + quote! { + impl #impl_generics ::openvm_circuit::arch::InsExecutorE2<#first_ty_generic> for #name #ty_generics { + #[inline(always)] + fn e2_pre_compute_size(&self) -> usize { + match self { + #(#pre_compute_size_arms,)* + } + } + + #[inline(always)] + fn pre_compute_e2( + &self, + chip_idx: usize, + pc: u32, + instruction: &::openvm_circuit::arch::instructions::instruction::Instruction, + data: &mut [u8], + ) -> ::openvm_circuit::arch::execution::Result<::openvm_circuit::arch::ExecuteFunc> + where + Ctx: ::openvm_circuit::arch::execution_mode::E2ExecutionCtx, { + match self { + #(#pre_compute_e2_arms,)* + } + } + } + } + .into() + } + Data::Union(_) => unimplemented!("Unions are not supported"), + } +} + /// Derives `AnyEnum` trait on an enum type. /// By default an enum arm will just return `self` as `&dyn Any`. /// @@ -347,7 +776,7 @@ pub fn vm_generic_config_derive(input: proc_macro::TokenStream) -> proc_macro::T let periphery_type = Ident::new(&format!("{}Periphery", name), name.span()); TokenStream::from(quote! { - #[derive(::openvm_circuit::circuit_derive::ChipUsageGetter, ::openvm_circuit::circuit_derive::Chip, ::openvm_circuit::derive::InstructionExecutor, ::derive_more::derive::From, ::openvm_circuit::derive::AnyEnum)] + #[derive(::openvm_circuit::circuit_derive::ChipUsageGetter, ::openvm_circuit::circuit_derive::Chip, ::openvm_circuit::derive::InstructionExecutor, ::openvm_circuit::derive::InsExecutorE1, ::openvm_circuit::derive::InsExecutorE2, ::derive_more::derive::From, ::openvm_circuit::derive::AnyEnum)] pub enum #executor_type { #[any_enum] #source_name_upper(#source_executor_type), diff --git a/crates/vm/src/arch/config.rs b/crates/vm/src/arch/config.rs index d82b5f7cf0..59ce57c94b 100644 --- a/crates/vm/src/arch/config.rs +++ b/crates/vm/src/arch/config.rs @@ -2,16 +2,17 @@ use std::{fs::File, io::Write, path::Path, sync::Arc}; use derive_new::new; use openvm_circuit::system::memory::MemoryTraceHeights; +use openvm_instructions::NATIVE_AS; use openvm_poseidon2_air::Poseidon2Config; use openvm_stark_backend::{p3_field::PrimeField32, ChipUsageGetter}; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use super::{ - segment::DefaultSegmentationStrategy, AnyEnum, InstructionExecutor, SegmentationStrategy, - SystemComplex, SystemExecutor, SystemPeriphery, VmChipComplex, VmInventoryError, - PUBLIC_VALUES_AIR_ID, + segmentation_strategy::{DefaultSegmentationStrategy, SegmentationStrategy}, + AnyEnum, InsExecutorE1, InsExecutorE2, InstructionExecutor, SystemComplex, SystemExecutor, + SystemPeriphery, VmChipComplex, VmInventoryError, PUBLIC_VALUES_AIR_ID, }; -use crate::system::memory::BOUNDARY_AIR_OFFSET; +use crate::system::memory::{merkle::public_values::PUBLIC_VALUES_AS, BOUNDARY_AIR_OFFSET}; // sbox is decomposed to have this max degree for Poseidon2. We set to 3 so quotient_degree = 2 // allows log_blowup = 1 @@ -19,6 +20,8 @@ const DEFAULT_POSEIDON2_MAX_CONSTRAINT_DEGREE: usize = 3; pub const DEFAULT_MAX_NUM_PUBLIC_VALUES: usize = 32; /// Width of Poseidon2 VM uses. pub const POSEIDON2_WIDTH: usize = 16; +/// Offset for address space indices. This is used to distinguish between different memory spaces. +pub const ADDR_SPACE_OFFSET: u32 = 1; /// Returns a Poseidon2 config for the VM. pub fn vm_poseidon2_config() -> Poseidon2Config { Poseidon2Config::default() @@ -27,7 +30,11 @@ pub fn vm_poseidon2_config() -> Poseidon2Config { pub trait VmConfig: Clone + Serialize + DeserializeOwned + InitFileGenerator { - type Executor: InstructionExecutor + AnyEnum + ChipUsageGetter; + type Executor: InstructionExecutor + + AnyEnum + + ChipUsageGetter + + InsExecutorE1 + + InsExecutorE2; type Periphery: AnyEnum + ChipUsageGetter; /// Must contain system config @@ -68,15 +75,15 @@ pub trait InitFileGenerator { } } -#[derive(Debug, Serialize, Deserialize, Clone, new, Copy)] +#[derive(Debug, Serialize, Deserialize, Clone, new)] pub struct MemoryConfig { - /// The maximum height of the address space. This means the trie has `as_height` layers for - /// searching the address space. The allowed address spaces are those in the range `[as_offset, - /// as_offset + 2^as_height)` where `as_offset` is currently fixed to `1` to not allow address - /// space `0` in memory. - pub as_height: usize, - /// The offset of the address space. Should be fixed to equal `1`. - pub as_offset: u32, + /// The maximum height of the address space. This means the trie has `addr_space_height` layers + /// for searching the address space. The allowed address spaces are those in the range `[1, + /// 1 + 2^addr_space_height)` where it starts from 1 to not allow address space 0 in memory. + pub addr_space_height: usize, + /// The number of cells in each address space. It is expected that the size of the list is + /// `1 << addr_space_height + 1` and the first element is 0, which means no address space. + pub addr_space_sizes: Vec, pub pointer_max_bits: usize, /// All timestamps must be in the range `[0, 2^clk_max_bits)`. Maximum allowed: 29. pub clk_max_bits: usize, @@ -84,13 +91,23 @@ pub struct MemoryConfig { pub decomp: usize, /// Maximum N AccessAdapter AIR to support. pub max_access_adapter_n: usize, - /// An expected upper bound on the number of memory accesses. - pub access_capacity: usize, } impl Default for MemoryConfig { fn default() -> Self { - Self::new(3, 1, 29, 29, 17, 32, 1 << 24) + let mut addr_space_sizes = vec![0; (1 << 3) + ADDR_SPACE_OFFSET as usize]; + addr_space_sizes[ADDR_SPACE_OFFSET as usize..=NATIVE_AS as usize].fill(1 << 29); + addr_space_sizes[PUBLIC_VALUES_AS as usize] = DEFAULT_MAX_NUM_PUBLIC_VALUES; + Self::new(3, addr_space_sizes, 29, 29, 17, 32) + } +} + +impl MemoryConfig { + /// Config for aggregation usage with only native address space. + pub fn aggregation() -> Self { + let mut addr_space_sizes = vec![0; (1 << 3) + ADDR_SPACE_OFFSET as usize]; + addr_space_sizes[NATIVE_AS as usize] = 1 << 29; + Self::new(3, addr_space_sizes, 29, 29, 17, 8) } } @@ -139,7 +156,7 @@ pub struct SystemTraceHeights { impl SystemConfig { pub fn new( max_constraint_degree: usize, - memory_config: MemoryConfig, + mut memory_config: MemoryConfig, num_public_values: usize, ) -> Self { let segmentation_strategy = get_default_segmentation_strategy(); @@ -147,6 +164,7 @@ impl SystemConfig { memory_config.clk_max_bits <= 29, "Timestamp max bits must be <= 29 for LessThan to work in 31-bit field" ); + memory_config.addr_space_sizes[PUBLIC_VALUES_AS as usize] = num_public_values; Self { max_constraint_degree, continuation_enabled: false, @@ -157,6 +175,14 @@ impl SystemConfig { } } + pub fn default_from_memory(memory_config: MemoryConfig) -> Self { + Self::new( + DEFAULT_POSEIDON2_MAX_CONSTRAINT_DEGREE, + memory_config, + DEFAULT_MAX_NUM_PUBLIC_VALUES, + ) + } + pub fn with_max_constraint_degree(mut self, max_constraint_degree: usize) -> Self { self.max_constraint_degree = max_constraint_degree; self @@ -174,6 +200,7 @@ impl SystemConfig { pub fn with_public_values(mut self, num_public_values: usize) -> Self { self.num_public_values = num_public_values; + self.memory_config.addr_space_sizes[PUBLIC_VALUES_AS as usize] = num_public_values; self } @@ -215,11 +242,7 @@ impl SystemConfig { impl Default for SystemConfig { fn default() -> Self { - Self::new( - DEFAULT_POSEIDON2_MAX_CONSTRAINT_DEGREE, - Default::default(), - DEFAULT_MAX_NUM_PUBLIC_VALUES, - ) + Self::default_from_memory(MemoryConfig::default()) } } diff --git a/crates/vm/src/arch/execution.rs b/crates/vm/src/arch/execution.rs index 4edc88d355..a91e471fad 100644 --- a/crates/vm/src/arch/execution.rs +++ b/crates/vm/src/arch/execution.rs @@ -1,18 +1,28 @@ -use std::{cell::RefCell, rc::Rc}; +use std::cell::RefCell; -use openvm_circuit_primitives_derive::AlignedBorrow; +use openvm_circuit_primitives_derive::{AlignedBorrow, AlignedBytesBorrow}; use openvm_instructions::{ instruction::Instruction, program::DEFAULT_PC_STEP, PhantomDiscriminant, VmOpcode, }; use openvm_stark_backend::{ interaction::{BusIndex, InteractionBuilder, PermutationCheckBus}, - p3_field::FieldAlgebra, + p3_field::{FieldAlgebra, PrimeField32}, }; +use rand::rngs::StdRng; use serde::{Deserialize, Serialize}; use thiserror::Error; -use super::Streams; -use crate::system::{memory::MemoryController, program::ProgramBus}; +use super::{execution_mode::E1ExecutionCtx, Streams, VmSegmentState}; +use crate::{ + arch::execution_mode::E2ExecutionCtx, + system::{ + memory::{ + online::{GuestMemory, TracingMemory}, + MemoryController, + }, + program::ProgramBus, + }, +}; pub type Result = std::result::Result; @@ -66,14 +76,42 @@ pub enum ExecutionError { DidNotTerminate, #[error("program exit code {0}")] FailedWithExitCode(u32), + #[error("trace buffer out of bounds: requested {requested} but capacity is {capacity}")] + TraceBufferOutOfBounds { requested: usize, capacity: usize }, + #[error("invalid instruction at pc {0}")] + InvalidInstruction(u32), +} + +/// Global VM state accessible during instruction execution. +/// The state is generic in guest memory `MEM` and additional host state `CTX`. +/// The host state is execution context specific. +#[derive(derive_new::new)] +pub struct VmStateMut<'a, F, MEM, CTX> { + pub pc: &'a mut u32, + pub memory: &'a mut MEM, + pub streams: &'a mut Streams, + pub rng: &'a mut StdRng, + pub ctx: &'a mut CTX, } +impl VmStateMut<'_, F, TracingMemory, CTX> { + // TODO: store as u32 directly + #[inline(always)] + pub fn ins_start(&self, from_state: &mut ExecutionState) { + from_state.pc = F::from_canonical_u32(*self.pc); + from_state.timestamp = F::from_canonical_u32(self.memory.timestamp); + } +} + +// TODO: old pub trait InstructionExecutor { /// Runtime execution of the instruction, if the instruction is owned by the /// current instance. May internally store records of this call for later trace generation. fn execute( &mut self, memory: &mut MemoryController, + streams: &mut Streams, + rng: &mut StdRng, instruction: &Instruction, from_state: ExecutionState, ) -> Result>; @@ -83,29 +121,110 @@ pub trait InstructionExecutor { fn get_opcode_name(&self, opcode: usize) -> String; } -impl> InstructionExecutor for RefCell { - fn execute( - &mut self, - memory: &mut MemoryController, - instruction: &Instruction, - prev_state: ExecutionState, - ) -> Result> { - self.borrow_mut().execute(memory, instruction, prev_state) +pub type ExecuteFunc = unsafe fn(&[u8], &mut VmSegmentState); + +pub struct PreComputeInstruction<'a, F, CTX> { + pub handler: ExecuteFunc, + pub pre_compute: &'a [u8], +} + +#[derive(Clone, AlignedBytesBorrow)] +#[repr(C)] +pub struct E2PreCompute { + pub chip_idx: u32, + pub data: DATA, +} + +/// Trait for E1 execution +pub trait InsExecutorE1 { + fn pre_compute_size(&self) -> usize; + + fn pre_compute_e1( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result> + where + Ctx: E1ExecutionCtx; + + fn set_trace_height(&mut self, height: usize); +} + +pub trait InsExecutorE2 { + fn e2_pre_compute_size(&self) -> usize; + + fn pre_compute_e2( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result> + where + Ctx: E2ExecutionCtx; +} + +impl InsExecutorE1 for RefCell +where + C: InsExecutorE1, +{ + #[inline(always)] + fn pre_compute_size(&self) -> usize { + self.borrow().pre_compute_size() + } + #[inline(always)] + fn pre_compute_e1( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result> + where + Ctx: E1ExecutionCtx, + { + self.borrow().pre_compute_e1(pc, inst, data) + } + #[inline(always)] + fn set_trace_height(&mut self, height: usize) { + self.borrow_mut().set_trace_height(height); } +} - fn get_opcode_name(&self, opcode: usize) -> String { - self.borrow().get_opcode_name(opcode) +impl InsExecutorE2 for RefCell +where + C: InsExecutorE2, +{ + #[inline(always)] + fn e2_pre_compute_size(&self) -> usize { + self.borrow().e2_pre_compute_size() + } + #[inline(always)] + fn pre_compute_e2( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result> + where + Ctx: E2ExecutionCtx, + { + self.borrow().pre_compute_e2(chip_idx, pc, inst, data) } } -impl> InstructionExecutor for Rc> { +impl> InstructionExecutor for RefCell { fn execute( &mut self, memory: &mut MemoryController, + streams: &mut Streams, + rng: &mut StdRng, instruction: &Instruction, prev_state: ExecutionState, ) -> Result> { - self.borrow_mut().execute(memory, instruction, prev_state) + self.borrow_mut() + .execute(memory, streams, rng, instruction, prev_state) } fn get_opcode_name(&self, opcode: usize) -> String { @@ -322,14 +441,16 @@ impl From<(u32, Option)> for PcIncOrSet { /// /// Phantom sub-instructions are only allowed to use operands /// `a,b` and `c_upper = c.as_canonical_u32() >> 16`. +#[allow(clippy::too_many_arguments)] pub trait PhantomSubExecutor: Send { fn phantom_execute( - &mut self, - memory: &MemoryController, + &self, + memory: &GuestMemory, streams: &mut Streams, + rng: &mut StdRng, discriminant: PhantomDiscriminant, - a: F, - b: F, + a: u32, + b: u32, c_upper: u16, ) -> eyre::Result<()>; } diff --git a/crates/vm/src/arch/execution_control.rs b/crates/vm/src/arch/execution_control.rs new file mode 100644 index 0000000000..5aba19cb3a --- /dev/null +++ b/crates/vm/src/arch/execution_control.rs @@ -0,0 +1,69 @@ +use openvm_instructions::instruction::Instruction; +use openvm_stark_backend::p3_field::PrimeField32; + +use super::{ExecutionError, VmChipComplex, VmConfig, VmSegmentState}; + +/// Trait for execution control, determining segmentation and stopping conditions +/// Invariants: +/// - `ExecutionControl` should be stateless. +/// - For E3/E4, `ExecutionControl` is for a specific execution and cannot be used for another +/// execution with different inputs or segmentation criteria. +pub trait ExecutionControl +where + F: PrimeField32, + VC: VmConfig, +{ + /// Host context + type Ctx; + + fn initialize_context(&self) -> Self::Ctx; + + /// Determines if execution should suspend + fn should_suspend( + &self, + state: &mut VmSegmentState, + chip_complex: &VmChipComplex, + ) -> bool; + + /// Called before execution begins + fn on_start( + &self, + state: &mut VmSegmentState, + chip_complex: &mut VmChipComplex, + ); + + /// Called after suspend or terminate + fn on_suspend_or_terminate( + &self, + state: &mut VmSegmentState, + chip_complex: &mut VmChipComplex, + exit_code: Option, + ); + + fn on_suspend( + &self, + state: &mut VmSegmentState, + chip_complex: &mut VmChipComplex, + ) { + self.on_suspend_or_terminate(state, chip_complex, None); + } + + fn on_terminate( + &self, + state: &mut VmSegmentState, + chip_complex: &mut VmChipComplex, + exit_code: u32, + ) { + self.on_suspend_or_terminate(state, chip_complex, Some(exit_code)); + } + + /// Execute a single instruction + fn execute_instruction( + &self, + state: &mut VmSegmentState, + instruction: &Instruction, + chip_complex: &mut VmChipComplex, + ) -> Result<(), ExecutionError> + where + F: PrimeField32; +} diff --git a/crates/vm/src/arch/execution_mode/e1.rs b/crates/vm/src/arch/execution_mode/e1.rs new file mode 100644 index 0000000000..49cda03dad --- /dev/null +++ b/crates/vm/src/arch/execution_mode/e1.rs @@ -0,0 +1,32 @@ +use crate::arch::{execution_mode::E1ExecutionCtx, VmSegmentState}; + +pub struct E1Ctx { + instret_end: u64, +} + +impl E1Ctx { + pub fn new(instret_end: Option) -> Self { + E1Ctx { + instret_end: if let Some(end) = instret_end { + end + } else { + u64::MAX + }, + } + } +} + +impl Default for E1Ctx { + fn default() -> Self { + Self::new(None) + } +} + +impl E1ExecutionCtx for E1Ctx { + #[inline(always)] + fn on_memory_operation(&mut self, _address_space: u32, _ptr: u32, _size: u32) {} + #[inline(always)] + fn should_suspend(vm_state: &mut VmSegmentState) -> bool { + vm_state.instret >= vm_state.ctx.instret_end + } +} diff --git a/crates/vm/src/arch/execution_mode/metered/ctx.rs b/crates/vm/src/arch/execution_mode/metered/ctx.rs new file mode 100644 index 0000000000..6f4a23dbf4 --- /dev/null +++ b/crates/vm/src/arch/execution_mode/metered/ctx.rs @@ -0,0 +1,245 @@ +use openvm_instructions::riscv::{ + RV32_IMM_AS, RV32_NUM_REGISTERS, RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS, +}; +use openvm_stark_backend::{p3_field::PrimeField32, ChipUsageGetter}; + +use super::{ + memory_ctx::MemoryCtx, + segment_ctx::{Segment, SegmentationCtx}, +}; +use crate::{ + arch::{ + execution_mode::{E1ExecutionCtx, E2ExecutionCtx}, + VmChipComplex, VmSegmentState, + }, + system::memory::dimensions::MemoryDimensions, +}; + +pub const DEFAULT_PAGE_BITS: usize = 6; + +#[derive(Debug)] +pub struct MeteredCtx { + pub trace_heights: Vec, + pub is_trace_height_constant: Vec, + + pub memory_ctx: MemoryCtx, + pub segmentation_ctx: SegmentationCtx, + pub instret_end: u64, + pub continuations_enabled: bool, +} + +impl MeteredCtx { + pub fn new( + chip_complex: &VmChipComplex, + interactions: Vec, + ) -> Self { + let constant_trace_heights: Vec<_> = chip_complex.constant_trace_heights().collect(); + let has_public_values_chip = chip_complex.config().has_public_values_chip(); + let continuation_enabled = chip_complex.config().continuation_enabled; + let as_alignment = chip_complex + .memory_controller() + .memory + .address_space_alignment(); + let memory_dimensions = chip_complex.config().memory_config.memory_dimensions(); + let air_names = chip_complex.air_names(); + let widths = chip_complex.get_air_widths(); + Self::new_impl( + constant_trace_heights, + has_public_values_chip, + continuation_enabled, + as_alignment, + memory_dimensions, + air_names, + widths, + interactions, + ) + } + #[allow(clippy::too_many_arguments)] + pub fn new_impl( + constant_trace_heights: Vec>, + has_public_values_chip: bool, + continuations_enabled: bool, + as_byte_alignment_bits: Vec, + memory_dimensions: MemoryDimensions, + air_names: Vec, + widths: Vec, + interactions: Vec, + ) -> Self { + let (trace_heights, is_trace_height_constant): (Vec, Vec) = + constant_trace_heights + .iter() + .map(|&constant_height| { + if let Some(height) = constant_height { + (height as u32, true) + } else { + (0, false) + } + }) + .unzip(); + + let memory_ctx = MemoryCtx::new( + has_public_values_chip, + continuations_enabled, + as_byte_alignment_bits, + memory_dimensions, + ); + + // Assert that the indices are correct + debug_assert_eq!(&air_names[memory_ctx.boundary_idx], "Boundary"); + if let Some(merkle_tree_index) = memory_ctx.merkle_tree_index { + debug_assert_eq!(&air_names[merkle_tree_index], "Merkle"); + } + debug_assert_eq!(&air_names[memory_ctx.adapter_offset], "AccessAdapter<2>"); + + let segmentation_ctx = SegmentationCtx::new(air_names, widths, interactions); + + let mut ctx = Self { + trace_heights, + is_trace_height_constant, + memory_ctx, + segmentation_ctx, + instret_end: u64::MAX, + continuations_enabled, + }; + + // Add merkle height contributions for all registers + ctx.add_register_merkle_heights(); + + ctx + } + + fn add_register_merkle_heights(&mut self) { + if self.continuations_enabled { + self.memory_ctx.update_boundary_merkle_heights( + &mut self.trace_heights, + RV32_REGISTER_AS, + 0, + (RV32_NUM_REGISTERS * RV32_REGISTER_NUM_LIMBS) as u32, + ); + } + } + + pub fn with_max_trace_height(mut self, max_trace_height: u32) -> Self { + self.segmentation_ctx.set_max_trace_height(max_trace_height); + self + } + + pub fn with_max_cells(mut self, max_cells: usize) -> Self { + self.segmentation_ctx.set_max_cells(max_cells); + self + } + + pub fn with_max_interactions(mut self, max_interactions: usize) -> Self { + self.segmentation_ctx.set_max_interactions(max_interactions); + self + } + + pub fn with_segment_check_insns(mut self, segment_check_insns: u64) -> Self { + self.segmentation_ctx + .set_segment_check_insns(segment_check_insns); + self + } + + pub fn with_instret_end(mut self, target_instret: u64) -> Self { + self.instret_end = target_instret; + self + } + + pub fn segments(&self) -> &[Segment] { + &self.segmentation_ctx.segments + } + + pub fn into_segments(self) -> Vec { + self.segmentation_ctx.segments + } + + fn reset_segment(&mut self) { + self.memory_ctx.clear(); + for (i, &is_constant) in self.is_trace_height_constant.iter().enumerate() { + if !is_constant { + self.trace_heights[i] = 0; + } + } + + // Add merkle height contributions for all registers + self.add_register_merkle_heights(); + } + + pub fn check_and_segment(&mut self, instret: u64) { + let did_segment = self.segmentation_ctx.check_and_segment( + instret, + &self.trace_heights, + &self.is_trace_height_constant, + ); + + if did_segment { + self.reset_segment(); + } + } + + #[allow(dead_code)] + pub fn print_heights(&self) { + println!("{:>10} {:<30}", "Height", "Air Name"); + println!("{}", "-".repeat(42)); + for (i, height) in self.trace_heights.iter().enumerate() { + let air_name = self + .segmentation_ctx + .air_names + .get(i) + .map(|s| s.as_str()) + .unwrap_or("Unknown"); + println!("{:>10} {:<30}", height, air_name); + } + } +} + +impl E1ExecutionCtx for MeteredCtx { + #[inline(always)] + fn on_memory_operation(&mut self, address_space: u32, ptr: u32, size: u32) { + debug_assert!( + address_space != RV32_IMM_AS, + "address space must not be immediate" + ); + debug_assert!( + size.is_power_of_two(), + "size must be a power of 2, got {}", + size + ); + + // Handle access adapter updates + let size_bits = size.ilog2(); + self.memory_ctx + .update_adapter_heights(&mut self.trace_heights, address_space, size_bits); + + // Handle merkle tree updates + if address_space != RV32_REGISTER_AS { + self.memory_ctx.update_boundary_merkle_heights( + &mut self.trace_heights, + address_space, + ptr, + size, + ); + } + } + + #[inline(always)] + fn should_suspend(vm_state: &mut VmSegmentState) -> bool { + vm_state.ctx.check_and_segment(vm_state.instret); + vm_state.instret == vm_state.ctx.instret_end + } + + #[inline(always)] + fn on_terminate(vm_state: &mut VmSegmentState) { + vm_state + .ctx + .segmentation_ctx + .segment(vm_state.instret, &vm_state.ctx.trace_heights); + } +} + +impl E2ExecutionCtx for MeteredCtx { + #[inline(always)] + fn on_height_change(&mut self, chip_idx: usize, height_delta: u32) { + self.trace_heights[chip_idx] += height_delta; + } +} diff --git a/crates/vm/src/arch/execution_mode/metered/memory_ctx.rs b/crates/vm/src/arch/execution_mode/metered/memory_ctx.rs new file mode 100644 index 0000000000..076cbd807c --- /dev/null +++ b/crates/vm/src/arch/execution_mode/metered/memory_ctx.rs @@ -0,0 +1,175 @@ +use crate::{ + arch::PUBLIC_VALUES_AIR_ID, + system::memory::{dimensions::MemoryDimensions, CHUNK}, +}; + +#[derive(Debug)] +pub struct BitSet { + words: Box<[u64]>, +} + +impl BitSet { + pub fn new(size_bits: usize) -> Self { + let num_words = 1 << size_bits.saturating_sub(6); + Self { + words: vec![0; num_words].into_boxed_slice(), + } + } + + pub fn insert(&mut self, index: usize) -> bool { + let word_index = index / 64; + let bit_index = index % 64; + let mask = 1u64 << bit_index; + + let was_set = (self.words[word_index] & mask) != 0; + self.words[word_index] |= mask; + !was_set + } + + pub fn clear(&mut self) { + for item in self.words.iter_mut() { + *item = 0; + } + } +} + +#[derive(Debug)] +pub struct MemoryCtx { + pub page_indices: BitSet, + memory_dimensions: MemoryDimensions, + as_byte_alignment_bits: Vec, + pub boundary_idx: usize, + pub merkle_tree_index: Option, + pub adapter_offset: usize, + chunk: u32, + chunk_bits: u32, +} + +impl MemoryCtx { + pub fn new( + has_public_values_chip: bool, + continuations_enabled: bool, + as_byte_alignment_bits: Vec, + memory_dimensions: MemoryDimensions, + ) -> Self { + let boundary_idx = if has_public_values_chip { + PUBLIC_VALUES_AIR_ID + 1 + } else { + PUBLIC_VALUES_AIR_ID + }; + + let merkle_tree_index = if continuations_enabled { + Some(boundary_idx + 1) + } else { + None + }; + + let adapter_offset = if continuations_enabled { + boundary_idx + 2 + } else { + boundary_idx + 1 + }; + + let chunk = if continuations_enabled { + // Persistent memory uses CHUNK-sized blocks + CHUNK as u32 + } else { + // Volatile memory uses single units + 1 + }; + + let chunk_bits = chunk.ilog2(); + let merkle_height = memory_dimensions.overall_height(); + + Self { + page_indices: BitSet::new(merkle_height.saturating_sub(PAGE_BITS)), + as_byte_alignment_bits, + boundary_idx, + merkle_tree_index, + adapter_offset, + chunk, + chunk_bits, + memory_dimensions, + } + } + + pub fn clear(&mut self) { + self.page_indices.clear(); + } + + pub fn update_boundary_merkle_heights( + &mut self, + trace_heights: &mut [u32], + address_space: u32, + ptr: u32, + size: u32, + ) { + let num_blocks = (size + self.chunk - 1) >> self.chunk_bits; + let mut addr = ptr; + for _ in 0..num_blocks { + let block_id = addr >> self.chunk_bits; + let index = if self.chunk == 1 { + // Volatile + block_id + } else { + self.memory_dimensions + .label_to_index((address_space, block_id)) as u32 + } as usize; + + if self.page_indices.insert(index >> PAGE_BITS) { + // On page fault, assume we add all leaves in a page + let leaves = 1 << PAGE_BITS; + trace_heights[self.boundary_idx] += leaves; + + if let Some(merkle_tree_idx) = self.merkle_tree_index { + let poseidon2_idx = trace_heights.len() - 2; + trace_heights[poseidon2_idx] += leaves * 2; + + let merkle_height = self.memory_dimensions.overall_height(); + let nodes = (((1 << PAGE_BITS) - 1) + (merkle_height - PAGE_BITS)) as u32; + trace_heights[poseidon2_idx] += nodes * 2; + trace_heights[merkle_tree_idx] += nodes * 2; + } + + // At finalize, we'll need to read it in chunk-sized units for the merkle chip + self.update_adapter_heights_batch( + trace_heights, + address_space, + self.chunk_bits, + leaves, + ); + } + + addr = addr.wrapping_add(self.chunk); + } + } + + pub fn update_adapter_heights( + &mut self, + trace_heights: &mut [u32], + address_space: u32, + size_bits: u32, + ) { + self.update_adapter_heights_batch(trace_heights, address_space, size_bits, 1); + } + + pub fn update_adapter_heights_batch( + &mut self, + trace_heights: &mut [u32], + address_space: u32, + size_bits: u32, + num: u32, + ) { + let align_bits = self.as_byte_alignment_bits[address_space as usize]; + debug_assert!( + align_bits as u32 <= size_bits, + "align_bits ({}) must be <= size_bits ({})", + align_bits, + size_bits + ); + for adapter_bits in (align_bits as u32 + 1..=size_bits).rev() { + let adapter_idx = self.adapter_offset + adapter_bits as usize - 1; + trace_heights[adapter_idx] += num << (size_bits - adapter_bits + 1); + } + } +} diff --git a/crates/vm/src/arch/execution_mode/metered/mod.rs b/crates/vm/src/arch/execution_mode/metered/mod.rs new file mode 100644 index 0000000000..9bd0799194 --- /dev/null +++ b/crates/vm/src/arch/execution_mode/metered/mod.rs @@ -0,0 +1,6 @@ +pub mod ctx; +pub mod memory_ctx; +pub mod segment_ctx; + +pub use ctx::MeteredCtx; +pub use segment_ctx::Segment; diff --git a/crates/vm/src/arch/execution_mode/metered/segment_ctx.rs b/crates/vm/src/arch/execution_mode/metered/segment_ctx.rs new file mode 100644 index 0000000000..3b809d50a0 --- /dev/null +++ b/crates/vm/src/arch/execution_mode/metered/segment_ctx.rs @@ -0,0 +1,204 @@ +use openvm_stark_backend::p3_field::PrimeField32; +use p3_baby_bear::BabyBear; +use serde::{Deserialize, Serialize}; + +/// Check segment every 100 instructions. +const DEFAULT_SEGMENT_CHECK_INSNS: u64 = 100; + +const DEFAULT_MAX_TRACE_HEIGHT: u32 = (1 << 23) - 100; +const DEFAULT_MAX_CELLS: usize = 2_000_000_000; // 2B +const DEFAULT_MAX_INTERACTIONS: usize = BabyBear::ORDER_U32 as usize; + +#[derive(derive_new::new, Clone, Debug, Serialize, Deserialize)] +pub struct Segment { + pub instret_start: u64, + pub num_insns: u64, + pub trace_heights: Vec, +} + +#[derive(Debug)] +pub struct SegmentationLimits { + pub max_trace_height: u32, + pub max_cells: usize, + pub max_interactions: usize, +} + +impl Default for SegmentationLimits { + fn default() -> Self { + Self { + max_trace_height: DEFAULT_MAX_TRACE_HEIGHT, + max_cells: DEFAULT_MAX_CELLS, + max_interactions: DEFAULT_MAX_INTERACTIONS, + } + } +} + +#[derive(Debug)] +pub struct SegmentationCtx { + pub segments: Vec, + instret_last_segment_check: u64, + pub(crate) air_names: Vec, + widths: Vec, + interactions: Vec, + segment_check_insns: u64, + segmentation_limits: SegmentationLimits, +} + +impl SegmentationCtx { + pub fn new(air_names: Vec, widths: Vec, interactions: Vec) -> Self { + Self { + segments: Vec::new(), + air_names, + widths, + interactions, + segment_check_insns: DEFAULT_SEGMENT_CHECK_INSNS, + segmentation_limits: SegmentationLimits::default(), + instret_last_segment_check: 0, + } + } + + pub fn set_max_trace_height(&mut self, max_trace_height: u32) { + self.segmentation_limits.max_trace_height = max_trace_height; + } + + pub fn set_max_cells(&mut self, max_cells: usize) { + self.segmentation_limits.max_cells = max_cells; + } + + pub fn set_max_interactions(&mut self, max_interactions: usize) { + self.segmentation_limits.max_interactions = max_interactions; + } + + pub fn set_segment_check_insns(&mut self, segment_check_insns: u64) { + self.segment_check_insns = segment_check_insns; + } + + /// Calculate the total cells used based on trace heights and widths + fn calculate_total_cells(&self, trace_heights: &[u32]) -> usize { + trace_heights + .iter() + .zip(&self.widths) + .map(|(&height, &width)| height as usize * width) + .sum() + } + + /// Calculate the total interactions based on trace heights and interaction counts + fn calculate_total_interactions(&self, trace_heights: &[u32]) -> usize { + trace_heights + .iter() + .zip(&self.interactions) + // We add 1 for the zero messages from the padding rows + .map(|(&height, &interactions)| (height + 1) as usize * interactions) + .sum() + } + + fn should_segment( + &self, + instret: u64, + trace_heights: &[u32], + is_trace_height_constant: &[bool], + ) -> bool { + let instret_start = self + .segments + .last() + .map_or(0, |s| s.instret_start + s.num_insns); + let num_insns = instret - instret_start; + // Segment should contain at least one cycle + if num_insns == 0 { + return false; + } + for (i, &height) in trace_heights.iter().enumerate() { + // Only segment if the height is not constant and exceeds the maximum height + if !is_trace_height_constant[i] && height > self.segmentation_limits.max_trace_height { + tracing::info!( + "Segment {:2} | instret {:9} | chip {} ({}) height ({:8}) > max ({:8})", + self.segments.len(), + instret, + i, + self.air_names[i], + height, + self.segmentation_limits.max_trace_height + ); + return true; + } + } + + let total_cells = self.calculate_total_cells(trace_heights); + if total_cells > self.segmentation_limits.max_cells { + tracing::info!( + "Segment {:2} | instret {:9} | total cells ({:10}) > max ({:10})", + self.segments.len(), + instret, + total_cells, + self.segmentation_limits.max_cells + ); + return true; + } + + let total_interactions = self.calculate_total_interactions(trace_heights); + if total_interactions > self.segmentation_limits.max_interactions { + tracing::info!( + "Segment {:2} | instret {:9} | total interactions ({:11}) > max ({:11})", + self.segments.len(), + instret, + total_interactions, + self.segmentation_limits.max_interactions + ); + return true; + } + + false + } + + pub fn check_and_segment( + &mut self, + instret: u64, + trace_heights: &[u32], + is_trace_height_constant: &[bool], + ) -> bool { + // Avoid checking segment too often. + if instret < self.instret_last_segment_check + self.segment_check_insns { + return false; + } + + let ret = self.should_segment(instret, trace_heights, is_trace_height_constant); + if ret { + self.segment(instret, trace_heights); + } + self.instret_last_segment_check = instret; + ret + } + + /// Try segment if there is at least one cycle + pub fn segment(&mut self, instret: u64, trace_heights: &[u32]) { + let instret_start = self + .segments + .last() + .map_or(0, |s| s.instret_start + s.num_insns); + let num_insns = instret - instret_start; + self.segments.push(Segment { + instret_start, + num_insns, + trace_heights: trace_heights.to_vec(), + }); + } + + pub fn add_final_segment(&mut self, instret: u64, trace_heights: &[u32]) { + tracing::info!( + "Segment {:2} | instret {:9} | terminated", + self.segments.len(), + instret, + ); + // Add the last segment + let instret_start = self + .segments + .last() + .map_or(0, |s| s.instret_start + s.num_insns); + let segment = Segment { + instret_start, + num_insns: instret - instret_start, + trace_heights: trace_heights.to_vec(), + }; + self.segments.push(segment); + } +} diff --git a/crates/vm/src/arch/execution_mode/mod.rs b/crates/vm/src/arch/execution_mode/mod.rs new file mode 100644 index 0000000000..24fd4ba28e --- /dev/null +++ b/crates/vm/src/arch/execution_mode/mod.rs @@ -0,0 +1,15 @@ +use crate::arch::VmSegmentState; + +pub mod e1; +pub mod metered; +pub mod tracegen; + +pub trait E1ExecutionCtx: Sized { + fn on_memory_operation(&mut self, address_space: u32, ptr: u32, size: u32); + fn should_suspend(vm_state: &mut VmSegmentState) -> bool; + fn on_terminate(_vm_state: &mut VmSegmentState) {} +} + +pub trait E2ExecutionCtx: E1ExecutionCtx { + fn on_height_change(&mut self, chip_idx: usize, height_delta: u32); +} diff --git a/crates/vm/src/arch/execution_mode/tracegen.rs b/crates/vm/src/arch/execution_mode/tracegen.rs new file mode 100644 index 0000000000..6aec38e170 --- /dev/null +++ b/crates/vm/src/arch/execution_mode/tracegen.rs @@ -0,0 +1,97 @@ +use openvm_instructions::instruction::Instruction; +use openvm_stark_backend::p3_field::PrimeField32; + +use crate::{ + arch::{ + execution_control::ExecutionControl, ExecutionError, ExecutionState, InstructionExecutor, + VmChipComplex, VmConfig, VmSegmentState, + }, + system::memory::INITIAL_TIMESTAMP, +}; + +#[derive(Default, derive_new::new)] +pub struct TracegenCtx { + pub instret_end: Option, +} + +#[derive(Default)] +pub struct TracegenExecutionControl; + +impl ExecutionControl for TracegenExecutionControl +where + F: PrimeField32, + VC: VmConfig, +{ + type Ctx = TracegenCtx; + + fn initialize_context(&self) -> Self::Ctx { + TracegenCtx { instret_end: None } + } + + fn should_suspend( + &self, + state: &mut VmSegmentState, + _chip_complex: &VmChipComplex, + ) -> bool { + state + .ctx + .instret_end + .is_some_and(|instret_end| state.instret >= instret_end) + } + + fn on_start( + &self, + state: &mut VmSegmentState, + chip_complex: &mut VmChipComplex, + ) { + chip_complex + .connector_chip_mut() + .begin(ExecutionState::new(state.pc, INITIAL_TIMESTAMP + 1)); + } + + fn on_suspend_or_terminate( + &self, + state: &mut VmSegmentState, + chip_complex: &mut VmChipComplex, + exit_code: Option, + ) { + let timestamp = chip_complex.memory_controller().timestamp(); + chip_complex + .connector_chip_mut() + .end(ExecutionState::new(state.pc, timestamp), exit_code); + } + + /// Execute a single instruction + fn execute_instruction( + &self, + state: &mut VmSegmentState, + instruction: &Instruction, + chip_complex: &mut VmChipComplex, + ) -> Result<(), ExecutionError> + where + F: PrimeField32, + { + let timestamp = chip_complex.memory_controller().timestamp(); + + let &Instruction { opcode, .. } = instruction; + + if let Some(executor) = chip_complex.inventory.get_mut_executor(&opcode) { + let memory_controller = &mut chip_complex.base.memory_controller; + let new_state = executor.execute( + memory_controller, + &mut state.streams, + &mut state.rng, + instruction, + ExecutionState::new(state.pc, timestamp), + )?; + state.pc = new_state.pc; + } else { + return Err(ExecutionError::DisabledOperation { + pc: state.pc, + opcode, + }); + }; + + Ok(()) + } +} diff --git a/crates/vm/src/arch/extensions.rs b/crates/vm/src/arch/extensions.rs index adda318f6a..0a676cabff 100644 --- a/crates/vm/src/arch/extensions.rs +++ b/crates/vm/src/arch/extensions.rs @@ -2,7 +2,7 @@ use std::{ any::{Any, TypeId}, cell::RefCell, iter::once, - sync::{Arc, Mutex}, + sync::Arc, }; use derive_more::derive::From; @@ -10,14 +10,15 @@ use getset::Getters; use itertools::{zip_eq, Itertools}; #[cfg(feature = "bench-metrics")] use metrics::counter; -use openvm_circuit_derive::{AnyEnum, InstructionExecutor}; +use openvm_circuit_derive::{AnyEnum, InsExecutorE1, InsExecutorE2, InstructionExecutor}; use openvm_circuit_primitives::{ utils::next_power_of_two_or_zero, var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus}, }; use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; use openvm_instructions::{ - program::Program, LocalOpcode, PhantomDiscriminant, PublishOpcode, SystemOpcode, VmOpcode, + program::Program, LocalOpcode, PhantomDiscriminant, PublishOpcode, SysPhantom, SysPhantom::Nop, + SystemOpcode, VmOpcode, }; use openvm_stark_backend::{ config::{Domain, StarkGenericConfig}, @@ -36,21 +37,29 @@ use serde::{Deserialize, Serialize}; use super::{ vm_poseidon2_config, ExecutionBus, GenerationError, InstructionExecutor, PhantomSubExecutor, - Streams, SystemConfig, SystemTraceHeights, + SystemConfig, SystemTraceHeights, }; #[cfg(feature = "bench-metrics")] use crate::metrics::VmMetrics; -use crate::system::{ - connector::VmConnectorChip, - memory::{ - offline_checker::{MemoryBridge, MemoryBus}, - MemoryController, MemoryImage, OfflineMemory, BOUNDARY_AIR_OFFSET, MERKLE_AIR_OFFSET, +use crate::{ + arch::{ExecutionBridge, VmAirWrapper}, + system::{ + connector::VmConnectorChip, + memory::{ + offline_checker::{MemoryBridge, MemoryBus}, + MemoryController, MemoryImage, BOUNDARY_AIR_OFFSET, MERKLE_AIR_OFFSET, + }, + native_adapter::{NativeAdapterAir, NativeAdapterStep}, + phantom::{ + CycleEndPhantomExecutor, CycleStartPhantomExecutor, NopPhantomExecutor, PhantomChip, + }, + poseidon2::Poseidon2PeripheryChip, + program::{ProgramBus, ProgramChip}, + public_values::{ + core::{PublicValuesCoreAir, PublicValuesCoreStep}, + PublicValuesChip, + }, }, - native_adapter::NativeAdapterChip, - phantom::PhantomChip, - poseidon2::Poseidon2PeripheryChip, - program::{ProgramBus, ProgramChip}, - public_values::{core::PublicValuesCoreChip, PublicValuesChip}, }; /// Global AIR ID in the VM circuit verifying key. @@ -116,7 +125,6 @@ pub struct SystemPort { pub struct VmInventoryBuilder<'a, F: PrimeField32> { system_config: &'a SystemConfig, system: &'a SystemBase, - streams: &'a Arc>>, bus_idx_mgr: BusIndexManager, /// Chips that are already included in the chipset and may be used /// as dependencies. The order should be that depended-on chips are ordered @@ -128,13 +136,11 @@ impl<'a, F: PrimeField32> VmInventoryBuilder<'a, F> { pub fn new( system_config: &'a SystemConfig, system: &'a SystemBase, - streams: &'a Arc>>, bus_idx_mgr: BusIndexManager, ) -> Self { Self { system_config, system, - streams, bus_idx_mgr, chips: Vec::new(), } @@ -187,11 +193,6 @@ impl<'a, F: PrimeField32> VmInventoryBuilder<'a, F> { Ok(()) } - /// Shareable streams. Clone to get a shared mutable reference. - pub fn streams(&self) -> &Arc>> { - self.streams - } - fn add_chip(&mut self, chip: &'a E) { self.chips.push(chip); } @@ -205,7 +206,7 @@ pub struct VmInventory { pub periphery: Vec

, /// Order of insertion. The reverse of this will be the order the chips are destroyed /// to generate trace. - insertion_order: Vec, + pub insertion_order: Vec, } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] @@ -326,11 +327,44 @@ impl VmInventory { self.executors.get(*id) } + pub fn get_executor_id(&self, opcode: VmOpcode) -> Option { + self.instruction_lookup.get(&opcode).cloned() + } + pub fn get_mut_executor(&mut self, opcode: &VmOpcode) -> Option<&mut E> { let id = self.instruction_lookup.get(opcode)?; self.executors.get_mut(*id) } + pub fn get_executor_idx_in_vkey(&self, opcode: &VmOpcode) -> Option { + let id = *self.instruction_lookup.get(opcode)?; + self.insertion_order + .iter() + .rev() + .position(|chip_id| match chip_id { + ChipId::Executor(exec_id) => *exec_id == id, + _ => false, + }) + } + + pub fn get_mut_executor_with_index(&mut self, opcode: &VmOpcode) -> Option<(&mut E, usize)> { + let id = *self.instruction_lookup.get(opcode)?; + + self.executors.get_mut(id).map(|executor| { + let insertion_id = self + .insertion_order + .iter() + .rev() + .position(|chip_id| match chip_id { + ChipId::Executor(exec_id) => *exec_id == id, + _ => false, + }) + .unwrap(); + + (executor, insertion_id) + }) + } + pub fn executors(&self) -> &[E] { &self.executors } @@ -441,7 +475,6 @@ pub struct VmChipComplex { /// Absolute maximum value a trace height can be and still be provable. max_trace_height: usize, - streams: Arc>>, bus_idx_mgr: BusIndexManager, } @@ -494,10 +527,6 @@ impl SystemBase { self.memory_controller.memory_bridge() } - pub fn offline_memory(&self) -> Arc>> { - self.memory_controller.offline_memory().clone() - } - pub fn execution_bus(&self) -> ExecutionBus { self.connector_chip.air.execution_bus } @@ -519,7 +548,9 @@ impl SystemBase { } } -#[derive(ChipUsageGetter, Chip, AnyEnum, From, InstructionExecutor)] +#[derive( + ChipUsageGetter, Chip, AnyEnum, From, InstructionExecutor, InsExecutorE1, InsExecutorE2, +)] pub enum SystemExecutor { PublicValues(PublicValuesChip), Phantom(RefCell>), @@ -544,7 +575,7 @@ impl SystemComplex { let memory_controller = if config.continuation_enabled { MemoryController::with_persistent_memory( memory_bus, - config.memory_config, + config.memory_config.clone(), range_checker.clone(), PermutationCheckBus::new(bus_idx_mgr.new_bus_idx()), PermutationCheckBus::new(bus_idx_mgr.new_bus_idx()), @@ -552,12 +583,11 @@ impl SystemComplex { } else { MemoryController::with_volatile_memory( memory_bus, - config.memory_config, + config.memory_config.clone(), range_checker.clone(), ) }; let memory_bridge = memory_controller.memory_bridge(); - let offline_memory = memory_controller.offline_memory(); let program_chip = ProgramChip::new(program_bus); let connector_chip = VmConnectorChip::new( execution_bus, @@ -570,14 +600,26 @@ impl SystemComplex { // PublicValuesChip is required when num_public_values > 0 in single segment mode. if config.has_public_values_chip() { assert_eq!(inventory.executors().len(), Self::PV_EXECUTOR_IDX); + let chip = PublicValuesChip::new( - NativeAdapterChip::new(execution_bus, program_bus, memory_bridge), - PublicValuesCoreChip::new( + VmAirWrapper::new( + NativeAdapterAir::new( + ExecutionBridge::new(execution_bus, program_bus), + memory_bridge, + ), + PublicValuesCoreAir::new( + config.num_public_values, + config.max_constraint_degree as u32 - 1, + ), + ), + PublicValuesCoreStep::new( + NativeAdapterStep::new(), config.num_public_values, config.max_constraint_degree as u32 - 1, ), - offline_memory, + memory_controller.helper(), ); + inventory .add_executor(chip, [PublishOpcode::PUBLISH.global_opcode()]) .unwrap(); @@ -600,11 +642,23 @@ impl SystemComplex { ); inventory.add_periphery_chip(chip); } - let streams = Arc::new(Mutex::new(Streams::default())); let phantom_opcode = SystemOpcode::PHANTOM.global_opcode(); let mut phantom_chip = PhantomChip::new(execution_bus, program_bus, SystemOpcode::CLASS_OFFSET); - phantom_chip.set_streams(streams.clone()); + // Use NopPhantomExecutor so the discriminant is set but `DebugPanic` is handled specially. + phantom_chip.add_sub_executor( + NopPhantomExecutor, + PhantomDiscriminant(SysPhantom::DebugPanic as u16), + ); + phantom_chip.add_sub_executor(NopPhantomExecutor, PhantomDiscriminant(Nop as u16)); + phantom_chip.add_sub_executor( + CycleStartPhantomExecutor, + PhantomDiscriminant(SysPhantom::CtStart as u16), + ); + phantom_chip.add_sub_executor( + CycleEndPhantomExecutor, + PhantomDiscriminant(SysPhantom::CtEnd as u16), + ); inventory .add_executor(RefCell::new(phantom_chip), [phantom_opcode]) .unwrap(); @@ -631,7 +685,6 @@ impl SystemComplex { base, inventory, bus_idx_mgr, - streams, overridden_inventory_heights: None, max_trace_height, } @@ -650,8 +703,7 @@ impl VmChipComplex { E: AnyEnum, P: AnyEnum, { - let mut builder = - VmInventoryBuilder::new(&self.config, &self.base, &self.streams, self.bus_idx_mgr); + let mut builder = VmInventoryBuilder::new(&self.config, &self.base, self.bus_idx_mgr); // Add range checker for convenience, the other system base chips aren't included - they can // be accessed directly from builder builder.add_chip(&self.base.range_checker_chip); @@ -696,7 +748,6 @@ impl VmChipComplex { base: self.base, inventory: self.inventory.transmute(), bus_idx_mgr: self.bus_idx_mgr, - streams: self.streams, overridden_inventory_heights: self.overridden_inventory_heights, max_trace_height: self.max_trace_height, } @@ -776,11 +827,11 @@ impl VmChipComplex { .as_any_kind_mut() .downcast_mut() .expect("Poseidon2 chip required for persistent memory"); - self.base.memory_controller.finalize(Some(hasher)) + self.base.memory_controller.finalize(Some(hasher)); } else { self.base .memory_controller - .finalize(None::<&mut Poseidon2PeripheryChip>) + .finalize(None::<&mut Poseidon2PeripheryChip>); }; } @@ -788,28 +839,17 @@ impl VmChipComplex { self.base.program_chip.set_program(program); } - pub(crate) fn set_initial_memory(&mut self, memory: MemoryImage) { + pub(crate) fn set_initial_memory(&mut self, memory: MemoryImage) { self.base.memory_controller.set_initial_memory(memory); } - /// Warning: this sets the stream in all chips which have a shared mutable reference to the - /// streams. - pub(crate) fn set_streams(&mut self, streams: Streams) { - *self.streams.lock().unwrap() = streams; - } - - /// This should **only** be called after segment execution has finished. - pub fn take_streams(&mut self) -> Streams { - std::mem::take(&mut self.streams.lock().unwrap()) - } - // This is O(1). pub fn num_airs(&self) -> usize { 3 + self.memory_controller().num_airs() + self.inventory.num_airs() } // we always need to special case it because we need to fix the air id. - fn public_values_chip_idx(&self) -> Option { + pub(crate) fn public_values_chip_idx(&self) -> Option { self.config .has_public_values_chip() .then_some(Self::PV_EXECUTOR_IDX) @@ -822,6 +862,15 @@ impl VmChipComplex { .then(|| &self.inventory.executors[Self::PV_EXECUTOR_IDX]) } + // The index at which the executor chips start in vkey. + pub(crate) fn get_executor_offset_in_vkey(&self) -> usize { + if self.config.has_public_values_chip() { + PUBLIC_VALUES_AIR_ID + 1 + self.memory_controller().num_airs() + } else { + PUBLIC_VALUES_AIR_ID + self.memory_controller().num_airs() + } + } + // All inventory chips except public values chip, in reverse order they were added. pub(crate) fn chips_excluding_pv_chip(&self) -> impl Iterator> { let public_values_chip_idx = self.public_values_chip_idx(); @@ -838,7 +887,7 @@ impl VmChipComplex { } /// Return air names of all chips in order. - pub(crate) fn air_names(&self) -> Vec + pub fn air_names(&self) -> Vec where E: ChipUsageGetter, P: ChipUsageGetter, @@ -851,6 +900,7 @@ impl VmChipComplex { .chain([self.range_checker_chip().air_name()]) .collect() } + /// Return trace heights of all chips in order corresponding to `air_names`. pub(crate) fn current_trace_heights(&self) -> Vec where @@ -903,6 +953,14 @@ impl VmChipComplex { ) } + pub(crate) fn set_adapter_heights(&mut self, heights: &[u32]) { + self.base + .memory_controller + .memory + .access_adapter_inventory + .set_arena_from_trace_heights(heights); + } + /// Override the trace heights for chips in the inventory. Usually this is for aggregation to /// generate a dummy proof and not useful for regular users. pub(crate) fn set_override_inventory_trace_heights( @@ -920,38 +978,35 @@ impl VmChipComplex { memory_controller.set_override_trace_heights(overridden_system_heights.memory); } - /// Return dynamic trace heights of all chips in order, or 0 if - /// chip has constant height. - // Used for continuation segmentation logic, so this is performance-sensitive. - // Return iterator so we can break early. - pub(crate) fn dynamic_trace_heights(&self) -> impl Iterator + '_ + /// Return constant trace heights of all chips in order, or None if + /// chip has dynamic height. + pub(crate) fn constant_trace_heights(&self) -> impl Iterator> + '_ where E: ChipUsageGetter, P: ChipUsageGetter, { - // program_chip, connector_chip - [0, 0] - .into_iter() - .chain(self._public_values_chip().map(|c| c.current_trace_height())) - .chain(self.memory_controller().current_trace_heights()) - .chain(self.chips_excluding_pv_chip().map(|c| match c { - // executor should never be constant height - Either::Executor(c) => c.current_trace_height(), - Either::Periphery(c) => { - if c.constant_trace_height().is_some() { - 0 - } else { - c.current_trace_height() - } - } - })) - .chain([0]) // range_checker_chip + [ + self.program_chip().constant_trace_height(), + self.connector_chip().constant_trace_height(), + ] + .into_iter() + .chain( + self._public_values_chip() + .map(|c| c.constant_trace_height()), + ) + .chain(std::iter::repeat(None).take(self.memory_controller().num_airs())) + .chain(self.chips_excluding_pv_chip().map(|c| match c { + Either::Periphery(c) => c.constant_trace_height(), + Either::Executor(c) => c.constant_trace_height(), + })) + .chain([self.range_checker_chip().constant_trace_height()]) } /// Return trace cells of all chips in order. /// This returns 0 cells for chips with preprocessed trace because the number of trace cells is /// constant in those cases. This function is used to sample periodically and provided to /// the segmentation strategy to decide whether to segment during execution. + #[cfg(feature = "bench-metrics")] pub(crate) fn current_trace_cells(&self) -> Vec where E: ChipUsageGetter, @@ -997,6 +1052,20 @@ impl VmChipComplex { .collect() } + pub fn get_air_widths(&self) -> Vec + where + E: ChipUsageGetter, + P: ChipUsageGetter, + { + once(self.program_chip().trace_width()) + .chain([self.connector_chip().trace_width()]) + .chain(self._public_values_chip().map(|c| c.trace_width())) + .chain(self.memory_controller().get_memory_trace_widths()) + .chain(self.chips_excluding_pv_chip().map(|c| c.trace_width())) + .chain([self.range_checker_chip().trace_width()]) + .collect() + } + pub(crate) fn generate_proof_input( mut self, cached_program: Option>, diff --git a/crates/vm/src/arch/hasher/mod.rs b/crates/vm/src/arch/hasher/mod.rs index df90a55e4b..e858da25f9 100644 --- a/crates/vm/src/arch/hasher/mod.rs +++ b/crates/vm/src/arch/hasher/mod.rs @@ -24,10 +24,10 @@ pub trait Hasher { leaves[0] } } -pub trait HasherChip: Hasher { +pub trait HasherChip: Hasher + Send + Sync { /// Stateful version of `hash` for recording the event in the chip. - fn compress_and_record(&mut self, left: &[F; CHUNK], right: &[F; CHUNK]) -> [F; CHUNK]; - fn hash_and_record(&mut self, values: &[F; CHUNK]) -> [F; CHUNK] { + fn compress_and_record(&self, left: &[F; CHUNK], right: &[F; CHUNK]) -> [F; CHUNK]; + fn hash_and_record(&self, values: &[F; CHUNK]) -> [F; CHUNK] { self.compress_and_record(values, &[F::ZERO; CHUNK]) } } diff --git a/crates/vm/src/arch/integration_api.rs b/crates/vm/src/arch/integration_api.rs index b1116d8c48..61cfcdd6c1 100644 --- a/crates/vm/src/arch/integration_api.rs +++ b/crates/vm/src/arch/integration_api.rs @@ -1,28 +1,39 @@ use std::{ + any::type_name, array::from_fn, - borrow::Borrow, + borrow::{Borrow, BorrowMut}, + io::Cursor, marker::PhantomData, - sync::{Arc, Mutex}, + ptr::{copy_nonoverlapping, slice_from_raw_parts_mut}, + sync::Arc, }; use openvm_circuit_primitives::utils::next_power_of_two_or_zero; use openvm_circuit_primitives_derive::AlignedBorrow; use openvm_instructions::{instruction::Instruction, LocalOpcode}; use openvm_stark_backend::{ - air_builders::{debug::DebugConstraintBuilder, symbolic::SymbolicRapBuilder}, config::{StarkGenericConfig, Val}, p3_air::{Air, AirBuilder, BaseAir}, - p3_field::{FieldAlgebra, PrimeField32}, + p3_field::{Field, FieldAlgebra, PrimeField32}, p3_matrix::{dense::RowMajorMatrix, Matrix}, p3_maybe_rayon::prelude::*, prover::types::AirProofInput, - rap::{get_air_name, BaseAirWithPublicValues, PartitionedBaseAir}, + rap::{get_air_name, AnyRap, BaseAirWithPublicValues, PartitionedBaseAir}, AirRef, Chip, ChipUsageGetter, }; -use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use rand::rngs::StdRng; +use serde::{Deserialize, Serialize}; -use super::{ExecutionState, InstructionExecutor, Result}; -use crate::system::memory::{MemoryController, OfflineMemory}; +use super::{ + execution_mode::E1ExecutionCtx, ExecuteFunc, ExecutionState, InsExecutorE1, InsExecutorE2, + InstructionExecutor, Result, Streams, VmStateMut, +}; +use crate::{ + arch::execution_mode::E2ExecutionCtx, + system::memory::{ + online::TracingMemory, MemoryAuxColsFactory, MemoryController, SharedMemoryHelper, + }, +}; /// The interface between primitive AIR and machine adapter AIR. pub trait VmAdapterInterface { @@ -37,60 +48,6 @@ pub trait VmAdapterInterface { type ProcessedInstruction; } -/// The adapter owns all memory accesses and timestamp changes. -/// The adapter AIR should also own `ExecutionBridge` and `MemoryBridge`. -pub trait VmAdapterChip { - /// Records generated by adapter before main instruction execution - type ReadRecord: Send + Serialize + DeserializeOwned; - /// Records generated by adapter after main instruction execution - type WriteRecord: Send + Serialize + DeserializeOwned; - /// AdapterAir should not have public values - type Air: BaseAir + Clone; - - type Interface: VmAdapterInterface; - - /// Given instruction, perform memory reads and return only the read data that the integrator - /// needs to use. This is called at the start of instruction execution. - /// - /// The implementer may choose to store data in the `Self::ReadRecord` struct, for example in - /// an [Option], which will later be sent to the `postprocess` method. - #[allow(clippy::type_complexity)] - fn preprocess( - &mut self, - memory: &mut MemoryController, - instruction: &Instruction, - ) -> Result<( - >::Reads, - Self::ReadRecord, - )>; - - /// Given instruction and the data to write, perform memory writes and return the `(record, - /// next_timestamp)` of the full adapter record for this instruction. This is guaranteed to - /// be called after `preprocess`. - fn postprocess( - &mut self, - memory: &mut MemoryController, - instruction: &Instruction, - from_state: ExecutionState, - output: AdapterRuntimeContext, - read_record: &Self::ReadRecord, - ) -> Result<(ExecutionState, Self::WriteRecord)>; - - /// Populates `row_slice` with values corresponding to `record`. - /// The provided `row_slice` will have length equal to `self.air().width()`. - /// This function will be called for each row in the trace which is being used, and all other - /// rows in the trace will be filled with zeroes. - fn generate_trace_row( - &self, - row_slice: &mut [F], - read_record: Self::ReadRecord, - write_record: Self::WriteRecord, - memory: &OfflineMemory, - ); - - fn air(&self) -> &Self::Air; -} - pub trait VmAdapterAir: BaseAir { type Interface: VmAdapterInterface; @@ -111,47 +68,6 @@ pub trait VmAdapterAir: BaseAir { fn get_from_pc(&self, local: &[AB::Var]) -> AB::Var; } -/// Trait to be implemented on primitive chip to integrate with the machine. -pub trait VmCoreChip> { - /// Minimum data that must be recorded to be able to generate trace for one row of - /// `PrimitiveAir`. - type Record: Send + Serialize + DeserializeOwned; - /// The primitive AIR with main constraints that do not depend on memory and other - /// architecture-specifics. - type Air: BaseAirWithPublicValues + Clone; - - #[allow(clippy::type_complexity)] - fn execute_instruction( - &self, - instruction: &Instruction, - from_pc: u32, - reads: I::Reads, - ) -> Result<(AdapterRuntimeContext, Self::Record)>; - - fn get_opcode_name(&self, opcode: usize) -> String; - - /// Populates `row_slice` with values corresponding to `record`. - /// The provided `row_slice` will have length equal to `self.air().width()`. - /// This function will be called for each row in the trace which is being used, and all other - /// rows in the trace will be filled with zeroes. - fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record); - - /// Returns a list of public values to publish. - fn generate_public_values(&self) -> Vec { - vec![] - } - - fn air(&self) -> &Self::Air; - - /// Finalize the trace, especially the padded rows if the all-zero rows don't satisfy the - /// constraints. This is done **after** records are consumed and the trace matrix is - /// generated. Most implementations should just leave the default implementation if padding - /// with rows of all 0s satisfies the constraints. - fn finalize(&self, _trace: &mut RowMajorMatrix, _num_records: usize) { - // do nothing by default - } -} - pub trait VmCoreAir: BaseAirWithPublicValues where AB: AirBuilder, @@ -183,164 +99,932 @@ where } } -pub struct AdapterRuntimeContext> { +pub struct AdapterAirContext> { /// Leave as `None` to allow the adapter to decide the `to_pc` automatically. - pub to_pc: Option, + pub to_pc: Option, + pub reads: I::Reads, pub writes: I::Writes, + pub instruction: I::ProcessedInstruction, +} + +/// Given some minimum layout of type `Layout`, the `RecordArena` should allocate a buffer, of +/// size possibly larger than the record, and then return mutable pointers to the record within the +/// buffer. +pub trait RecordArena<'a, Layout, RecordMut> { + /// Allocates underlying buffer and returns a mutable reference `RecordMut`. + /// Note that calling this function may not call an underlying memory allocation as the record + /// arena may be virtual. + fn alloc(&'a mut self, layout: Layout) -> RecordMut; +} + +/// Interface for trace generation of a single instruction.The trace is provided as a mutable +/// buffer during both instruction execution and trace generation. +/// It is expected that no additional memory allocation is necessary and the trace buffer +/// is sufficient, with possible overwriting. +pub trait TraceStep { + type RecordLayout; + type RecordMut<'a>; + + fn execute<'buf, RA>( + &mut self, + state: VmStateMut, CTX>, + instruction: &Instruction, + arena: &'buf mut RA, + ) -> Result<()> + where + RA: RecordArena<'buf, Self::RecordLayout, Self::RecordMut<'buf>>; + + /// Returns a list of public values to publish. + fn generate_public_values(&self) -> Vec { + vec![] + } + + /// Displayable opcode name for logging and debugging purposes. + fn get_opcode_name(&self, opcode: usize) -> String; +} + +// TODO[jpw]: this might be temporary trait before moving trace to CTX +pub trait RowMajorMatrixArena { + /// Set the arena's capacity based on the projected trace height. + fn set_capacity(&mut self, trace_height: usize); + fn with_capacity(height: usize, width: usize) -> Self; + fn width(&self) -> usize; + fn trace_offset(&self) -> usize; + fn into_matrix(self) -> RowMajorMatrix; +} + +// TODO[jpw]: revisit if this trait makes sense +pub trait TraceFiller { + /// Populates `trace`. This function will always be called after + /// [`TraceStep::execute`], so the `trace` should already contain the records necessary to fill + /// in the rest of it. + fn fill_trace( + &self, + mem_helper: &MemoryAuxColsFactory, + trace: &mut RowMajorMatrix, + rows_used: usize, + ) where + Self: Send + Sync, + F: Send + Sync + Clone, + { + let width = trace.width(); + trace.values[..rows_used * width] + .par_chunks_exact_mut(width) + .for_each(|row_slice| { + self.fill_trace_row(mem_helper, row_slice); + }); + trace.values[rows_used * width..] + .par_chunks_exact_mut(width) + .for_each(|row_slice| { + self.fill_dummy_trace_row(mem_helper, row_slice); + }); + } + + /// Populates `row_slice`. This function will always be called after + /// [`TraceStep::execute`], so the `row_slice` should already contain context necessary to + /// fill in the rest of the row. This function will be called for each row in the trace which + /// is being used, and for all other rows in the trace see `fill_dummy_trace_row`. + /// + /// The provided `row_slice` will have length equal to the width of the AIR. + fn fill_trace_row(&self, _mem_helper: &MemoryAuxColsFactory, _row_slice: &mut [F]) { + unreachable!("fill_trace_row is not implemented") + } + + /// Populates `row_slice`. This function will be called on dummy rows. + /// By default the trace is padded with empty (all 0) rows to make the height a power of 2. + /// + /// The provided `row_slice` will have length equal to the width of the AIR. + fn fill_dummy_trace_row(&self, _mem_helper: &MemoryAuxColsFactory, _row_slice: &mut [F]) { + // By default, the row is filled with zeroes + } +} + +/// Converts a field element slice into a record type. +/// This function transmutes the `&mut [F]` to raw bytes, +/// then uses the `CustomBorrow` trait to transmute to the desired record type `T`. +/// ## Safety +/// `slice` must satisfy the requirements of the `CustomBorrow` trait. +pub unsafe fn get_record_from_slice<'a, T, F, L>(slice: &mut &'a mut [F], layout: L) -> T +where + [u8]: CustomBorrow<'a, T, L>, +{ + // The alignment of `[u8]` is always satisfiedƒ + let record_buffer = + &mut *slice_from_raw_parts_mut(slice.as_mut_ptr() as *mut u8, size_of_val::<[F]>(*slice)); + let record: T = record_buffer.custom_borrow(layout); + record +} + +/// Minimal layout information that [RecordArena] requires for record allocation +/// in scenarios involving chips that: +/// - have a single row per record, and +/// - have trace row = [adapter_row, core_row] +/// +/// **NOTE**: `M` is the metadata type that implements `AdapterCoreMetadata` +#[derive(Debug, Clone, Default)] +pub struct AdapterCoreLayout { + pub metadata: M, +} + +/// `Metadata` types need to implement this trait to be used with `AdapterCoreLayout` +/// **NOTE**: get_adapter_width returns the size in bytes +pub trait AdapterCoreMetadata { + fn get_adapter_width() -> usize; +} + +impl AdapterCoreLayout { + pub fn new() -> Self + where + M: Default, + { + Self::default() + } + + pub fn with_metadata(metadata: M) -> Self { + Self { metadata } + } +} + +/// Empty metadata that implements `AdapterCoreMetadata` +/// **NOTE**: `AS` is the adapter type that implements `AdapterTraceStep` +/// **WARNING**: `AS::WIDTH` is the number of field elements, not the size in bytes +pub struct AdapterCoreEmptyMetadata { + _phantom: PhantomData<(F, AS)>, } -impl> AdapterRuntimeContext { - /// Leave `to_pc` as `None` to allow the adapter to decide the `to_pc` automatically. - pub fn without_pc(writes: impl Into) -> Self { +impl Clone for AdapterCoreEmptyMetadata { + fn clone(&self) -> Self { Self { - to_pc: None, - writes: writes.into(), + _phantom: PhantomData, } } } -pub struct AdapterAirContext> { - /// Leave as `None` to allow the adapter to decide the `to_pc` automatically. - pub to_pc: Option, - pub reads: I::Reads, - pub writes: I::Writes, - pub instruction: I::ProcessedInstruction, +impl AdapterCoreEmptyMetadata { + pub fn new() -> Self { + Self { + _phantom: PhantomData, + } + } } -pub struct VmChipWrapper, C: VmCoreChip> { - pub adapter: A, - pub core: C, - pub records: Vec<(A::ReadRecord, A::WriteRecord, C::Record)>, - offline_memory: Arc>>, +impl Default for AdapterCoreEmptyMetadata { + fn default() -> Self { + Self { + _phantom: PhantomData, + } + } +} + +impl AdapterCoreMetadata for AdapterCoreEmptyMetadata +where + AS: AdapterTraceStep, +{ + #[inline(always)] + fn get_adapter_width() -> usize { + AS::WIDTH * size_of::() + } +} + +/// AdapterCoreLayout with empty metadata that can be used by chips that have record type +/// (&mut A, &mut C) where `A` and `C` are `Sized` +pub type EmptyAdapterCoreLayout = AdapterCoreLayout>; + +/// Minimal layout information that [RecordArena] requires for record allocation +/// in scenarios involving chips that: +/// - can have multiple rows per record, and +/// - have possibly variable length records +/// +/// **NOTE**: `M` is the metadata type that implements `MultiRowMetadata` +#[derive(Debug, Clone, Default, derive_new::new)] +pub struct MultiRowLayout { + pub metadata: M, +} + +/// `Metadata` types need to implement this trait to be used with `MultiRowLayout` +pub trait MultiRowMetadata { + fn get_num_rows(&self) -> usize; +} + +/// Empty metadata that implements `MultiRowMetadata` with `get_num_rows` always returning 1 +#[derive(Debug, Clone, Default, derive_new::new)] +pub struct EmptyMultiRowMetadata {} + +impl MultiRowMetadata for EmptyMultiRowMetadata { + #[inline(always)] + fn get_num_rows(&self) -> usize { + 1 + } +} + +/// Empty metadata that implements `MultiRowMetadata` +pub type EmptyMultiRowLayout = MultiRowLayout; + +/// A trait that allows for custom implementation of `borrow` given the necessary information +/// This is useful for record structs that have dynamic size +pub trait CustomBorrow<'a, T, L> { + fn custom_borrow(&'a mut self, layout: L) -> T; + + /// Given `&self` as a valid starting pointer of a reference that has already been previously + /// allocated and written to, extracts and returns the corresponding layout. + /// This must work even if `T` is not sized. + /// + /// # Safety + /// - `&self` must be a valid starting pointer on which `custom_borrow` has already been called + /// - The data underlying `&self` has already been written to and is self-describing, so layout + /// can be extracted + unsafe fn extract_layout(&self) -> L; +} + +/// If a struct implements `BorrowMut`, then the same implementation can be used for +/// `CustomBorrow::custom_borrow` with any layout +impl<'a, T: Sized, L: Default> CustomBorrow<'a, &'a mut T, L> for [u8] +where + [u8]: BorrowMut, +{ + fn custom_borrow(&'a mut self, _layout: L) -> &'a mut T { + self.borrow_mut() + } + + unsafe fn extract_layout(&self) -> L { + L::default() + } +} + +/// `SizedRecord` is a trait that provides additional information about the size and alignment +/// requirements of a record. Should be implemented on RecordMut types +pub trait SizedRecord { + /// The minimal size in bytes that the RecordMut requires to be properly constructed + /// given the layout. + fn size(layout: &Layout) -> usize; + /// The minimal alignment required for the RecordMut to be properly constructed + /// given the layout. + fn alignment(layout: &Layout) -> usize; +} + +impl SizedRecord for &mut Record +where + Record: Sized, +{ + fn size(_layout: &Layout) -> usize { + size_of::() + } + + fn alignment(_layout: &Layout) -> usize { + align_of::() + } +} + +// TEMP[jpw]: buffer should be inside CTX +pub struct MatrixRecordArena { + pub trace_buffer: Vec, + pub width: usize, + pub trace_offset: usize, +} + +impl MatrixRecordArena { + pub fn alloc_single_row(&mut self) -> &mut [u8] { + self.alloc_buffer(1) + } + + pub fn alloc_buffer(&mut self, num_rows: usize) -> &mut [u8] { + let start = self.trace_offset; + self.trace_offset += num_rows * self.width; + let row_slice = &mut self.trace_buffer[start..self.trace_offset]; + let size = size_of_val(row_slice); + let ptr = row_slice as *mut [F] as *mut u8; + // SAFETY: + // - `ptr` is non-null + // - `size` is correct + // - alignment of `u8` is always satisfied + unsafe { &mut *std::ptr::slice_from_raw_parts_mut(ptr, size) } + } +} + +impl RowMajorMatrixArena for MatrixRecordArena { + fn set_capacity(&mut self, trace_height: usize) { + let size = trace_height * self.width; + // PERF: use memset + self.trace_buffer.resize(size, F::ZERO); + } + + fn with_capacity(height: usize, width: usize) -> Self { + let trace_buffer = F::zero_vec(height * width); + Self { + trace_buffer, + width, + trace_offset: 0, + } + } + + fn width(&self) -> usize { + self.width + } + + fn trace_offset(&self) -> usize { + self.trace_offset + } + + fn into_matrix(self) -> RowMajorMatrix { + RowMajorMatrix::new(self.trace_buffer, self.width) + } +} + +/// [RecordArena] implementation for [MatrixRecordArena], with [AdapterCoreLayout] +/// **NOTE**: `A` is the adapter RecordMut type and `C` is the core RecordMut type +impl<'a, F: Field, A, C, M: AdapterCoreMetadata> RecordArena<'a, AdapterCoreLayout, (A, C)> + for MatrixRecordArena +where + [u8]: CustomBorrow<'a, A, AdapterCoreLayout> + CustomBorrow<'a, C, AdapterCoreLayout>, + M: Clone, +{ + fn alloc(&'a mut self, layout: AdapterCoreLayout) -> (A, C) { + let adapter_width = M::get_adapter_width(); + let buffer = self.alloc_single_row(); + // Doing a unchecked split here for perf + let (adapter_buffer, core_buffer) = unsafe { buffer.split_at_mut_unchecked(adapter_width) }; + + let adapter_record: A = adapter_buffer.custom_borrow(layout.clone()); + let core_record: C = core_buffer.custom_borrow(layout); + + (adapter_record, core_record) + } +} + +/// [RecordArena] implementation for [MatrixRecordArena], with [MultiRowLayout] +/// **NOTE**: `R` is the RecordMut type +impl<'a, F: Field, M: MultiRowMetadata, R> RecordArena<'a, MultiRowLayout, R> + for MatrixRecordArena +where + [u8]: CustomBorrow<'a, R, MultiRowLayout>, +{ + fn alloc(&'a mut self, layout: MultiRowLayout) -> R { + let buffer = self.alloc_buffer(layout.metadata.get_num_rows()); + let record: R = buffer.custom_borrow(layout); + record + } +} + +pub struct DenseRecordArena { + pub records_buffer: Cursor>, +} + +const MAX_ALIGNMENT: usize = 32; + +impl DenseRecordArena { + /// Creates a new [DenseRecordArena] with the given capacity in bytes. + pub fn with_capacity(size_bytes: usize) -> Self { + let buffer = vec![0; size_bytes + MAX_ALIGNMENT]; + let offset = (MAX_ALIGNMENT - (buffer.as_ptr() as usize % MAX_ALIGNMENT)) % MAX_ALIGNMENT; + let mut cursor = Cursor::new(buffer); + cursor.set_position(offset as u64); + Self { + records_buffer: cursor, + } + } + + pub fn set_capacity(&mut self, size_bytes: usize) { + let buffer = vec![0; size_bytes + MAX_ALIGNMENT]; + let offset = (MAX_ALIGNMENT - (buffer.as_ptr() as usize % MAX_ALIGNMENT)) % MAX_ALIGNMENT; + let mut cursor = Cursor::new(buffer); + cursor.set_position(offset as u64); + self.records_buffer = cursor; + } + + /// Returns the current size of the allocated buffer so far. + pub fn current_size(&self) -> usize { + self.records_buffer.position() as usize + } + + /// Allocates `count` bytes and returns as a mutable slice. + pub fn alloc_bytes<'a>(&mut self, count: usize) -> &'a mut [u8] { + let begin = self.records_buffer.position(); + debug_assert!( + begin as usize + count <= self.records_buffer.get_ref().len(), + "failed to allocate {count} bytes from {begin} when the capacity is {}", + self.records_buffer.get_ref().len() + ); + self.records_buffer.set_position(begin + count as u64); + unsafe { + std::slice::from_raw_parts_mut( + self.records_buffer + .get_mut() + .as_mut_ptr() + .add(begin as usize), + count, + ) + } + } + + pub fn allocated(&self) -> &[u8] { + let size = self.records_buffer.position() as usize; + let offset = (MAX_ALIGNMENT + - (self.records_buffer.get_ref().as_ptr() as usize % MAX_ALIGNMENT)) + % MAX_ALIGNMENT; + &self.records_buffer.get_ref()[offset..size] + } + + pub fn allocated_mut(&mut self) -> &mut [u8] { + let size = self.records_buffer.position() as usize; + let offset = (MAX_ALIGNMENT + - (self.records_buffer.get_ref().as_ptr() as usize % MAX_ALIGNMENT)) + % MAX_ALIGNMENT; + &mut self.records_buffer.get_mut()[offset..size] + } + + pub fn align_to(&mut self, alignment: usize) { + debug_assert!(MAX_ALIGNMENT % alignment == 0); + let offset = + (alignment - (self.records_buffer.get_ref().as_ptr() as usize % alignment)) % alignment; + self.records_buffer.set_position(offset as u64); + } + + // Returns a [RecordSeeker] on the allocated buffer + pub fn get_record_seeker(&mut self) -> RecordSeeker { + RecordSeeker::new(self.allocated_mut()) + } +} + +/// [RecordArena] implementation for [DenseRecordArena], with [AdapterCoreLayout] +/// **NOTE**: `A` is the adapter RecordMut type and `C` is the core record type +impl<'a, A, C, M> RecordArena<'a, AdapterCoreLayout, (A, C)> for DenseRecordArena +where + [u8]: CustomBorrow<'a, A, AdapterCoreLayout> + CustomBorrow<'a, C, AdapterCoreLayout>, + M: Clone, + A: SizedRecord>, + C: SizedRecord>, +{ + fn alloc(&'a mut self, layout: AdapterCoreLayout) -> (A, C) { + let adapter_alignment = A::alignment(&layout); + let core_alignment = C::alignment(&layout); + let adapter_size = A::size(&layout); + let aligned_adapter_size = adapter_size.next_multiple_of(core_alignment); + let core_size = C::size(&layout); + let aligned_core_size = (aligned_adapter_size + core_size) + .next_multiple_of(adapter_alignment) + - aligned_adapter_size; + debug_assert_eq!(MAX_ALIGNMENT % adapter_alignment, 0); + debug_assert_eq!(MAX_ALIGNMENT % core_alignment, 0); + let buffer = self.alloc_bytes(aligned_adapter_size + aligned_core_size); + // Doing an unchecked split here for perf + let (adapter_buffer, core_buffer) = + unsafe { buffer.split_at_mut_unchecked(aligned_adapter_size) }; + + let adapter_record: A = adapter_buffer.custom_borrow(layout.clone()); + let core_record: C = core_buffer.custom_borrow(layout); + + (adapter_record, core_record) + } +} + +/// [RecordArena] implementation for [DenseRecordArena], with [MultiRowLayout] +/// **NOTE**: `R` is the RecordMut type +impl<'a, R, M> RecordArena<'a, MultiRowLayout, R> for DenseRecordArena +where + [u8]: CustomBorrow<'a, R, MultiRowLayout>, + R: SizedRecord>, +{ + fn alloc(&'a mut self, layout: MultiRowLayout) -> R { + let record_size = R::size(&layout); + let record_alignment = R::alignment(&layout); + let aligned_record_size = record_size.next_multiple_of(record_alignment); + let buffer = self.alloc_bytes(aligned_record_size); + let record: R = buffer.custom_borrow(layout); + record + } +} + +// This is a helper struct that implements a few utility methods +pub struct RecordSeeker<'a, RA, RecordMut, Layout> { + pub buffer: &'a mut [u8], // The buffer that the records are written to + _phantom: PhantomData<(RA, RecordMut, Layout)>, +} + +impl<'a, RA, RecordMut, Layout> RecordSeeker<'a, RA, RecordMut, Layout> { + pub fn new(record_buffer: &'a mut [u8]) -> Self { + Self { + buffer: record_buffer, + _phantom: PhantomData, + } + } +} + +// `RecordSeeker` implementation for [DenseRecordArena], with [MultiRowLayout] +// **NOTE** Assumes that `layout` can be extracted from the record alone +impl<'a, R, M> RecordSeeker<'a, DenseRecordArena, R, MultiRowLayout> +where + [u8]: CustomBorrow<'a, R, MultiRowLayout>, + R: SizedRecord>, + M: MultiRowMetadata + Clone, +{ + // Returns the layout at the given offset in the buffer + // **SAFETY**: `offset` has to be a valid offset, pointing to the start of a record + pub fn get_layout_at(offset: &mut usize, buffer: &[u8]) -> MultiRowLayout { + let buffer = &buffer[*offset..]; + unsafe { buffer.extract_layout() } + } + + // Returns a record at the given offset in the buffer + // **SAFETY**: `offset` has to be a valid offset, pointing to the start of a record + pub fn get_record_at(offset: &mut usize, buffer: &'a mut [u8]) -> R { + let layout = Self::get_layout_at(offset, buffer); + let buffer = &mut buffer[*offset..]; + let record_size = R::size(&layout); + let record_alignment = R::alignment(&layout); + let aligned_record_size = record_size.next_multiple_of(record_alignment); + let record: R = buffer.custom_borrow(layout); + *offset += aligned_record_size; + record + } + + // Returns a vector of all the records in the buffer + pub fn extract_records(&'a mut self) -> Vec { + let mut records = Vec::new(); + let len = self.buffer.len(); + let buff = &mut self.buffer[..]; + let mut offset = 0; + while offset < len { + let record: R = { + let buff = unsafe { &mut *slice_from_raw_parts_mut(buff.as_mut_ptr(), len) }; + Self::get_record_at(&mut offset, buff) + }; + records.push(record); + } + records + } + + // Transfers the records in the buffer to a [MatrixRecordArena], used in testing + pub fn transfer_to_matrix_arena( + &'a mut self, + arena: &mut MatrixRecordArena, + ) { + let len = self.buffer.len(); + arena.trace_offset = 0; + let mut offset = 0; + while offset < len { + let layout = Self::get_layout_at(&mut offset, self.buffer); + let record_size = R::size(&layout); + let record_alignment = R::alignment(&layout); + let aligned_record_size = record_size.next_multiple_of(record_alignment); + let src_ptr = unsafe { self.buffer.as_ptr().add(offset) }; + let dst_ptr = arena + .alloc_buffer(layout.metadata.get_num_rows()) + .as_mut_ptr(); + unsafe { copy_nonoverlapping(src_ptr, dst_ptr, aligned_record_size) }; + offset += aligned_record_size; + } + } +} + +// `RecordSeeker` implementation for [DenseRecordArena], with [AdapterCoreLayout] +// **NOTE** Assumes that `layout` is the same for all the records, so it is expected to be passed as +// a parameter +impl<'a, A, C, M> RecordSeeker<'a, DenseRecordArena, (A, C), AdapterCoreLayout> +where + [u8]: CustomBorrow<'a, A, AdapterCoreLayout> + CustomBorrow<'a, C, AdapterCoreLayout>, + A: SizedRecord>, + C: SizedRecord>, + M: AdapterCoreMetadata + Clone, +{ + // Returns the aligned sizes of the adapter and core records given their layout + pub fn get_aligned_sizes(layout: &AdapterCoreLayout) -> (usize, usize) { + let adapter_alignment = A::alignment(layout); + let core_alignment = C::alignment(layout); + let adapter_size = A::size(layout); + let aligned_adapter_size = adapter_size.next_multiple_of(core_alignment); + let core_size = C::size(layout); + let aligned_core_size = (aligned_adapter_size + core_size) + .next_multiple_of(adapter_alignment) + - aligned_adapter_size; + (aligned_adapter_size, aligned_core_size) + } + + // Returns the aligned size of a single record given its layout + pub fn get_aligned_record_size(layout: &AdapterCoreLayout) -> usize { + let (adapter_size, core_size) = Self::get_aligned_sizes(layout); + adapter_size + core_size + } + + // Returns a record at the given offset in the buffer + // **SAFETY**: `offset` has to be a valid offset, pointing to the start of a record + pub fn get_record_at( + offset: &mut usize, + buffer: &'a mut [u8], + layout: AdapterCoreLayout, + ) -> (A, C) { + let buffer = &mut buffer[*offset..]; + let (adapter_size, core_size) = Self::get_aligned_sizes(&layout); + let (adapter_buffer, core_buffer) = unsafe { buffer.split_at_mut_unchecked(adapter_size) }; + let adapter_record: A = adapter_buffer.custom_borrow(layout.clone()); + let core_record: C = core_buffer.custom_borrow(layout); + *offset += adapter_size + core_size; + (adapter_record, core_record) + } + + // Returns a vector of all the records in the buffer + pub fn extract_records(&'a mut self, layout: AdapterCoreLayout) -> Vec<(A, C)> { + let mut records = Vec::new(); + let len = self.buffer.len(); + let buff = &mut self.buffer[..]; + let mut offset = 0; + while offset < len { + let record: (A, C) = { + let buff = unsafe { &mut *slice_from_raw_parts_mut(buff.as_mut_ptr(), len) }; + Self::get_record_at(&mut offset, buff, layout.clone()) + }; + records.push(record); + } + records + } + + // Transfers the records in the buffer to a [MatrixRecordArena], used in testing + pub fn transfer_to_matrix_arena( + &'a mut self, + arena: &mut MatrixRecordArena, + layout: AdapterCoreLayout, + ) { + let len = self.buffer.len(); + arena.trace_offset = 0; + let mut offset = 0; + let (adapter_size, core_size) = Self::get_aligned_sizes(&layout); + while offset < len { + let dst_buffer = arena.alloc_single_row(); + let (adapter_buf, core_buf) = + unsafe { dst_buffer.split_at_mut_unchecked(M::get_adapter_width()) }; + unsafe { + let src_ptr = self.buffer.as_ptr().add(offset); + copy_nonoverlapping(src_ptr, adapter_buf.as_mut_ptr(), adapter_size); + copy_nonoverlapping(src_ptr.add(adapter_size), core_buf.as_mut_ptr(), core_size); + } + offset += adapter_size + core_size; + } + } } -const DEFAULT_RECORDS_CAPACITY: usize = 1 << 20; +pub struct NewVmChipWrapper { + pub air: AIR, + pub step: STEP, + pub arena: RA, + mem_helper: SharedMemoryHelper, +} -impl VmChipWrapper +// TODO(AG): more general RA +impl NewVmChipWrapper> where - A: VmAdapterChip, - C: VmCoreChip, + F: Field, + AIR: BaseAir, { - pub fn new(adapter: A, core: C, offline_memory: Arc>>) -> Self { + pub fn new(air: AIR, step: STEP, mem_helper: SharedMemoryHelper) -> Self { + let width = air.width(); + assert!( + align_of::() >= align_of::(), + "type {} should have at least alignment of u32", + type_name::() + ); + let arena = MatrixRecordArena::with_capacity(0, width); Self { - adapter, - core, - records: Vec::with_capacity(DEFAULT_RECORDS_CAPACITY), - offline_memory, + air, + step, + arena, + mem_helper, } } + + pub fn set_trace_buffer_height(&mut self, height: usize) { + self.arena.set_capacity(height); + } } -impl InstructionExecutor for VmChipWrapper +// TODO(AG): more general RA +impl NewVmChipWrapper +where + F: Field, + AIR: BaseAir, +{ + pub fn new(air: AIR, step: STEP, mem_helper: SharedMemoryHelper) -> Self { + assert!( + align_of::() >= align_of::(), + "type {} should have at least alignment of u32", + type_name::() + ); + let arena = DenseRecordArena::with_capacity(0); + Self { + air, + step, + arena, + mem_helper, + } + } + + pub fn set_trace_buffer_height(&mut self, height: usize) { + let width = self.air.width(); + self.arena.set_capacity(height * width * size_of::()); + } +} + +impl InstructionExecutor for NewVmChipWrapper where F: PrimeField32, - A: VmAdapterChip + Send + Sync, - M: VmCoreChip + Send + Sync, + STEP: TraceStep // TODO: CTX? + + StepExecutorE1, + for<'buf> RA: RecordArena<'buf, STEP::RecordLayout, STEP::RecordMut<'buf>>, { fn execute( &mut self, memory: &mut MemoryController, + streams: &mut Streams, + rng: &mut StdRng, instruction: &Instruction, from_state: ExecutionState, ) -> Result> { - let (reads, read_record) = self.adapter.preprocess(memory, instruction)?; - let (output, core_record) = - self.core - .execute_instruction(instruction, from_state.pc, reads)?; - let (to_state, write_record) = - self.adapter - .postprocess(memory, instruction, from_state, output, &read_record)?; - self.records.push((read_record, write_record, core_record)); - Ok(to_state) + let mut pc = from_state.pc; + let state = VmStateMut { + pc: &mut pc, + memory: &mut memory.memory, + streams, + rng, + ctx: &mut (), + }; + self.step.execute(state, instruction, &mut self.arena)?; + + Ok(ExecutionState { + pc, + timestamp: memory.memory.timestamp, + }) } fn get_opcode_name(&self, opcode: usize) -> String { - self.core.get_opcode_name(opcode) + self.step.get_opcode_name(opcode) } } // Note[jpw]: the statement we want is: -// - when A::Air is an AdapterAir for all AirBuilders needed by stark-backend -// - and when M::Air is an CoreAir for all AirBuilders needed by stark-backend, -// then VmAirWrapper is an Air for all AirBuilders needed -// by stark-backend, which is equivalent to saying it implements AirRef +// - `Air` is an `Air` for all `AB: AirBuilder`s needed by stark-backend +// which is equivalent to saying it implements AirRef // The where clauses to achieve this statement is unfortunately really verbose. -impl Chip for VmChipWrapper, A, C> +impl Chip for NewVmChipWrapper, AIR, STEP, RA> where SC: StarkGenericConfig, Val: PrimeField32, - A: VmAdapterChip> + Send + Sync, - C: VmCoreChip, A::Interface> + Send + Sync, - A::Air: Send + Sync + 'static, - A::Air: VmAdapterAir>>, - A::Air: for<'a> VmAdapterAir>, - C::Air: Send + Sync + 'static, - C::Air: VmCoreAir< - SymbolicRapBuilder>, - >>>::Interface, - >, - C::Air: for<'a> VmCoreAir< - DebugConstraintBuilder<'a, SC>, - >>::Interface, - >, + STEP: TraceStep, ()> + TraceFiller, ()> + Send + Sync, + AIR: Clone + AnyRap + 'static, + RA: RowMajorMatrixArena>, { fn air(&self) -> AirRef { - let air: VmAirWrapper = VmAirWrapper { - adapter: self.adapter.air().clone(), - core: self.core.air().clone(), - }; - Arc::new(air) + Arc::new(self.air.clone()) } fn generate_air_proof_input(self) -> AirProofInput { - let num_records = self.records.len(); - let height = next_power_of_two_or_zero(num_records); - let core_width = self.core.air().width(); - let adapter_width = self.adapter.air().width(); - let width = core_width + adapter_width; - let mut values = Val::::zero_vec(height * width); - - let memory = self.offline_memory.lock().unwrap(); - - // This zip only goes through records. - // The padding rows between records.len()..height are filled with zeros. - values - .par_chunks_mut(width) - .zip(self.records.into_par_iter()) - .for_each(|(row_slice, record)| { - let (adapter_row, core_row) = row_slice.split_at_mut(adapter_width); - self.adapter - .generate_trace_row(adapter_row, record.0, record.1, &memory); - self.core.generate_trace_row(core_row, record.2); - }); - - let mut trace = RowMajorMatrix::new(values, width); - self.core.finalize(&mut trace, num_records); - - AirProofInput::simple(trace, self.core.generate_public_values()) + let width = self.arena.width(); + assert_eq!(self.arena.trace_offset() % width, 0); + let rows_used = self.arena.trace_offset() / width; + let height = next_power_of_two_or_zero(rows_used); + let mut trace = self.arena.into_matrix(); + // This should be automatic since trace_buffer's height is a power of two: + assert!(height.checked_mul(width).unwrap() <= trace.values.len()); + trace.values.truncate(height * width); + let mem_helper = self.mem_helper.as_borrowed(); + self.step.fill_trace(&mem_helper, &mut trace, rows_used); + drop(self.mem_helper); + + AirProofInput::simple(trace, self.step.generate_public_values()) } } -impl ChipUsageGetter for VmChipWrapper +impl ChipUsageGetter for NewVmChipWrapper where - A: VmAdapterChip + Sync, - M: VmCoreChip + Sync, + C: Sync, + RA: RowMajorMatrixArena, { fn air_name(&self) -> String { - format!( - "<{},{}>", - get_air_name(self.adapter.air()), - get_air_name(self.core.air()) - ) + get_air_name(&self.air) } fn current_trace_height(&self) -> usize { - self.records.len() + self.arena.trace_offset() / self.arena.width() + } + fn trace_width(&self) -> usize { + self.arena.width() + } +} + +/// A helper trait for expressing generic state accesses within the implementation of +/// [TraceStep]. Note that this is only a helper trait when the same interface of state access +/// is reused or shared by multiple implementations. It is not required to implement this trait if +/// it is easier to implement the [TraceStep] trait directly without this trait. +pub trait AdapterTraceStep { + const WIDTH: usize; + type ReadData; + type WriteData; + // @dev This can either be a &mut _ type or a struct with &mut _ fields. + // The latter is helpful if we want to directly write certain values in place into a trace + // matrix. + type RecordMut<'a> + where + Self: 'a; + + fn start(pc: u32, memory: &TracingMemory, record: &mut Self::RecordMut<'_>); + + fn read( + &self, + memory: &mut TracingMemory, + instruction: &Instruction, + record: &mut Self::RecordMut<'_>, + ) -> Self::ReadData; + + fn write( + &self, + memory: &mut TracingMemory, + instruction: &Instruction, + data: Self::WriteData, + record: &mut Self::RecordMut<'_>, + ); +} + +// NOTE[jpw]: cannot reuse `TraceSubRowGenerator` trait because we need associated constant +// `WIDTH`. +pub trait AdapterTraceFiller: AdapterTraceStep { + /// Post-execution filling of rest of adapter row. + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, adapter_row: &mut [F]); +} + +// TODO: Rename core/step to operator +pub trait StepExecutorE1 { + fn pre_compute_size(&self) -> usize; + fn pre_compute_e1( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result> + where + Ctx: E1ExecutionCtx; +} + +pub trait StepExecutorE2 { + fn e2_pre_compute_size(&self) -> usize; + fn pre_compute_e2( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result> + where + Ctx: E2ExecutionCtx; +} + +impl InsExecutorE1 for NewVmChipWrapper> +where + F: PrimeField32, + S: StepExecutorE1, + A: BaseAir, +{ + #[inline(always)] + fn pre_compute_size(&self) -> usize { + self.step.pre_compute_size() + } + #[inline(always)] + fn pre_compute_e1( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result> + where + Ctx: E1ExecutionCtx, + { + self.step.pre_compute_e1(pc, inst, data) + } + + fn set_trace_height(&mut self, height: usize) { + self.set_trace_buffer_height(height); + } +} + +impl InsExecutorE2 for NewVmChipWrapper> +where + F: PrimeField32, + S: StepExecutorE2, +{ + fn e2_pre_compute_size(&self) -> usize { + self.step.e2_pre_compute_size() } - fn trace_width(&self) -> usize { - self.adapter.air().width() + self.core.air().width() + + fn pre_compute_e2( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result> + where + Ctx: E2ExecutionCtx, + { + self.step.pre_compute_e2(chip_idx, pc, inst, data) } } +#[derive(Clone, Copy, derive_new::new)] pub struct VmAirWrapper { pub adapter: A, pub core: C, @@ -608,49 +1292,6 @@ mod conversions { } } - // AdapterRuntimeContext: VecHeapAdapterInterface -> DynInterface - impl< - T, - const NUM_READS: usize, - const BLOCKS_PER_READ: usize, - const BLOCKS_PER_WRITE: usize, - const READ_SIZE: usize, - const WRITE_SIZE: usize, - > - From< - AdapterRuntimeContext< - T, - VecHeapAdapterInterface< - T, - NUM_READS, - BLOCKS_PER_READ, - BLOCKS_PER_WRITE, - READ_SIZE, - WRITE_SIZE, - >, - >, - > for AdapterRuntimeContext> - { - fn from( - ctx: AdapterRuntimeContext< - T, - VecHeapAdapterInterface< - T, - NUM_READS, - BLOCKS_PER_READ, - BLOCKS_PER_WRITE, - READ_SIZE, - WRITE_SIZE, - >, - >, - ) -> Self { - AdapterRuntimeContext { - to_pc: ctx.to_pc, - writes: ctx.writes.into(), - } - } - } - // AdapterAirContext: DynInterface -> VecHeapAdapterInterface impl< T, @@ -682,35 +1323,6 @@ mod conversions { } } - // AdapterRuntimeContext: DynInterface -> VecHeapAdapterInterface - impl< - T, - const NUM_READS: usize, - const BLOCKS_PER_READ: usize, - const BLOCKS_PER_WRITE: usize, - const READ_SIZE: usize, - const WRITE_SIZE: usize, - > From>> - for AdapterRuntimeContext< - T, - VecHeapAdapterInterface< - T, - NUM_READS, - BLOCKS_PER_READ, - BLOCKS_PER_WRITE, - READ_SIZE, - WRITE_SIZE, - >, - > - { - fn from(ctx: AdapterRuntimeContext>) -> Self { - AdapterRuntimeContext { - to_pc: ctx.to_pc, - writes: ctx.writes.into(), - } - } - } - // AdapterAirContext: DynInterface -> VecHeapTwoReadsAdapterInterface impl< T: Clone, @@ -742,95 +1354,6 @@ mod conversions { } } - // AdapterRuntimeContext: DynInterface -> VecHeapAdapterInterface - impl< - T, - const BLOCKS_PER_READ1: usize, - const BLOCKS_PER_READ2: usize, - const BLOCKS_PER_WRITE: usize, - const READ_SIZE: usize, - const WRITE_SIZE: usize, - > From>> - for AdapterRuntimeContext< - T, - VecHeapTwoReadsAdapterInterface< - T, - BLOCKS_PER_READ1, - BLOCKS_PER_READ2, - BLOCKS_PER_WRITE, - READ_SIZE, - WRITE_SIZE, - >, - > - { - fn from(ctx: AdapterRuntimeContext>) -> Self { - AdapterRuntimeContext { - to_pc: ctx.to_pc, - writes: ctx.writes.into(), - } - } - } - - // AdapterRuntimeContext: BasicInterface -> VecHeapAdapterInterface - impl< - T, - PI, - const BASIC_NUM_READS: usize, - const BASIC_NUM_WRITES: usize, - const NUM_READS: usize, - const BLOCKS_PER_READ: usize, - const BLOCKS_PER_WRITE: usize, - const READ_SIZE: usize, - const WRITE_SIZE: usize, - > - From< - AdapterRuntimeContext< - T, - BasicAdapterInterface< - T, - PI, - BASIC_NUM_READS, - BASIC_NUM_WRITES, - READ_SIZE, - WRITE_SIZE, - >, - >, - > - for AdapterRuntimeContext< - T, - VecHeapAdapterInterface< - T, - NUM_READS, - BLOCKS_PER_READ, - BLOCKS_PER_WRITE, - READ_SIZE, - WRITE_SIZE, - >, - > - { - fn from( - ctx: AdapterRuntimeContext< - T, - BasicAdapterInterface< - T, - PI, - BASIC_NUM_READS, - BASIC_NUM_WRITES, - READ_SIZE, - WRITE_SIZE, - >, - >, - ) -> Self { - assert_eq!(BASIC_NUM_WRITES, BLOCKS_PER_WRITE); - let mut writes_it = ctx.writes.into_iter(); - let writes = from_fn(|_| writes_it.next().unwrap()); - AdapterRuntimeContext { - to_pc: ctx.to_pc, - writes, - } - } - } - // AdapterAirContext: BasicInterface -> VecHeapAdapterInterface impl< T, @@ -985,79 +1508,6 @@ mod conversions { } } - // AdapterRuntimeContext: BasicInterface -> FlatInterface - impl< - T, - PI, - const NUM_READS: usize, - const NUM_WRITES: usize, - const READ_SIZE: usize, - const WRITE_SIZE: usize, - const READ_CELLS: usize, - const WRITE_CELLS: usize, - > - From< - AdapterRuntimeContext< - T, - BasicAdapterInterface, - >, - > for AdapterRuntimeContext> - { - /// ## Panics - /// If `WRITE_CELLS != NUM_WRITES * WRITE_SIZE`. - /// This is a runtime assertion until Rust const generics expressions are stabilized. - fn from( - ctx: AdapterRuntimeContext< - T, - BasicAdapterInterface, - >, - ) -> AdapterRuntimeContext> { - assert_eq!(WRITE_CELLS, NUM_WRITES * WRITE_SIZE); - let mut writes_it = ctx.writes.into_iter().flatten(); - let writes = from_fn(|_| writes_it.next().unwrap()); - AdapterRuntimeContext { - to_pc: ctx.to_pc, - writes, - } - } - } - - // AdapterRuntimeContext: FlatInterface -> BasicInterface - impl< - T: FieldAlgebra, - PI, - const NUM_READS: usize, - const NUM_WRITES: usize, - const READ_SIZE: usize, - const WRITE_SIZE: usize, - const READ_CELLS: usize, - const WRITE_CELLS: usize, - > From>> - for AdapterRuntimeContext< - T, - BasicAdapterInterface, - > - { - /// ## Panics - /// If `WRITE_CELLS != NUM_WRITES * WRITE_SIZE`. - /// This is a runtime assertion until Rust const generics expressions are stabilized. - fn from( - ctx: AdapterRuntimeContext>, - ) -> AdapterRuntimeContext< - T, - BasicAdapterInterface, - > { - assert_eq!(WRITE_CELLS, NUM_WRITES * WRITE_SIZE); - let mut writes_it = ctx.writes.into_iter(); - let writes: [[T; WRITE_SIZE]; NUM_WRITES] = - from_fn(|_| from_fn(|_| writes_it.next().unwrap())); - AdapterRuntimeContext { - to_pc: ctx.to_pc, - writes, - } - } - } - impl From> for DynArray { fn from(v: Vec) -> Self { Self(v) @@ -1169,35 +1619,6 @@ mod conversions { } } - // AdapterRuntimeContext: BasicInterface -> DynInterface - impl< - T, - PI, - const NUM_READS: usize, - const NUM_WRITES: usize, - const READ_SIZE: usize, - const WRITE_SIZE: usize, - > - From< - AdapterRuntimeContext< - T, - BasicAdapterInterface, - >, - > for AdapterRuntimeContext> - { - fn from( - ctx: AdapterRuntimeContext< - T, - BasicAdapterInterface, - >, - ) -> Self { - AdapterRuntimeContext { - to_pc: ctx.to_pc, - writes: ctx.writes.into(), - } - } - } - // AdapterAirContext: DynInterface -> BasicInterface impl< T, @@ -1224,28 +1645,6 @@ mod conversions { } } - // AdapterRuntimeContext: DynInterface -> BasicInterface - impl< - T, - PI, - const NUM_READS: usize, - const NUM_WRITES: usize, - const READ_SIZE: usize, - const WRITE_SIZE: usize, - > From>> - for AdapterRuntimeContext< - T, - BasicAdapterInterface, - > - { - fn from(ctx: AdapterRuntimeContext>) -> Self { - AdapterRuntimeContext { - to_pc: ctx.to_pc, - writes: ctx.writes.into(), - } - } - } - // AdapterAirContext: FlatInterface -> DynInterface impl>, const READ_CELLS: usize, const WRITE_CELLS: usize> From>> @@ -1261,21 +1660,6 @@ mod conversions { } } - // AdapterRuntimeContext: FlatInterface -> DynInterface - impl - From>> - for AdapterRuntimeContext> - { - fn from( - ctx: AdapterRuntimeContext>, - ) -> Self { - AdapterRuntimeContext { - to_pc: ctx.to_pc, - writes: ctx.writes.to_vec().into(), - } - } - } - impl From> for DynArray { fn from(m: MinimalInstruction) -> Self { Self(vec![m.is_valid, m.opcode]) diff --git a/crates/vm/src/arch/interpreter.rs b/crates/vm/src/arch/interpreter.rs new file mode 100644 index 0000000000..e9a16a7db0 --- /dev/null +++ b/crates/vm/src/arch/interpreter.rs @@ -0,0 +1,477 @@ +use std::{ + alloc::{alloc, dealloc, handle_alloc_error, Layout}, + borrow::{Borrow, BorrowMut}, + ptr::NonNull, +}; + +use itertools::Itertools; +use openvm_circuit_primitives_derive::AlignedBytesBorrow; +use openvm_instructions::{ + exe::VmExe, instruction::Instruction, program::Program, LocalOpcode, SystemOpcode, +}; +use openvm_stark_backend::p3_field::{Field, PrimeField32}; +use rand::{rngs::StdRng, SeedableRng}; +use tracing::info_span; + +use crate::{ + arch::{ + execution_mode::{E1ExecutionCtx, E2ExecutionCtx}, + ExecuteFunc, + ExecutionError::{self, InvalidInstruction}, + ExitCode, InsExecutorE1, InsExecutorE2, PreComputeInstruction, Streams, VmChipComplex, + VmConfig, VmSegmentState, + }, + system::memory::{online::GuestMemory, AddressMap}, +}; + +/// VM pure executor(E1/E2 executor) which doesn't consider trace generation. +/// Note: This executor doesn't hold any VM state and can be used for multiple execution. +pub struct InterpretedInstance> { + exe: VmExe, + vm_config: VC, + e1_pre_compute_max_size: usize, + e2_pre_compute_max_size: usize, +} + +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct TerminatePreCompute { + exit_code: u32, +} + +macro_rules! execute_with_metrics { + ($span:literal, $program:expr, $vm_state:expr, $pre_compute_insts:expr) => {{ + #[cfg(feature = "bench-metrics")] + let start = std::time::Instant::now(); + #[cfg(feature = "bench-metrics")] + let start_instret = $vm_state.instret; + + info_span!($span).in_scope(|| unsafe { + execute_impl($program, $vm_state, $pre_compute_insts); + }); + + #[cfg(feature = "bench-metrics")] + { + let elapsed = start.elapsed(); + let insns = $vm_state.instret - start_instret; + metrics::counter!("insns").absolute(insns); + metrics::gauge!(concat!($span, "_insn_mi/s")) + .set(insns as f64 / elapsed.as_micros() as f64); + } + }}; +} + +impl> InterpretedInstance { + pub fn new(vm_config: VC, exe: impl Into>) -> Self { + let exe = exe.into(); + let program = &exe.program; + let chip_complex = vm_config.create_chip_complex().unwrap(); + let e1_pre_compute_max_size = get_pre_compute_max_size(program, &chip_complex); + let e2_pre_compute_max_size = get_e2_pre_compute_max_size(program, &chip_complex); + Self { + exe, + vm_config, + e1_pre_compute_max_size, + e2_pre_compute_max_size, + } + } + + /// Execute the VM program with the given execution control and inputs. Returns the final VM + /// state with the `ExecutionControl` context. + pub fn execute( + &self, + ctx: Ctx, + inputs: impl Into>, + ) -> Result, ExecutionError> { + // Initialize the chip complex + let chip_complex = self.vm_config.create_chip_complex().unwrap(); + let mut vm_state = self.init_vm_state(ctx, inputs); + + // Start execution + let program = &self.exe.program; + let pre_compute_max_size = self.e1_pre_compute_max_size; + let mut pre_compute_buf = self.alloc_pre_compute_buf(pre_compute_max_size); + let mut split_pre_compute_buf = + self.split_pre_compute_buf(&mut pre_compute_buf, pre_compute_max_size); + + let pre_compute_insts = get_pre_compute_instructions::<_, _, _, Ctx>( + program, + &chip_complex, + &mut split_pre_compute_buf, + )?; + execute_with_metrics!("execute_e1", program, &mut vm_state, &pre_compute_insts); + if vm_state.exit_code.is_err() { + Err(vm_state.exit_code.err().unwrap()) + } else { + check_exit_code(&vm_state)?; + Ok(vm_state) + } + } + + /// Execute the VM program with the given execution control and inputs. Returns the final VM + /// state with the `ExecutionControl` context. + pub fn execute_e2( + &self, + ctx: Ctx, + inputs: impl Into>, + ) -> Result, ExecutionError> { + // Initialize the chip complex + let chip_complex = self.vm_config.create_chip_complex().unwrap(); + let mut vm_state = self.init_vm_state(ctx, inputs); + + // Start execution + let program = &self.exe.program; + let pre_compute_max_size = self.e2_pre_compute_max_size; + let mut pre_compute_buf = self.alloc_pre_compute_buf(pre_compute_max_size); + let mut split_pre_compute_buf = + self.split_pre_compute_buf(&mut pre_compute_buf, pre_compute_max_size); + + let pre_compute_insts = get_e2_pre_compute_instructions::<_, _, _, Ctx>( + program, + &chip_complex, + &mut split_pre_compute_buf, + )?; + execute_with_metrics!( + "execute_metered", + program, + &mut vm_state, + &pre_compute_insts + ); + if vm_state.exit_code.is_err() { + Err(vm_state.exit_code.err().unwrap()) + } else { + check_exit_code(&vm_state)?; + Ok(vm_state) + } + } + + pub fn init_vm_state( + &self, + ctx: Ctx, + inputs: impl Into>, + ) -> VmSegmentState { + let memory = if self.vm_config.system().continuation_enabled { + let mem_config = self.vm_config.system().memory_config.clone(); + Some(GuestMemory::new(AddressMap::from_sparse( + mem_config.addr_space_sizes.clone(), + self.exe.init_memory.clone(), + ))) + } else { + None + }; + + VmSegmentState::new( + 0, + self.exe.pc_start, + memory, + inputs.into(), + StdRng::seed_from_u64(0), + ctx, + ) + } + + #[inline(always)] + fn alloc_pre_compute_buf(&self, pre_compute_max_size: usize) -> AlignedBuf { + let program = &self.exe.program; + let program_len = program.instructions_and_debug_infos.len(); + let buf_len = program_len * pre_compute_max_size; + AlignedBuf::uninit(buf_len, pre_compute_max_size) + } + + #[inline(always)] + fn split_pre_compute_buf<'a>( + &self, + pre_compute_buf: &'a mut AlignedBuf, + pre_compute_max_size: usize, + ) -> Vec<&'a mut [u8]> { + let program = &self.exe.program; + let program_len = program.instructions_and_debug_infos.len(); + let buf_len = program_len * pre_compute_max_size; + let mut pre_compute_buf_ptr = + unsafe { std::slice::from_raw_parts_mut(pre_compute_buf.ptr, buf_len) }; + let mut split_pre_compute_buf = Vec::with_capacity(program_len); + for _ in 0..program_len { + let (first, last) = pre_compute_buf_ptr.split_at_mut(pre_compute_max_size); + pre_compute_buf_ptr = last; + split_pre_compute_buf.push(first); + } + split_pre_compute_buf + } +} + +#[inline(never)] +unsafe fn execute_impl( + program: &Program, + vm_state: &mut VmSegmentState, + fn_ptrs: &[PreComputeInstruction], +) { + // let start = Instant::now(); + while vm_state + .exit_code + .as_ref() + .is_ok_and(|exit_code| exit_code.is_none()) + { + if Ctx::should_suspend(vm_state) { + break; + } + let pc_index = get_pc_index(program, vm_state.pc).unwrap(); + let inst = &fn_ptrs[pc_index]; + unsafe { (inst.handler)(inst.pre_compute, vm_state) }; + } + if vm_state + .exit_code + .as_ref() + .is_ok_and(|exit_code| exit_code.is_some()) + { + Ctx::on_terminate(vm_state); + } + // println!("execute time: {}ms", start.elapsed().as_millis()); +} + +fn get_pc_index(program: &Program, pc: u32) -> Result { + let step = program.step; + let pc_base = program.pc_base; + let pc_index = ((pc - pc_base) / step) as usize; + if !(0..program.len()).contains(&pc_index) { + return Err(ExecutionError::PcOutOfBounds { + pc, + step, + pc_base, + program_len: program.len(), + }); + } + Ok(pc_index) +} + +/// Bytes allocated according to the given Layout +pub struct AlignedBuf { + pub ptr: *mut u8, + pub layout: Layout, +} + +impl AlignedBuf { + /// Allocate a new buffer whose start address is aligned to `align` bytes. + /// *NOTE* if `len` is zero then a creates new `NonNull` that is dangling and 16-byte aligned. + pub fn uninit(len: usize, align: usize) -> Self { + let layout = Layout::from_size_align(len, align).unwrap(); + if layout.size() == 0 { + return Self { + ptr: NonNull::::dangling().as_ptr() as *mut u8, + layout, + }; + } + // SAFETY: `len` is nonzero + let ptr = unsafe { alloc(layout) }; + if ptr.is_null() { + handle_alloc_error(layout); + } + AlignedBuf { ptr, layout } + } +} + +impl Drop for AlignedBuf { + fn drop(&mut self) { + if self.layout.size() != 0 { + unsafe { + dealloc(self.ptr, self.layout); + } + } + } +} + +unsafe fn terminate_execute_e12_impl( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &TerminatePreCompute = pre_compute.borrow(); + vm_state.instret += 1; + vm_state.exit_code = Ok(Some(pre_compute.exit_code)); +} + +fn get_pre_compute_max_size, P>( + program: &Program, + chip_complex: &VmChipComplex, +) -> usize { + program + .instructions_and_debug_infos + .iter() + .map(|inst_opt| { + if let Some((inst, _)) = inst_opt { + if let Some(size) = system_opcode_pre_compute_size(inst) { + size + } else { + chip_complex + .inventory + .get_executor(inst.opcode) + .map(|executor| executor.pre_compute_size()) + .unwrap() + } + } else { + 0 + } + }) + .max() + .unwrap() + .next_power_of_two() +} + +fn get_e2_pre_compute_max_size, P>( + program: &Program, + chip_complex: &VmChipComplex, +) -> usize { + program + .instructions_and_debug_infos + .iter() + .map(|inst_opt| { + if let Some((inst, _)) = inst_opt { + if let Some(size) = system_opcode_pre_compute_size(inst) { + size + } else { + chip_complex + .inventory + .get_executor(inst.opcode) + .map(|executor| executor.e2_pre_compute_size()) + .unwrap() + } + } else { + 0 + } + }) + .max() + .unwrap() + .next_power_of_two() +} + +fn system_opcode_pre_compute_size(inst: &Instruction) -> Option { + if inst.opcode == SystemOpcode::TERMINATE.global_opcode() { + return Some(size_of::()); + } + None +} + +fn get_pre_compute_instructions< + 'a, + F: PrimeField32, + E: InsExecutorE1, + P, + Ctx: E1ExecutionCtx, +>( + program: &'a Program, + chip_complex: &'a VmChipComplex, + pre_compute: &'a mut [&mut [u8]], +) -> Result>, ExecutionError> { + program + .instructions_and_debug_infos + .iter() + .zip_eq(pre_compute.iter_mut()) + .enumerate() + .map(|(i, (inst_opt, buf))| { + let buf: &mut [u8] = buf; + let pre_inst = if let Some((inst, _)) = inst_opt { + let pc = program.pc_base + i as u32 * program.step; + if let Some(handler) = get_system_opcode_handler(inst, buf) { + PreComputeInstruction { + handler, + pre_compute: buf, + } + } else if let Some(executor) = chip_complex.inventory.get_executor(inst.opcode) { + PreComputeInstruction { + handler: executor.pre_compute_e1(pc, inst, buf)?, + pre_compute: buf, + } + } else { + return Err(ExecutionError::DisabledOperation { + pc, + opcode: inst.opcode, + }); + } + } else { + PreComputeInstruction { + handler: |_, vm_state| { + vm_state.exit_code = Err(InvalidInstruction(vm_state.pc)); + }, + pre_compute: buf, + } + }; + Ok(pre_inst) + }) + .collect::, _>>() +} + +fn get_e2_pre_compute_instructions< + 'a, + F: PrimeField32, + E: InsExecutorE2, + P, + Ctx: E2ExecutionCtx, +>( + program: &'a Program, + chip_complex: &'a VmChipComplex, + pre_compute: &'a mut [&mut [u8]], +) -> Result>, ExecutionError> { + let executor_idx_offset = chip_complex.get_executor_offset_in_vkey(); + program + .instructions_and_debug_infos + .iter() + .zip_eq(pre_compute.iter_mut()) + .enumerate() + .map(|(i, (inst_opt, buf))| { + let buf: &mut [u8] = buf; + let pre_inst = if let Some((inst, _)) = inst_opt { + let pc = program.pc_base + i as u32 * program.step; + if let Some(handler) = get_system_opcode_handler(inst, buf) { + PreComputeInstruction { + handler, + pre_compute: buf, + } + } else if let Some(executor) = chip_complex.inventory.get_executor(inst.opcode) { + let executor_idx = executor_idx_offset + + chip_complex + .inventory + .get_executor_idx_in_vkey(&inst.opcode) + .unwrap(); + PreComputeInstruction { + handler: executor.pre_compute_e2(executor_idx, pc, inst, buf)?, + pre_compute: buf, + } + } else { + return Err(ExecutionError::DisabledOperation { + pc, + opcode: inst.opcode, + }); + } + } else { + PreComputeInstruction { + handler: |_, vm_state| { + vm_state.exit_code = Err(InvalidInstruction(vm_state.pc)); + }, + pre_compute: buf, + } + }; + Ok(pre_inst) + }) + .collect::, _>>() +} + +fn get_system_opcode_handler( + inst: &Instruction, + buf: &mut [u8], +) -> Option> { + if inst.opcode == SystemOpcode::TERMINATE.global_opcode() { + let pre_compute: &mut TerminatePreCompute = buf.borrow_mut(); + pre_compute.exit_code = inst.c.as_canonical_u32(); + return Some(terminate_execute_e12_impl); + } + None +} + +fn check_exit_code( + vm_state: &VmSegmentState, +) -> Result<(), ExecutionError> { + if let Ok(Some(exit_code)) = vm_state.exit_code.as_ref() { + if *exit_code != ExitCode::Success as u32 { + return Err(ExecutionError::FailedWithExitCode(*exit_code)); + } + } + Ok(()) +} diff --git a/crates/vm/src/arch/mod.rs b/crates/vm/src/arch/mod.rs index 63ee5e6f8b..a3f7ab3e49 100644 --- a/crates/vm/src/arch/mod.rs +++ b/crates/vm/src/arch/mod.rs @@ -1,19 +1,25 @@ mod config; /// Instruction execution traits and types. /// Execution bus and interface. -mod execution; +pub mod execution; +/// Module for controlling VM execution flow, including segmentation and instruction execution +pub mod execution_control; +pub mod execution_mode; /// Traits and builders to compose collections of chips into a virtual machine. mod extensions; /// Traits and wrappers to facilitate VM chip integration mod integration_api; /// Runtime execution and segmentation pub mod segment; +/// Strategy for determining when to segment VM execution +pub mod segmentation_strategy; /// Top level [VirtualMachine] constructor and API. pub mod vm; pub use openvm_instructions as instructions; pub mod hasher; +pub mod interpreter; /// Testing framework #[cfg(any(test, feature = "test-utils"))] pub mod testing; @@ -23,4 +29,5 @@ pub use execution::*; pub use extensions::*; pub use integration_api::*; pub use segment::*; +pub use segmentation_strategy::*; pub use vm::*; diff --git a/crates/vm/src/arch/segment.rs b/crates/vm/src/arch/segment.rs index 634632ce2b..be3410ad6d 100644 --- a/crates/vm/src/arch/segment.rs +++ b/crates/vm/src/arch/segment.rs @@ -1,10 +1,9 @@ -use std::sync::Arc; +use std::fmt::Debug; use backtrace::Backtrace; use openvm_instructions::{ exe::FnBounds, instruction::{DebugInfo, Instruction}, - program::Program, }; use openvm_stark_backend::{ config::{Domain, StarkGenericConfig}, @@ -12,183 +11,167 @@ use openvm_stark_backend::{ p3_commit::PolynomialSpace, p3_field::PrimeField32, prover::types::{CommittedTraceData, ProofInput}, - utils::metrics_span, Chip, }; +use rand::rngs::StdRng; +use tracing::instrument; +#[cfg(feature = "bench-metrics")] +use super::InstructionExecutor; use super::{ - ExecutionError, GenerationError, Streams, SystemBase, SystemConfig, VmChipComplex, - VmComplexTraceHeights, VmConfig, + execution_control::ExecutionControl, ExecutionError, GenerationError, Streams, SystemConfig, + VmChipComplex, VmComplexTraceHeights, VmConfig, }; #[cfg(feature = "bench-metrics")] use crate::metrics::VmMetrics; use crate::{ - arch::{instructions::*, ExecutionState, InstructionExecutor}, - system::memory::MemoryImage, + arch::{execution_mode::E1ExecutionCtx, instructions::*}, + system::memory::online::GuestMemory, }; -/// Check segment every 100 instructions. -const SEGMENT_CHECK_INTERVAL: usize = 100; - -const DEFAULT_MAX_SEGMENT_LEN: usize = (1 << 22) - 100; -// a heuristic number for the maximum number of cells per chip in a segment -// a few reasons for this number: -// 1. `VmAirWrapper` is -// the chip with the most cells in a segment from the reth-benchmark. -// 2. `VmAirWrapper`: -// its trace width is 36 and its after challenge trace width is 80. -const DEFAULT_MAX_CELLS_PER_CHIP_IN_SEGMENT: usize = DEFAULT_MAX_SEGMENT_LEN * 120; - -pub trait SegmentationStrategy: - std::fmt::Debug + Send + Sync + std::panic::UnwindSafe + std::panic::RefUnwindSafe -{ - /// Whether the execution should segment based on the trace heights and cells. - /// - /// Air names are provided for debugging purposes. - fn should_segment( - &self, - air_names: &[String], - trace_heights: &[usize], - trace_cells: &[usize], - ) -> bool; - - /// A strategy that segments more aggressively than the current one. - /// - /// Called when `should_segment` results in a segment that is infeasible. Execution will be - /// re-run with the stricter segmentation strategy. - fn stricter_strategy(&self) -> Arc; -} - -/// Default segmentation strategy: segment if any chip's height or cells exceed the limits. -#[derive(Debug, Clone)] -pub struct DefaultSegmentationStrategy { - max_segment_len: usize, - max_cells_per_chip_in_segment: usize, +pub struct VmSegmentState { + pub instret: u64, + pub pc: u32, + pub memory: GuestMemory, + pub streams: Streams, + pub rng: StdRng, + pub exit_code: Result, ExecutionError>, + pub ctx: Ctx, } -impl Default for DefaultSegmentationStrategy { - fn default() -> Self { +impl VmSegmentState { + pub fn new( + instret: u64, + pc: u32, + memory: Option, + streams: Streams, + rng: StdRng, + ctx: Ctx, + ) -> Self { Self { - max_segment_len: DEFAULT_MAX_SEGMENT_LEN, - max_cells_per_chip_in_segment: DEFAULT_MAX_CELLS_PER_CHIP_IN_SEGMENT, + instret, + pc, + memory: if let Some(mem) = memory { + mem + } else { + GuestMemory::new(Default::default()) + }, + streams, + rng, + ctx, + exit_code: Ok(None), } } -} - -impl DefaultSegmentationStrategy { - pub fn new_with_max_segment_len(max_segment_len: usize) -> Self { - Self { - max_segment_len, - max_cells_per_chip_in_segment: max_segment_len * 120, - } + /// Runtime read operation for a block of memory + #[inline(always)] + pub fn vm_read( + &mut self, + addr_space: u32, + ptr: u32, + ) -> [T; BLOCK_SIZE] + where + Ctx: E1ExecutionCtx, + { + self.ctx + .on_memory_operation(addr_space, ptr, BLOCK_SIZE as u32); + self.host_read(addr_space, ptr) } - pub fn new(max_segment_len: usize, max_cells_per_chip_in_segment: usize) -> Self { - Self { - max_segment_len, - max_cells_per_chip_in_segment, - } + /// Runtime write operation for a block of memory + #[inline(always)] + pub fn vm_write( + &mut self, + addr_space: u32, + ptr: u32, + data: &[T; BLOCK_SIZE], + ) where + Ctx: E1ExecutionCtx, + { + self.ctx + .on_memory_operation(addr_space, ptr, BLOCK_SIZE as u32); + self.host_write(addr_space, ptr, data) } - pub fn max_segment_len(&self) -> usize { - self.max_segment_len + #[inline(always)] + pub fn vm_read_slice(&mut self, addr_space: u32, ptr: u32, len: usize) -> &[T] + where + Ctx: E1ExecutionCtx, + { + self.ctx.on_memory_operation(addr_space, ptr, len as u32); + self.host_read_slice(addr_space, ptr, len) } -} - -const SEGMENTATION_BACKOFF_FACTOR: usize = 4; -impl SegmentationStrategy for DefaultSegmentationStrategy { - fn should_segment( + #[inline(always)] + pub fn host_read( &self, - air_names: &[String], - trace_heights: &[usize], - trace_cells: &[usize], - ) -> bool { - for (i, &height) in trace_heights.iter().enumerate() { - if height > self.max_segment_len { - tracing::info!( - "Should segment because chip {} (name: {}) has height {}", - i, - air_names[i], - height - ); - return true; - } - } - for (i, &num_cells) in trace_cells.iter().enumerate() { - if num_cells > self.max_cells_per_chip_in_segment { - tracing::info!( - "Should segment because chip {} (name: {}) has {} cells", - i, - air_names[i], - num_cells - ); - return true; - } - } - false + addr_space: u32, + ptr: u32, + ) -> [T; BLOCK_SIZE] + where + Ctx: E1ExecutionCtx, + { + unsafe { self.memory.read(addr_space, ptr) } } - - fn stricter_strategy(&self) -> Arc { - Arc::new(Self { - max_segment_len: self.max_segment_len / SEGMENTATION_BACKOFF_FACTOR, - max_cells_per_chip_in_segment: self.max_cells_per_chip_in_segment - / SEGMENTATION_BACKOFF_FACTOR, - }) + #[inline(always)] + pub fn host_write( + &mut self, + addr_space: u32, + ptr: u32, + data: &[T; BLOCK_SIZE], + ) where + Ctx: E1ExecutionCtx, + { + unsafe { self.memory.write(addr_space, ptr, *data) } + } + #[inline(always)] + pub fn host_read_slice(&self, addr_space: u32, ptr: u32, len: usize) -> &[T] + where + Ctx: E1ExecutionCtx, + { + unsafe { self.memory.get_slice(addr_space, ptr, len) } } } -pub struct ExecutionSegment +pub struct VmSegmentExecutor where F: PrimeField32, VC: VmConfig, + Ctrl: ExecutionControl, { pub chip_complex: VmChipComplex, - /// Memory image after segment was executed. Not used in trace generation. - pub final_memory: Option>, + /// Execution control for determining segmentation and stopping conditions + pub ctrl: Ctrl, - pub since_last_segment_check: usize, pub trace_height_constraints: Vec, /// Air names for debug purposes only. + #[cfg(feature = "bench-metrics")] pub(crate) air_names: Vec, /// Metrics collected for this execution segment alone. #[cfg(feature = "bench-metrics")] pub metrics: VmMetrics, } -pub struct ExecutionSegmentState { - pub pc: u32, - pub is_terminated: bool, -} - -impl> ExecutionSegment { +impl VmSegmentExecutor +where + F: PrimeField32, + VC: VmConfig, + Ctrl: ExecutionControl, +{ /// Creates a new execution segment from a program and initial state, using parent VM config pub fn new( - config: &VC, - program: Program, - init_streams: Streams, - initial_memory: Option>, + chip_complex: VmChipComplex, trace_height_constraints: Vec, #[allow(unused_variables)] fn_bounds: FnBounds, + ctrl: Ctrl, ) -> Self { - let mut chip_complex = config.create_chip_complex().unwrap(); - chip_complex.set_streams(init_streams); - let program = if !config.system().profiling { - program.strip_debug_infos() - } else { - program - }; - chip_complex.set_program(program); - - if let Some(initial_memory) = initial_memory { - chip_complex.set_initial_memory(initial_memory); - } + #[cfg(feature = "bench-metrics")] let air_names = chip_complex.air_names(); Self { chip_complex, - final_memory: None, + ctrl, + #[cfg(feature = "bench-metrics")] air_names, trace_height_constraints, #[cfg(feature = "bench-metrics")] @@ -196,7 +179,6 @@ impl> ExecutionSegment { fn_bounds, ..Default::default() }, - since_last_segment_check: 0, } } @@ -211,134 +193,122 @@ impl> ExecutionSegment { .set_override_inventory_trace_heights(overridden_heights.inventory); } - /// Stopping is triggered by should_segment() - pub fn execute_from_pc( + /// Stopping is triggered by should_stop() or if VM is terminated + pub fn execute_from_state( &mut self, - mut pc: u32, - ) -> Result { - let mut timestamp = self.chip_complex.memory_controller().timestamp(); + state: &mut VmSegmentState, + ) -> Result<(), ExecutionError> { let mut prev_backtrace: Option = None; - self.chip_complex - .connector_chip_mut() - .begin(ExecutionState::new(pc, timestamp)); - - let mut did_terminate = false; + // Call the pre-execution hook + self.ctrl.on_start(state, &mut self.chip_complex); loop { - #[allow(unused_variables)] - let (opcode, dsl_instr) = { - let Self { - chip_complex, - #[cfg(feature = "bench-metrics")] - metrics, - .. - } = self; - let SystemBase { - program_chip, - memory_controller, - .. - } = &mut chip_complex.base; - - let (instruction, debug_info) = program_chip.get_instruction(pc)?; - tracing::trace!("pc: {pc:#x} | time: {timestamp} | {:?}", instruction); - - #[allow(unused_variables)] - let (dsl_instr, trace) = debug_info.as_ref().map_or( - (None, None), - |DebugInfo { - dsl_instruction, - trace, - }| (Some(dsl_instruction), trace.as_ref()), - ); - - let &Instruction { opcode, c, .. } = instruction; - if opcode == SystemOpcode::TERMINATE.global_opcode() { - did_terminate = true; - self.chip_complex.connector_chip_mut().end( - ExecutionState::new(pc, timestamp), - Some(c.as_canonical_u32()), - ); - break; + if let Ok(Some(exit_code)) = state.exit_code { + self.ctrl + .on_terminate(state, &mut self.chip_complex, exit_code); + break; + } + if self.should_suspend(state) { + self.ctrl.on_suspend(state, &mut self.chip_complex); + break; + } + + // Fetch, decode and execute single instruction + self.execute_instruction(state, &mut prev_backtrace)?; + state.instret += 1; + } + Ok(()) + } + + /// Executes a single instruction and updates VM state + fn execute_instruction( + &mut self, + state: &mut VmSegmentState, + prev_backtrace: &mut Option, + ) -> Result<(), ExecutionError> { + let pc = state.pc; + let timestamp = self.chip_complex.memory_controller().timestamp(); + + // Process an instruction and update VM state + let (instruction, debug_info) = self.chip_complex.base.program_chip.get_instruction(pc)?; + + tracing::trace!("pc: {pc:#x} | time: {timestamp} | {:?}", instruction); + + let &Instruction { opcode, c, .. } = instruction; + + // Handle termination instruction + if opcode == SystemOpcode::TERMINATE.global_opcode() { + state.exit_code = Ok(Some(c.as_canonical_u32())); + return Ok(()); + } + + // Extract debug info components + #[allow(unused_variables)] + let (dsl_instr, trace) = debug_info.as_ref().map_or( + (None, None), + |DebugInfo { + dsl_instruction, + trace, + }| (Some(dsl_instruction.clone()), trace.as_ref()), + ); + + // Handle phantom instructions + if opcode == SystemOpcode::PHANTOM.global_opcode() { + let discriminant = c.as_canonical_u32() as u16; + if let Some(phantom) = SysPhantom::from_repr(discriminant) { + tracing::trace!("pc: {pc:#x} | system phantom: {phantom:?}"); + + if phantom == SysPhantom::DebugPanic { + if let Some(mut backtrace) = prev_backtrace.take() { + backtrace.resolve(); + eprintln!("openvm program failure; backtrace:\n{:?}", backtrace); + } else { + eprintln!("openvm program failure; no backtrace"); + } + return Err(ExecutionError::Fail { pc }); } - // Some phantom instruction handling is more convenient to do here than in - // PhantomChip. - if opcode == SystemOpcode::PHANTOM.global_opcode() { - // Note: the discriminant is the lower 16 bits of the c operand. - let discriminant = c.as_canonical_u32() as u16; - let phantom = SysPhantom::from_repr(discriminant); - tracing::trace!("pc: {pc:#x} | system phantom: {phantom:?}"); + #[cfg(feature = "bench-metrics")] + { + let dsl_str = dsl_instr.clone().unwrap_or_else(|| "Default".to_string()); match phantom { - Some(SysPhantom::DebugPanic) => { - if let Some(mut backtrace) = prev_backtrace { - backtrace.resolve(); - eprintln!("openvm program failure; backtrace:\n{:?}", backtrace); - } else { - eprintln!("openvm program failure; no backtrace"); - } - return Err(ExecutionError::Fail { pc }); - } - Some(SysPhantom::CtStart) => - { - #[cfg(feature = "bench-metrics")] - metrics - .cycle_tracker - .start(dsl_instr.cloned().unwrap_or("Default".to_string())) - } - Some(SysPhantom::CtEnd) => - { - #[cfg(feature = "bench-metrics")] - metrics - .cycle_tracker - .end(dsl_instr.cloned().unwrap_or("Default".to_string())) - } + SysPhantom::CtStart => self.metrics.cycle_tracker.start(dsl_str), + SysPhantom::CtEnd => self.metrics.cycle_tracker.end(dsl_str), _ => {} } } - prev_backtrace = trace.cloned(); - - if let Some(executor) = chip_complex.inventory.get_mut_executor(&opcode) { - let next_state = InstructionExecutor::execute( - executor, - memory_controller, - instruction, - ExecutionState::new(pc, timestamp), - )?; - assert!(next_state.timestamp > timestamp); - pc = next_state.pc; - timestamp = next_state.timestamp; - } else { - return Err(ExecutionError::DisabledOperation { pc, opcode }); - }; - (opcode, dsl_instr.cloned()) - }; + } + } - #[cfg(feature = "bench-metrics")] + *prev_backtrace = trace.cloned(); + + // Execute the instruction using the control implementation + // TODO(AG): maybe avoid cloning the instruction? + self.ctrl + .execute_instruction(state, &instruction.clone(), &mut self.chip_complex)?; + + // Update metrics if enabled + #[cfg(feature = "bench-metrics")] + { self.update_instruction_metrics(pc, opcode, dsl_instr); + } - if self.should_segment() { - self.chip_complex - .connector_chip_mut() - .end(ExecutionState::new(pc, timestamp), None); - break; - } + Ok(()) + } + + /// Returns bool of whether to switch to next segment or not. + fn should_suspend(&mut self, state: &mut VmSegmentState) -> bool { + if !self.system_config().continuation_enabled { + return false; } - self.final_memory = Some( - self.chip_complex - .base - .memory_controller - .memory_image() - .clone(), - ); - Ok(ExecutionSegmentState { - pc, - is_terminated: did_terminate, - }) + // Check with the execution control policy + self.ctrl.should_suspend(state, &self.chip_complex) } /// Generate ProofInput to prove the segment. Should be called after ::execute + #[instrument(name = "trace_gen", skip_all)] pub fn generate_proof_input( #[allow(unused_mut)] mut self, cached_program: Option>, @@ -348,40 +318,59 @@ impl> ExecutionSegment { VC::Executor: Chip, VC::Periphery: Chip, { - metrics_span("trace_gen_time_ms", || { - self.chip_complex.generate_proof_input( - cached_program, - &self.trace_height_constraints, - #[cfg(feature = "bench-metrics")] - &mut self.metrics, - ) - }) + self.chip_complex.generate_proof_input( + cached_program, + &self.trace_height_constraints, + #[cfg(feature = "bench-metrics")] + &mut self.metrics, + ) } - /// Returns bool of whether to switch to next segment or not. This is called every clock cycle - /// inside of Core trace generation. - fn should_segment(&mut self) -> bool { - if !self.system_config().continuation_enabled { - return false; - } - // Avoid checking segment too often. - if self.since_last_segment_check != SEGMENT_CHECK_INTERVAL { - self.since_last_segment_check += 1; - return false; + #[cfg(feature = "bench-metrics")] + #[allow(unused_variables)] + pub fn update_instruction_metrics( + &mut self, + pc: u32, + opcode: VmOpcode, + dsl_instr: Option, + ) { + self.metrics.cycle_count += 1; + + if self.system_config().profiling { + let executor = self.chip_complex.inventory.get_executor(opcode).unwrap(); + let opcode_name = executor.get_opcode_name(opcode.as_usize()); + self.metrics.update_trace_cells( + &self.air_names, + self.chip_complex.current_trace_cells(), + opcode_name, + dsl_instr, + ); + + #[cfg(feature = "function-span")] + self.metrics.update_current_fn(pc); } - self.since_last_segment_check = 0; - let segmentation_strategy = &self.system_config().segmentation_strategy; - segmentation_strategy.should_segment( - &self.air_names, - &self - .chip_complex - .dynamic_trace_heights() - .collect::>(), - &self.chip_complex.current_trace_cells(), - ) } +} - pub fn current_trace_cells(&self) -> Vec { - self.chip_complex.current_trace_cells() - } +/// Macro for executing with a compile-time span name for better tracing performance +#[macro_export] +macro_rules! execute_spanned { + ($name:literal, $executor:expr, $state:expr) => {{ + #[cfg(feature = "bench-metrics")] + let start = std::time::Instant::now(); + #[cfg(feature = "bench-metrics")] + let start_instret = $state.instret; + + let result = tracing::info_span!($name).in_scope(|| $executor.execute_from_state($state)); + + #[cfg(feature = "bench-metrics")] + { + let elapsed = start.elapsed(); + let insns = $state.instret - start_instret; + metrics::counter!("insns").absolute(insns); + metrics::gauge!(concat!($name, "_insn_mi/s")) + .set(insns as f64 / elapsed.as_micros() as f64); + } + result + }}; } diff --git a/crates/vm/src/arch/segmentation_strategy.rs b/crates/vm/src/arch/segmentation_strategy.rs new file mode 100644 index 0000000000..0336546626 --- /dev/null +++ b/crates/vm/src/arch/segmentation_strategy.rs @@ -0,0 +1,113 @@ +use std::sync::Arc; + +pub const DEFAULT_MAX_SEGMENT_LEN: usize = (1 << 22) - 100; +pub const DEFAULT_MAX_CELLS_IN_SEGMENT: usize = 2_000_000_000; // 2B + +pub trait SegmentationStrategy: + std::fmt::Debug + Send + Sync + std::panic::UnwindSafe + std::panic::RefUnwindSafe +{ + /// Whether the execution should segment based on the trace heights and cells. + /// + /// Air names are provided for debugging purposes. + fn should_segment( + &self, + air_names: &[String], + trace_heights: &[usize], + trace_cells: &[usize], + ) -> bool; + + /// A strategy that segments more aggressively than the current one. + /// + /// Called when `should_segment` results in a segment that is infeasible. Execution will be + /// re-run with the stricter segmentation strategy. + fn stricter_strategy(&self) -> Arc; + + /// Maximum height of any chip in a segment. + fn max_trace_height(&self) -> usize; + + /// Maximum number of cells in a segment. + fn max_cells(&self) -> usize; +} + +/// Default segmentation strategy: segment if any chip's height or cells exceed the limits. +#[derive(Debug, Clone)] +pub struct DefaultSegmentationStrategy { + max_segment_len: usize, + max_cells_in_segment: usize, +} + +impl Default for DefaultSegmentationStrategy { + fn default() -> Self { + Self { + max_segment_len: DEFAULT_MAX_SEGMENT_LEN, + max_cells_in_segment: DEFAULT_MAX_CELLS_IN_SEGMENT, + } + } +} + +impl DefaultSegmentationStrategy { + pub fn new_with_max_segment_len(max_segment_len: usize) -> Self { + Self { + max_segment_len, + max_cells_in_segment: DEFAULT_MAX_CELLS_IN_SEGMENT, + } + } + + pub fn new(max_segment_len: usize, max_cells_in_segment: usize) -> Self { + Self { + max_segment_len, + max_cells_in_segment, + } + } + + pub fn max_segment_len(&self) -> usize { + self.max_segment_len + } +} + +const SEGMENTATION_BACKOFF_FACTOR: usize = 4; + +impl SegmentationStrategy for DefaultSegmentationStrategy { + fn max_trace_height(&self) -> usize { + self.max_segment_len + } + + fn max_cells(&self) -> usize { + self.max_cells_in_segment + } + + fn should_segment( + &self, + air_names: &[String], + trace_heights: &[usize], + trace_cells: &[usize], + ) -> bool { + for (i, &height) in trace_heights.iter().enumerate() { + if height > self.max_segment_len { + tracing::info!( + "Should segment because chip {} (name: {}) has height {}", + i, + air_names[i], + height + ); + return true; + } + } + let total_cells: usize = trace_cells.iter().sum(); + if total_cells > self.max_cells_in_segment { + tracing::info!( + "Should segment because total cells across all chips is {}", + total_cells + ); + return true; + } + false + } + + fn stricter_strategy(&self) -> Arc { + Arc::new(Self { + max_segment_len: self.max_segment_len / SEGMENTATION_BACKOFF_FACTOR, + max_cells_in_segment: self.max_cells_in_segment / SEGMENTATION_BACKOFF_FACTOR, + }) + } +} diff --git a/crates/vm/src/arch/testing/memory/air.rs b/crates/vm/src/arch/testing/memory/air.rs index 8a394c0cce..90c1b4ce49 100644 --- a/crates/vm/src/arch/testing/memory/air.rs +++ b/crates/vm/src/arch/testing/memory/air.rs @@ -1,46 +1,155 @@ -use std::{borrow::Borrow, mem::size_of}; +use std::{mem::size_of, sync::Arc}; -use openvm_circuit_primitives_derive::AlignedBorrow; use openvm_stark_backend::{ + config::{StarkGenericConfig, Val}, interaction::InteractionBuilder, p3_air::{Air, BaseAir}, - p3_matrix::Matrix, + p3_field::{FieldAlgebra, PrimeField32}, + p3_matrix::{dense::RowMajorMatrix, Matrix}, + prover::types::AirProofInput, rap::{BaseAirWithPublicValues, PartitionedBaseAir}, + AirRef, Chip, ChipUsageGetter, }; use crate::system::memory::{offline_checker::MemoryBus, MemoryAddress}; -#[derive(Clone, Copy, Debug, AlignedBorrow, derive_new::new)] #[repr(C)] -pub struct DummyMemoryInteractionCols { - pub address: MemoryAddress, - pub data: [T; BLOCK_SIZE], - pub timestamp: T, +#[derive(Clone, Copy)] +pub struct DummyMemoryInteractionColsRef<'a, T> { + pub address: MemoryAddress<&'a T, &'a T>, + pub data: &'a [T], + pub timestamp: &'a T, /// The send frequency. Send corresponds to write. To read, set to negative. - pub count: T, + pub count: &'a T, +} + +#[repr(C)] +pub struct DummyMemoryInteractionColsMut<'a, T> { + pub address: MemoryAddress<&'a mut T, &'a mut T>, + pub data: &'a mut [T], + pub timestamp: &'a mut T, + /// The send frequency. Send corresponds to write. To read, set to negative. + pub count: &'a mut T, +} + +impl<'a, T> DummyMemoryInteractionColsRef<'a, T> { + pub fn from_slice(slice: &'a [T]) -> Self { + let (address, slice) = slice.split_at(size_of::>()); + let (count, slice) = slice.split_last().unwrap(); + let (timestamp, data) = slice.split_last().unwrap(); + Self { + address: MemoryAddress::new(&address[0], &address[1]), + data, + timestamp, + count, + } + } +} + +impl<'a, T> DummyMemoryInteractionColsMut<'a, T> { + pub fn from_mut_slice(slice: &'a mut [T]) -> Self { + let (addr_space, slice) = slice.split_first_mut().unwrap(); + let (ptr, slice) = slice.split_first_mut().unwrap(); + let (count, slice) = slice.split_last_mut().unwrap(); + let (timestamp, data) = slice.split_last_mut().unwrap(); + Self { + address: MemoryAddress::new(addr_space, ptr), + data, + timestamp, + count, + } + } } #[derive(Clone, Copy, Debug, derive_new::new)] -pub struct MemoryDummyAir { +pub struct MemoryDummyAir { pub bus: MemoryBus, + pub block_size: usize, } -impl BaseAirWithPublicValues for MemoryDummyAir {} -impl PartitionedBaseAir for MemoryDummyAir {} -impl BaseAir for MemoryDummyAir { +impl BaseAirWithPublicValues for MemoryDummyAir {} +impl PartitionedBaseAir for MemoryDummyAir {} +impl BaseAir for MemoryDummyAir { fn width(&self) -> usize { - size_of::>() + self.block_size + 4 } } -impl Air for MemoryDummyAir { +impl Air for MemoryDummyAir { fn eval(&self, builder: &mut AB) { let main = builder.main(); let local = main.row_slice(0); - let local: &DummyMemoryInteractionCols = (*local).borrow(); + let local = DummyMemoryInteractionColsRef::from_slice(&local); self.bus - .send(local.address, local.data.to_vec(), local.timestamp) - .eval(builder, local.count); + .send( + MemoryAddress::new(*local.address.address_space, *local.address.pointer), + local.data.to_vec(), + *local.timestamp, + ) + .eval(builder, *local.count); + } +} + +#[derive(Clone)] +pub struct MemoryDummyChip { + pub air: MemoryDummyAir, + pub trace: Vec, +} + +impl MemoryDummyChip { + pub fn new(air: MemoryDummyAir) -> Self { + Self { + air, + trace: Vec::new(), + } + } +} + +impl MemoryDummyChip { + pub fn send(&mut self, addr_space: u32, ptr: u32, data: &[F], timestamp: u32) { + self.push(addr_space, ptr, data, timestamp, F::ONE); + } + + pub fn receive(&mut self, addr_space: u32, ptr: u32, data: &[F], timestamp: u32) { + self.push(addr_space, ptr, data, timestamp, F::NEG_ONE); + } + + pub fn push(&mut self, addr_space: u32, ptr: u32, data: &[F], timestamp: u32, count: F) { + assert_eq!(data.len(), self.air.block_size); + self.trace.push(F::from_canonical_u32(addr_space)); + self.trace.push(F::from_canonical_u32(ptr)); + self.trace.extend_from_slice(data); + self.trace.push(F::from_canonical_u32(timestamp)); + self.trace.push(count); + } +} + +impl Chip for MemoryDummyChip> +where + Val: PrimeField32, +{ + fn air(&self) -> AirRef { + Arc::new(self.air) + } + + fn generate_air_proof_input(mut self) -> AirProofInput { + let height = self.current_trace_height().next_power_of_two(); + let width = self.trace_width(); + self.trace.resize(height * width, Val::::ZERO); + + AirProofInput::simple_no_pis(RowMajorMatrix::new(self.trace, width)) + } +} + +impl ChipUsageGetter for MemoryDummyChip { + fn air_name(&self) -> String { + format!("MemoryDummyAir<{}>", self.air.block_size) + } + fn current_trace_height(&self) -> usize { + self.trace.len() / self.trace_width() + } + fn trace_width(&self) -> usize { + BaseAir::::width(&self.air) } } diff --git a/crates/vm/src/arch/testing/memory/mod.rs b/crates/vm/src/arch/testing/memory/mod.rs index ae1136bc7f..6ffd31eae0 100644 --- a/crates/vm/src/arch/testing/memory/mod.rs +++ b/crates/vm/src/arch/testing/memory/mod.rs @@ -1,140 +1,105 @@ -use std::{array::from_fn, borrow::BorrowMut as _, cell::RefCell, mem::size_of, rc::Rc, sync::Arc}; +use std::collections::HashMap; -use air::{DummyMemoryInteractionCols, MemoryDummyAir}; +use air::{MemoryDummyAir, MemoryDummyChip}; use openvm_circuit::system::memory::MemoryController; -use openvm_stark_backend::{ - config::{StarkGenericConfig, Val}, - p3_field::{FieldAlgebra, PrimeField32}, - p3_matrix::dense::RowMajorMatrix, - prover::types::AirProofInput, - AirRef, Chip, ChipUsageGetter, -}; -use rand::{seq::SliceRandom, Rng}; - -use crate::system::memory::{offline_checker::MemoryBus, MemoryAddress, RecordId}; +use openvm_stark_backend::p3_field::PrimeField32; +use rand::Rng; pub mod air; -const WORD_SIZE: usize = 1; - /// A dummy testing chip that will add unconstrained messages into the [MemoryBus]. /// Stores a log of raw messages to send/receive to the [MemoryBus]. /// /// It will create a [air::MemoryDummyAir] to add messages to MemoryBus. pub struct MemoryTester { - pub bus: MemoryBus, - pub controller: Rc>>, - /// Log of record ids - pub records: Vec, + /// Map from `block_size` to [MemoryDummyChip] of that block size + pub chip_for_block: HashMap>, + // TODO: make this just TracedMemory? + pub controller: MemoryController, } impl MemoryTester { - pub fn new(controller: Rc>>) -> Self { - let bus = controller.borrow().memory_bus; + pub fn new(controller: MemoryController) -> Self { + let bus = controller.memory_bus; + let mut chip_for_block = HashMap::new(); + for log_block_size in 0..6 { + let block_size = 1 << log_block_size; + let chip = MemoryDummyChip::new(MemoryDummyAir::new(bus, block_size)); + chip_for_block.insert(block_size, chip); + } Self { - bus, + chip_for_block, controller, - records: Vec::new(), - } - } - - /// Returns the cell value at the current timestamp according to `MemoryController`. - pub fn read_cell(&mut self, address_space: usize, pointer: usize) -> F { - let [addr_space, pointer] = [address_space, pointer].map(F::from_canonical_usize); - // core::BorrowMut confuses compiler - let (record_id, value) = - RefCell::borrow_mut(&self.controller).read_cell(addr_space, pointer); - self.records.push(record_id); - value - } - - pub fn write_cell(&mut self, address_space: usize, pointer: usize, value: F) { - let [addr_space, pointer] = [address_space, pointer].map(F::from_canonical_usize); - let (record_id, _) = - RefCell::borrow_mut(&self.controller).write_cell(addr_space, pointer, value); - self.records.push(record_id); - } - - pub fn read(&mut self, address_space: usize, pointer: usize) -> [F; N] { - from_fn(|i| self.read_cell(address_space, pointer + i)) - } - - pub fn write( - &mut self, - address_space: usize, - mut pointer: usize, - cells: [F; N], - ) { - for cell in cells { - self.write_cell(address_space, pointer, cell); - pointer += 1; } } -} -impl Chip for MemoryTester> -where - Val: PrimeField32, -{ - fn air(&self) -> AirRef { - Arc::new(MemoryDummyAir::::new(self.bus)) - } - - fn generate_air_proof_input(self) -> AirProofInput { - let offline_memory = self.controller.borrow().offline_memory(); - let offline_memory = offline_memory.lock().unwrap(); - - let height = self.records.len().next_power_of_two(); - let width = self.trace_width(); - let mut values = Val::::zero_vec(2 * height * width); - // This zip only goes through records. The padding rows between records.len()..height - // are filled with zeros - in particular count = 0 so nothing is added to bus. - for (row, id) in values.chunks_mut(2 * width).zip(self.records) { - let (first, second) = row.split_at_mut(width); - let row: &mut DummyMemoryInteractionCols, WORD_SIZE> = first.borrow_mut(); - let record = offline_memory.record_by_id(id); - row.address = MemoryAddress { - address_space: record.address_space, - pointer: record.pointer, - }; - row.data - .copy_from_slice(record.prev_data_slice().unwrap_or(record.data_slice())); - row.timestamp = Val::::from_canonical_u32(record.prev_timestamp); - row.count = -Val::::ONE; - - let row: &mut DummyMemoryInteractionCols, WORD_SIZE> = second.borrow_mut(); - row.address = MemoryAddress { - address_space: record.address_space, - pointer: record.pointer, + // TODO: change interface by implementing GuestMemory trait after everything works + pub fn read(&mut self, addr_space: usize, ptr: usize) -> [F; N] { + let controller = &mut self.controller; + let t = controller.memory.timestamp(); + // TODO: hack + let (t_prev, data) = if addr_space <= 3 { + let (t_prev, data) = unsafe { + controller + .memory + .read::(addr_space as u32, ptr as u32) }; - row.data.copy_from_slice(record.data_slice()); - row.timestamp = Val::::from_canonical_u32(record.timestamp); - row.count = Val::::ONE; - } - AirProofInput::simple_no_pis(RowMajorMatrix::new(values, width)) + (t_prev, data.map(F::from_canonical_u8)) + } else { + unsafe { + controller + .memory + .read::(addr_space as u32, ptr as u32) + } + }; + self.chip_for_block.get_mut(&N).unwrap().receive( + addr_space as u32, + ptr as u32, + &data, + t_prev, + ); + self.chip_for_block + .get_mut(&N) + .unwrap() + .send(addr_space as u32, ptr as u32, &data, t); + + data } -} -impl ChipUsageGetter for MemoryTester { - fn air_name(&self) -> String { - "MemoryDummyAir".to_string() - } - fn current_trace_height(&self) -> usize { - self.records.len() - } - - fn trace_width(&self) -> usize { - size_of::>() + // TODO: see read + pub fn write(&mut self, addr_space: usize, ptr: usize, data: [F; N]) { + let controller = &mut self.controller; + let t = controller.memory.timestamp(); + // TODO: hack + let (t_prev, data_prev) = if addr_space <= 3 { + let (t_prev, data_prev) = unsafe { + controller.memory.write::( + addr_space as u32, + ptr as u32, + data.map(|x| x.as_canonical_u32() as u8), + ) + }; + (t_prev, data_prev.map(F::from_canonical_u8)) + } else { + unsafe { + controller + .memory + .write::(addr_space as u32, ptr as u32, data) + } + }; + self.chip_for_block.get_mut(&N).unwrap().receive( + addr_space as u32, + ptr as u32, + &data_prev, + t_prev, + ); + self.chip_for_block + .get_mut(&N) + .unwrap() + .send(addr_space as u32, ptr as u32, &data, t); } } -pub fn gen_address_space(rng: &mut R) -> usize -where - R: Rng + ?Sized, -{ - *[1, 2].choose(rng).unwrap() -} - pub fn gen_pointer(rng: &mut R, len: usize) -> usize where R: Rng + ?Sized, diff --git a/crates/vm/src/arch/testing/mod.rs b/crates/vm/src/arch/testing/mod.rs index 44b19177be..c8f44c224b 100644 --- a/crates/vm/src/arch/testing/mod.rs +++ b/crates/vm/src/arch/testing/mod.rs @@ -1,20 +1,17 @@ -use std::{ - cell::RefCell, - iter::zip, - rc::Rc, - sync::{Arc, Mutex}, -}; +use std::{borrow::Borrow, iter::zip}; use openvm_circuit_primitives::var_range::{ SharedVariableRangeCheckerChip, VariableRangeCheckerBus, }; -use openvm_instructions::instruction::Instruction; +use openvm_instructions::{instruction::Instruction, NATIVE_AS}; +use openvm_poseidon2_air::Poseidon2Config; use openvm_stark_backend::{ config::{StarkGenericConfig, Val}, engine::VerificationData, - interaction::BusIndex, + interaction::{BusIndex, PermutationCheckBus}, p3_field::PrimeField32, p3_matrix::dense::{DenseMatrix, RowMajorMatrix}, + p3_util::log2_strict_usize, prover::types::AirProofInput, verifier::VerificationError, AirRef, Chip, @@ -32,13 +29,14 @@ use program::ProgramTester; use rand::{rngs::StdRng, RngCore, SeedableRng}; use tracing::Level; -use super::{ExecutionBus, InstructionExecutor, SystemPort}; +use super::{ExecutionBridge, ExecutionBus, InstructionExecutor, SystemPort}; use crate::{ - arch::{ExecutionState, MemoryConfig}, + arch::{ExecutionState, MemoryConfig, Streams}, system::{ memory::{ + interface::MemoryInterface, offline_checker::{MemoryBridge, MemoryBus}, - MemoryController, OfflineMemory, + MemoryController, SharedMemoryHelper, }, poseidon2::Poseidon2PeripheryChip, program::ProgramBus, @@ -48,11 +46,9 @@ use crate::{ pub mod execution; pub mod memory; pub mod program; -pub mod test_adapter; pub use execution::ExecutionTester; pub use memory::MemoryTester; -pub use test_adapter::TestAdapterChip; pub const EXECUTION_BUS: BusIndex = 0; pub const MEMORY_BUS: BusIndex = 1; @@ -63,30 +59,36 @@ pub const BYTE_XOR_BUS: BusIndex = 10; pub const RANGE_TUPLE_CHECKER_BUS: BusIndex = 11; pub const MEMORY_MERKLE_BUS: BusIndex = 12; -const RANGE_CHECKER_BUS: BusIndex = 4; +pub const RANGE_CHECKER_BUS: BusIndex = 4; pub struct VmChipTestBuilder { pub memory: MemoryTester, + pub streams: Streams, + pub rng: StdRng, pub execution: ExecutionTester, pub program: ProgramTester, - rng: StdRng, + internal_rng: StdRng, default_register: usize, default_pointer: usize, } impl VmChipTestBuilder { pub fn new( - memory_controller: Rc>>, + memory_controller: MemoryController, + streams: Streams, + rng: StdRng, execution_bus: ExecutionBus, program_bus: ProgramBus, - rng: StdRng, + internal_rng: StdRng, ) -> Self { setup_tracing_with_log_level(Level::WARN); Self { memory: MemoryTester::new(memory_controller), + streams, + rng, execution: ExecutionTester::new(execution_bus), program: ProgramTester::new(program_bus), - rng, + internal_rng, default_register: 0, default_pointer: 0, } @@ -110,13 +112,15 @@ impl VmChipTestBuilder { ) { let initial_state = ExecutionState { pc: initial_pc, - timestamp: self.memory.controller.borrow().timestamp(), + timestamp: self.memory.controller.timestamp(), }; tracing::debug!(?initial_state.timestamp); let final_state = executor .execute( - &mut *self.memory.controller.borrow_mut(), + &mut self.memory.controller, + &mut self.streams, + &mut self.rng, instruction, initial_state, ) @@ -127,15 +131,7 @@ impl VmChipTestBuilder { } fn next_elem_size_u32(&mut self) -> u32 { - self.rng.next_u32() % (1 << (F::bits() - 2)) - } - - pub fn read_cell(&mut self, address_space: usize, pointer: usize) -> F { - self.memory.read_cell(address_space, pointer) - } - - pub fn write_cell(&mut self, address_space: usize, pointer: usize, value: F) { - self.memory.write_cell(address_space, pointer, value); + self.internal_rng.next_u32() % (1 << (F::bits() - 2)) } pub fn read(&mut self, address_space: usize, pointer: usize) -> [F; N] { @@ -162,9 +158,22 @@ impl VmChipTestBuilder { pointer: usize, writes: Vec<[F; NUM_LIMBS]>, ) { - self.write(1usize, register, [F::from_canonical_usize(pointer)]); - for (i, &write) in writes.iter().enumerate() { - self.write(2usize, pointer + i * NUM_LIMBS, write); + self.write( + 1usize, + register, + pointer.to_le_bytes().map(F::from_canonical_u8), + ); + if NUM_LIMBS.is_power_of_two() { + for (i, &write) in writes.iter().enumerate() { + self.write(2usize, pointer + i * NUM_LIMBS, write); + } + } else { + for (i, &write) in writes.iter().enumerate() { + let ptr = pointer + i * NUM_LIMBS; + for j in (0..NUM_LIMBS).step_by(4) { + self.write::<4>(2usize, ptr + j, write[j..j + 4].try_into().unwrap()); + } + } } } @@ -176,6 +185,10 @@ impl VmChipTestBuilder { } } + pub fn execution_bridge(&self) -> ExecutionBridge { + ExecutionBridge::new(self.execution.bus, self.program.bus) + } + pub fn execution_bus(&self) -> ExecutionBus { self.execution.bus } @@ -185,27 +198,27 @@ impl VmChipTestBuilder { } pub fn memory_bus(&self) -> MemoryBus { - self.memory.bus + self.memory.controller.memory_bus } - pub fn memory_controller(&self) -> Rc>> { - self.memory.controller.clone() + pub fn memory_controller(&self) -> &MemoryController { + &self.memory.controller } pub fn range_checker(&self) -> SharedVariableRangeCheckerChip { - self.memory.controller.borrow().range_checker.clone() + self.memory.controller.range_checker.clone() } pub fn memory_bridge(&self) -> MemoryBridge { - self.memory.controller.borrow().memory_bridge() + self.memory.controller.memory_bridge() } - pub fn address_bits(&self) -> usize { - self.memory.controller.borrow().mem_config.pointer_max_bits + pub fn memory_helper(&self) -> SharedMemoryHelper { + self.memory.controller.helper() } - pub fn offline_memory_mutex_arc(&self) -> Arc>> { - self.memory_controller().borrow().offline_memory().clone() + pub fn address_bits(&self) -> usize { + self.memory.controller.mem_config.pointer_max_bits } pub fn get_default_register(&mut self, increment: usize) -> usize { @@ -247,10 +260,6 @@ type TestSC = BabyBearBlake3Config; impl VmChipTestBuilder { pub fn build(self) -> VmChipTester { - self.memory - .controller - .borrow_mut() - .finalize(None::<&mut Poseidon2PeripheryChip>); let tester = VmChipTester { memory: Some(self.memory), ..Default::default() @@ -259,10 +268,6 @@ impl VmChipTestBuilder { tester.load(self.program) } pub fn build_babybear_poseidon2(self) -> VmChipTester { - self.memory - .controller - .borrow_mut() - .finalize(None::<&mut Poseidon2PeripheryChip>); let tester = VmChipTester { memory: Some(self.memory), ..Default::default() @@ -272,29 +277,84 @@ impl VmChipTestBuilder { } } -impl Default for VmChipTestBuilder { - fn default() -> Self { - let mem_config = MemoryConfig::default(); +impl VmChipTestBuilder { + pub fn default_persistent() -> Self { + let mut mem_config = MemoryConfig::default(); + mem_config.addr_space_sizes[NATIVE_AS as usize] = 0; + Self::persistent(mem_config) + } + + pub fn default_native() -> Self { + Self::volatile(MemoryConfig::aggregation()) + } + + pub fn persistent(mem_config: MemoryConfig) -> Self { + setup_tracing_with_log_level(Level::INFO); let range_checker = SharedVariableRangeCheckerChip::new(VariableRangeCheckerBus::new( RANGE_CHECKER_BUS, mem_config.decomp, )); - let memory_controller = MemoryController::with_volatile_memory( + let max_access_adapter_n = log2_strict_usize(mem_config.max_access_adapter_n); + let mut memory_controller = MemoryController::with_persistent_memory( MemoryBus::new(MEMORY_BUS), mem_config, range_checker, + PermutationCheckBus::new(MEMORY_MERKLE_BUS), + PermutationCheckBus::new(POSEIDON2_DIRECT_BUS), ); + memory_controller + .memory + .access_adapter_inventory + .set_arena_from_trace_heights(&vec![1 << 16; max_access_adapter_n]); Self { - memory: MemoryTester::new(Rc::new(RefCell::new(memory_controller))), + memory: MemoryTester::new(memory_controller), + streams: Default::default(), + rng: StdRng::seed_from_u64(0), execution: ExecutionTester::new(ExecutionBus::new(EXECUTION_BUS)), program: ProgramTester::new(ProgramBus::new(READ_INSTRUCTION_BUS)), + internal_rng: StdRng::seed_from_u64(0), + default_register: 0, + default_pointer: 0, + } + } + + pub fn volatile(mem_config: MemoryConfig) -> Self { + setup_tracing_with_log_level(Level::INFO); + let range_checker = SharedVariableRangeCheckerChip::new(VariableRangeCheckerBus::new( + RANGE_CHECKER_BUS, + mem_config.decomp, + )); + let max_access_adapter_n = log2_strict_usize(mem_config.max_access_adapter_n); + let mut memory_controller = MemoryController::with_volatile_memory( + MemoryBus::new(MEMORY_BUS), + mem_config, + range_checker, + ); + memory_controller + .memory + .access_adapter_inventory + .set_arena_from_trace_heights(&vec![1 << 16; max_access_adapter_n]); + Self { + memory: MemoryTester::new(memory_controller), + streams: Default::default(), rng: StdRng::seed_from_u64(0), + execution: ExecutionTester::new(ExecutionBus::new(EXECUTION_BUS)), + program: ProgramTester::new(ProgramBus::new(READ_INSTRUCTION_BUS)), + internal_rng: StdRng::seed_from_u64(0), default_register: 0, default_pointer: 0, } } } +impl Default for VmChipTestBuilder { + fn default() -> Self { + let mut mem_config = MemoryConfig::default(); + mem_config.addr_space_sizes[NATIVE_AS as usize] = 0; + Self::volatile(mem_config) + } +} + pub struct VmChipTester { pub memory: Option>>, pub air_proof_inputs: Vec<(AirRef, AirProofInput)>, @@ -326,19 +386,47 @@ where pub fn finalize(mut self) -> Self { if let Some(memory_tester) = self.memory.take() { - let memory_controller = memory_tester.controller.clone(); - let range_checker = memory_controller.borrow().range_checker.clone(); - self = self.load(memory_tester); // dummy memory interactions - { - let airs = memory_controller.borrow().airs(); - let air_proof_inputs = Rc::try_unwrap(memory_controller) - .unwrap_or_else(|_| panic!("Memory controller was not dropped")) - .into_inner() - .generate_air_proof_inputs(); - self.air_proof_inputs.extend( - zip(airs, air_proof_inputs).filter(|(_, input)| input.main_trace_height() > 0), - ); - } + // Balance memory boundaries + let mut memory_controller = memory_tester.controller; + let range_checker = memory_controller.range_checker.clone(); + match &memory_controller.interface_chip { + MemoryInterface::Volatile { .. } => { + memory_controller.finalize(None::<&mut Poseidon2PeripheryChip>>); + // dummy memory interactions: + for mem_chip in memory_tester.chip_for_block.into_values() { + self = self.load(mem_chip); + } + { + let airs = memory_controller.borrow().airs(); + let air_proof_inputs = memory_controller.generate_air_proof_inputs(); + self.air_proof_inputs.extend( + zip(airs, air_proof_inputs) + .filter(|(_, input)| input.main_trace_height() > 0), + ); + } + } + MemoryInterface::Persistent { .. } => { + let mut poseidon_chip = Poseidon2PeripheryChip::new( + Poseidon2Config::default(), + POSEIDON2_DIRECT_BUS, + 3, + ); + memory_controller.finalize(Some(&mut poseidon_chip)); + // dummy memory interactions: + for mem_chip in memory_tester.chip_for_block.into_values() { + self = self.load(mem_chip); + } + { + let airs = memory_controller.borrow().airs(); + let air_proof_inputs = memory_controller.generate_air_proof_inputs(); + self.air_proof_inputs.extend( + zip(airs, air_proof_inputs) + .filter(|(_, input)| input.main_trace_height() > 0), + ); + } + self = self.load(poseidon_chip); + } + }; self = self.load(range_checker); // this must be last because other trace generation // mutates its state } diff --git a/crates/vm/src/arch/testing/program/mod.rs b/crates/vm/src/arch/testing/program/mod.rs index 04c4feee60..88c31781fe 100644 --- a/crates/vm/src/arch/testing/program/mod.rs +++ b/crates/vm/src/arch/testing/program/mod.rs @@ -15,7 +15,7 @@ use crate::{ system::program::{ProgramBus, ProgramExecutionCols}, }; -mod air; +pub mod air; #[derive(Debug)] pub struct ProgramTester { diff --git a/crates/vm/src/arch/testing/test_adapter.rs b/crates/vm/src/arch/testing/test_adapter.rs deleted file mode 100644 index bca9eed724..0000000000 --- a/crates/vm/src/arch/testing/test_adapter.rs +++ /dev/null @@ -1,175 +0,0 @@ -use std::{ - borrow::{Borrow, BorrowMut}, - collections::VecDeque, - fmt::Debug, -}; - -use openvm_circuit_primitives_derive::AlignedBorrow; -use openvm_instructions::instruction::Instruction; -use openvm_stark_backend::{ - interaction::InteractionBuilder, - p3_air::BaseAir, - p3_field::{Field, FieldAlgebra, PrimeField32}, -}; -use serde::{Deserialize, Serialize}; - -use crate::{ - arch::{ - AdapterAirContext, AdapterRuntimeContext, DynAdapterInterface, DynArray, ExecutionBridge, - ExecutionState, MinimalInstruction, Result, VmAdapterAir, VmAdapterChip, - }, - system::memory::{MemoryController, OfflineMemory}, -}; - -// Replaces A: VmAdapterChip while testing VmCoreChip functionality, as it has no -// constraints and thus cannot cause a failure. -pub struct TestAdapterChip { - /// List of the return values of `preprocess` this chip should provide on each sequential call. - pub prank_reads: VecDeque>, - /// List of `pc_inc` to use in `postprocess` on each sequential call. - /// Defaults to `4` if not provided. - pub prank_pc_inc: VecDeque>, - - pub air: TestAdapterAir, -} - -impl TestAdapterChip { - pub fn new( - prank_reads: Vec>, - prank_pc_inc: Vec>, - execution_bridge: ExecutionBridge, - ) -> Self { - Self { - prank_reads: prank_reads.into(), - prank_pc_inc: prank_pc_inc.into(), - air: TestAdapterAir { execution_bridge }, - } - } -} - -#[derive(Clone, Serialize, Deserialize)] -pub struct TestAdapterRecord { - pub from_pc: u32, - pub operands: [T; 7], -} - -impl VmAdapterChip for TestAdapterChip { - type ReadRecord = (); - type WriteRecord = TestAdapterRecord; - type Air = TestAdapterAir; - type Interface = DynAdapterInterface; - - fn preprocess( - &mut self, - _memory: &mut MemoryController, - _instruction: &Instruction, - ) -> Result<(DynArray, Self::ReadRecord)> { - Ok(( - self.prank_reads - .pop_front() - .expect("Not enough prank reads provided") - .into(), - (), - )) - } - - fn postprocess( - &mut self, - memory: &mut MemoryController, - instruction: &Instruction, - from_state: ExecutionState, - _output: AdapterRuntimeContext, - _read_record: &Self::ReadRecord, - ) -> Result<(ExecutionState, Self::WriteRecord)> { - let pc_inc = self - .prank_pc_inc - .pop_front() - .map(|x| x.unwrap_or(4)) - .unwrap_or(4); - Ok(( - ExecutionState { - pc: from_state.pc + pc_inc, - timestamp: memory.timestamp(), - }, - TestAdapterRecord { - operands: [ - instruction.a, - instruction.b, - instruction.c, - instruction.d, - instruction.e, - instruction.f, - instruction.g, - ], - from_pc: from_state.pc, - }, - )) - } - - fn generate_trace_row( - &self, - row_slice: &mut [F], - _read_record: Self::ReadRecord, - write_record: Self::WriteRecord, - _memory: &OfflineMemory, - ) { - let cols: &mut TestAdapterCols = row_slice.borrow_mut(); - cols.from_pc = F::from_canonical_u32(write_record.from_pc); - cols.operands = write_record.operands; - // row_slice[0] = F::from_canonical_u32(write_record.from_pc); - // row_slice[1..].copy_from_slice(&write_record.operands); - } - - fn air(&self) -> &Self::Air { - &self.air - } -} - -#[derive(Clone, Copy, Debug)] -pub struct TestAdapterAir { - pub execution_bridge: ExecutionBridge, -} - -#[repr(C)] -#[derive(AlignedBorrow)] -pub struct TestAdapterCols { - pub from_pc: T, - pub operands: [T; 7], -} - -impl BaseAir for TestAdapterAir { - fn width(&self) -> usize { - TestAdapterCols::::width() - } -} - -impl VmAdapterAir for TestAdapterAir { - type Interface = DynAdapterInterface; - - fn eval( - &self, - builder: &mut AB, - local: &[AB::Var], - ctx: AdapterAirContext, - ) { - let processed_instruction: MinimalInstruction = ctx.instruction.into(); - let cols: &TestAdapterCols = local.borrow(); - self.execution_bridge - .execute_and_increment_or_set_pc( - processed_instruction.opcode, - cols.operands.to_vec(), - ExecutionState { - pc: cols.from_pc.into(), - timestamp: AB::Expr::ONE, - }, - AB::Expr::ZERO, - (4, ctx.to_pc), - ) - .eval(builder, processed_instruction.is_valid); - } - - fn get_from_pc(&self, local: &[AB::Var]) -> AB::Var { - let cols: &TestAdapterCols = local.borrow(); - cols.from_pc - } -} diff --git a/crates/vm/src/arch/vm.rs b/crates/vm/src/arch/vm.rs index c9d5cb2ffc..4f9c5056eb 100644 --- a/crates/vm/src/arch/vm.rs +++ b/crates/vm/src/arch/vm.rs @@ -2,12 +2,14 @@ use std::{ borrow::Borrow, collections::{HashMap, VecDeque}, marker::PhantomData, - mem, sync::Arc, }; use openvm_circuit::system::program::trace::compute_exe_commit; -use openvm_instructions::exe::VmExe; +use openvm_instructions::{ + exe::{SparseMemoryImage, VmExe}, + program::Program, +}; use openvm_stark_backend::{ config::{Com, Domain, StarkGenericConfig, Val}, engine::StarkEngine, @@ -15,30 +17,42 @@ use openvm_stark_backend::{ p3_commit::PolynomialSpace, p3_field::{FieldAlgebra, PrimeField32}, proof::Proof, - prover::types::{CommittedTraceData, ProofInput}, - utils::metrics_span, + prover::types::ProofInput, verifier::VerificationError, Chip, }; +use rand::{rngs::StdRng, SeedableRng}; use serde::{Deserialize, Serialize}; use thiserror::Error; use tracing::info_span; use super::{ - ExecutionError, VmComplexTraceHeights, VmConfig, CONNECTOR_AIR_ID, MERKLE_AIR_ID, - PROGRAM_AIR_ID, PROGRAM_CACHED_TRACE_INDEX, + execution_mode::{e1::E1Ctx, metered::ctx::DEFAULT_PAGE_BITS}, + ChipId, ExecutionError, InsExecutorE1, MemoryConfig, VmChipComplex, VmComplexTraceHeights, + VmConfig, VmInventoryError, CONNECTOR_AIR_ID, MERKLE_AIR_ID, PROGRAM_AIR_ID, + PROGRAM_CACHED_TRACE_INDEX, PUBLIC_VALUES_AIR_ID, }; #[cfg(feature = "bench-metrics")] use crate::metrics::VmMetrics; use crate::{ - arch::{hasher::poseidon2::vm_poseidon2_hasher, segment::ExecutionSegment}, + arch::{ + execution_mode::{ + metered::{MeteredCtx, Segment}, + tracegen::{TracegenCtx, TracegenExecutionControl}, + }, + hasher::poseidon2::vm_poseidon2_hasher, + interpreter::InterpretedInstance, + VmSegmentExecutor, VmSegmentState, + }, + execute_spanned, system::{ connector::{VmConnectorPvs, DEFAULT_SUSPEND_EXIT_CODE}, memory::{ - merkle::MemoryMerklePvs, - paged_vec::AddressMap, - tree::public_values::{UserPublicValuesProof, UserPublicValuesProofError}, - MemoryImage, CHUNK, + merkle::{ + public_values::{UserPublicValuesProof, UserPublicValuesProofError}, + MemoryMerklePvs, + }, + AddressMap, MemoryImage, CHUNK, }, program::trace::VmCommittedExe, }, @@ -52,9 +66,6 @@ pub enum GenerationError { Execution(#[from] ExecutionError), } -/// VM memory state for continuations. -pub type VmMemoryState = MemoryImage; - /// A trait for key-value store for `Streams`. pub trait KvStore: Send + Sync { fn get(&self, key: &[u8]) -> Option<&[u8]>; @@ -122,38 +133,57 @@ pub enum ExitCode { pub struct VmExecutorResult { pub per_segment: Vec>, /// When VM is running on persistent mode, public values are stored in a special memory space. - pub final_memory: Option>>, + pub final_memory: Option, } -pub struct VmExecutorNextSegmentState { - pub memory: MemoryImage, - pub input: Streams, +pub struct VmState +where + F: PrimeField32, +{ + pub instret: u64, pub pc: u32, + pub memory: MemoryImage, + pub input: Streams, + // TODO(ayush): make generic over SeedableRng + pub rng: StdRng, #[cfg(feature = "bench-metrics")] pub metrics: VmMetrics, } -impl VmExecutorNextSegmentState { - pub fn new(memory: MemoryImage, input: impl Into>, pc: u32) -> Self { +impl VmState { + pub fn new( + instret: u64, + pc: u32, + memory: MemoryImage, + input: impl Into>, + seed: u64, + ) -> Self { Self { + instret, + pc, memory, input: input.into(), - pc, + rng: StdRng::seed_from_u64(seed), #[cfg(feature = "bench-metrics")] metrics: VmMetrics::default(), } } } -pub struct VmExecutorOneSegmentResult> { - pub segment: ExecutionSegment, - pub next_state: Option>, +pub struct VmExecutorOneSegmentResult +where + F: PrimeField32, + VC: VmConfig, +{ + pub segment: VmSegmentExecutor, + pub next_state: Option>, } impl VmExecutor where F: PrimeField32, VC: VmConfig, + VC::Executor: InsExecutorE1, { /// Create a new VM executor with a given config. /// @@ -182,204 +212,269 @@ where self.config.system().continuation_enabled } - /// Executes the program in segments. + pub fn execute_e1( + &self, + exe: impl Into>, + inputs: impl Into>, + num_insns: Option, + ) -> Result, ExecutionError> { + let interpreter = InterpretedInstance::new(self.config.clone(), exe); + + let ctx = E1Ctx::new(num_insns); + let state = interpreter.execute(ctx, inputs)?; + + Ok(VmState { + instret: state.instret, + pc: state.pc, + memory: state.memory.memory, + input: state.streams, + rng: state.rng, + #[cfg(feature = "bench-metrics")] + metrics: VmMetrics::default(), + }) + } + + pub fn execute_metered( + &self, + exe: impl Into>, + input: impl Into>, + interactions: &[usize], + ) -> Result, ExecutionError> { + let interpreter = InterpretedInstance::new(self.config.clone(), exe); + + let chip_complex = self.config.create_chip_complex().unwrap(); + let segmentation_strategy = &self.config.system().segmentation_strategy; + + let ctx: MeteredCtx = + MeteredCtx::new(&chip_complex, interactions.to_vec()) + // TODO(ayush): get rid of segmentation_strategy altogether + .with_max_trace_height(segmentation_strategy.max_trace_height() as u32) + .with_max_cells(segmentation_strategy.max_cells()); + + let state = interpreter.execute_e2(ctx, input)?; + check_termination(state.exit_code)?; + + Ok(state.ctx.into_segments()) + } + + /// Base execution function that operates from a given state /// After each segment is executed, call the provided closure on the execution result. /// Returns the results from each closure, one per segment. /// /// The closure takes `f(segment_idx, segment) -> R`. - pub fn execute_and_then( + pub fn execute_and_then_from_state( &self, - exe: impl Into>, - input: impl Into>, - mut f: impl FnMut(usize, ExecutionSegment) -> Result, + exe: VmExe, + mut state: VmState, + segments: &[Segment], + mut f: impl FnMut(usize, VmSegmentExecutor) -> Result, map_err: impl Fn(ExecutionError) -> E, ) -> Result, E> { - let mem_config = self.config.system().memory_config; - let exe = exe.into(); - let mut segment_results = vec![]; - let memory = AddressMap::from_iter( - mem_config.as_offset, - 1 << mem_config.as_height, - 1 << mem_config.pointer_max_bits, - exe.init_memory.clone(), - ); - let pc = exe.pc_start; - let mut state = VmExecutorNextSegmentState::new(memory, input, pc); - - #[cfg(feature = "bench-metrics")] - { - state.metrics.fn_bounds = exe.fn_bounds.clone(); + // assert that segments are valid + assert_eq!(segments.first().unwrap().instret_start, state.instret); + for (prev, current) in segments.iter().zip(segments.iter().skip(1)) { + assert_eq!(current.instret_start, prev.instret_start + prev.num_insns); } - let mut segment_idx = 0; - - loop { + let mut results = Vec::new(); + for ( + segment_idx, + Segment { + num_insns, + trace_heights, + .. + }, + ) in segments.iter().enumerate() + { let _span = info_span!("execute_segment", segment = segment_idx).entered(); - let one_segment_result = self - .execute_until_segment(exe.clone(), state) - .map_err(&map_err)?; - segment_results.push(f(segment_idx, one_segment_result.segment)?); - if one_segment_result.next_state.is_none() { - break; + let chip_complex = create_and_initialize_chip_complex( + &self.config, + exe.program.clone(), + Some(state.memory), + Some(trace_heights), + ) + .unwrap(); + + let mut segment = VmSegmentExecutor::<_, VC, _>::new( + chip_complex, + self.trace_height_constraints.clone(), + exe.fn_bounds.clone(), + TracegenExecutionControl, + ); + + #[cfg(feature = "bench-metrics")] + { + segment.metrics = state.metrics; } - state = one_segment_result.next_state.unwrap(); - segment_idx += 1; + + let instret_end = state.instret + num_insns; + let ctx = TracegenCtx::new(Some(instret_end)); + let mut exec_state = + VmSegmentState::new(state.instret, state.pc, None, state.input, state.rng, ctx); + execute_spanned!("execute_e3", segment, &mut exec_state).map_err(&map_err)?; + + assert_eq!( + exec_state.pc, + segment.chip_complex.connector_chip().boundary_states[1] + .unwrap() + .pc + ); + + state = VmState { + instret: exec_state.instret, + pc: exec_state.pc, + memory: segment + .chip_complex + .base + .memory_controller + .memory_image() + .clone(), + input: exec_state.streams, + rng: exec_state.rng, + #[cfg(feature = "bench-metrics")] + metrics: segment.metrics.partial_take(), + }; + + results.push(f(segment_idx, segment)?); } - tracing::debug!("Number of continuation segments: {}", segment_results.len()); + tracing::debug!("Number of continuation segments: {}", results.len()); #[cfg(feature = "bench-metrics")] - metrics::counter!("num_segments").absolute(segment_results.len() as u64); + metrics::counter!("num_segments").absolute(results.len() as u64); - Ok(segment_results) + Ok(results) } - pub fn execute_segments( + pub fn execute_and_then( &self, exe: impl Into>, input: impl Into>, - ) -> Result>, ExecutionError> { - self.execute_and_then(exe, input, |_, seg| Ok(seg), |err| err) + segments: &[Segment], + f: impl FnMut(usize, VmSegmentExecutor) -> Result, + map_err: impl Fn(ExecutionError) -> E, + ) -> Result, E> { + let exe = exe.into(); + let state = create_initial_state(&self.config.system().memory_config, &exe, input, 0); + self.execute_and_then_from_state(exe, state, segments, f, map_err) } - /// Executes a program until a segmentation happens. - /// Returns the last segment and the vm state for next segment. - /// This is so that the tracegen and proving of this segment can be immediately started (on a - /// separate machine). - pub fn execute_until_segment( + pub fn execute_from_state( &self, - exe: impl Into>, - from_state: VmExecutorNextSegmentState, - ) -> Result, ExecutionError> { - let exe = exe.into(); - let mut segment = ExecutionSegment::new( - &self.config, - exe.program.clone(), - from_state.input, - Some(from_state.memory), - self.trace_height_constraints.clone(), - exe.fn_bounds.clone(), + exe: VmExe, + state: VmState, + segments: &[Segment], + ) -> Result, ExecutionError> { + let executors = + self.execute_and_then_from_state(exe, state, segments, |_, seg| Ok(seg), |err| err)?; + let last = executors + .last() + .expect("at least one segment must be executed"); + let final_memory = Some( + last.chip_complex + .base + .memory_controller + .memory_image() + .clone(), ); - #[cfg(feature = "bench-metrics")] - { - segment.metrics = from_state.metrics; - } - if let Some(overridden_heights) = self.overridden_heights.as_ref() { - segment.set_override_trace_heights(overridden_heights.clone()); - } - let state = metrics_span("execute_time_ms", || segment.execute_from_pc(from_state.pc))?; - - if state.is_terminated { - return Ok(VmExecutorOneSegmentResult { - segment, - next_state: None, - }); + let end_state = + last.chip_complex.connector_chip().boundary_states[1].expect("end state must be set"); + if end_state.is_terminate != 1 { + return Err(ExecutionError::DidNotTerminate); } - - assert!( - self.continuation_enabled(), - "multiple segments require to enable continuations" - ); - assert_eq!( - state.pc, - segment.chip_complex.connector_chip().boundary_states[1] - .unwrap() - .pc - ); - let final_memory = mem::take(&mut segment.final_memory) - .expect("final memory should be set in continuations segment"); - let streams = segment.chip_complex.take_streams(); - #[cfg(feature = "bench-metrics")] - let metrics = segment.metrics.partial_take(); - Ok(VmExecutorOneSegmentResult { - segment, - next_state: Some(VmExecutorNextSegmentState { - memory: final_memory, - input: streams, - pc: state.pc, - #[cfg(feature = "bench-metrics")] - metrics, - }), - }) + check_exit_code(end_state.exit_code)?; + Ok(final_memory) } pub fn execute( &self, exe: impl Into>, input: impl Into>, - ) -> Result>, ExecutionError> { - let mut last = None; - self.execute_and_then( - exe, - input, - |_, seg| { - last = Some(seg); - Ok(()) - }, - |err| err, - )?; - let last = last.expect("at least one segment must be executed"); - let final_memory = last.final_memory; - let end_state = - last.chip_complex.connector_chip().boundary_states[1].expect("end state must be set"); - if end_state.is_terminate != 1 { - return Err(ExecutionError::DidNotTerminate); - } - if end_state.exit_code != ExitCode::Success as u32 { - return Err(ExecutionError::FailedWithExitCode(end_state.exit_code)); - } - Ok(final_memory) + segments: &[Segment], + ) -> Result, ExecutionError> { + let exe = exe.into(); + let state = create_initial_state(&self.config.system().memory_config, &exe, input, 0); + self.execute_from_state(exe, state, segments) } - pub fn execute_and_generate( + // TODO(ayush): this is required in dummy keygen because it expects heights + // in VmComplexTraceHeights format. should be removed later + pub fn execute_segments( &self, exe: impl Into>, input: impl Into>, + segments: &[Segment], + ) -> Result>, ExecutionError> { + self.execute_and_then(exe, input, segments, |_, seg| Ok(seg), |err| err) + } + + pub fn execute_from_state_and_generate( + &self, + exe: VmExe, + state: VmState, + segments: &[Segment], ) -> Result, GenerationError> where + SC: StarkGenericConfig, Domain: PolynomialSpace, VC::Executor: Chip, VC::Periphery: Chip, { - self.execute_and_generate_impl(exe.into(), None, input) + let mut final_memory = None; + let per_segment = self.execute_and_then_from_state( + exe, + state, + segments, + |seg_idx, seg| { + final_memory = Some(seg.chip_complex.memory_controller().memory_image().clone()); + tracing::info_span!("trace_gen", segment = seg_idx) + .in_scope(|| seg.generate_proof_input(None)) + }, + GenerationError::Execution, + )?; + + Ok(VmExecutorResult { + per_segment, + final_memory, + }) } - pub fn execute_and_generate_with_cached_program( + pub fn execute_and_generate( &self, - committed_exe: Arc>, + exe: impl Into>, input: impl Into>, + segments: &[Segment], ) -> Result, GenerationError> where + SC: StarkGenericConfig, Domain: PolynomialSpace, VC::Executor: Chip, VC::Periphery: Chip, { - self.execute_and_generate_impl( - committed_exe.exe.clone(), - Some(committed_exe.committed_program.clone()), - input, - ) + let exe = exe.into(); + let state = create_initial_state(&self.config.system().memory_config, &exe, input, 0); + self.execute_from_state_and_generate(exe, state, segments) } - fn execute_and_generate_impl( + pub fn execute_and_generate_with_cached_program( &self, - exe: VmExe, - committed_program: Option>, + committed_exe: Arc>, input: impl Into>, + segments: &[Segment], ) -> Result, GenerationError> where Domain: PolynomialSpace, - VC::Executor: Chip, + VC::Executor: Chip + InsExecutorE1, VC::Periphery: Chip, { let mut final_memory = None; let per_segment = self.execute_and_then( - exe, + committed_exe.exe.clone(), input, - |seg_idx, mut seg| { - // Note: this will only be Some on the last segment; otherwise it is - // already moved into next segment state - final_memory = mem::take(&mut seg.final_memory); - tracing::info_span!("trace_gen", segment = seg_idx) - .in_scope(|| seg.generate_proof_input(committed_program.clone())) + segments, + |seg_idx, seg| { + final_memory = Some(seg.chip_complex.memory_controller().memory_image().clone()); + tracing::info_span!("trace_gen", segment = seg_idx).in_scope(|| { + seg.generate_proof_input(Some(committed_exe.committed_program.clone())) + }) }, GenerationError::Execution, )?; @@ -417,6 +512,7 @@ impl SingleSegmentVmExecutor where F: PrimeField32, VC: VmConfig, + VC::Executor: InsExecutorE1, { pub fn new(config: VC) -> Self { Self::new_with_overridden_trace_heights(config, None) @@ -446,29 +542,65 @@ where self.trace_height_constraints = constraints; } - /// Executes a program, compute the trace heights, and returns the public values. - pub fn execute_and_compute_heights( + pub fn execute_metered( &self, - exe: impl Into>, + exe: VmExe, input: impl Into>, - ) -> Result, ExecutionError> { - let segment = { - let mut segment = self.execute_impl(exe.into(), input.into())?; - segment.chip_complex.finalize_memory(); - segment - }; - let air_heights = segment.chip_complex.current_trace_heights(); - let vm_heights = segment.chip_complex.get_internal_trace_heights(); - let public_values = if let Some(pv_chip) = segment.chip_complex.public_values_chip() { - pv_chip.core.get_custom_public_values() - } else { - vec![] - }; - Ok(SingleSegmentVmExecutionResult { - public_values, - air_heights, - vm_heights, - }) + interactions: &[usize], + ) -> Result, ExecutionError> { + let interpreter = InterpretedInstance::new(self.config.clone(), exe); + + let chip_complex = self.config.create_chip_complex().unwrap(); + + let ctx: MeteredCtx = + MeteredCtx::new(&chip_complex, interactions.to_vec()) + .with_segment_check_insns(u64::MAX); + + let state = interpreter.execute_e2(ctx, input)?; + check_termination(state.exit_code)?; + + // Check segment count + let segments = state.ctx.into_segments(); + assert_eq!( + segments.len(), + 1, + "Expected exactly 1 segment, but got {}", + segments.len() + ); + let segment = segments.into_iter().next().unwrap(); + Ok(segment.trace_heights) + } + + fn execute_impl( + &self, + exe: VmExe, + input: impl Into>, + trace_heights: Option<&[u32]>, + ) -> Result, ExecutionError> { + let rng = StdRng::seed_from_u64(0); + let chip_complex = create_and_initialize_chip_complex( + &self.config, + exe.program.clone(), + None, + trace_heights, + ) + .unwrap(); + + let mut segment = VmSegmentExecutor::new( + chip_complex, + self.trace_height_constraints.clone(), + exe.fn_bounds.clone(), + TracegenExecutionControl, + ); + + if let Some(overridden_heights) = self.overridden_heights.as_ref() { + segment.set_override_trace_heights(overridden_heights.clone()); + } + + let ctx = TracegenCtx::default(); + let mut exec_state = VmSegmentState::new(0, exe.pc_start, None, input.into(), rng, ctx); + execute_spanned!("execute_e3", segment, &mut exec_state)?; + Ok(segment) } /// Executes a program and returns its proof input. @@ -476,38 +608,46 @@ where &self, committed_exe: Arc>, input: impl Into>, + max_trace_heights: &[u32], ) -> Result, GenerationError> where Domain: PolynomialSpace, VC::Executor: Chip, VC::Periphery: Chip, { - let segment = self.execute_impl(committed_exe.exe.clone(), input)?; + let segment = + self.execute_impl(committed_exe.exe.clone(), input, Some(max_trace_heights))?; let proof_input = tracing::info_span!("trace_gen").in_scope(|| { segment.generate_proof_input(Some(committed_exe.committed_program.clone())) })?; Ok(proof_input) } - fn execute_impl( + /// Executes a program, compute the trace heights, and returns the public values. + pub fn execute_and_compute_heights( &self, - exe: VmExe, + exe: impl Into>, input: impl Into>, - ) -> Result, ExecutionError> { - let pc_start = exe.pc_start; - let mut segment = ExecutionSegment::new( - &self.config, - exe.program.clone(), - input.into(), - None, - self.trace_height_constraints.clone(), - exe.fn_bounds.clone(), - ); - if let Some(overridden_heights) = self.overridden_heights.as_ref() { - segment.set_override_trace_heights(overridden_heights.clone()); - } - metrics_span("execute_time_ms", || segment.execute_from_pc(pc_start))?; - Ok(segment) + max_trace_heights: &[u32], + ) -> Result, ExecutionError> { + let executor = { + let mut executor = + self.execute_impl(exe.into(), input.into(), Some(max_trace_heights))?; + executor.chip_complex.finalize_memory(); + executor + }; + let air_heights = executor.chip_complex.current_trace_heights(); + let vm_heights = executor.chip_complex.get_internal_trace_heights(); + let public_values = if let Some(pv_chip) = executor.chip_complex.public_values_chip() { + pv_chip.step.get_custom_public_values() + } else { + vec![] + }; + Ok(SingleSegmentVmExecutionResult { + public_values, + air_heights, + vm_heights, + }) } } @@ -559,7 +699,7 @@ where E: StarkEngine, Domain: PolynomialSpace, VC: VmConfig, - VC::Executor: Chip, + VC::Executor: Chip + InsExecutorE1, VC::Periphery: Chip, { pub fn new(engine: E, config: VC) -> Self { @@ -610,32 +750,31 @@ where Arc::new(VmCommittedExe::commit(exe, self.engine.config().pcs())) } - pub fn execute( + pub fn execute_metered( &self, exe: impl Into>, input: impl Into>, - ) -> Result>, ExecutionError> { - self.executor.execute(exe, input) + interactions: &[usize], + ) -> Result, ExecutionError> { + self.executor.execute_metered(exe, input, interactions) } - pub fn execute_and_generate( + pub fn execute( &self, exe: impl Into>, input: impl Into>, - ) -> Result, GenerationError> { - self.executor.execute_and_generate(exe, input) + segments: &[Segment], + ) -> Result, ExecutionError> { + self.executor.execute(exe, input, segments) } - pub fn execute_and_generate_with_cached_program( + pub fn execute_and_generate( &self, - committed_exe: Arc>, + exe: impl Into>, input: impl Into>, - ) -> Result, GenerationError> - where - Domain: PolynomialSpace, - { - self.executor - .execute_and_generate_with_cached_program(committed_exe, input) + segments: &[Segment], + ) -> Result, GenerationError> { + self.executor.execute_and_generate(exe, input, segments) } pub fn prove_single( @@ -886,3 +1025,128 @@ where } } } + +pub fn create_memory_image( + memory_config: &MemoryConfig, + init_memory: SparseMemoryImage, +) -> MemoryImage { + AddressMap::from_sparse(memory_config.addr_space_sizes.clone(), init_memory) +} + +pub fn create_initial_state( + memory_config: &MemoryConfig, + exe: &VmExe, + input: impl Into>, + seed: u64, +) -> VmState +where + F: PrimeField32, +{ + let memory = create_memory_image(memory_config, exe.init_memory.clone()); + #[cfg(feature = "bench-metrics")] + let mut state = VmState::new(0, exe.pc_start, memory, input, seed); + #[cfg(not(feature = "bench-metrics"))] + let state = VmState::new(0, exe.pc_start, memory, input, seed); + #[cfg(feature = "bench-metrics")] + { + state.metrics.fn_bounds = exe.fn_bounds.clone(); + } + state +} + +/// Create and initialize a chip complex with program, streams, optional memory, and optional trace +/// heights +pub fn create_and_initialize_chip_complex( + config: &VC, + program: Program, + initial_memory: Option, + max_trace_heights: Option<&[u32]>, +) -> Result, VmInventoryError> +where + F: PrimeField32, + VC: VmConfig, + VC::Executor: InsExecutorE1, +{ + let mut chip_complex = config.create_chip_complex()?; + + // Strip debug info if profiling is disabled + let program = if !config.system().profiling { + program.strip_debug_infos() + } else { + program + }; + + chip_complex.set_program(program); + + if let Some(initial_memory) = initial_memory { + chip_complex.set_initial_memory(initial_memory); + } + + if let Some(max_trace_heights) = max_trace_heights { + let executor_chip_offset = if chip_complex.config().has_public_values_chip() { + PUBLIC_VALUES_AIR_ID + 1 + chip_complex.memory_controller().num_airs() + } else { + PUBLIC_VALUES_AIR_ID + chip_complex.memory_controller().num_airs() + }; + + // Calculate adapter offset the same way as in MeteredCtx + // TODO: extract + reuse this logic instead of maintaining this copy-paste + let boundary_idx = if chip_complex.config().has_public_values_chip() { + PUBLIC_VALUES_AIR_ID + 1 + } else { + PUBLIC_VALUES_AIR_ID + }; + + let adapter_offset = if chip_complex.config().continuation_enabled { + boundary_idx + 2 + } else { + boundary_idx + 1 + }; + + // Set trace heights for memory adapters + let num_access_adapters = chip_complex + .memory_controller() + .memory + .access_adapter_inventory + .num_access_adapters(); + chip_complex.set_adapter_heights( + &max_trace_heights[adapter_offset..adapter_offset + num_access_adapters], + ); + + for (i, chip_id) in chip_complex + .inventory + .insertion_order + .iter() + .rev() + .enumerate() + { + if let ChipId::Executor(exec_id) = chip_id { + if let Some(height_index) = executor_chip_offset.checked_add(i) { + if let Some(&height) = max_trace_heights.get(height_index) { + if let Some(executor) = chip_complex.inventory.executors.get_mut(*exec_id) { + // TODO(ayush): remove conversion + executor.set_trace_height(height.next_power_of_two() as usize); + } + } + } + } + } + } + + Ok(chip_complex) +} + +fn check_exit_code(exit_code: u32) -> Result<(), ExecutionError> { + if exit_code != ExitCode::Success as u32 { + return Err(ExecutionError::FailedWithExitCode(exit_code)); + } + Ok(()) +} + +fn check_termination(exit_code: Result, ExecutionError>) -> Result<(), ExecutionError> { + let exit_code = exit_code?; + match exit_code { + Some(code) => check_exit_code(code), + None => Err(ExecutionError::DidNotTerminate), + } +} diff --git a/crates/vm/src/metrics/mod.rs b/crates/vm/src/metrics/mod.rs index 916e8251ac..c36e04eac2 100644 --- a/crates/vm/src/metrics/mod.rs +++ b/crates/vm/src/metrics/mod.rs @@ -2,13 +2,7 @@ use std::{collections::BTreeMap, mem}; use cycle_tracker::CycleTracker; use metrics::counter; -use openvm_instructions::{ - exe::{FnBound, FnBounds}, - VmOpcode, -}; -use openvm_stark_backend::p3_field::PrimeField32; - -use crate::arch::{ExecutionSegment, InstructionExecutor, VmConfig}; +use openvm_instructions::exe::{FnBound, FnBounds}; pub mod cycle_tracker; @@ -30,39 +24,8 @@ pub struct VmMetrics { pub(crate) current_trace_cells: Vec, } -impl ExecutionSegment -where - F: PrimeField32, - VC: VmConfig, -{ - /// Update metrics that increment per instruction - #[allow(unused_variables)] - pub fn update_instruction_metrics( - &mut self, - pc: u32, - opcode: VmOpcode, - dsl_instr: Option, - ) { - self.metrics.cycle_count += 1; - - if self.system_config().profiling { - let executor = self.chip_complex.inventory.get_executor(opcode).unwrap(); - let opcode_name = executor.get_opcode_name(opcode.as_usize()); - self.metrics.update_trace_cells( - &self.air_names, - self.current_trace_cells(), - opcode_name, - dsl_instr, - ); - - #[cfg(feature = "function-span")] - self.metrics.update_current_fn(pc); - } - } -} - impl VmMetrics { - fn update_trace_cells( + pub fn update_trace_cells( &mut self, air_names: &[String], now_trace_cells: Vec, @@ -105,7 +68,7 @@ impl VmMetrics { } #[cfg(feature = "function-span")] - fn update_current_fn(&mut self, pc: u32) { + pub(super) fn update_current_fn(&mut self, pc: u32) { if self.fn_bounds.is_empty() { return; } diff --git a/crates/vm/src/system/connector/tests.rs b/crates/vm/src/system/connector/tests.rs index f3ded1812c..bc6f3fdb71 100644 --- a/crates/vm/src/system/connector/tests.rs +++ b/crates/vm/src/system/connector/tests.rs @@ -71,6 +71,7 @@ fn test_impl( let engine = BabyBearPoseidon2Engine::new(FriParameters::new_for_testing(3)); let vm = VirtualMachine::new(engine, vm_config.clone()); let pk = vm.keygen(); + let vk = pk.get_vk(); { let instructions = vec![Instruction::from_isize( @@ -88,8 +89,9 @@ fn test_impl( vm.engine.config.pcs(), )); let single_vm = SingleSegmentVmExecutor::new(vm_config); + let max_trace_heights = vec![0; vk.total_widths().len()]; let mut proof_input = single_vm - .execute_and_generate(committed_exe, vec![]) + .execute_and_generate(committed_exe, vec![], &max_trace_heights) .unwrap(); let connector_air_input = proof_input .per_air diff --git a/crates/vm/src/system/memory/adapter/mod.rs b/crates/vm/src/system/memory/adapter/mod.rs index 64e79a920b..82690d33a6 100644 --- a/crates/vm/src/system/memory/adapter/mod.rs +++ b/crates/vm/src/system/memory/adapter/mod.rs @@ -1,4 +1,9 @@ -use std::{borrow::BorrowMut, cmp::max, sync::Arc}; +use std::{ + borrow::{Borrow, BorrowMut}, + marker::PhantomData, + ptr::copy_nonoverlapping, + sync::Arc, +}; pub use air::*; pub use columns::*; @@ -8,31 +13,43 @@ use openvm_circuit_primitives::{ var_range::SharedVariableRangeCheckerChip, TraceSubRowGenerator, }; use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; +use openvm_instructions::NATIVE_AS; use openvm_stark_backend::{ config::{Domain, StarkGenericConfig, Val}, p3_air::BaseAir, p3_commit::PolynomialSpace, p3_field::PrimeField32, p3_matrix::dense::RowMajorMatrix, - p3_maybe_rayon::prelude::*, p3_util::log2_strict_usize, prover::types::AirProofInput, AirRef, Chip, ChipUsageGetter, }; -use crate::system::memory::{offline_checker::MemoryBus, MemoryAddress}; +use crate::{ + arch::{CustomBorrow, DenseRecordArena, RecordArena, SizedRecord}, + system::memory::{ + adapter::records::{ + size_by_layout, AccessLayout, AccessRecordHeader, AccessRecordMut, + MERGE_AND_NOT_SPLIT_FLAG, + }, + offline_checker::MemoryBus, + MemoryAddress, + }, +}; mod air; mod columns; +pub mod records; #[cfg(test)] mod tests; pub struct AccessAdapterInventory { chips: Vec>, + pub arena: DenseRecordArena, air_names: Vec, } -impl AccessAdapterInventory { +impl AccessAdapterInventory { pub fn new( range_checker: SharedVariableRangeCheckerChip, memory_bus: MemoryBus, @@ -55,41 +72,55 @@ impl AccessAdapterInventory { .flatten() .collect(); let air_names = (0..chips.len()).map(|i| air_name(1 << (i + 1))).collect(); - Self { chips, air_names } + Self { + chips, + arena: DenseRecordArena::with_capacity(0), + air_names, + } } + pub fn num_access_adapters(&self) -> usize { self.chips.len() } + pub fn set_override_trace_heights(&mut self, overridden_heights: Vec) { - assert_eq!(overridden_heights.len(), self.chips.len()); + self.set_arena_from_trace_heights( + &overridden_heights + .iter() + .map(|&h| h as u32) + .collect::>(), + ); for (chip, oh) in self.chips.iter_mut().zip(overridden_heights) { - chip.set_override_trace_heights(oh); + chip.set_override_trace_height(oh); } } - pub fn add_record(&mut self, record: AccessAdapterRecord) { - let n = record.data.len(); - let idx = log2_strict_usize(n) - 1; - let chip = &mut self.chips[idx]; - debug_assert!(chip.n() == n); - chip.add_record(record); - } - pub fn extend_records(&mut self, records: Vec>) { - for record in records { - self.add_record(record); - } - } - - #[cfg(test)] - pub fn records_for_n(&self, n: usize) -> &[AccessAdapterRecord] { - let idx = log2_strict_usize(n) - 1; - let chip = &self.chips[idx]; - chip.records() - } - - #[cfg(test)] - pub fn total_records(&self) -> usize { - self.chips.iter().map(|chip| chip.records().len()).sum() + pub fn set_arena_from_trace_heights(&mut self, trace_heights: &[u32]) { + assert_eq!(trace_heights.len(), self.chips.len()); + // At the very worst, each row in `Adapter` + // corresponds to a unique record of `block_size` being `2 * N`, + // and its `lowest_block_size` is at least 1 and `type_size` is at most 4. + let size_bound = trace_heights + .iter() + .enumerate() + .map(|(i, &h)| { + size_by_layout(&AccessLayout { + block_size: 1 << (i + 1), + lowest_block_size: 1, + type_size: 4, + }) * h as usize + }) + .sum::(); + assert!(self + .chips + .iter() + .all(|chip| chip.overridden_trace_height().is_none())); + tracing::debug!( + "Allocating {} bytes for memory adapters arena from heights {:?}", + size_bound, + trace_heights + ); + self.arena.set_capacity(size_bound); } pub fn get_heights(&self) -> Vec { @@ -100,7 +131,10 @@ impl AccessAdapterInventory { } #[allow(dead_code)] pub fn get_widths(&self) -> Vec { - self.chips.iter().map(|chip| chip.trace_width()).collect() + self.chips + .iter() + .map(|chip: &GenericAccessAdapterChip| chip.trace_width()) + .collect() } pub fn get_cells(&self) -> Vec { self.chips @@ -118,14 +152,143 @@ impl AccessAdapterInventory { pub fn air_names(&self) -> Vec { self.air_names.clone() } - pub fn generate_air_proof_inputs(self) -> Vec> + pub fn compute_trace_heights(&mut self) { + let num_adapters = self.chips.len(); + let mut heights = vec![0; num_adapters]; + + self.compute_heights_from_arena(&mut heights); + self.apply_overridden_heights(&mut heights); + for (chip, height) in self.chips.iter_mut().zip(heights) { + chip.set_computed_trace_height(height); + } + } + + fn compute_heights_from_arena(&mut self, heights: &mut [usize]) { + let bytes = self.arena.allocated_mut(); + tracing::debug!( + "Computing heights from memory adapters arena: used {} bytes", + bytes.len() + ); + let mut ptr = 0; + while ptr < bytes.len() { + let header: &AccessRecordHeader = bytes[ptr..].borrow(); + let layout: AccessLayout = unsafe { bytes[ptr..].extract_layout() }; + ptr += as SizedRecord>::size(&layout); + + let log_max_block_size = log2_strict_usize(header.block_size as usize); + for (i, h) in heights + .iter_mut() + .enumerate() + .take(log_max_block_size) + .skip(log2_strict_usize(header.lowest_block_size as usize)) + { + *h += 1 << (log_max_block_size - i - 1); + } + } + tracing::debug!("Computed heights from memory adapters arena: {:?}", heights); + } + + fn apply_overridden_heights(&mut self, heights: &mut [usize]) { + for (i, h) in heights.iter_mut().enumerate() { + if let Some(oh) = self.chips[i].overridden_trace_height() { + assert!( + oh >= *h, + "Overridden height {oh} is less than the required height {}", + *h + ); + *h = oh; + } + *h = next_power_of_two_or_zero(*h); + } + } + + pub fn generate_air_proof_inputs(mut self) -> Vec> where F: PrimeField32, Domain: PolynomialSpace, { - self.chips + let num_adapters = self.chips.len(); + + let mut heights = vec![0; num_adapters]; + self.compute_heights_from_arena(&mut heights); + self.apply_overridden_heights(&mut heights); + + let widths = self + .chips + .iter() + .map(|chip| chip.trace_width()) + .collect::>(); + let mut traces = widths + .iter() + .zip(heights.iter()) + .map(|(&width, &height)| RowMajorMatrix::new(vec![F::ZERO; width * height], width)) + .collect::>(); + + let mut trace_ptrs = vec![0; num_adapters]; + + let bytes = self.arena.allocated_mut(); + let mut ptr = 0; + while ptr < bytes.len() { + let layout: AccessLayout = unsafe { bytes[ptr..].extract_layout() }; + let record: AccessRecordMut<'_> = bytes[ptr..].custom_borrow(layout.clone()); + ptr += as SizedRecord>::size(&layout); + + let log_min_block_size = log2_strict_usize(record.header.lowest_block_size as usize); + let log_max_block_size = log2_strict_usize(record.header.block_size as usize); + + if record.header.timestamp_and_mask & MERGE_AND_NOT_SPLIT_FLAG != 0 { + for i in log_min_block_size..log_max_block_size { + let data_len = layout.type_size << i; + let ts_len = 1 << (i - log_min_block_size); + for j in 0..record.data.len() / (2 * data_len) { + let row_slice = + &mut traces[i].values[trace_ptrs[i]..trace_ptrs[i] + widths[i]]; + trace_ptrs[i] += widths[i]; + self.chips[i].fill_trace_row( + row_slice, + false, + MemoryAddress::new( + record.header.address_space, + record.header.pointer + (j << (i + 1)) as u32, + ), + &record.data[j * 2 * data_len..(j + 1) * 2 * data_len], + *record.timestamps[2 * j * ts_len..(2 * j + 1) * ts_len] + .iter() + .max() + .unwrap(), + *record.timestamps[(2 * j + 1) * ts_len..(2 * j + 2) * ts_len] + .iter() + .max() + .unwrap(), + ); + } + } + } else { + let timestamp = record.header.timestamp_and_mask; + for i in log_min_block_size..log_max_block_size { + let data_len = layout.type_size << i; + for j in 0..record.data.len() / (2 * data_len) { + let row_slice = + &mut traces[i].values[trace_ptrs[i]..trace_ptrs[i] + widths[i]]; + trace_ptrs[i] += widths[i]; + self.chips[i].fill_trace_row( + row_slice, + true, + MemoryAddress::new( + record.header.address_space, + record.header.pointer + (j << (i + 1)) as u32, + ), + &record.data[j * 2 * data_len..(j + 1) * 2 * data_len], + timestamp, + timestamp, + ); + } + } + } + } + traces .into_iter() - .map(|chip| chip.generate_air_proof_input()) + .map(|trace| AirProofInput::simple_no_pis(trace)) .collect() } @@ -134,7 +297,10 @@ impl AccessAdapterInventory { memory_bus: MemoryBus, clk_max_bits: usize, max_access_adapter_n: usize, - ) -> Option> { + ) -> Option> + where + F: Clone + Send + Sync, + { if N <= max_access_adapter_n { Some(GenericAccessAdapterChip::new::( range_checker, @@ -145,33 +311,27 @@ impl AccessAdapterInventory { None } } -} -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum AccessAdapterRecordKind { - Split, - Merge { - left_timestamp: u32, - right_timestamp: u32, - }, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct AccessAdapterRecord { - pub timestamp: u32, - pub address_space: T, - pub start_index: T, - pub data: Vec, - pub kind: AccessAdapterRecordKind, + pub(crate) fn alloc_record(&mut self, layout: AccessLayout) -> AccessRecordMut { + self.arena.alloc(layout) + } } #[enum_dispatch] -pub trait GenericAccessAdapterChipTrait { - fn set_override_trace_heights(&mut self, overridden_height: usize); - fn add_record(&mut self, record: AccessAdapterRecord); - fn n(&self) -> usize; - fn generate_trace(self) -> RowMajorMatrix - where +pub(crate) trait GenericAccessAdapterChipTrait { + fn set_override_trace_height(&mut self, overridden_height: usize); + fn overridden_trace_height(&self) -> Option; + fn set_computed_trace_height(&mut self, height: usize); + + fn fill_trace_row( + &self, + row: &mut [F], + is_split: bool, + address: MemoryAddress, + values: &[u8], + left_timestamp: u32, + right_timestamp: u32, + ) where F: PrimeField32; } @@ -186,7 +346,7 @@ enum GenericAccessAdapterChip { N32(AccessAdapterChip), } -impl GenericAccessAdapterChip { +impl GenericAccessAdapterChip { fn new( range_checker: SharedVariableRangeCheckerChip, memory_bus: MemoryBus, @@ -204,25 +364,17 @@ impl GenericAccessAdapterChip { _ => panic!("Only supports N in (2, 4, 8, 16, 32)"), } } - - #[cfg(test)] - fn records(&self) -> &[AccessAdapterRecord] { - match &self { - GenericAccessAdapterChip::N2(chip) => &chip.records, - GenericAccessAdapterChip::N4(chip) => &chip.records, - GenericAccessAdapterChip::N8(chip) => &chip.records, - GenericAccessAdapterChip::N16(chip) => &chip.records, - GenericAccessAdapterChip::N32(chip) => &chip.records, - } - } } -pub struct AccessAdapterChip { + +pub(crate) struct AccessAdapterChip { air: AccessAdapterAir, range_checker: SharedVariableRangeCheckerChip, - pub records: Vec>, overridden_height: Option, + computed_trace_height: Option, + _marker: PhantomData, } -impl AccessAdapterChip { + +impl AccessAdapterChip { pub fn new( range_checker: SharedVariableRangeCheckerChip, memory_bus: MemoryBus, @@ -232,67 +384,63 @@ impl AccessAdapterChip { Self { air: AccessAdapterAir:: { memory_bus, lt_air }, range_checker, - records: vec![], overridden_height: None, + computed_trace_height: None, + _marker: PhantomData, } } } impl GenericAccessAdapterChipTrait for AccessAdapterChip { - fn set_override_trace_heights(&mut self, overridden_height: usize) { + fn set_override_trace_height(&mut self, overridden_height: usize) { self.overridden_height = Some(overridden_height); } - fn add_record(&mut self, record: AccessAdapterRecord) { - self.records.push(record); + + fn overridden_trace_height(&self) -> Option { + self.overridden_height } - fn n(&self) -> usize { - N + + fn set_computed_trace_height(&mut self, height: usize) { + self.computed_trace_height = Some(height); } - fn generate_trace(self) -> RowMajorMatrix - where + + fn fill_trace_row( + &self, + row: &mut [F], + is_split: bool, + address: MemoryAddress, + values: &[u8], + left_timestamp: u32, + right_timestamp: u32, + ) where F: PrimeField32, { - let width = BaseAir::::width(&self.air); - let height = if let Some(oh) = self.overridden_height { - assert!( - oh >= self.records.len(), - "Overridden height is less than the required height" - ); - oh + let row: &mut AccessAdapterCols = row.borrow_mut(); + row.is_valid = F::ONE; + row.is_split = F::from_bool(is_split); + row.address = MemoryAddress::new( + F::from_canonical_u32(address.address_space), + F::from_canonical_u32(address.pointer), + ); + // TODO: normal way + if address.address_space < NATIVE_AS { + for (dst, src) in row.values.iter_mut().zip(values.iter()) { + *dst = F::from_canonical_u8(*src); + } } else { - self.records.len() - }; - let height = next_power_of_two_or_zero(height); - let mut values = F::zero_vec(height * width); - - values - .par_chunks_mut(width) - .zip(self.records.into_par_iter()) - .for_each(|(row, record)| { - let row: &mut AccessAdapterCols = row.borrow_mut(); - - row.is_valid = F::ONE; - row.values = record.data.try_into().unwrap(); - row.address = MemoryAddress::new(record.address_space, record.start_index); - - let (left_timestamp, right_timestamp) = match record.kind { - AccessAdapterRecordKind::Split => (record.timestamp, record.timestamp), - AccessAdapterRecordKind::Merge { - left_timestamp, - right_timestamp, - } => (left_timestamp, right_timestamp), - }; - debug_assert_eq!(max(left_timestamp, right_timestamp), record.timestamp); - - row.left_timestamp = F::from_canonical_u32(left_timestamp); - row.right_timestamp = F::from_canonical_u32(right_timestamp); - row.is_split = F::from_bool(record.kind == AccessAdapterRecordKind::Split); - - self.air.lt_air.generate_subrow( - (self.range_checker.as_ref(), left_timestamp, right_timestamp), - (&mut row.lt_aux, &mut row.is_right_larger), + unsafe { + copy_nonoverlapping( + values.as_ptr(), + row.values.as_mut_ptr() as *mut u8, + N * size_of::(), ); - }); - RowMajorMatrix::new(values, width) + } + } + row.left_timestamp = F::from_canonical_u32(left_timestamp); + row.right_timestamp = F::from_canonical_u32(right_timestamp); + self.air.lt_air.generate_subrow( + (self.range_checker.as_ref(), left_timestamp, right_timestamp), + (&mut row.lt_aux, &mut row.is_right_larger), + ); } } @@ -305,8 +453,7 @@ where } fn generate_air_proof_input(self) -> AirProofInput { - let trace = self.generate_trace(); - AirProofInput::simple_no_pis(trace) + unreachable!("AccessAdapterInventory should take care of adapters' trace generation") } } @@ -316,7 +463,7 @@ impl ChipUsageGetter for AccessAdapterChip { } fn current_trace_height(&self) -> usize { - self.records.len() + self.computed_trace_height.unwrap_or(0) } fn trace_width(&self) -> usize { @@ -328,3 +475,14 @@ impl ChipUsageGetter for AccessAdapterChip { fn air_name(n: usize) -> String { format!("AccessAdapter<{}>", n) } + +#[inline(always)] +pub fn get_chip_index(block_size: usize) -> usize { + assert!( + block_size.is_power_of_two() && block_size >= 2, + "Invalid block size {}", + block_size + ); + let index = block_size.trailing_zeros() - 1; + index as usize +} diff --git a/crates/vm/src/system/memory/adapter/records.rs b/crates/vm/src/system/memory/adapter/records.rs new file mode 100644 index 0000000000..f0918aab9b --- /dev/null +++ b/crates/vm/src/system/memory/adapter/records.rs @@ -0,0 +1,120 @@ +use std::{ + borrow::{Borrow, BorrowMut}, + mem::{align_of, size_of}, +}; + +use openvm_circuit_primitives::AlignedBytesBorrow; + +use crate::arch::{CustomBorrow, DenseRecordArena, RecordArena, SizedRecord}; + +#[repr(C)] +#[derive(Debug, Clone, Copy, AlignedBytesBorrow, PartialEq, Eq, PartialOrd, Ord)] +pub struct AccessRecordHeader { + /// Iff we need to merge before, this has the `MERGE_AND_NOT_SPLIT_FLAG` bit set + pub timestamp_and_mask: u32, + pub address_space: u32, + pub pointer: u32, + // TODO: these three are easily mergeable into a single u32 + pub block_size: u32, + pub lowest_block_size: u32, + pub type_size: u32, +} + +#[repr(C)] +#[derive(Debug)] +pub struct AccessRecordMut<'a> { + pub header: &'a mut AccessRecordHeader, + // TODO(AG): optimize with some `Option` serialization stuff + pub timestamps: &'a mut [u32], // len is block_size / lowest_block_size + pub data: &'a mut [u8], // len is block_size * type_size +} + +#[derive(Debug, Clone)] +pub struct AccessLayout { + /// The size of the block in elements. + pub block_size: usize, + /// The size of the minimal block we may split into/merge from (usually 1 or 4) + pub lowest_block_size: usize, + /// The size of the type in bytes (1 for u8, 4 for F). + pub type_size: usize, +} + +impl AccessLayout { + pub(crate) fn from_record_header(header: &AccessRecordHeader) -> Self { + Self { + block_size: header.block_size as usize, + lowest_block_size: header.lowest_block_size as usize, + type_size: header.type_size as usize, + } + } +} + +pub(crate) const MERGE_AND_NOT_SPLIT_FLAG: u32 = 1 << 31; + +pub(crate) fn size_by_layout(layout: &AccessLayout) -> usize { + size_of::() // header struct + + (layout.block_size / layout.lowest_block_size) * size_of::() // timestamps + + (layout.block_size * layout.type_size).next_multiple_of(4) // data +} + +impl SizedRecord for AccessRecordMut<'_> { + fn size(layout: &AccessLayout) -> usize { + size_by_layout(layout) + } + + fn alignment(_: &AccessLayout) -> usize { + align_of::() + } +} + +impl<'a> CustomBorrow<'a, AccessRecordMut<'a>, AccessLayout> for [u8] { + fn custom_borrow(&'a mut self, layout: AccessLayout) -> AccessRecordMut<'a> { + // header: AccessRecordHeader (using trivial borrowing) + let (header_buf, rest) = + unsafe { self.split_at_mut_unchecked(size_of::()) }; + let header = header_buf.borrow_mut(); + + let mut offset = 0; + + // timestamps: [u32] (block_size / cell_size * 4 bytes) + let timestamps = unsafe { + std::slice::from_raw_parts_mut( + rest.as_mut_ptr().add(offset) as *mut u32, + layout.block_size / layout.lowest_block_size, + ) + }; + offset += layout.block_size / layout.lowest_block_size * size_of::(); + + // data: [u8] (block_size * type_size bytes) + let data = unsafe { + std::slice::from_raw_parts_mut( + rest.as_mut_ptr().add(offset), + layout.block_size * layout.type_size, + ) + }; + + AccessRecordMut { + header, + data, + timestamps, + } + } + + unsafe fn extract_layout(&self) -> AccessLayout { + let header: &AccessRecordHeader = self.borrow(); + AccessLayout { + block_size: header.block_size as usize, + lowest_block_size: header.lowest_block_size as usize, + type_size: header.type_size as usize, + } + } +} + +impl<'a> RecordArena<'a, AccessLayout, AccessRecordMut<'a>> for DenseRecordArena { + fn alloc(&'a mut self, layout: AccessLayout) -> AccessRecordMut<'a> { + let bytes = self.alloc_bytes( as SizedRecord>::size( + &layout, + )); + <[u8] as CustomBorrow, AccessLayout>>::custom_borrow(bytes, layout) + } +} diff --git a/crates/vm/src/system/memory/controller/dimensions.rs b/crates/vm/src/system/memory/controller/dimensions.rs index 1082d3adf0..77345c2e82 100644 --- a/crates/vm/src/system/memory/controller/dimensions.rs +++ b/crates/vm/src/system/memory/controller/dimensions.rs @@ -2,23 +2,24 @@ use derive_new::new; use openvm_stark_backend::p3_util::log2_strict_usize; use serde::{Deserialize, Serialize}; -use crate::{arch::MemoryConfig, system::memory::CHUNK}; +use crate::{ + arch::{MemoryConfig, ADDR_SPACE_OFFSET}, + system::memory::CHUNK, +}; -// indicates that there are 2^`as_height` address spaces numbered starting from `as_offset`, +// indicates that there are 2^`addr_space_height` address spaces numbered starting from 1, // and that each address space has 2^`address_height` addresses numbered starting from 0 #[derive(Clone, Copy, Debug, Serialize, Deserialize, new)] pub struct MemoryDimensions { /// Address space height - pub as_height: usize, + pub addr_space_height: usize, /// Pointer height pub address_height: usize, - /// Address space offset - pub as_offset: u32, } impl MemoryDimensions { pub fn overall_height(&self) -> usize { - self.as_height + self.address_height + self.addr_space_height + self.address_height } /// Convert an address label (address space, block id) to its index in the memory merkle tree. /// @@ -27,17 +28,29 @@ impl MemoryDimensions { /// This function is primarily for internal use for accessing the memory merkle tree. /// Users should use a higher-level API when possible. pub fn label_to_index(&self, (addr_space, block_id): (u32, u32)) -> u64 { - debug_assert!(block_id < (1 << self.address_height)); - (((addr_space - self.as_offset) as u64) << self.address_height) + block_id as u64 + debug_assert!( + block_id < (1 << self.address_height), + "block_id={block_id} exceeds address_height={}", + self.address_height + ); + (((addr_space - ADDR_SPACE_OFFSET) as u64) << self.address_height) + block_id as u64 + } + + /// Convert an index in the memory merkle tree to an address label (address space, block id). + /// + /// This function performs the inverse operation of `label_to_index`. + pub fn index_to_label(&self, index: u64) -> (u32, u32) { + let block_id = (index & ((1 << self.address_height) - 1)) as u32; + let addr_space = (index >> self.address_height) as u32 + ADDR_SPACE_OFFSET; + (addr_space, block_id) } } impl MemoryConfig { pub fn memory_dimensions(&self) -> MemoryDimensions { MemoryDimensions { - as_height: self.as_height, + addr_space_height: self.addr_space_height, address_height: self.pointer_max_bits - log2_strict_usize(CHUNK), - as_offset: self.as_offset, } } } diff --git a/crates/vm/src/system/memory/controller/interface.rs b/crates/vm/src/system/memory/controller/interface.rs index b51e960a32..5a06e3cfbc 100644 --- a/crates/vm/src/system/memory/controller/interface.rs +++ b/crates/vm/src/system/memory/controller/interface.rs @@ -13,25 +13,11 @@ pub enum MemoryInterface { Persistent { boundary_chip: PersistentBoundaryChip, merkle_chip: MemoryMerkleChip, - initial_memory: MemoryImage, + initial_memory: MemoryImage, }, } impl MemoryInterface { - pub fn touch_range(&mut self, addr_space: u32, pointer: u32, len: u32) { - match self { - MemoryInterface::Volatile { .. } => {} - MemoryInterface::Persistent { - boundary_chip, - merkle_chip, - .. - } => { - boundary_chip.touch_range(addr_space, pointer, len); - merkle_chip.touch_range(addr_space, pointer, len); - } - } - } - pub fn compression_bus(&self) -> Option { match self { MemoryInterface::Volatile { .. } => None, diff --git a/crates/vm/src/system/memory/controller/mod.rs b/crates/vm/src/system/memory/controller/mod.rs index 680a03ab8e..36a75de8c1 100644 --- a/crates/vm/src/system/memory/controller/mod.rs +++ b/crates/vm/src/system/memory/controller/mod.rs @@ -1,20 +1,15 @@ -use std::{ - array, - collections::BTreeMap, - iter, - marker::PhantomData, - mem, - sync::{Arc, Mutex}, -}; +use std::{array::from_fn, collections::BTreeMap, fmt::Debug, iter, marker::PhantomData}; use getset::{Getters, MutGetters}; use openvm_circuit_primitives::{ assert_less_than::{AssertLtSubAir, LessThanAuxCols}, - is_zero::IsZeroSubAir, utils::next_power_of_two_or_zero, - var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus}, + var_range::{ + SharedVariableRangeCheckerChip, VariableRangeCheckerBus, VariableRangeCheckerChip, + }, TraceSubRowGenerator, }; +use openvm_instructions::NATIVE_AS; use openvm_stark_backend::{ config::{Domain, StarkGenericConfig}, interaction::PermutationCheckBus, @@ -26,26 +21,19 @@ use openvm_stark_backend::{ AirRef, Chip, ChipUsageGetter, }; use serde::{Deserialize, Serialize}; +use tracing::instrument; use self::interface::MemoryInterface; -use super::{ - paged_vec::{AddressMap, PAGE_SIZE}, - volatile::VolatileBoundaryChip, -}; +use super::{online::INITIAL_TIMESTAMP, volatile::VolatileBoundaryChip, AddressMap, MemoryAddress}; use crate::{ - arch::{hasher::HasherChip, MemoryConfig}, + arch::{hasher::HasherChip, MemoryConfig, ADDR_SPACE_OFFSET}, system::memory::{ - adapter::AccessAdapterInventory, + adapter::records::AccessRecordHeader, dimensions::MemoryDimensions, merkle::{MemoryMerkleChip, SerialReceiver}, - offline::{MemoryRecord, OfflineMemory, INITIAL_TIMESTAMP}, - offline_checker::{ - MemoryBaseAuxCols, MemoryBridge, MemoryBus, MemoryReadAuxCols, - MemoryReadOrImmediateAuxCols, MemoryWriteAuxCols, AUX_LEN, - }, - online::{Memory, MemoryLogEntry}, + offline_checker::{MemoryBaseAuxCols, MemoryBridge, MemoryBus, AUX_LEN}, + online::{AccessMetadata, TracingMemory}, persistent::PersistentBoundaryChip, - tree::MemoryNode, }, }; @@ -53,16 +41,13 @@ pub mod dimensions; pub mod interface; pub const CHUNK: usize = 8; + /// The offset of the Merkle AIR in AIRs of MemoryController. pub const MERKLE_AIR_OFFSET: usize = 1; /// The offset of the boundary AIR in AIRs of MemoryController. pub const BOUNDARY_AIR_OFFSET: usize = 0; -#[repr(C)] -#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] -pub struct RecordId(pub usize); - -pub type MemoryImage = AddressMap; +pub type MemoryImage = AddressMap; #[repr(C)] #[derive(Clone, Copy, Debug, PartialEq, Eq)] @@ -71,14 +56,11 @@ pub struct TimestampedValues { pub values: [T; N], } -/// An equipartition of memory, with timestamps and values. +/// A sorted equipartition of memory, with timestamps and values. /// -/// The key is a pair `(address_space, label)`, where `label` is the index of the block in the +/// The "key" is a pair `(address_space, label)`, where `label` is the index of the block in the /// partition. I.e., the starting address of the block is `(address_space, label * N)`. -/// -/// If a key is not present in the map, then the block is uninitialized (and therefore zero). -pub type TimestampedEquipartition = - BTreeMap<(u32, u32), TimestampedValues>; +pub type TimestampedEquipartition = Vec<((u32, u32), TimestampedValues)>; /// An equipartition of memory values. /// @@ -98,29 +80,7 @@ pub struct MemoryController { // Store separately to avoid smart pointer reference each time range_checker_bus: VariableRangeCheckerBus, // addr_space -> Memory data structure - memory: Memory, - /// A reference to the `OfflineMemory`. Will be populated after `finalize()`. - offline_memory: Arc>>, - pub access_adapters: AccessAdapterInventory, - // Filled during finalization. - final_state: Option>, -} - -#[allow(clippy::large_enum_variant)] -#[derive(Debug)] -enum FinalState { - Volatile(VolatileFinalState), - #[allow(dead_code)] - Persistent(PersistentFinalState), -} -#[derive(Debug, Default)] -struct VolatileFinalState { - _marker: PhantomData, -} -#[allow(dead_code)] -#[derive(Debug)] -struct PersistentFinalState { - final_memory: Equipartition, + pub memory: TracingMemory, } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] @@ -226,15 +186,18 @@ impl MemoryController { range_checker: SharedVariableRangeCheckerChip, ) -> Self { let range_checker_bus = range_checker.bus(); - let initial_memory = AddressMap::from_mem_config(&mem_config); assert!(mem_config.pointer_max_bits <= F::bits() - 2); - assert!(mem_config.as_height < F::bits() - 2); + assert!(mem_config + .addr_space_sizes + .iter() + .all(|&x| x <= (1 << mem_config.pointer_max_bits))); + assert!(mem_config.addr_space_height < F::bits() - 2); let addr_space_max_bits = log2_ceil_usize( - (mem_config.as_offset + 2u32.pow(mem_config.as_height as u32)) as usize, + (ADDR_SPACE_OFFSET + 2u32.pow(mem_config.addr_space_height as u32)) as usize, ); Self { memory_bus, - mem_config, + mem_config: mem_config.clone(), interface_chip: MemoryInterface::Volatile { boundary_chip: VolatileBoundaryChip::new( memory_bus, @@ -243,23 +206,9 @@ impl MemoryController { range_checker.clone(), ), }, - memory: Memory::new(&mem_config), - offline_memory: Arc::new(Mutex::new(OfflineMemory::new( - initial_memory, - 1, - memory_bus, - range_checker.clone(), - mem_config, - ))), - access_adapters: AccessAdapterInventory::new( - range_checker.clone(), - memory_bus, - mem_config.clk_max_bits, - mem_config.max_access_adapter_n, - ), + memory: TracingMemory::new(&mem_config, range_checker.clone(), memory_bus, 1), range_checker, range_checker_bus, - final_state: None, } } @@ -273,11 +222,9 @@ impl MemoryController { merkle_bus: PermutationCheckBus, compression_bus: PermutationCheckBus, ) -> Self { - assert_eq!(mem_config.as_offset, 1); let memory_dims = MemoryDimensions { - as_height: mem_config.as_height, + addr_space_height: mem_config.addr_space_height, address_height: mem_config.pointer_max_bits - log2_strict_usize(CHUNK), - as_offset: 1, }; let range_checker_bus = range_checker.bus(); let interface_chip = MemoryInterface::Persistent { @@ -292,30 +239,17 @@ impl MemoryController { }; Self { memory_bus, - mem_config, + mem_config: mem_config.clone(), interface_chip, - memory: Memory::new(&mem_config), // it is expected that the memory will be set later - offline_memory: Arc::new(Mutex::new(OfflineMemory::new( - AddressMap::from_mem_config(&mem_config), - CHUNK, - memory_bus, - range_checker.clone(), - mem_config, - ))), - access_adapters: AccessAdapterInventory::new( - range_checker.clone(), - memory_bus, - mem_config.clk_max_bits, - mem_config.max_access_adapter_n, - ), + memory: TracingMemory::new(&mem_config, range_checker.clone(), memory_bus, CHUNK), /* it is expected that the memory will be + * set later */ range_checker, range_checker_bus, - final_state: None, } } - pub fn memory_image(&self) -> &MemoryImage { - &self.memory.data + pub fn memory_image(&self) -> &MemoryImage { + &self.memory.data.memory } pub fn set_override_trace_heights(&mut self, overridden_heights: MemoryTraceHeights) { @@ -323,7 +257,8 @@ impl MemoryController { MemoryInterface::Volatile { boundary_chip } => match overridden_heights { MemoryTraceHeights::Volatile(oh) => { boundary_chip.set_overridden_height(oh.boundary); - self.access_adapters + self.memory + .access_adapter_inventory .set_override_trace_heights(oh.access_adapters); } _ => panic!("Expect overridden_heights to be MemoryTraceHeights::Volatile"), @@ -336,7 +271,8 @@ impl MemoryController { MemoryTraceHeights::Persistent(oh) => { boundary_chip.set_overridden_height(oh.boundary); merkle_chip.set_overridden_height(oh.merkle); - self.access_adapters + self.memory + .access_adapter_inventory .set_override_trace_heights(oh.access_adapters); } _ => panic!("Expect overridden_heights to be MemoryTraceHeights::Persistent"), @@ -344,26 +280,29 @@ impl MemoryController { } } - pub fn set_initial_memory(&mut self, memory: MemoryImage) { + // TODO[jpw]: change MemoryImage interface here + pub fn set_initial_memory(&mut self, memory: MemoryImage) { if self.timestamp() > INITIAL_TIMESTAMP + 1 { panic!("Cannot set initial memory after first timestamp"); } - let mut offline_memory = self.offline_memory.lock().unwrap(); - offline_memory.set_initial_memory(memory.clone(), self.mem_config); - - self.memory = Memory::from_image(memory.clone(), self.mem_config.access_capacity); match &mut self.interface_chip { MemoryInterface::Volatile { .. } => { - assert!( - memory.is_empty(), - "Cannot set initial memory for volatile memory" - ); + // Skip initialization for volatile memory + return; } MemoryInterface::Persistent { initial_memory, .. } => { - *initial_memory = memory; + *initial_memory = memory.clone(); } } + + self.memory = TracingMemory::new( + &self.mem_config, + self.range_checker.clone(), + self.memory_bus, + CHUNK, + ) + .with_image(memory); } pub fn memory_bridge(&self) -> MemoryBridge { @@ -374,69 +313,19 @@ impl MemoryController { ) } - pub fn read_cell(&mut self, address_space: F, pointer: F) -> (RecordId, F) { - let (record_id, [data]) = self.read(address_space, pointer); - (record_id, data) - } - - pub fn read(&mut self, address_space: F, pointer: F) -> (RecordId, [F; N]) { - let address_space_u32 = address_space.as_canonical_u32(); - let ptr_u32 = pointer.as_canonical_u32(); - assert!( - address_space == F::ZERO || ptr_u32 < (1 << self.mem_config.pointer_max_bits), - "memory out of bounds: {ptr_u32:?}", - ); - - let (record_id, values) = self.memory.read::(address_space_u32, ptr_u32); - - (record_id, values) - } - - /// Reads a word directly from memory without updating internal state. - /// - /// Any value returned is unconstrained. - pub fn unsafe_read_cell(&self, addr_space: F, ptr: F) -> F { - self.unsafe_read::<1>(addr_space, ptr)[0] - } - - /// Reads a word directly from memory without updating internal state. - /// - /// Any value returned is unconstrained. - pub fn unsafe_read(&self, addr_space: F, ptr: F) -> [F; N] { - let addr_space = addr_space.as_canonical_u32(); - let ptr = ptr.as_canonical_u32(); - array::from_fn(|i| self.memory.get(addr_space, ptr + i as u32)) - } - - /// Writes `data` to the given cell. - /// - /// Returns the `RecordId` and previous data. - pub fn write_cell(&mut self, address_space: F, pointer: F, data: F) -> (RecordId, F) { - let (record_id, [data]) = self.write(address_space, pointer, [data]); - (record_id, data) - } - - pub fn write( - &mut self, - address_space: F, - pointer: F, - data: [F; N], - ) -> (RecordId, [F; N]) { - assert_ne!(address_space, F::ZERO); - let address_space_u32 = address_space.as_canonical_u32(); - let ptr_u32 = pointer.as_canonical_u32(); - assert!( - ptr_u32 < (1 << self.mem_config.pointer_max_bits), - "memory out of bounds: {ptr_u32:?}", - ); - - self.memory.write(address_space_u32, ptr_u32, data) + pub fn helper(&self) -> SharedMemoryHelper { + let range_bus = self.range_checker.bus(); + SharedMemoryHelper { + range_checker: self.range_checker.clone(), + timestamp_lt_air: AssertLtSubAir::new(range_bus, self.mem_config.clk_max_bits), + _marker: Default::default(), + } } pub fn aux_cols_factory(&self) -> MemoryAuxColsFactory { let range_bus = self.range_checker.bus(); MemoryAuxColsFactory { - range_checker: self.range_checker.clone(), + range_checker: self.range_checker.as_ref(), timestamp_lt_air: AssertLtSubAir::new(range_bus, self.mem_config.clk_max_bits), _marker: Default::default(), } @@ -454,106 +343,188 @@ impl MemoryController { self.memory.timestamp() } - fn replay_access_log(&mut self) { - let log = mem::take(&mut self.memory.log); - if log.is_empty() { - // Online memory logs may be empty, but offline memory may be replayed from external - // sources. In these cases, we skip the calls to replay access logs because - // `set_log_capacity` would panic. - tracing::debug!("skipping replay_access_log"); - return; - } + /// Returns the equipartition of the touched blocks. + /// Modifies records and adds new to account for the initial/final segments. + fn touched_blocks_to_equipartition( + &mut self, + touched_blocks: Vec<((u32, u32), AccessMetadata)>, + ) -> TimestampedEquipartition { + // [perf] We can `.with_capacity()` if we keep track of the number of segments we initialize + let mut final_memory = Vec::new(); - let mut offline_memory = self.offline_memory.lock().unwrap(); - offline_memory.set_log_capacity(log.len()); + debug_assert!(touched_blocks.is_sorted_by_key(|(addr, _)| addr)); + let (bytes, fs): (Vec<_>, Vec<_>) = touched_blocks + .into_iter() + .partition(|((addr_sp, _), _)| *addr_sp < NATIVE_AS); // TODO: normal way - for entry in log { - Self::replay_access( - entry, - &mut offline_memory, - &mut self.interface_chip, - &mut self.access_adapters, - ); - } + self.handle_touched_blocks::(&mut final_memory, bytes, 4, |x| { + F::from_canonical_u8(x) + }); + self.handle_touched_blocks::(&mut final_memory, fs, 1, |x| x); + + debug_assert!(final_memory.is_sorted_by_key(|(key, _)| *key)); + final_memory } - /// Low-level API to replay a single memory access log entry and populate the [OfflineMemory], - /// [MemoryInterface], and `AccessAdapterInventory`. - pub fn replay_access( - entry: MemoryLogEntry, - offline_memory: &mut OfflineMemory, - interface_chip: &mut MemoryInterface, - adapter_records: &mut AccessAdapterInventory, + fn handle_touched_blocks( + &mut self, + final_memory: &mut Vec<((u32, u32), TimestampedValues)>, + touched_blocks: Vec<((u32, u32), AccessMetadata)>, + min_block_size: usize, + convert: impl Fn(T) -> F, ) { - match entry { - MemoryLogEntry::Read { - address_space, - pointer, - len, - } => { - if address_space != 0 { - interface_chip.touch_range(address_space, pointer, len as u32); - } - offline_memory.read(address_space, pointer, len, adapter_records); + let mut current_values = [T::default(); CHUNK]; + let mut current_cnt = 0; + let mut current_address = MemoryAddress::new(0, 0); + let mut current_timestamps = vec![0; CHUNK]; + for ((addr_space, ptr), metadata) in touched_blocks { + let AccessMetadata { + start_ptr, + timestamp, + block_size, + } = metadata; + assert!( + current_cnt == 0 + || (current_address.address_space == addr_space + && current_address.pointer + current_cnt as u32 == ptr), + "The union of all touched blocks must consist of blocks with sizes divisible by `CHUNK`" + ); + debug_assert!(block_size >= min_block_size as u32); + debug_assert!(ptr % min_block_size as u32 == 0); + + if current_cnt == 0 { + assert_eq!( + ptr & (CHUNK as u32 - 1), + 0, + "The union of all touched blocks must consist of `CHUNK`-aligned blocks" + ); + current_address = MemoryAddress::new(addr_space, ptr); } - MemoryLogEntry::Write { - address_space, - pointer, - data, - } => { - if address_space != 0 { - interface_chip.touch_range(address_space, pointer, data.len() as u32); - } - offline_memory.write(address_space, pointer, data, adapter_records); + + if block_size > min_block_size as u32 { + self.memory.add_split_record(AccessRecordHeader { + timestamp_and_mask: timestamp, + address_space: addr_space, + pointer: start_ptr, + block_size, + lowest_block_size: min_block_size as u32, + type_size: size_of::() as u32, + }); } - MemoryLogEntry::IncrementTimestampBy(amount) => { - offline_memory.increment_timestamp_by(amount); + if min_block_size > CHUNK { + assert_eq!(current_cnt, 0); + for i in (0..block_size).step_by(min_block_size) { + self.memory.add_split_record(AccessRecordHeader { + timestamp_and_mask: timestamp, + address_space: addr_space, + pointer: start_ptr + i, + block_size: min_block_size as u32, + lowest_block_size: CHUNK as u32, + type_size: size_of::() as u32, + }); + } + let values = unsafe { + self.memory + .data + .memory + .get_slice::((addr_space, ptr), block_size as usize) + }; + for i in (0..block_size).step_by(CHUNK) { + final_memory.push(( + (addr_space, ptr + i), + TimestampedValues { + timestamp, + values: from_fn(|j| convert(values[i as usize + j])), + }, + )); + } + } else { + for i in 0..block_size { + current_values[current_cnt] = + unsafe { self.memory.data.memory.get((addr_space, ptr + i)) }; + if current_cnt & (min_block_size - 1) == 0 { + current_timestamps[current_cnt / min_block_size] = timestamp; + } + current_cnt += 1; + if current_cnt == CHUNK { + let timestamp = *current_timestamps[..CHUNK / min_block_size] + .iter() + .max() + .unwrap(); + self.memory.add_merge_record( + AccessRecordHeader { + timestamp_and_mask: timestamp, + address_space: addr_space, + pointer: current_address.pointer, + block_size: CHUNK as u32, + lowest_block_size: min_block_size as u32, + type_size: size_of::() as u32, + }, + ¤t_values, + ¤t_timestamps[..CHUNK / min_block_size], + ); + final_memory.push(( + (current_address.address_space, current_address.pointer), + TimestampedValues { + timestamp, + values: from_fn(|i| convert(current_values[i])), + }, + )); + current_address.pointer += current_cnt as u32; + current_cnt = 0; + } + } } - }; + } + assert_eq!(current_cnt, 0, "The union of all touched blocks must consist of blocks with sizes divisible by `CHUNK`"); } - /// Returns the final memory state if persistent. + /// Finalize the boundary and merkle chips. + #[instrument(name = "memory_finalize", skip_all)] pub fn finalize(&mut self, hasher: Option<&mut H>) where H: HasherChip + Sync + for<'a> SerialReceiver<&'a [F]>, { - if self.final_state.is_some() { - return; - } + let touched_blocks = self.memory.touched_blocks(); + + // Compute trace heights for access adapter chips and update their stored heights + self.memory.access_adapter_inventory.compute_trace_heights(); - self.replay_access_log(); - let mut offline_memory = self.offline_memory.lock().unwrap(); + let mut final_memory_volatile = None; + let mut final_memory_persistent = None; + + match &self.interface_chip { + MemoryInterface::Volatile { .. } => { + final_memory_volatile = + Some(self.touched_blocks_to_equipartition::<1>(touched_blocks)); + } + MemoryInterface::Persistent { .. } => { + final_memory_persistent = + Some(self.touched_blocks_to_equipartition::(touched_blocks)); + } + } match &mut self.interface_chip { MemoryInterface::Volatile { boundary_chip } => { - let final_memory = offline_memory.finalize::<1>(&mut self.access_adapters); + let final_memory = final_memory_volatile.unwrap(); boundary_chip.finalize(final_memory); - self.final_state = Some(FinalState::Volatile(VolatileFinalState::default())); } MemoryInterface::Persistent { - merkle_chip, boundary_chip, + merkle_chip, initial_memory, } => { - let hasher = hasher.unwrap(); - let final_partition = offline_memory.finalize::(&mut self.access_adapters); + let final_memory = final_memory_persistent.unwrap(); - boundary_chip.finalize(initial_memory, &final_partition, hasher); - let final_memory_values = final_partition + let hasher = hasher.unwrap(); + boundary_chip.finalize(initial_memory, &final_memory, hasher); + let final_memory_values = final_memory .into_par_iter() .map(|(key, value)| (key, value.values)) .collect(); - let initial_node = MemoryNode::tree_from_memory( - merkle_chip.air.memory_dimensions, - initial_memory, - hasher, - ); - merkle_chip.finalize(&initial_node, &final_memory_values, hasher); - self.final_state = Some(FinalState::Persistent(PersistentFinalState { - final_memory: final_memory_values.clone(), - })); + merkle_chip.finalize(initial_memory, &final_memory_values, hasher); } - }; + } } pub fn generate_air_proof_inputs(self) -> Vec> @@ -562,12 +533,8 @@ impl MemoryController { { let mut ret = Vec::new(); - let Self { - interface_chip, - access_adapters, - .. - } = self; - match interface_chip { + let access_adapters = self.memory.access_adapter_inventory; + match self.interface_chip { MemoryInterface::Volatile { boundary_chip } => { ret.push(boundary_chip.generate_air_proof_input()); } @@ -608,7 +575,7 @@ impl MemoryController { airs.push(merkle_chip.air()); } } - airs.extend(self.access_adapters.airs()); + airs.extend(self.memory.access_adapter_inventory.airs()); airs } @@ -619,7 +586,7 @@ impl MemoryController { if self.continuation_enabled() { num_airs += 1; } - num_airs += self.access_adapters.num_access_adapters(); + num_airs += self.memory.access_adapter_inventory.num_access_adapters(); num_airs } @@ -628,7 +595,7 @@ impl MemoryController { if self.continuation_enabled() { air_names.push("Merkle".to_string()); } - air_names.extend(self.access_adapters.air_names()); + air_names.extend(self.memory.access_adapter_inventory.air_names()); air_names } @@ -637,7 +604,7 @@ impl MemoryController { } pub fn get_memory_trace_heights(&self) -> MemoryTraceHeights { - let access_adapters = self.access_adapters.get_heights(); + let access_adapters = self.memory.access_adapter_inventory.get_heights(); match &self.interface_chip { MemoryInterface::Volatile { boundary_chip } => { MemoryTraceHeights::Volatile(VolatileMemoryTraceHeights { @@ -658,7 +625,7 @@ impl MemoryController { } pub fn get_dummy_memory_trace_heights(&self) -> MemoryTraceHeights { - let access_adapters = vec![1; self.access_adapters.num_access_adapters()]; + let access_adapters = vec![1; self.memory.access_adapter_inventory.num_access_adapters()]; match &self.interface_chip { MemoryInterface::Volatile { .. } => { MemoryTraceHeights::Volatile(VolatileMemoryTraceHeights { @@ -676,6 +643,23 @@ impl MemoryController { } } + pub fn get_memory_trace_widths(&self) -> Vec { + let access_adapter_widths = self.memory.access_adapter_inventory.get_widths(); + match &self.interface_chip { + MemoryInterface::Volatile { boundary_chip } => { + vec![boundary_chip.trace_width()] + } + MemoryInterface::Persistent { + boundary_chip, + merkle_chip, + .. + } => [boundary_chip.trace_width(), merkle_chip.trace_width()] + .into_iter() + .chain(access_adapter_widths) + .collect(), + } + } + pub fn current_trace_cells(&self) -> Vec { let mut ret = Vec::new(); match &self.interface_chip { @@ -691,77 +675,41 @@ impl MemoryController { ret.push(merkle_chip.current_trace_cells()); } } - ret.extend(self.access_adapters.get_cells()); + ret.extend(self.memory.access_adapter_inventory.get_cells()); ret } - - /// Returns a reference to the offline memory. - /// - /// Until `finalize` is called, the `OfflineMemory` does not contain useful state, and should - /// therefore not be used by any chip during execution. However, to obtain a reference to the - /// offline memory that will be useful in trace generation, a chip can call `offline_memory()` - /// and store the returned reference for later use. - pub fn offline_memory(&self) -> Arc>> { - self.offline_memory.clone() - } - pub fn get_memory_logs(&self) -> &Vec> { - &self.memory.log - } - pub fn set_memory_logs(&mut self, logs: Vec>) { - self.memory.log = logs; - } - pub fn take_memory_logs(&mut self) -> Vec> { - std::mem::take(&mut self.memory.log) - } } -pub struct MemoryAuxColsFactory { +/// Owned version of [MemoryAuxColsFactory]. +pub struct SharedMemoryHelper { pub(crate) range_checker: SharedVariableRangeCheckerChip, pub(crate) timestamp_lt_air: AssertLtSubAir, pub(crate) _marker: PhantomData, } +/// A helper for generating trace values in auxiliary memory columns related to the offline memory +/// argument. +pub struct MemoryAuxColsFactory<'a, T> { + pub(crate) range_checker: &'a VariableRangeCheckerChip, + pub(crate) timestamp_lt_air: AssertLtSubAir, + pub(crate) _marker: PhantomData, +} + // NOTE[jpw]: The `make_*_aux_cols` functions should be thread-safe so they can be used in // parallelized trace generation. -impl MemoryAuxColsFactory { - pub fn generate_read_aux(&self, read: &MemoryRecord, buffer: &mut MemoryReadAuxCols) { - assert!( - !read.address_space.is_zero(), - "cannot make `MemoryReadAuxCols` for address space 0" - ); - self.generate_base_aux(read, &mut buffer.base); +impl MemoryAuxColsFactory<'_, F> { + /// Fill the trace assuming `prev_timestamp` is already provided in `buffer`. + pub fn fill(&self, prev_timestamp: u32, timestamp: u32, buffer: &mut MemoryBaseAuxCols) { + self.generate_timestamp_lt(prev_timestamp, timestamp, &mut buffer.timestamp_lt_aux); + // Safety: even if prev_timestamp were obtained by transmute_ref from + // `buffer.prev_timestamp`, this should still work because it is a direct assignment + buffer.prev_timestamp = F::from_canonical_u32(prev_timestamp); } - pub fn generate_read_or_immediate_aux( - &self, - read: &MemoryRecord, - buffer: &mut MemoryReadOrImmediateAuxCols, - ) { - IsZeroSubAir.generate_subrow( - read.address_space, - (&mut buffer.is_zero_aux, &mut buffer.is_immediate), - ); - self.generate_base_aux(read, &mut buffer.base); - } - - pub fn generate_write_aux( - &self, - write: &MemoryRecord, - buffer: &mut MemoryWriteAuxCols, - ) { - buffer - .prev_data - .copy_from_slice(write.prev_data_slice().unwrap()); - self.generate_base_aux(write, &mut buffer.base); - } - - pub fn generate_base_aux(&self, record: &MemoryRecord, buffer: &mut MemoryBaseAuxCols) { - buffer.prev_timestamp = F::from_canonical_u32(record.prev_timestamp); - self.generate_timestamp_lt( - record.prev_timestamp, - record.timestamp, - &mut buffer.timestamp_lt_aux, - ); + /// # Safety + /// We assume that `F::ZERO` has underlying memory equivalent to `mem::zeroed()`. + pub fn fill_zero(&self, buffer: &mut MemoryBaseAuxCols) { + *buffer = unsafe { std::mem::zeroed() }; } fn generate_timestamp_lt( @@ -770,50 +718,24 @@ impl MemoryAuxColsFactory { timestamp: u32, buffer: &mut LessThanAuxCols, ) { - debug_assert!(prev_timestamp < timestamp); + debug_assert!( + prev_timestamp < timestamp, + "prev_timestamp {prev_timestamp} >= timestamp {timestamp}" + ); self.timestamp_lt_air.generate_subrow( - (self.range_checker.as_ref(), prev_timestamp, timestamp), + (self.range_checker, prev_timestamp, timestamp), &mut buffer.lower_decomp, ); } +} - /// In general, prefer `generate_read_aux` which writes in-place rather than this function. - pub fn make_read_aux_cols(&self, read: &MemoryRecord) -> MemoryReadAuxCols { - assert!( - !read.address_space.is_zero(), - "cannot make `MemoryReadAuxCols` for address space 0" - ); - MemoryReadAuxCols::new( - read.prev_timestamp, - self.generate_timestamp_lt_cols(read.prev_timestamp, read.timestamp), - ) - } - - /// In general, prefer `generate_write_aux` which writes in-place rather than this function. - pub fn make_write_aux_cols( - &self, - write: &MemoryRecord, - ) -> MemoryWriteAuxCols { - let prev_data = write.prev_data_slice().unwrap(); - MemoryWriteAuxCols::new( - prev_data.try_into().unwrap(), - F::from_canonical_u32(write.prev_timestamp), - self.generate_timestamp_lt_cols(write.prev_timestamp, write.timestamp), - ) - } - - fn generate_timestamp_lt_cols( - &self, - prev_timestamp: u32, - timestamp: u32, - ) -> LessThanAuxCols { - debug_assert!(prev_timestamp < timestamp); - let mut decomp = [F::ZERO; AUX_LEN]; - self.timestamp_lt_air.generate_subrow( - (self.range_checker.as_ref(), prev_timestamp, timestamp), - &mut decomp, - ); - LessThanAuxCols::new(decomp) +impl SharedMemoryHelper { + pub fn as_borrowed(&self) -> MemoryAuxColsFactory<'_, T> { + MemoryAuxColsFactory { + range_checker: self.range_checker.as_ref(), + timestamp_lt_air: self.timestamp_lt_air, + _marker: PhantomData, + } } } @@ -824,7 +746,7 @@ mod tests { }; use openvm_stark_backend::{interaction::BusIndex, p3_field::FieldAlgebra}; use openvm_stark_sdk::p3_baby_bear::BabyBear; - use rand::{prelude::SliceRandom, thread_rng, Rng}; + use rand::{thread_rng, Rng}; use super::MemoryController; use crate::{ @@ -843,27 +765,37 @@ mod tests { let range_bus = VariableRangeCheckerBus::new(RANGE_CHECKER_BUS, memory_config.decomp); let range_checker = SharedVariableRangeCheckerChip::new(range_bus); - let mut memory_controller = MemoryController::with_volatile_memory( + let mut memory_controller = MemoryController::::with_volatile_memory( memory_bus, - memory_config, + memory_config.clone(), range_checker.clone(), ); let mut rng = thread_rng(); for _ in 0..1000 { - let address_space = F::from_canonical_u32(*[1, 2].choose(&mut rng).unwrap()); - let pointer = - F::from_canonical_u32(rng.gen_range(0..1 << memory_config.pointer_max_bits)); + // TODO[jpw]: test other address spaces? + let address_space = 4u32; + let pointer = rng.gen_range(0..1 << memory_config.pointer_max_bits); if rng.gen_bool(0.5) { let data = F::from_canonical_u32(rng.gen_range(0..1 << 30)); - memory_controller.write(address_space, pointer, [data]); + // address space is 4 so cell type is `F` + unsafe { + memory_controller + .memory + .write::(address_space, pointer, [data]); + } } else { - memory_controller.read::<1>(address_space, pointer); + unsafe { + memory_controller + .memory + .read::(address_space, pointer); + } } } assert!(memory_controller - .access_adapters + .memory + .access_adapter_inventory .get_heights() .iter() .all(|&h| h == 0)); diff --git a/crates/vm/src/system/memory/merkle/mod.rs b/crates/vm/src/system/memory/merkle/mod.rs index 74f8951bc4..ebef286baf 100644 --- a/crates/vm/src/system/memory/merkle/mod.rs +++ b/crates/vm/src/system/memory/merkle/mod.rs @@ -1,27 +1,36 @@ -use openvm_stark_backend::{interaction::PermutationCheckBus, p3_field::PrimeField32}; -use rustc_hash::FxHashSet; +use std::array; + +use openvm_stark_backend::{ + interaction::PermutationCheckBus, p3_field::PrimeField32, p3_maybe_rayon::prelude::*, +}; + +use super::{controller::dimensions::MemoryDimensions, online::LinearMemory, MemoryImage}; +use crate::system::memory::online::PAGE_SIZE; -use super::controller::dimensions::MemoryDimensions; mod air; mod columns; +pub mod public_values; mod trace; +mod tree; pub use air::*; pub use columns::*; pub(super) use trace::SerialReceiver; +pub use tree::*; -#[cfg(test)] -mod tests; +// TODO: add back +// #[cfg(test)] +// mod tests; pub struct MemoryMerkleChip { pub air: MemoryMerkleAir, - touched_nodes: FxHashSet<(usize, u32, u32)>, - num_touched_nonleaves: usize, final_state: Option>, + // TODO(AG): how are these two different? Doesn't one just end up being copied to the other? + trace_height: Option, overridden_height: Option, } #[derive(Debug)] -struct FinalState { +pub struct FinalState { rows: Vec>, init_root: [F; CHUNK], final_root: [F; CHUNK], @@ -35,46 +44,92 @@ impl MemoryMerkleChip { merkle_bus: PermutationCheckBus, compression_bus: PermutationCheckBus, ) -> Self { - assert!(memory_dimensions.as_height > 0); + assert!(memory_dimensions.addr_space_height > 0); assert!(memory_dimensions.address_height > 0); - let mut touched_nodes = FxHashSet::default(); - touched_nodes.insert((memory_dimensions.overall_height(), 0, 0)); Self { air: MemoryMerkleAir { memory_dimensions, merkle_bus, compression_bus, }, - touched_nodes, - num_touched_nonleaves: 1, final_state: None, + trace_height: None, overridden_height: None, } } pub fn set_overridden_height(&mut self, override_height: usize) { self.overridden_height = Some(override_height); } +} - fn touch_node(&mut self, height: usize, as_label: u32, address_label: u32) { - if self.touched_nodes.insert((height, as_label, address_label)) { - assert_ne!(height, self.air.memory_dimensions.overall_height()); - if height != 0 { - self.num_touched_nonleaves += 1; - } - if height >= self.air.memory_dimensions.address_height { - self.touch_node(height + 1, as_label / 2, address_label); +#[tracing::instrument(level = "info", skip_all)] +fn memory_to_vec_partition( + memory: &MemoryImage, + md: &MemoryDimensions, +) -> Vec<(u64, [F; N])> { + (0..memory.mem.len()) + .into_par_iter() + .map(move |as_idx| { + let space_mem = memory.mem[as_idx].as_slice(); + let cell_size = memory.cell_size[as_idx]; + debug_assert_eq!(PAGE_SIZE % (cell_size * N), 0); + + let num_nonzero_pages = space_mem + .par_chunks(PAGE_SIZE) + .enumerate() + .flat_map(|(idx, page)| { + if page.iter().any(|x| *x != 0) { + Some(idx + 1) + } else { + None + } + }) + .max() + .unwrap_or(0); + + let space_mem = &space_mem[..(num_nonzero_pages * PAGE_SIZE).min(space_mem.len())]; + let mut num_elements = space_mem.len() / (cell_size * N); + // virtual memory may be larger than dimensions due to rounding up to page size + num_elements = num_elements.min(1 << md.address_height); + + // TODO: handle different cell sizes better + if cell_size == 1 { + (0..num_elements) + .into_par_iter() + .map(move |idx| { + let byte_index = idx * cell_size * N; + unsafe { + let ptr = space_mem.as_ptr(); + let src = ptr.add(byte_index); + ( + md.label_to_index((as_idx as u32, idx as u32)), + array::from_fn(|i| { + F::from_canonical_u8(core::ptr::read(src.add(i))) + }), + ) + } + }) + .collect::>() } else { - self.touch_node(height + 1, as_label, address_label / 2); + assert_eq!(cell_size, 4); + (0..num_elements) + .into_par_iter() + .map(move |idx| { + let byte_index = idx * cell_size * N; + unsafe { + let ptr = space_mem.as_ptr(); + let src = ptr.add(byte_index) as *const F; + ( + md.label_to_index((as_idx as u32, idx as u32)), + array::from_fn(|i| core::ptr::read(src.add(i))), + ) + } + }) + .collect::>() } - } - } - - pub fn touch_range(&mut self, address_space: u32, address: u32, len: u32) { - let as_label = address_space - self.air.memory_dimensions.as_offset; - let first_address_label = address / CHUNK as u32; - let last_address_label = (address + len - 1) / CHUNK as u32; - for address_label in first_address_label..=last_address_label { - self.touch_node(0, as_label, address_label); - } - } + }) + .collect::>() + .into_iter() + .flatten() + .collect::>() } diff --git a/crates/vm/src/system/memory/tree/public_values.rs b/crates/vm/src/system/memory/merkle/public_values.rs similarity index 66% rename from crates/vm/src/system/memory/tree/public_values.rs rename to crates/vm/src/system/memory/merkle/public_values.rs index 1c6866b959..43b3867a88 100644 --- a/crates/vm/src/system/memory/tree/public_values.rs +++ b/crates/vm/src/system/memory/merkle/public_values.rs @@ -1,17 +1,16 @@ -use std::{collections::BTreeMap, sync::Arc}; - use openvm_stark_backend::{p3_field::PrimeField32, p3_util::log2_strict_usize}; use serde::{Deserialize, Serialize}; use thiserror::Error; use crate::{ - arch::hasher::Hasher, + arch::{hasher::Hasher, ADDR_SPACE_OFFSET}, system::memory::{ - dimensions::MemoryDimensions, paged_vec::Address, tree::MemoryNode, MemoryImage, + dimensions::MemoryDimensions, merkle::tree::MerkleTree, online::LinearMemory, MemoryImage, }, }; -pub const PUBLIC_VALUES_ADDRESS_SPACE_OFFSET: u32 = 2; +pub const PUBLIC_VALUES_AS: u32 = 3; +pub const PUBLIC_VALUES_ADDRESS_SPACE_OFFSET: u32 = PUBLIC_VALUES_AS - ADDR_SPACE_OFFSET; /// Merkle proof for user public values in the memory state. #[derive(Clone, Debug, Serialize, Deserialize)] @@ -47,11 +46,13 @@ impl UserPublicValuesProof { /// Computes the proof of the public values from the final memory state. /// Assumption: /// - `num_public_values` is a power of two * CHUNK. It cannot be 0. + // PERF[jpw]: this currently reconstructs the merkle tree from final memory; we should avoid + // this pub fn compute( memory_dimensions: MemoryDimensions, num_public_values: usize, hasher: &(impl Hasher + Sync), - final_memory: &MemoryImage, + final_memory: &MemoryImage, ) -> Self { let proof = compute_merkle_proof_to_user_public_values_root( memory_dimensions, @@ -59,8 +60,7 @@ impl UserPublicValuesProof { hasher, final_memory, ); - let public_values = - extract_public_values(&memory_dimensions, num_public_values, final_memory); + let public_values = extract_public_values(num_public_values, final_memory); let public_values_commit = hasher.merkle_root(&public_values); UserPublicValuesProof { proof, @@ -81,7 +81,7 @@ impl UserPublicValuesProof { // 2. Compare user public values commitment with Merkle root of user public values. let pv_commit = self.public_values_commit; // 0. - let pv_as = PUBLIC_VALUES_ADDRESS_SPACE_OFFSET + memory_dimensions.as_offset; + let pv_as = PUBLIC_VALUES_AS; let pv_start_idx = memory_dimensions.label_to_index((pv_as, 0)); let pvs = &self.public_values; if pvs.len() % CHUNK != 0 || !(pvs.len() / CHUNK).is_power_of_two() { @@ -121,14 +121,14 @@ fn compute_merkle_proof_to_user_public_values_root + Sync), - final_memory: &MemoryImage, + final_memory: &MemoryImage, ) -> Vec<[F; CHUNK]> { assert_eq!( num_public_values % CHUNK, 0, "num_public_values must be a multiple of memory chunk {CHUNK}" ); - let root = MemoryNode::tree_from_memory(memory_dimensions, final_memory, hasher); + let tree = MerkleTree::::from_memory(final_memory, &memory_dimensions, hasher); let num_pv_chunks: usize = num_public_values / CHUNK; // This enforces the number of public values cannot be 0. assert!( @@ -138,63 +138,48 @@ fn compute_merkle_proof_to_user_public_values_root( - memory_dimensions: &MemoryDimensions, num_public_values: usize, - final_memory: &MemoryImage, + final_memory: &MemoryImage, ) -> Vec { - // All (addr, value) pairs in the public value address space. - let f_as_start = PUBLIC_VALUES_ADDRESS_SPACE_OFFSET + memory_dimensions.as_offset; - let f_as_end = PUBLIC_VALUES_ADDRESS_SPACE_OFFSET + memory_dimensions.as_offset + 1; - - // This clones the entire memory. Ideally this should run in time proportional to - // the size of the PV address space, not entire memory. - let final_memory: BTreeMap = final_memory.items().collect(); + let mut public_values: Vec = { + // TODO: make constant for public values cell size + assert_eq!(final_memory.cell_size[PUBLIC_VALUES_AS as usize], 1); + final_memory.mem[PUBLIC_VALUES_AS as usize] + .as_slice() + .iter() + .map(|&x| F::from_canonical_u8(x)) + .collect() + }; - let used_pvs: Vec<_> = final_memory - .range((f_as_start, 0)..(f_as_end, 0)) - .map(|(&(_, pointer), &value)| (pointer as usize, value)) - .collect(); - if let Some(&last_pv) = used_pvs.last() { - assert!( - last_pv.0 < num_public_values || last_pv.1 == F::ZERO, - "Last public value is out of bounds" - ); - } - let mut public_values = F::zero_vec(num_public_values); - for (i, pv) in used_pvs { - if i < num_public_values { - public_values[i] = pv; - } - } + assert!( + public_values.len() >= num_public_values, + "Public values address space has {} elements, but configuration has num_public_values={}", + public_values.len(), + num_public_values + ); + public_values.truncate(num_public_values); public_values } @@ -203,27 +188,30 @@ mod tests { use openvm_stark_backend::p3_field::FieldAlgebra; use openvm_stark_sdk::p3_baby_bear::BabyBear; - use super::{UserPublicValuesProof, PUBLIC_VALUES_ADDRESS_SPACE_OFFSET}; + use super::UserPublicValuesProof; use crate::{ arch::{hasher::poseidon2::vm_poseidon2_hasher, SystemConfig}, - system::memory::{paged_vec::AddressMap, tree::MemoryNode, CHUNK}, + system::memory::{ + merkle::{public_values::PUBLIC_VALUES_AS, tree::MerkleTree}, + online::GuestMemory, + AddressMap, CHUNK, + }, }; type F = BabyBear; #[test] fn test_public_value_happy_path() { let mut vm_config = SystemConfig::default(); - vm_config.memory_config.as_height = 4; + vm_config.memory_config.addr_space_height = 4; vm_config.memory_config.pointer_max_bits = 5; let memory_dimensions = vm_config.memory_config.memory_dimensions(); - let pv_as = PUBLIC_VALUES_ADDRESS_SPACE_OFFSET + memory_dimensions.as_offset; let num_public_values = 16; - let memory = AddressMap::from_iter( - memory_dimensions.as_offset, - 1 << memory_dimensions.as_height, - 1 << memory_dimensions.address_height, - [((pv_as, 15), F::ONE)], - ); + let mut memory = GuestMemory { + memory: AddressMap::new(vec![0, 0, 0, num_public_values]), + }; + unsafe { + memory.write::(PUBLIC_VALUES_AS, 12, [0, 0, 0, 1]); + } let mut expected_pvs = F::zero_vec(num_public_values); expected_pvs[15] = F::ONE; @@ -232,12 +220,13 @@ mod tests { memory_dimensions, num_public_values, &hasher, - &memory, + &memory.memory, ); assert_eq!(pv_proof.public_values, expected_pvs); - let final_memory_root = MemoryNode::tree_from_memory(memory_dimensions, &memory, &hasher); + let final_memory_root = + MerkleTree::from_memory(&memory.memory, &memory_dimensions, &hasher).root(); pv_proof - .verify(&hasher, memory_dimensions, final_memory_root.hash()) + .verify(&hasher, memory_dimensions, final_memory_root) .unwrap(); } } diff --git a/crates/vm/src/system/memory/merkle/tests/mod.rs b/crates/vm/src/system/memory/merkle/tests/mod.rs index 05c966dc23..ecb48d1130 100644 --- a/crates/vm/src/system/memory/merkle/tests/mod.rs +++ b/crates/vm/src/system/memory/merkle/tests/mod.rs @@ -7,7 +7,7 @@ use std::{ use openvm_stark_backend::{ interaction::{PermutationCheckBus, PermutationInteractionType}, - p3_field::FieldAlgebra, + p3_field::{FieldAlgebra, PrimeField32}, p3_matrix::dense::RowMajorMatrix, prover::types::AirProofInput, Chip, ChipUsageGetter, @@ -19,14 +19,18 @@ use openvm_stark_sdk::{ }; use rand::RngCore; +use super::memory_to_partition; use crate::{ - arch::testing::{MEMORY_MERKLE_BUS, POSEIDON2_DIRECT_BUS}, + arch::{ + ADDR_SPACE_OFFSET, + testing::{MEMORY_MERKLE_BUS, POSEIDON2_DIRECT_BUS}, + }, system::memory::{ merkle::{ columns::MemoryMerkleCols, tests::util::HashTestChip, MemoryDimensions, MemoryMerkleChip, }, - paged_vec::{AddressMap, PAGE_SIZE}, + AddressMap, tree::MemoryNode, Equipartition, MemoryImage, }, @@ -39,42 +43,42 @@ const COMPRESSION_BUS: PermutationCheckBus = PermutationCheckBus::new(POSEIDON2_ fn test( memory_dimensions: MemoryDimensions, - initial_memory: &MemoryImage, + initial_memory: &MemoryImage, touched_labels: BTreeSet<(u32, u32)>, - final_memory: &MemoryImage, + final_memory: &MemoryImage, ) { let MemoryDimensions { - as_height, + addr_space_height, address_height, - as_offset, } = memory_dimensions; let merkle_bus = PermutationCheckBus::new(MEMORY_MERKLE_BUS); // checking validity of test data - for ((address_space, pointer), value) in final_memory.items() { + for ((address_space, pointer), value) in final_memory.items::() { let label = pointer / CHUNK as u32; - assert!(address_space - as_offset < (1 << as_height)); - assert!(pointer < ((CHUNK << address_height).div_ceil(PAGE_SIZE) * PAGE_SIZE) as u32); - if initial_memory.get(&(address_space, pointer)) != Some(&value) { + assert!(address_space - ADDR_SPACE_OFFSET < (1 << addr_space_height)); + assert!(pointer < (CHUNK << address_height) as u32); + if unsafe { initial_memory.get::((address_space, pointer)) } != value { assert!(touched_labels.contains(&(address_space, label))); } } - for key in initial_memory.items().map(|(key, _)| key) { - assert!(final_memory.get(&key).is_some()); - } - for &(address_space, label) in touched_labels.iter() { - let mut contains_some_key = false; - for i in 0..CHUNK { - if final_memory - .get(&(address_space, label * CHUNK as u32 + i as u32)) - .is_some() - { - contains_some_key = true; - break; - } - } - assert!(contains_some_key); - } + // for key in initial_memory.items().map(|(key, _)| key) { + // assert!(unsafe { final_memory.get(key).is_some() }); + // } + // for &(address_space, label) in touched_labels.iter() { + // let mut contains_some_key = false; + // for i in 0..CHUNK { + // if unsafe { + // final_memory + // .get((address_space, label * CHUNK as u32 + i as u32)) + // .is_some() + // } { + // contains_some_key = true; + // break; + // } + // } + // assert!(contains_some_key); + // } let mut hash_test_chip = HashTestChip::new(); @@ -126,13 +130,12 @@ fn test( }; for (address_space, address_label) in touched_labels { - let initial_values = array::from_fn(|i| { - initial_memory - .get(&(address_space, address_label * CHUNK as u32 + i as u32)) - .copied() - .unwrap_or_default() - }); - let as_label = address_space - as_offset; + let initial_values = unsafe { + array::from_fn(|i| { + initial_memory.get((address_space, address_label * CHUNK as u32 + i as u32)) + }) + }; + let as_label = address_space; interaction( PermutationInteractionType::Send, false, @@ -180,20 +183,6 @@ fn test( .expect("Verification failed"); } -fn memory_to_partition( - memory: &MemoryImage, -) -> Equipartition { - let mut memory_partition = Equipartition::new(); - for ((address_space, pointer), value) in memory.items() { - let label = (address_space, pointer / N as u32); - let chunk = memory_partition - .entry(label) - .or_insert_with(|| [F::default(); N]); - chunk[(pointer % N as u32) as usize] = value; - } - memory_partition -} - fn random_test( height: usize, max_value: u32, @@ -203,8 +192,12 @@ fn random_test( let mut rng = create_seeded_rng(); let mut next_u32 = || rng.next_u64() as u32; - let mut initial_memory = AddressMap::new(1, 2, CHUNK << height); - let mut final_memory = AddressMap::new(1, 2, CHUNK << height); + let as_cnt = 2; + let mut initial_memory = AddressMap::new(vec![CHUNK << height; as_cnt]); + let mut final_memory = AddressMap::new(vec![CHUNK << height; as_cnt]); + // TEMP[jpw]: override so address space uses field element + initial_memory.cell_size = vec![4; as_cnt]; + final_memory.cell_size = vec![4; as_cnt]; let mut seen = HashSet::new(); let mut touched_labels = BTreeSet::new(); @@ -221,15 +214,19 @@ fn random_test( if is_initial && num_initial_addresses != 0 { num_initial_addresses -= 1; let value = BabyBear::from_canonical_u32(next_u32() % max_value); - initial_memory.insert(&(address_space, pointer), value); - final_memory.insert(&(address_space, pointer), value); + unsafe { + initial_memory.insert((address_space, pointer), value); + final_memory.insert((address_space, pointer), value); + } } if is_touched && num_touched_addresses != 0 { num_touched_addresses -= 1; touched_labels.insert((address_space, label)); if value_changes || !is_initial { let value = BabyBear::from_canonical_u32(next_u32() % max_value); - final_memory.insert(&(address_space, pointer), value); + unsafe { + final_memory.insert((address_space, pointer), value); + } } } } @@ -237,9 +234,8 @@ fn random_test( test::( MemoryDimensions { - as_height: 1, + addr_space_height: 1, address_height: height, - as_offset: 1, }, &initial_memory, touched_labels, @@ -265,16 +261,13 @@ fn expand_test_2() { #[test] fn expand_test_no_accesses() { let memory_dimensions = MemoryDimensions { - as_height: 2, + addr_space_height: 2, address_height: 1, - as_offset: 7, }; let mut hash_test_chip = HashTestChip::new(); let memory = AddressMap::new( - memory_dimensions.as_offset, - 1 << memory_dimensions.as_height, - 1 << memory_dimensions.address_height, + vec![1 << memory_dimensions.address_height; 1 + (1 << memory_dimensions.addr_space_height)], ); let tree = MemoryNode::::tree_from_memory( memory_dimensions, @@ -304,17 +297,14 @@ fn expand_test_no_accesses() { #[should_panic] fn expand_test_negative() { let memory_dimensions = MemoryDimensions { - as_height: 2, + addr_space_height: 2, address_height: 1, - as_offset: 7, }; let mut hash_test_chip = HashTestChip::new(); let memory = AddressMap::new( - memory_dimensions.as_offset, - 1 << memory_dimensions.as_height, - 1 << memory_dimensions.address_height, + vec![1 << memory_dimensions.address_height; 1 + (1 << memory_dimensions.addr_space_height)], ); let tree = MemoryNode::::tree_from_memory( memory_dimensions, diff --git a/crates/vm/src/system/memory/merkle/trace.rs b/crates/vm/src/system/memory/merkle/trace.rs index 52609f259a..cfc9cf24f0 100644 --- a/crates/vm/src/system/memory/merkle/trace.rs +++ b/crates/vm/src/system/memory/merkle/trace.rs @@ -1,6 +1,5 @@ use std::{ borrow::BorrowMut, - cmp::Reverse, sync::{atomic::AtomicU32, Arc}, }; @@ -11,16 +10,14 @@ use openvm_stark_backend::{ prover::types::AirProofInput, AirRef, Chip, ChipUsageGetter, }; -use rustc_hash::FxHashSet; +use tracing::instrument; use crate::{ arch::hasher::HasherChip, system::{ memory::{ - controller::dimensions::MemoryDimensions, - merkle::{FinalState, MemoryMerkleChip, MemoryMerkleCols}, - tree::MemoryNode::{self, NonLeaf}, - Equipartition, + merkle::{tree::MerkleTree, FinalState, MemoryMerkleChip, MemoryMerkleCols}, + Equipartition, MemoryImage, }, poseidon2::{ Poseidon2PeripheryBaseChip, Poseidon2PeripheryChip, PERIPHERY_POSEIDON2_WIDTH, @@ -29,39 +26,17 @@ use crate::{ }; impl MemoryMerkleChip { + #[instrument(name = "merkle_finalize", skip_all)] pub fn finalize( &mut self, - initial_tree: &MemoryNode, + initial_memory: &MemoryImage, final_memory: &Equipartition, hasher: &mut impl HasherChip, ) { assert!(self.final_state.is_none(), "Merkle chip already finalized"); - // there needs to be a touched node with `height_section` = 0 - // shouldn't be a leaf because - // trace generation will expect an interaction from MemoryInterfaceChip in that case - if self.touched_nodes.len() == 1 { - self.touch_node(1, 0, 0); - } - - let mut rows = vec![]; - let mut tree_helper = TreeHelper { - memory_dimensions: self.air.memory_dimensions, - final_memory, - touched_nodes: &self.touched_nodes, - trace_rows: &mut rows, - }; - let final_tree = tree_helper.recur( - self.air.memory_dimensions.overall_height(), - initial_tree, - 0, - 0, - hasher, - ); - self.final_state = Some(FinalState { - rows, - init_root: initial_tree.hash(), - final_root: final_tree.hash(), - }); + let mut tree = MerkleTree::from_memory(initial_memory, &self.air.memory_dimensions, hasher); + self.final_state = Some(tree.finalize(hasher, final_memory, &self.air.memory_dimensions)); + self.trace_height = Some(self.final_state.as_ref().unwrap().rows.len()); } } @@ -85,7 +60,8 @@ where } = self.final_state.unwrap(); // important that this sort be stable, // because we need the initial root to be first and the final root to be second - rows.sort_by_key(|row| Reverse(row.parent_height)); + rows.reverse(); + rows.swap(0, 1); let width = MemoryMerkleCols::, CHUNK>::width(); let mut height = rows.len().next_power_of_two(); @@ -114,7 +90,8 @@ impl ChipUsageGetter for MemoryMerkleChip usize { - 2 * self.num_touched_nonleaves + // TODO is it ok? + self.trace_height.unwrap_or(0) } fn trace_width(&self) -> usize { @@ -122,136 +99,6 @@ impl ChipUsageGetter for MemoryMerkleChip { - memory_dimensions: MemoryDimensions, - final_memory: &'a Equipartition, - touched_nodes: &'a FxHashSet<(usize, u32, u32)>, - trace_rows: &'a mut Vec>, -} - -impl TreeHelper<'_, CHUNK, F> { - fn recur( - &mut self, - height: usize, - initial_node: &MemoryNode, - as_label: u32, - address_label: u32, - hasher: &mut impl HasherChip, - ) -> MemoryNode { - if height == 0 { - let address_space = as_label + self.memory_dimensions.as_offset; - let leaf_values = *self - .final_memory - .get(&(address_space, address_label)) - .unwrap_or(&[F::ZERO; CHUNK]); - MemoryNode::new_leaf(hasher.hash(&leaf_values)) - } else if let NonLeaf { - left: initial_left_node, - right: initial_right_node, - .. - } = initial_node.clone() - { - // Tell the hasher about this hash. - hasher.compress_and_record(&initial_left_node.hash(), &initial_right_node.hash()); - - let is_as_section = height > self.memory_dimensions.address_height; - - let (left_as_label, right_as_label) = if is_as_section { - (2 * as_label, 2 * as_label + 1) - } else { - (as_label, as_label) - }; - let (left_address_label, right_address_label) = if is_as_section { - (address_label, address_label) - } else { - (2 * address_label, 2 * address_label + 1) - }; - - let left_is_final = - !self - .touched_nodes - .contains(&(height - 1, left_as_label, left_address_label)); - - let final_left_node = if left_is_final { - initial_left_node - } else { - Arc::new(self.recur( - height - 1, - &initial_left_node, - left_as_label, - left_address_label, - hasher, - )) - }; - - let right_is_final = - !self - .touched_nodes - .contains(&(height - 1, right_as_label, right_address_label)); - - let final_right_node = if right_is_final { - initial_right_node - } else { - Arc::new(self.recur( - height - 1, - &initial_right_node, - right_as_label, - right_address_label, - hasher, - )) - }; - - let final_node = MemoryNode::new_nonleaf(final_left_node, final_right_node, hasher); - self.add_trace_row(height, as_label, address_label, initial_node, None); - self.add_trace_row( - height, - as_label, - address_label, - &final_node, - Some([left_is_final, right_is_final]), - ); - final_node - } else { - panic!("Leaf {:?} found at nonzero height {}", initial_node, height); - } - } - - /// Expects `node` to be NonLeaf - fn add_trace_row( - &mut self, - parent_height: usize, - as_label: u32, - address_label: u32, - node: &MemoryNode, - direction_changes: Option<[bool; 2]>, - ) { - let [left_direction_change, right_direction_change] = - direction_changes.unwrap_or([false; 2]); - let cols = if let NonLeaf { hash, left, right } = node { - MemoryMerkleCols { - expand_direction: if direction_changes.is_none() { - F::ONE - } else { - F::NEG_ONE - }, - height_section: F::from_bool(parent_height > self.memory_dimensions.address_height), - parent_height: F::from_canonical_usize(parent_height), - is_root: F::from_bool(parent_height == self.memory_dimensions.overall_height()), - parent_as_label: F::from_canonical_u32(as_label), - parent_address_label: F::from_canonical_u32(address_label), - parent_hash: *hash, - left_child_hash: left.hash(), - right_child_hash: right.hash(), - left_direction_different: F::from_bool(left_direction_change), - right_direction_different: F::from_bool(right_direction_change), - } - } else { - panic!("trace_rows expects node = {:?} to be NonLeaf", node); - }; - self.trace_rows.push(cols); - } -} - pub trait SerialReceiver { fn receive(&mut self, msg: T); } diff --git a/crates/vm/src/system/memory/merkle/tree.rs b/crates/vm/src/system/memory/merkle/tree.rs new file mode 100644 index 0000000000..599cc5a919 --- /dev/null +++ b/crates/vm/src/system/memory/merkle/tree.rs @@ -0,0 +1,267 @@ +use openvm_stark_backend::{ + p3_field::PrimeField32, + p3_maybe_rayon::prelude::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator}, +}; +use rustc_hash::FxHashMap; + +use super::{FinalState, MemoryMerkleCols}; +use crate::{ + arch::hasher::{Hasher, HasherChip}, + system::memory::{ + dimensions::MemoryDimensions, merkle::memory_to_vec_partition, AddressMap, Equipartition, + }, +}; + +#[derive(Debug)] +pub struct MerkleTree { + /// Height of the tree -- the root is the only node at height `height`, + /// and the leaves are at height `0`. + height: usize, + /// Nodes corresponding to all zeroes. + zero_nodes: Vec<[F; CHUNK]>, + /// Nodes in the tree that have ever been touched. + nodes: FxHashMap, +} + +impl MerkleTree { + pub fn new(height: usize, hasher: &impl Hasher) -> Self { + Self { + height, + zero_nodes: (0..height + 1) + .scan(hasher.hash(&[F::ZERO; CHUNK]), |acc, _| { + let result = Some(*acc); + *acc = hasher.compress(acc, acc); + result + }) + .collect(), + nodes: FxHashMap::default(), + } + } + + pub fn root(&self) -> [F; CHUNK] { + self.get_node(1) + } + + pub fn get_node(&self, index: u64) -> [F; CHUNK] { + self.nodes + .get(&index) + .cloned() + .unwrap_or(self.zero_nodes[self.height - index.ilog2() as usize]) + } + + #[allow(clippy::type_complexity)] + /// Shared logic for both from_memory and finalize. + fn process_layers( + &mut self, + layer: Vec<(u64, [F; CHUNK])>, + md: &MemoryDimensions, + mut rows: Option<&mut Vec>>, + compress: CompressFn, + ) where + CompressFn: Fn(&[F; CHUNK], &[F; CHUNK]) -> [F; CHUNK] + Send + Sync, + { + let mut new_entries = layer; + let mut layer = new_entries + .par_iter() + .map(|(index, values)| { + let old_values = self.nodes.get(index).unwrap_or(&self.zero_nodes[0]); + (*index, *values, *old_values) + }) + .collect::>(); + for height in 1..=self.height { + let new_layer = layer + .iter() + .enumerate() + .filter_map(|(i, (index, values, old_values))| { + if i > 0 && layer[i - 1].0 ^ 1 == *index { + return None; + } + + let par_index = index >> 1; + + if i + 1 < layer.len() && layer[i + 1].0 == index ^ 1 { + let (_, sibling_values, sibling_old_values) = &layer[i + 1]; + Some(( + par_index, + Some((values, old_values)), + Some((sibling_values, sibling_old_values)), + )) + } else if index & 1 == 0 { + Some((par_index, Some((values, old_values)), None)) + } else { + Some((par_index, None, Some((values, old_values)))) + } + }) + .collect::>(); + + match rows { + None => { + layer = new_layer + .into_par_iter() + .map(|(par_index, left, right)| { + let left = if let Some(left) = left { + left.0 + } else { + &self.get_node(2 * par_index) + }; + let right = if let Some(right) = right { + right.0 + } else { + &self.get_node(2 * par_index + 1) + }; + let combined = compress(left, right); + let par_old_values = self.get_node(par_index); + (par_index, combined, par_old_values) + }) + .collect(); + } + Some(ref mut rows) => { + let label_section_height = md.address_height.saturating_sub(height); + let (tmp, new_rows): (Vec<(u64, [F; CHUNK], [F; CHUNK])>, Vec<[_; 2]>) = + new_layer + .into_par_iter() + .map(|(par_index, left, right)| { + let parent_address_label = + (par_index & ((1 << label_section_height) - 1)) as u32; + let parent_as_label = ((par_index & !(1 << (self.height - height))) + >> label_section_height) + as u32; + let left_node; + let (left, old_left, changed_left) = match left { + Some((left, old_left)) => (left, old_left, true), + None => { + left_node = self.get_node(2 * par_index); + (&left_node, &left_node, false) + } + }; + let right_node; + let (right, old_right, changed_right) = match right { + Some((right, old_right)) => (right, old_right, true), + None => { + right_node = self.get_node(2 * par_index + 1); + (&right_node, &right_node, false) + } + }; + let combined = compress(left, right); + // This is a hacky way to say: + // "and we also want to record the old values" + compress(old_left, old_right); + let par_old_values = self.get_node(par_index); + ( + (par_index, combined, par_old_values), + [ + MemoryMerkleCols { + expand_direction: F::ONE, + height_section: F::from_bool( + height > md.address_height, + ), + parent_height: F::from_canonical_usize(height), + is_root: F::from_bool(height == md.overall_height()), + parent_as_label: F::from_canonical_u32(parent_as_label), + parent_address_label: F::from_canonical_u32( + parent_address_label, + ), + parent_hash: par_old_values, + left_child_hash: *old_left, + right_child_hash: *old_right, + left_direction_different: F::ZERO, + right_direction_different: F::ZERO, + }, + MemoryMerkleCols { + expand_direction: F::NEG_ONE, + height_section: F::from_bool( + height > md.address_height, + ), + parent_height: F::from_canonical_usize(height), + is_root: F::from_bool(height == md.overall_height()), + parent_as_label: F::from_canonical_u32(parent_as_label), + parent_address_label: F::from_canonical_u32( + parent_address_label, + ), + parent_hash: combined, + left_child_hash: *left, + right_child_hash: *right, + left_direction_different: F::from_bool(!changed_left), + right_direction_different: F::from_bool(!changed_right), + }, + ], + ) + }) + .unzip(); + rows.extend(new_rows.into_iter().flatten()); + layer = tmp; + } + } + new_entries.extend(layer.iter().map(|(idx, values, _)| (*idx, *values))); + } + + if self.nodes.is_empty() { + // This, for example, should happen in every `from_memory` call + self.nodes = FxHashMap::from_iter(new_entries); + } else { + self.nodes.extend(new_entries); + } + } + + pub fn from_memory( + memory: &AddressMap, + md: &MemoryDimensions, + hasher: &(impl Hasher + Sync), + ) -> Self { + let mut tree = Self::new(md.overall_height(), hasher); + let layer: Vec<_> = memory_to_vec_partition(memory, md) + .par_iter() + .map(|(idx, v)| ((1 << tree.height) + idx, hasher.hash(v))) + .collect(); + tree.process_layers(layer, md, None, |left, right| hasher.compress(left, right)); + tree + } + + pub fn finalize( + &mut self, + hasher: &mut impl HasherChip, + touched: &Equipartition, + md: &MemoryDimensions, + ) -> FinalState { + let init_root = self.get_node(1); + let layer: Vec<_> = if !touched.is_empty() { + touched + .iter() + .map(|((addr_sp, ptr), v)| { + ( + (1 << self.height) + md.label_to_index((*addr_sp, *ptr / CHUNK as u32)), + hasher.hash(v), + ) + }) + .collect() + } else { + let index = 1 << self.height; + vec![(index, self.get_node(index))] + }; + let mut rows = Vec::with_capacity(if layer.is_empty() { + 0 + } else { + layer + .iter() + .zip(layer.iter().skip(1)) + .fold(md.overall_height(), |acc, ((lhs, _), (rhs, _))| { + acc + (lhs ^ rhs).ilog2() as usize + }) + }); + self.process_layers(layer, md, Some(&mut rows), |left, right| { + hasher.compress_and_record(left, right) + }); + if touched.is_empty() { + // If we made an artificial touch, we need to change the direction changes for the + // leaves + rows[1].left_direction_different = F::ONE; + rows[1].right_direction_different = F::ONE; + } + let final_root = self.get_node(1); + FinalState { + rows, + init_root, + final_root, + } + } +} diff --git a/crates/vm/src/system/memory/mod.rs b/crates/vm/src/system/memory/mod.rs index ac6a7d85cf..a776d99a73 100644 --- a/crates/vm/src/system/memory/mod.rs +++ b/crates/vm/src/system/memory/mod.rs @@ -1,21 +1,20 @@ use openvm_circuit_primitives_derive::AlignedBorrow; -mod adapter; +pub mod adapter; mod controller; pub mod merkle; -mod offline; pub mod offline_checker; pub mod online; -pub mod paged_vec; -mod persistent; -#[cfg(test)] -mod tests; -pub mod tree; -mod volatile; +pub mod persistent; +// TODO: add back +// #[cfg(test)] +// mod tests; +pub mod volatile; pub use controller::*; -pub use offline::*; -pub use paged_vec::*; +pub use online::{Address, AddressMap, INITIAL_TIMESTAMP}; + +pub const POINTER_MAX_BITS: usize = 29; #[derive(PartialEq, Copy, Clone, Debug, Eq)] pub enum OpType { diff --git a/crates/vm/src/system/memory/offline.rs b/crates/vm/src/system/memory/offline.rs deleted file mode 100644 index 74bb238811..0000000000 --- a/crates/vm/src/system/memory/offline.rs +++ /dev/null @@ -1,1070 +0,0 @@ -use std::{array, cmp::max}; - -use openvm_circuit_primitives::{ - assert_less_than::AssertLtSubAir, var_range::SharedVariableRangeCheckerChip, -}; -use openvm_stark_backend::p3_field::PrimeField32; -use rustc_hash::FxHashSet; - -use super::{AddressMap, PagedVec, PAGE_SIZE}; -use crate::{ - arch::MemoryConfig, - system::memory::{ - adapter::{AccessAdapterInventory, AccessAdapterRecord, AccessAdapterRecordKind}, - offline_checker::{MemoryBridge, MemoryBus}, - MemoryAuxColsFactory, MemoryImage, RecordId, TimestampedEquipartition, TimestampedValues, - }, -}; - -pub const INITIAL_TIMESTAMP: u32 = 0; - -#[repr(C)] -#[derive(Clone, Default, PartialEq, Eq, Debug)] -struct BlockData { - pointer: u32, - timestamp: u32, - size: usize, -} - -struct BlockMap { - /// Block ids. 0 is a special value standing for the default block. - id: AddressMap, - /// The place where non-default blocks are stored. - storage: Vec, - initial_block_size: usize, -} - -impl BlockMap { - pub fn from_mem_config(mem_config: &MemoryConfig, initial_block_size: usize) -> Self { - assert!(initial_block_size.is_power_of_two()); - Self { - id: AddressMap::from_mem_config(mem_config), - storage: vec![], - initial_block_size, - } - } - - fn initial_block_data(pointer: u32, initial_block_size: usize) -> BlockData { - let aligned_pointer = (pointer / initial_block_size as u32) * initial_block_size as u32; - BlockData { - pointer: aligned_pointer, - size: initial_block_size, - timestamp: INITIAL_TIMESTAMP, - } - } - - pub fn get_without_adding(&self, address: &(u32, u32)) -> BlockData { - let idx = self.id.get(address).unwrap_or(&0); - if idx == &0 { - Self::initial_block_data(address.1, self.initial_block_size) - } else { - self.storage[idx - 1].clone() - } - } - - pub fn get(&mut self, address: &(u32, u32)) -> &BlockData { - let (address_space, pointer) = *address; - let idx = self.id.get(&(address_space, pointer)).unwrap_or(&0); - if idx == &0 { - // `initial_block_size` is a power of two, as asserted in `from_mem_config`. - let pointer = pointer & !(self.initial_block_size as u32 - 1); - self.set_range( - &(address_space, pointer), - self.initial_block_size, - Self::initial_block_data(pointer, self.initial_block_size), - ); - self.storage.last().unwrap() - } else { - &self.storage[idx - 1] - } - } - - pub fn get_mut(&mut self, address: &(u32, u32)) -> &mut BlockData { - let (address_space, pointer) = *address; - let idx = self.id.get(&(address_space, pointer)).unwrap_or(&0); - if idx == &0 { - let pointer = pointer - pointer % self.initial_block_size as u32; - self.set_range( - &(address_space, pointer), - self.initial_block_size, - Self::initial_block_data(pointer, self.initial_block_size), - ); - self.storage.last_mut().unwrap() - } else { - &mut self.storage[idx - 1] - } - } - - pub fn set_range(&mut self, address: &(u32, u32), len: usize, block: BlockData) { - let (address_space, pointer) = address; - self.storage.push(block); - for i in 0..len { - self.id - .insert(&(*address_space, pointer + i as u32), self.storage.len()); - } - } - - pub fn items(&self) -> impl Iterator + '_ { - self.id - .items() - .filter(|(_, idx)| *idx > 0) - .map(|(address, idx)| (address, &self.storage[idx - 1])) - } -} - -#[derive(Debug, Clone, PartialEq)] -pub struct MemoryRecord { - pub address_space: T, - pub pointer: T, - pub timestamp: u32, - pub prev_timestamp: u32, - data: Vec, - /// None if a read. - prev_data: Option>, -} - -impl MemoryRecord { - pub fn data_slice(&self) -> &[T] { - self.data.as_slice() - } - - pub fn prev_data_slice(&self) -> Option<&[T]> { - self.prev_data.as_deref() - } -} - -impl MemoryRecord { - pub fn data_at(&self, index: usize) -> T { - self.data[index] - } -} - -pub struct OfflineMemory { - block_data: BlockMap, - data: Vec>, - as_offset: u32, - timestamp: u32, - timestamp_max_bits: usize, - - memory_bus: MemoryBus, - range_checker: SharedVariableRangeCheckerChip, - - log: Vec>>, -} - -impl OfflineMemory { - /// Creates a new partition with the given initial block size. - /// - /// Panics if the initial block size is not a power of two. - pub fn new( - initial_memory: MemoryImage, - initial_block_size: usize, - memory_bus: MemoryBus, - range_checker: SharedVariableRangeCheckerChip, - config: MemoryConfig, - ) -> Self { - assert_eq!(initial_memory.as_offset, config.as_offset); - Self { - block_data: BlockMap::from_mem_config(&config, initial_block_size), - data: initial_memory.paged_vecs, - as_offset: config.as_offset, - timestamp: INITIAL_TIMESTAMP + 1, - timestamp_max_bits: config.clk_max_bits, - memory_bus, - range_checker, - log: vec![], - } - } - - pub fn set_initial_memory(&mut self, initial_memory: MemoryImage, config: MemoryConfig) { - assert_eq!(self.timestamp, INITIAL_TIMESTAMP + 1); - assert_eq!(initial_memory.as_offset, config.as_offset); - self.as_offset = config.as_offset; - self.data = initial_memory.paged_vecs; - } - - pub(super) fn set_log_capacity(&mut self, access_capacity: usize) { - assert!(self.log.is_empty()); - self.log = Vec::with_capacity(access_capacity); - } - - pub fn memory_bridge(&self) -> MemoryBridge { - MemoryBridge::new( - self.memory_bus, - self.timestamp_max_bits, - self.range_checker.bus(), - ) - } - - pub fn timestamp(&self) -> u32 { - self.timestamp - } - - /// Increments the current timestamp by one and returns the new value. - pub fn increment_timestamp(&mut self) { - self.increment_timestamp_by(1) - } - - /// Increments the current timestamp by a specified delta and returns the new value. - pub fn increment_timestamp_by(&mut self, delta: u32) { - self.log.push(None); - self.timestamp += delta; - } - - /// Writes an array of values to the memory at the specified address space and start index. - pub fn write( - &mut self, - address_space: u32, - pointer: u32, - values: Vec, - records: &mut AccessAdapterInventory, - ) { - let len = values.len(); - assert!(len.is_power_of_two()); - assert_ne!(address_space, 0); - - let prev_timestamp = self.access_updating_timestamp(address_space, pointer, len, records); - - debug_assert!(prev_timestamp < self.timestamp); - - let pointer = pointer as usize; - let prev_data = self.data[(address_space - self.as_offset) as usize] - .set_range(pointer..pointer + len, &values); - - let record = MemoryRecord { - address_space: F::from_canonical_u32(address_space), - pointer: F::from_canonical_usize(pointer), - timestamp: self.timestamp, - prev_timestamp, - data: values, - prev_data: Some(prev_data), - }; - self.log.push(Some(record)); - self.timestamp += 1; - } - - /// Reads an array of values from the memory at the specified address space and start index. - pub fn read( - &mut self, - address_space: u32, - pointer: u32, - len: usize, - adapter_records: &mut AccessAdapterInventory, - ) { - assert!(len.is_power_of_two()); - if address_space == 0 { - let pointer = F::from_canonical_u32(pointer); - self.log.push(Some(MemoryRecord { - address_space: F::ZERO, - pointer, - timestamp: self.timestamp, - prev_timestamp: 0, - data: vec![pointer], - prev_data: None, - })); - self.timestamp += 1; - return; - } - - let prev_timestamp = - self.access_updating_timestamp(address_space, pointer, len, adapter_records); - - debug_assert!(prev_timestamp < self.timestamp); - - let values = self.range_vec(address_space, pointer, len); - - self.log.push(Some(MemoryRecord { - address_space: F::from_canonical_u32(address_space), - pointer: F::from_canonical_u32(pointer), - timestamp: self.timestamp, - prev_timestamp, - data: values, - prev_data: None, - })); - self.timestamp += 1; - } - - pub fn record_by_id(&self, id: RecordId) -> &MemoryRecord { - self.log[id.0].as_ref().unwrap() - } - - pub fn finalize( - &mut self, - adapter_records: &mut AccessAdapterInventory, - ) -> TimestampedEquipartition { - // First make sure the partition we maintain in self.block_data is an equipartition. - // Grab all aligned pointers that need to be re-accessed. - let to_access: FxHashSet<_> = self - .block_data - .items() - .map(|((address_space, pointer), _)| (address_space, (pointer / N as u32) * N as u32)) - .collect(); - - for &(address_space, pointer) in to_access.iter() { - let block = self.block_data.get(&(address_space, pointer)); - if block.pointer != pointer || block.size != N { - self.access(address_space, pointer, N, adapter_records); - } - } - - let mut equipartition = TimestampedEquipartition::::new(); - for (address_space, pointer) in to_access { - let block = self.block_data.get(&(address_space, pointer)); - - debug_assert_eq!(block.pointer % N as u32, 0); - debug_assert_eq!(block.size, N); - - equipartition.insert( - (address_space, pointer / N as u32), - TimestampedValues { - timestamp: block.timestamp, - values: self.range_array::(address_space, pointer), - }, - ); - } - equipartition - } - - // Modifies the partition to ensure that there is a block starting at (address_space, query). - fn split_to_make_boundary( - &mut self, - address_space: u32, - query: u32, - records: &mut AccessAdapterInventory, - ) { - let lim = (self.data[(address_space - self.as_offset) as usize].memory_size()) as u32; - if query == lim { - return; - } - assert!(query < lim); - let original_block = self.block_containing(address_space, query); - if original_block.pointer == query { - return; - } - - let data = self.range_vec(address_space, original_block.pointer, original_block.size); - - let timestamp = original_block.timestamp; - - let mut cur_ptr = original_block.pointer; - let mut cur_size = original_block.size; - while cur_size > 0 { - // Split. - records.add_record(AccessAdapterRecord { - timestamp, - address_space: F::from_canonical_u32(address_space), - start_index: F::from_canonical_u32(cur_ptr), - data: data[(cur_ptr - original_block.pointer) as usize - ..(cur_ptr - original_block.pointer) as usize + cur_size] - .to_vec(), - kind: AccessAdapterRecordKind::Split, - }); - - let half_size = cur_size / 2; - let half_size_u32 = half_size as u32; - let mid_ptr = cur_ptr + half_size_u32; - - if query <= mid_ptr { - // The right is finalized; add it to the partition. - let block = BlockData { - pointer: mid_ptr, - size: half_size, - timestamp, - }; - self.block_data - .set_range(&(address_space, mid_ptr), half_size, block); - } - if query >= cur_ptr + half_size_u32 { - // The left is finalized; add it to the partition. - let block = BlockData { - pointer: cur_ptr, - size: half_size, - timestamp, - }; - self.block_data - .set_range(&(address_space, cur_ptr), half_size, block); - } - if mid_ptr <= query { - cur_ptr = mid_ptr; - } - if cur_ptr == query { - break; - } - cur_size = half_size; - } - } - - fn access_updating_timestamp( - &mut self, - address_space: u32, - pointer: u32, - size: usize, - records: &mut AccessAdapterInventory, - ) -> u32 { - self.access(address_space, pointer, size, records); - - let mut prev_timestamp = None; - - let mut i = 0; - while i < size as u32 { - let block = self.block_data.get_mut(&(address_space, pointer + i)); - debug_assert!(i == 0 || prev_timestamp == Some(block.timestamp)); - prev_timestamp = Some(block.timestamp); - block.timestamp = self.timestamp; - i = block.pointer + block.size as u32; - } - prev_timestamp.unwrap() - } - - fn access( - &mut self, - address_space: u32, - pointer: u32, - size: usize, - records: &mut AccessAdapterInventory, - ) { - self.split_to_make_boundary(address_space, pointer, records); - self.split_to_make_boundary(address_space, pointer + size as u32, records); - - let block_data = self.block_containing(address_space, pointer); - - if block_data.pointer == pointer && block_data.size == size { - return; - } - assert!(size > 1); - - // Now recursively access left and right blocks to ensure they are in the partition. - let half_size = size / 2; - self.access(address_space, pointer, half_size, records); - self.access( - address_space, - pointer + half_size as u32, - half_size, - records, - ); - - self.merge_block_with_next(address_space, pointer, records); - } - - /// Merges the two adjacent blocks starting at (address_space, pointer). - /// - /// Panics if there is no block starting at (address_space, pointer) or if the two blocks - /// do not have the same size. - fn merge_block_with_next( - &mut self, - address_space: u32, - pointer: u32, - records: &mut AccessAdapterInventory, - ) { - let left_block = self.block_data.get(&(address_space, pointer)); - - let left_timestamp = left_block.timestamp; - let size = left_block.size; - - let right_timestamp = self - .block_data - .get(&(address_space, pointer + size as u32)) - .timestamp; - - let timestamp = max(left_timestamp, right_timestamp); - self.block_data.set_range( - &(address_space, pointer), - 2 * size, - BlockData { - pointer, - size: 2 * size, - timestamp, - }, - ); - records.add_record(AccessAdapterRecord { - timestamp, - address_space: F::from_canonical_u32(address_space), - start_index: F::from_canonical_u32(pointer), - data: self.range_vec(address_space, pointer, 2 * size), - kind: AccessAdapterRecordKind::Merge { - left_timestamp, - right_timestamp, - }, - }); - } - - fn block_containing(&mut self, address_space: u32, pointer: u32) -> BlockData { - self.block_data - .get_without_adding(&(address_space, pointer)) - } - - pub fn get(&self, address_space: u32, pointer: u32) -> F { - self.data[(address_space - self.as_offset) as usize] - .get(pointer as usize) - .cloned() - .unwrap_or_default() - } - - fn range_array(&self, address_space: u32, pointer: u32) -> [F; N] { - array::from_fn(|i| self.get(address_space, pointer + i as u32)) - } - - fn range_vec(&self, address_space: u32, pointer: u32, len: usize) -> Vec { - let pointer = pointer as usize; - self.data[(address_space - self.as_offset) as usize].range_vec(pointer..pointer + len) - } - - pub fn aux_cols_factory(&self) -> MemoryAuxColsFactory { - let range_bus = self.range_checker.bus(); - MemoryAuxColsFactory { - range_checker: self.range_checker.clone(), - timestamp_lt_air: AssertLtSubAir::new(range_bus, self.timestamp_max_bits), - _marker: Default::default(), - } - } - - // just for unit testing - #[cfg(test)] - fn last_record(&self) -> &MemoryRecord { - self.log.last().unwrap().as_ref().unwrap() - } -} - -#[cfg(test)] -mod tests { - use openvm_circuit_primitives::var_range::{ - SharedVariableRangeCheckerChip, VariableRangeCheckerBus, - }; - use openvm_stark_backend::p3_field::FieldAlgebra; - use openvm_stark_sdk::p3_baby_bear::BabyBear; - - use super::{BlockData, MemoryRecord, OfflineMemory}; - use crate::{ - arch::MemoryConfig, - system::memory::{ - adapter::{AccessAdapterInventory, AccessAdapterRecord, AccessAdapterRecordKind}, - offline_checker::MemoryBus, - paged_vec::AddressMap, - MemoryImage, TimestampedValues, - }, - }; - - macro_rules! bb { - ($x:expr) => { - BabyBear::from_canonical_u32($x) - }; - } - - macro_rules! bba { - [$($x:expr),*] => { - [$(BabyBear::from_canonical_u32($x)),*] - } - } - - macro_rules! bbvec { - [$($x:expr),*] => { - vec![$(BabyBear::from_canonical_u32($x)),*] - } - } - - fn setup_test( - initial_memory: MemoryImage, - initial_block_size: usize, - ) -> (OfflineMemory, AccessAdapterInventory) { - let memory_bus = MemoryBus::new(0); - let range_checker = - SharedVariableRangeCheckerChip::new(VariableRangeCheckerBus::new(1, 29)); - let mem_config = MemoryConfig { - as_offset: initial_memory.as_offset, - ..Default::default() - }; - let memory = OfflineMemory::new( - initial_memory, - initial_block_size, - memory_bus, - range_checker.clone(), - mem_config, - ); - let access_adapter_inventory = AccessAdapterInventory::new( - range_checker, - memory_bus, - mem_config.clk_max_bits, - mem_config.max_access_adapter_n, - ); - (memory, access_adapter_inventory) - } - - #[test] - fn test_partition() { - let initial_memory = AddressMap::new(0, 1, 16); - let (mut memory, _) = setup_test(initial_memory, 8); - assert_eq!( - memory.block_containing(1, 13), - BlockData { - pointer: 8, - size: 8, - timestamp: 0, - } - ); - - assert_eq!( - memory.block_containing(1, 8), - BlockData { - pointer: 8, - size: 8, - timestamp: 0, - } - ); - - assert_eq!( - memory.block_containing(1, 15), - BlockData { - pointer: 8, - size: 8, - timestamp: 0, - } - ); - - assert_eq!( - memory.block_containing(1, 16), - BlockData { - pointer: 16, - size: 8, - timestamp: 0, - } - ); - } - - #[test] - fn test_write_read_initial_block_len_1() { - let (mut memory, mut access_adapters) = setup_test(MemoryImage::default(), 1); - let address_space = 1; - - memory.write(address_space, 0, bbvec![1, 2, 3, 4], &mut access_adapters); - - memory.read(address_space, 0, 2, &mut access_adapters); - let read_record = memory.last_record(); - assert_eq!(read_record.data, bba![1, 2]); - - memory.write(address_space, 2, bbvec![100], &mut access_adapters); - - memory.read(address_space, 0, 4, &mut access_adapters); - let read_record = memory.last_record(); - assert_eq!(read_record.data, bba![1, 2, 100, 4]); - } - - #[test] - fn test_records_initial_block_len_1() { - let (mut memory, mut adapter_records) = setup_test(MemoryImage::default(), 1); - - memory.write(1, 0, bbvec![1, 2, 3, 4], &mut adapter_records); - - // Above write first causes merge of [0:1] and [1:2] into [0:2]. - assert_eq!( - adapter_records.records_for_n(2)[0], - AccessAdapterRecord { - timestamp: 0, - address_space: bb!(1), - start_index: bb!(0), - data: bbvec![0, 0], - kind: AccessAdapterRecordKind::Merge { - left_timestamp: 0, - right_timestamp: 0, - }, - } - ); - // then merge [2:3] and [3:4] into [2:4]. - assert_eq!( - adapter_records.records_for_n(2)[1], - AccessAdapterRecord { - timestamp: 0, - address_space: bb!(1), - start_index: bb!(2), - data: bbvec![0, 0], - kind: AccessAdapterRecordKind::Merge { - left_timestamp: 0, - right_timestamp: 0, - }, - } - ); - // then merge [0:2] and [2:4] into [0:4]. - assert_eq!( - adapter_records.records_for_n(4)[0], - AccessAdapterRecord { - timestamp: 0, - address_space: bb!(1), - start_index: bb!(0), - data: bbvec![0, 0, 0, 0], - kind: AccessAdapterRecordKind::Merge { - left_timestamp: 0, - right_timestamp: 0, - }, - } - ); - // At time 1 we write [0:4]. - let write_record = memory.last_record(); - assert_eq!( - write_record, - &MemoryRecord { - address_space: bb!(1), - pointer: bb!(0), - timestamp: 1, - prev_timestamp: 0, - data: bbvec![1, 2, 3, 4], - prev_data: Some(bbvec![0, 0, 0, 0]), - } - ); - assert_eq!(memory.timestamp(), 2); - assert_eq!(adapter_records.total_records(), 3); - - memory.read(1, 0, 4, &mut adapter_records); - let read_record = memory.last_record(); - // At time 2 we read [0:4]. - assert_eq!(adapter_records.total_records(), 3); - assert_eq!( - read_record, - &MemoryRecord { - address_space: bb!(1), - pointer: bb!(0), - timestamp: 2, - prev_timestamp: 1, - data: bbvec![1, 2, 3, 4], - prev_data: None, - } - ); - assert_eq!(memory.timestamp(), 3); - - memory.write(1, 0, bbvec![10, 11], &mut adapter_records); - let write_record = memory.last_record(); - // write causes split [0:4] into [0:2] and [2:4] (to prepare for write to [0:2]). - assert_eq!(adapter_records.total_records(), 4); - assert_eq!( - adapter_records.records_for_n(4).last().unwrap(), - &AccessAdapterRecord { - timestamp: 2, - address_space: bb!(1), - start_index: bb!(0), - data: bbvec![1, 2, 3, 4], - kind: AccessAdapterRecordKind::Split, - } - ); - - // At time 3 we write [10, 11] into [0, 2]. - assert_eq!( - write_record, - &MemoryRecord { - address_space: bb!(1), - pointer: bb!(0), - timestamp: 3, - prev_timestamp: 2, - data: bbvec![10, 11], - prev_data: Some(bbvec![1, 2]), - } - ); - - memory.read(1, 0, 4, &mut adapter_records); - let read_record = memory.last_record(); - assert_eq!(adapter_records.total_records(), 5); - assert_eq!( - adapter_records.records_for_n(4).last().unwrap(), - &AccessAdapterRecord { - timestamp: 3, - address_space: bb!(1), - start_index: bb!(0), - data: bbvec![10, 11, 3, 4], - kind: AccessAdapterRecordKind::Merge { - left_timestamp: 3, - right_timestamp: 2 - }, - } - ); - // At time 9 we read [0:4]. - assert_eq!( - read_record, - &MemoryRecord { - address_space: bb!(1), - pointer: bb!(0), - timestamp: 4, - prev_timestamp: 3, - data: bbvec![10, 11, 3, 4], - prev_data: None, - } - ); - } - - #[test] - fn test_records_initial_block_len_8() { - let (mut memory, mut adapter_records) = setup_test(MemoryImage::default(), 8); - - memory.write(1, 0, bbvec![1, 2, 3, 4], &mut adapter_records); - let write_record = memory.last_record(); - - // Above write first causes split of [0:8] into [0:4] and [4:8]. - assert_eq!(adapter_records.total_records(), 1); - assert_eq!( - adapter_records.records_for_n(8).last().unwrap(), - &AccessAdapterRecord { - timestamp: 0, - address_space: bb!(1), - start_index: bb!(0), - data: bbvec![0, 0, 0, 0, 0, 0, 0, 0], - kind: AccessAdapterRecordKind::Split, - } - ); - // At time 1 we write [0:4]. - assert_eq!( - write_record, - &MemoryRecord { - address_space: bb!(1), - pointer: bb!(0), - timestamp: 1, - prev_timestamp: 0, - data: bbvec![1, 2, 3, 4], - prev_data: Some(bbvec![0, 0, 0, 0]), - } - ); - assert_eq!(memory.timestamp(), 2); - - memory.read(1, 0, 4, &mut adapter_records); - let read_record = memory.last_record(); - // At time 2 we read [0:4]. - assert_eq!(adapter_records.total_records(), 1); - assert_eq!( - read_record, - &MemoryRecord { - address_space: bb!(1), - pointer: bb!(0), - timestamp: 2, - prev_timestamp: 1, - data: bbvec![1, 2, 3, 4], - prev_data: None, - } - ); - assert_eq!(memory.timestamp(), 3); - - memory.write(1, 0, bbvec![10, 11], &mut adapter_records); - let write_record = memory.last_record(); - // write causes split [0:4] into [0:2] and [2:4] (to prepare for write to [0:2]). - assert_eq!(adapter_records.total_records(), 2); - assert_eq!( - adapter_records.records_for_n(4).last().unwrap(), - &AccessAdapterRecord { - timestamp: 2, - address_space: bb!(1), - start_index: bb!(0), - data: bbvec![1, 2, 3, 4], - kind: AccessAdapterRecordKind::Split, - } - ); - - // At time 3 we write [10, 11] into [0, 2]. - assert_eq!( - write_record, - &MemoryRecord { - address_space: bb!(1), - pointer: bb!(0), - timestamp: 3, - prev_timestamp: 2, - data: bbvec![10, 11], - prev_data: Some(bbvec![1, 2]), - } - ); - - memory.read(1, 0, 4, &mut adapter_records); - let read_record = memory.last_record(); - assert_eq!(adapter_records.total_records(), 3); - assert_eq!( - adapter_records.records_for_n(4).last().unwrap(), - &AccessAdapterRecord { - timestamp: 3, - address_space: bb!(1), - start_index: bb!(0), - data: bbvec![10, 11, 3, 4], - kind: AccessAdapterRecordKind::Merge { - left_timestamp: 3, - right_timestamp: 2 - }, - } - ); - // At time 9 we read [0:4]. - assert_eq!( - read_record, - &MemoryRecord { - address_space: bb!(1), - pointer: bb!(0), - timestamp: 4, - prev_timestamp: 3, - data: bbvec![10, 11, 3, 4], - prev_data: None, - } - ); - } - - #[test] - fn test_get_initial_block_len_1() { - let (mut memory, mut adapter_records) = setup_test(MemoryImage::default(), 1); - - memory.write(2, 0, bbvec![4, 3, 2, 1], &mut adapter_records); - - assert_eq!(memory.get(2, 0), BabyBear::from_canonical_u32(4)); - assert_eq!(memory.get(2, 1), BabyBear::from_canonical_u32(3)); - assert_eq!(memory.get(2, 2), BabyBear::from_canonical_u32(2)); - assert_eq!(memory.get(2, 3), BabyBear::from_canonical_u32(1)); - assert_eq!(memory.get(2, 5), BabyBear::ZERO); - - assert_eq!(memory.get(1, 0), BabyBear::ZERO); - } - - #[test] - fn test_get_initial_block_len_8() { - let (mut memory, mut adapter_records) = setup_test(MemoryImage::default(), 8); - - memory.write(2, 0, bbvec![4, 3, 2, 1], &mut adapter_records); - - assert_eq!(memory.get(2, 0), BabyBear::from_canonical_u32(4)); - assert_eq!(memory.get(2, 1), BabyBear::from_canonical_u32(3)); - assert_eq!(memory.get(2, 2), BabyBear::from_canonical_u32(2)); - assert_eq!(memory.get(2, 3), BabyBear::from_canonical_u32(1)); - assert_eq!(memory.get(2, 5), BabyBear::ZERO); - assert_eq!(memory.get(2, 9), BabyBear::ZERO); - assert_eq!(memory.get(1, 0), BabyBear::ZERO); - } - - #[test] - fn test_finalize_empty() { - let (mut memory, mut adapter_records) = setup_test(MemoryImage::default(), 4); - - let memory = memory.finalize::<4>(&mut adapter_records); - assert_eq!(memory.len(), 0); - assert_eq!(adapter_records.total_records(), 0); - } - - #[test] - fn test_finalize_block_len_8() { - let (mut memory, mut adapter_records) = setup_test(MemoryImage::default(), 8); - // Make block 0:4 in address space 1 active. - memory.write(1, 0, bbvec![1, 2, 3, 4], &mut adapter_records); - - // Make block 16:32 in address space 1 active. - memory.write( - 1, - 16, - bbvec![1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], - &mut adapter_records, - ); - - // Make block 64:72 in address space 2 active. - memory.write(2, 64, bbvec![8, 7, 6, 5, 4, 3, 2, 1], &mut adapter_records); - - let num_records_before_finalize = adapter_records.total_records(); - - // Finalize to a partition of size 8. - let final_memory = memory.finalize::<8>(&mut adapter_records); - assert_eq!(final_memory.len(), 4); - assert_eq!( - final_memory.get(&(1, 0)), - Some(&TimestampedValues { - values: bba![1, 2, 3, 4, 0, 0, 0, 0], - timestamp: 1, - }) - ); - // start_index = 16 corresponds to label = 2 - assert_eq!( - final_memory.get(&(1, 2)), - Some(&TimestampedValues { - values: bba![1, 1, 1, 1, 1, 1, 1, 1], - timestamp: 2, - }) - ); - // start_index = 24 corresponds to label = 3 - assert_eq!( - final_memory.get(&(1, 3)), - Some(&TimestampedValues { - values: bba![1, 1, 1, 1, 1, 1, 1, 1], - timestamp: 2, - }) - ); - // start_index = 64 corresponds to label = 8 - assert_eq!( - final_memory.get(&(2, 8)), - Some(&TimestampedValues { - values: bba![8, 7, 6, 5, 4, 3, 2, 1], - timestamp: 3, - }) - ); - - // We need to do 1 + 1 + 0 = 2 adapters. - assert_eq!( - adapter_records.total_records() - num_records_before_finalize, - 2 - ); - } - - #[test] - fn test_write_read_initial_block_len_8_initial_memory() { - type F = BabyBear; - - // Initialize initial memory with blocks at indices 0 and 2 - let mut initial_memory = MemoryImage::default(); - for i in 0..8 { - initial_memory.insert(&(1, i), F::from_canonical_u32(i + 1)); - initial_memory.insert(&(1, 16 + i), F::from_canonical_u32(i + 1)); - } - - let (mut memory, mut adapter_records) = setup_test(initial_memory, 8); - - // Verify initial state of block 0 (pointers 0–8) - memory.read(1, 0, 8, &mut adapter_records); - let initial_read_record_0 = memory.last_record(); - assert_eq!(initial_read_record_0.data, bbvec![1, 2, 3, 4, 5, 6, 7, 8]); - - // Verify initial state of block 2 (pointers 16–24) - memory.read(1, 16, 8, &mut adapter_records); - let initial_read_record_2 = memory.last_record(); - assert_eq!(initial_read_record_2.data, bbvec![1, 2, 3, 4, 5, 6, 7, 8]); - - // Test: Write a partial block to block 0 (pointer 0) and read back partially and fully - memory.write(1, 0, bbvec![9, 9, 9, 9], &mut adapter_records); - memory.read(1, 0, 2, &mut adapter_records); - let partial_read_record = memory.last_record(); - assert_eq!(partial_read_record.data, bbvec![9, 9]); - - memory.read(1, 0, 8, &mut adapter_records); - let full_read_record_0 = memory.last_record(); - assert_eq!(full_read_record_0.data, bbvec![9, 9, 9, 9, 5, 6, 7, 8]); - - // Test: Write a single element to pointer 2 and verify read in different lengths - memory.write(1, 2, bbvec![100], &mut adapter_records); - memory.read(1, 1, 4, &mut adapter_records); - let read_record_4 = memory.last_record(); - assert_eq!(read_record_4.data, bbvec![9, 100, 9, 5]); - - memory.read(1, 2, 8, &mut adapter_records); - let full_read_record_2 = memory.last_record(); - assert_eq!(full_read_record_2.data, bba![100, 9, 5, 6, 7, 8, 0, 0]); - - // Test: Write and read at the last pointer in block 2 (pointer 23, part of key (1, 2)) - memory.write(1, 23, bbvec![77], &mut adapter_records); - memory.read(1, 23, 2, &mut adapter_records); - let boundary_read_record = memory.last_record(); - assert_eq!(boundary_read_record.data, bba![77, 0]); // Last byte modified, ensuring boundary check - - // Test: Reading from an uninitialized block (should default to 0) - memory.read(1, 10, 4, &mut adapter_records); - let default_read_record = memory.last_record(); - assert_eq!(default_read_record.data, bba![0, 0, 0, 0]); - - memory.read(1, 100, 4, &mut adapter_records); - let default_read_record = memory.last_record(); - assert_eq!(default_read_record.data, bba![0, 0, 0, 0]); - - // Test: Overwrite entire memory pointer 16–24 and verify - memory.write( - 1, - 16, - bbvec![50, 50, 50, 50, 50, 50, 50, 50], - &mut adapter_records, - ); - memory.read(1, 16, 8, &mut adapter_records); - let overwrite_read_record = memory.last_record(); - assert_eq!( - overwrite_read_record.data, - bba![50, 50, 50, 50, 50, 50, 50, 50] - ); // Verify entire block overwrite - } -} diff --git a/crates/vm/src/system/memory/offline_checker/bridge.rs b/crates/vm/src/system/memory/offline_checker/bridge.rs index 2c7e180cfb..3174309454 100644 --- a/crates/vm/src/system/memory/offline_checker/bridge.rs +++ b/crates/vm/src/system/memory/offline_checker/bridge.rs @@ -21,7 +21,7 @@ use crate::system::memory::{ /// be decomposed into) for the `AssertLtSubAir` in the `MemoryOfflineChecker`. /// Warning: This requires that (clk_max_bits + decomp - 1) / decomp = AUX_LEN /// in MemoryOfflineChecker (or whenever AssertLtSubAir is used) -pub(crate) const AUX_LEN: usize = 2; +pub const AUX_LEN: usize = 2; /// The [MemoryBridge] is used within AIR evaluation functions to constrain logical memory /// operations (read/write). It adds all necessary constraints and interactions. diff --git a/crates/vm/src/system/memory/offline_checker/columns.rs b/crates/vm/src/system/memory/offline_checker/columns.rs index 5a27b3e433..630774c639 100644 --- a/crates/vm/src/system/memory/offline_checker/columns.rs +++ b/crates/vm/src/system/memory/offline_checker/columns.rs @@ -9,37 +9,27 @@ use crate::system::memory::offline_checker::bridge::AUX_LEN; // repr(C) is needed to make sure that the compiler does not reorder the fields // we assume the order of the fields when using borrow or borrow_mut -#[repr(C)] /// Base structure for auxiliary memory columns. +#[repr(C)] #[derive(Clone, Copy, Debug, AlignedBorrow)] pub struct MemoryBaseAuxCols { /// The previous timestamps in which the cells were accessed. - pub(in crate::system::memory) prev_timestamp: T, + pub prev_timestamp: T, /// The auxiliary columns to perform the less than check. - pub(in crate::system::memory) timestamp_lt_aux: LessThanAuxCols, + pub timestamp_lt_aux: LessThanAuxCols, +} + +impl MemoryBaseAuxCols { + pub fn set_prev(&mut self, prev_timestamp: F) { + self.prev_timestamp = prev_timestamp; + } } #[repr(C)] #[derive(Clone, Copy, Debug, AlignedBorrow)] pub struct MemoryWriteAuxCols { - pub(in crate::system::memory) base: MemoryBaseAuxCols, - pub(in crate::system::memory) prev_data: [T; N], -} - -impl MemoryWriteAuxCols { - pub(in crate::system::memory) fn new( - prev_data: [T; N], - prev_timestamp: T, - lt_aux: LessThanAuxCols, - ) -> Self { - Self { - base: MemoryBaseAuxCols { - prev_timestamp, - timestamp_lt_aux: lt_aux, - }, - prev_data, - } - } + pub base: MemoryBaseAuxCols, + pub prev_data: [T; N], } impl MemoryWriteAuxCols { @@ -54,6 +44,11 @@ impl MemoryWriteAuxCols { pub fn prev_data(&self) -> &[T; N] { &self.prev_data } + + /// Sets the previous data **without** updating the less than auxiliary columns. + pub fn set_prev_data(&mut self, data: [T; N]) { + self.prev_data = data; + } } /// The auxiliary columns for a memory read operation with block size `N`. @@ -67,10 +62,7 @@ pub struct MemoryReadAuxCols { } impl MemoryReadAuxCols { - pub(in crate::system::memory) fn new( - prev_timestamp: u32, - timestamp_lt_aux: LessThanAuxCols, - ) -> Self { + pub fn new(prev_timestamp: u32, timestamp_lt_aux: LessThanAuxCols) -> Self { Self { base: MemoryBaseAuxCols { prev_timestamp: F::from_canonical_u32(prev_timestamp), @@ -82,14 +74,19 @@ impl MemoryReadAuxCols { pub fn get_base(self) -> MemoryBaseAuxCols { self.base } + + /// Sets the previous timestamp **without** updating the less than auxiliary columns. + pub fn set_prev(&mut self, timestamp: F) { + self.base.prev_timestamp = timestamp; + } } #[repr(C)] #[derive(Clone, Debug, AlignedBorrow)] pub struct MemoryReadOrImmediateAuxCols { - pub(crate) base: MemoryBaseAuxCols, - pub(crate) is_immediate: T, - pub(crate) is_zero_aux: T, + pub base: MemoryBaseAuxCols, + pub is_immediate: T, + pub is_zero_aux: T, } impl AsRef> for MemoryWriteAuxCols { @@ -102,3 +99,21 @@ impl AsRef> for MemoryWriteAuxCols unsafe { &*(self as *const MemoryWriteAuxCols as *const MemoryReadAuxCols) } } } + +impl AsMut> for MemoryWriteAuxCols { + fn as_mut(&mut self) -> &mut MemoryBaseAuxCols { + &mut self.base + } +} + +impl AsMut> for MemoryReadAuxCols { + fn as_mut(&mut self) -> &mut MemoryBaseAuxCols { + &mut self.base + } +} + +impl AsMut> for MemoryReadOrImmediateAuxCols { + fn as_mut(&mut self) -> &mut MemoryBaseAuxCols { + &mut self.base + } +} diff --git a/crates/vm/src/system/memory/offline_checker/mod.rs b/crates/vm/src/system/memory/offline_checker/mod.rs index ac9f32dc18..8b15328185 100644 --- a/crates/vm/src/system/memory/offline_checker/mod.rs +++ b/crates/vm/src/system/memory/offline_checker/mod.rs @@ -5,3 +5,18 @@ mod columns; pub use bridge::*; pub use bus::*; pub use columns::*; + +#[repr(C)] +#[derive(Debug, Clone)] +pub struct MemoryReadAuxRecord { + pub prev_timestamp: u32, +} + +#[repr(C)] +#[derive(Debug, Clone)] +pub struct MemoryWriteAuxRecord { + pub prev_timestamp: u32, + pub prev_data: [T; NUM_LIMBS], +} + +pub type MemoryWriteBytesAuxRecord = MemoryWriteAuxRecord; diff --git a/crates/vm/src/system/memory/online.rs b/crates/vm/src/system/memory/online.rs index a5bf663e4c..08d38e5383 100644 --- a/crates/vm/src/system/memory/online.rs +++ b/crates/vm/src/system/memory/online.rs @@ -1,151 +1,841 @@ -use std::fmt::Debug; +use std::{fmt::Debug, slice::from_raw_parts}; -use openvm_stark_backend::p3_field::PrimeField32; -use serde::{Deserialize, Serialize}; +use getset::Getters; +use itertools::{izip, zip_eq}; +use openvm_circuit_primitives::var_range::SharedVariableRangeCheckerChip; +use openvm_instructions::{exe::SparseMemoryImage, NATIVE_AS}; +use openvm_stark_backend::{ + p3_field::PrimeField32, p3_maybe_rayon::prelude::*, p3_util::log2_strict_usize, +}; -use super::paged_vec::{AddressMap, PAGE_SIZE}; +use super::{adapter::AccessAdapterInventory, offline_checker::MemoryBus}; use crate::{ arch::MemoryConfig, - system::memory::{offline::INITIAL_TIMESTAMP, MemoryImage, RecordId}, + system::memory::{ + adapter::records::{AccessLayout, AccessRecordHeader, MERGE_AND_NOT_SPLIT_FLAG}, + MemoryImage, + }, }; -#[derive(Debug, Clone, Serialize, Deserialize)] -pub enum MemoryLogEntry { - Read { - address_space: u32, - pointer: u32, +mod basic; +#[cfg(any(unix, windows))] +mod memmap; +mod paged_vec; + +#[cfg(not(any(unix, windows)))] +pub use basic::*; +#[cfg(any(unix, windows))] +pub use memmap::*; +pub use paged_vec::PagedVec; + +#[cfg(all(any(unix, windows), not(feature = "basic-memory")))] +pub type MemoryBackend = memmap::MmapMemory; +#[cfg(any(not(any(unix, windows)), feature = "basic-memory"))] +pub type MemoryBackend = basic::BasicMemory; + +pub const INITIAL_TIMESTAMP: u32 = 0; +/// Default mmap page size. Change this if using THB. +pub const PAGE_SIZE: usize = 4096; + +/// (address_space, pointer) +pub type Address = (u32, u32); + +/// API for any memory implementation that allocates a contiguous region of memory. +pub trait LinearMemory { + /// Create instance of `Self` with `size` bytes. + fn new(size: usize) -> Self; + /// Allocated size of the memory in bytes. + fn size(&self) -> usize; + /// Returns the entire memory as a raw byte slice. + fn as_slice(&self) -> &[u8]; + /// Returns the entire memory as a raw byte slice. + fn as_mut_slice(&mut self) -> &mut [u8]; + /// Read `BLOCK` from `self` at `from` address without moving it. + /// + /// Panics or segfaults if `from..from + size_of::()` is out of bounds. + /// + /// # Safety + /// - `BLOCK` should be "plain old data" (see [`Pod`](https://docs.rs/bytemuck/latest/bytemuck/trait.Pod.html)). + /// We do not add a trait bound due to Plonky3 types not implementing the trait. + /// - See [`core::ptr::read`] for similar considerations. + /// - Memory at `from` must be properly aligned for `BLOCK`. Use [`Self::read_unaligned`] if + /// alignment is not guaranteed. + unsafe fn read(&self, from: usize) -> BLOCK; + /// Read `BLOCK` from `self` at `from` address without moving it. + /// Same as [`Self::read`] except that it does not require alignment. + /// + /// Panics or segfaults if `from..from + size_of::()` is out of bounds. + /// + /// # Safety + /// - `BLOCK` should be "plain old data" (see [`Pod`](https://docs.rs/bytemuck/latest/bytemuck/trait.Pod.html)). + /// We do not add a trait bound due to Plonky3 types not implementing the trait. + /// - See [`core::ptr::read`] for similar considerations. + unsafe fn read_unaligned(&self, from: usize) -> BLOCK; + /// Write `BLOCK` to `self` at `start` address without reading the old value. Does not drop + /// `values`. Semantically, `values` is moved into the location pointed to by `start`. + /// + /// Panics or segfaults if `start..start + size_of::()` is out of bounds. + /// + /// # Safety + /// - See [`core::ptr::write`] for similar considerations. + /// - Memory at `start` must be properly aligned for `BLOCK`. Use [`Self::write_unaligned`] if + /// alignment is not guaranteed. + unsafe fn write(&mut self, start: usize, values: BLOCK); + /// Write `BLOCK` to `self` at `start` address without reading the old value. Does not drop + /// `values`. Semantically, `values` is moved into the location pointed to by `start`. + /// Same as [`Self::write`] but without alignment requirement. + /// + /// Panics or segfaults if `start..start + size_of::()` is out of bounds. + /// + /// # Safety + /// - See [`core::ptr::write`] for similar considerations. + unsafe fn write_unaligned(&mut self, start: usize, values: BLOCK); + /// Swaps `values` with memory at `start..start + size_of::()`. + /// + /// Panics or segfaults if `start..start + size_of::()` is out of bounds. + /// + /// # Safety + /// - `BLOCK` should be "plain old data" (see [`Pod`](https://docs.rs/bytemuck/latest/bytemuck/trait.Pod.html)). + /// We do not add a trait bound due to Plonky3 types not implementing the trait. + /// - Memory at `start` must be properly aligned for `BLOCK`. + /// - The data in `values` should not overlap with memory in `self`. + unsafe fn swap(&mut self, start: usize, values: &mut BLOCK); + /// Copies `data` into memory at `to` address. + /// + /// Panics or segfaults if `to..to + size_of_val(data)` is out of bounds. + /// + /// # Safety + /// - `T` should be "plain old data" (see [`Pod`](https://docs.rs/bytemuck/latest/bytemuck/trait.Pod.html)). + /// We do not add a trait bound due to Plonky3 types not implementing the trait. + /// - The underlying memory of `data` should not overlap with `self`. + /// - The starting pointer of `self` should be aligned to `T`. + /// - The memory pointer at `to` should be aligned to `T`. + unsafe fn copy_nonoverlapping(&mut self, to: usize, data: &[T]); + /// Returns a slice `&[T]` for the memory region `start..start + len`. + /// + /// Panics or segfaults if `start..start + len * size_of::()` is out of bounds. + /// + /// # Safety + /// - `T` should be "plain old data" (see [`Pod`](https://docs.rs/bytemuck/latest/bytemuck/trait.Pod.html)). + /// We do not add a trait bound due to Plonky3 types not implementing the trait. + /// - Memory at `start` must be properly aligned for `T`. + unsafe fn get_aligned_slice(&self, start: usize, len: usize) -> &[T]; +} + +/// Map from address space to linear memory. +/// The underlying memory is typeless, stored as raw bytes, but usage implicitly assumes that each +/// address space has memory cells of a fixed type (e.g., `u8, F`). We do not use a typemap for +/// performance reasons, and it is up to the user to enforce types. Needless to say, this is a very +/// `unsafe` API. +#[derive(Debug, Clone)] +pub struct AddressMap { + pub mem: Vec, + /// byte size of cells per address space + pub cell_size: Vec, // TODO: move to MmapWrapper +} + +impl Default for AddressMap { + fn default() -> Self { + Self::from_mem_config(&MemoryConfig::default()) + } +} + +impl AddressMap { + /// `mem_size` is the number of **cells** in each address space. It is required that + /// `mem_size[0] = 0`. + pub fn new(mem_size: Vec) -> Self { + // TMP: hardcoding for now + let mut cell_size = vec![1; 4]; + cell_size.resize(mem_size.len(), 4); + let mem = zip_eq(&cell_size, &mem_size) + .map(|(cell_size, mem_size)| M::new(mem_size.checked_mul(*cell_size).unwrap())) + .collect(); + Self { mem, cell_size } + } + + pub fn from_mem_config(mem_config: &MemoryConfig) -> Self { + Self::new(mem_config.addr_space_sizes.clone()) + } + + #[inline(always)] + pub fn get_memory(&self) -> &Vec { + &self.mem + } + + #[inline(always)] + pub fn get_memory_mut(&mut self) -> &mut Vec { + &mut self.mem + } + + pub fn get_f(&self, addr_space: u32, ptr: u32) -> F { + debug_assert_ne!(addr_space, 0); + // TODO: fix this + unsafe { + if self.cell_size[addr_space as usize] == 1 { + F::from_canonical_u8(self.get::((addr_space, ptr))) + } else { + debug_assert_eq!(self.cell_size[addr_space as usize], 4); + self.get::((addr_space, ptr)) + } + } + } + + /// # Safety + /// - `T` **must** be the correct type for a single memory cell for `addr_space` + /// - Assumes `addr_space` is within the configured memory and not out of bounds + pub unsafe fn get(&self, (addr_space, ptr): Address) -> T { + debug_assert_eq!(size_of::(), self.cell_size[addr_space as usize]); + // SAFETY: + // - alignment is automatic since we multiply by `size_of::()` + self.mem + .get_unchecked(addr_space as usize) + .read((ptr as usize) * size_of::()) + } + + /// Panics or segfaults if `ptr..ptr + len` is out of bounds + /// + /// # Safety + /// - `T` **must** be the correct type for a single memory cell for `addr_space` + /// - Assumes `addr_space` is within the configured memory and not out of bounds + pub unsafe fn get_slice( + &self, + (addr_space, ptr): Address, len: usize, - }, - Write { - address_space: u32, - pointer: u32, - data: Vec, - }, - IncrementTimestampBy(u32), + ) -> &[T] { + debug_assert_eq!(size_of::(), self.cell_size[addr_space as usize]); + let start = (ptr as usize) * size_of::(); + let mem = self.mem.get_unchecked(addr_space as usize); + // SAFETY: + // - alignment is automatic since we multiply by `size_of::()` + mem.get_aligned_slice(start, len) + } + + /// Panics or segfaults if `ptr..ptr + len` is out of bounds + /// + /// # Safety + /// - Assumes `addr_space` is within the configured memory and not out of bounds + pub unsafe fn get_u8_slice(&self, addr_space: usize, start: usize, len: usize) -> &[u8] { + let mem = self.mem.get_unchecked(addr_space); + mem.get_aligned_slice(start, len) + } + + /// Copies `data` into the memory at `(addr_space, ptr)`. + /// + /// Panics or segfaults if `ptr + size_of_val(data)` is out of bounds. + /// + /// # Safety + /// - `T` **must** be the correct type for a single memory cell for `addr_space` + /// - The linear memory in `addr_space` is aligned to `T`. + pub unsafe fn copy_slice_nonoverlapping( + &mut self, + (addr_space, ptr): Address, + data: &[T], + ) { + let start = (ptr as usize) * size_of::(); + // SAFETY: + // - Linear memory is aligned to `T` and `start` is multiple of `size_of::()` so + // alignment is satisfied. + // - `data` and `self.mem` are non-overlapping + self.mem + .get_unchecked_mut(addr_space as usize) + .copy_nonoverlapping(start, data); + } + + // TODO[jpw]: stabilize the boundary memory image format and how to construct + /// # Safety + /// - `T` **must** be the correct type for a single memory cell for `addr_space` + /// - Assumes `addr_space` is within the configured memory and not out of bounds + pub fn from_sparse(mem_size: Vec, sparse_map: SparseMemoryImage) -> Self { + let mut vec = Self::new(mem_size); + for ((addr_space, index), data_byte) in sparse_map.into_iter() { + // SAFETY: + // - safety assumptions in function doc comments + unsafe { + vec.mem + .get_unchecked_mut(addr_space as usize) + .write_unaligned(index as usize, data_byte); + } + } + vec + } +} + +/// API for guest memory conforming to OpenVM ISA +// @dev Note we don't make this a trait because phantom executors currently need a concrete type for +// guest memory +#[derive(Debug, Clone)] +pub struct GuestMemory { + pub memory: AddressMap, +} + +impl GuestMemory { + pub fn new(addr: AddressMap) -> Self { + Self { memory: addr } + } + /// Returns `[pointer:BLOCK_SIZE]_{address_space}` + /// + /// # Safety + /// The type `T` must be stack-allocated `repr(C)` or `repr(transparent)`, + /// and it must be the exact type used to represent a single memory cell in + /// address space `address_space`. For standard usage, + /// `T` is either `u8` or `F` where `F` is the base field of the ZK backend. + pub unsafe fn read( + &self, + addr_space: u32, + ptr: u32, + ) -> [T; BLOCK_SIZE] + where + T: Copy + Debug, + { + debug_assert_eq!(size_of::(), self.memory.cell_size[addr_space as usize]); + // SAFETY: + // - `T` should be "plain old data" + // - alignment for `[T; BLOCK_SIZE]` is automatic since we multiply by `size_of::()` + self.memory + .get_memory() + .get_unchecked(addr_space as usize) + .read((ptr as usize) * size_of::()) + } + + /// Writes `values` to `[pointer:BLOCK_SIZE]_{address_space}` + /// + /// # Safety + /// See [`GuestMemory::read`]. + pub unsafe fn write( + &mut self, + addr_space: u32, + ptr: u32, + values: [T; BLOCK_SIZE], + ) where + T: Copy + Debug, + { + debug_assert_eq!(size_of::(), self.memory.cell_size[addr_space as usize]); + // SAFETY: + // - alignment for `[T; BLOCK_SIZE]` is automatic since we multiply by `size_of::()` + self.memory + .get_memory_mut() + .get_unchecked_mut(addr_space as usize) + .write((ptr as usize) * size_of::(), values); + } + + /// Swaps `values` with `[pointer:BLOCK_SIZE]_{address_space}`. + /// + /// # Safety + /// See [`GuestMemory::read`] and [`LinearMemory::swap`]. + #[inline(always)] + pub unsafe fn swap( + &mut self, + addr_space: u32, + ptr: u32, + values: &mut [T; BLOCK_SIZE], + ) where + T: Copy + Debug, + { + debug_assert_eq!(size_of::(), self.memory.cell_size[addr_space as usize]); + // SAFETY: + // - alignment for `[T; BLOCK_SIZE]` is automatic since we multiply by `size_of::()` + self.memory + .get_memory_mut() + .get_unchecked_mut(addr_space as usize) + .swap((ptr as usize) * size_of::(), values); + } + + #[inline(always)] + #[allow(clippy::missing_safety_doc)] + pub unsafe fn get_slice(&self, addr_space: u32, ptr: u32, len: usize) -> &[T] { + self.memory.get_slice((addr_space, ptr), len) + } +} + +// perf[jpw]: since we restrict `timestamp < 2^29`, we could pack `timestamp, log2(block_size)` +// into a single u32 to save some memory, since `block_size` is a power of 2 and its log2 +// is less than 2^3. +#[repr(C)] +#[derive(Clone, Copy, PartialEq, Eq, Debug, Default, derive_new::new)] +pub struct AccessMetadata { + /// The starting pointer of the access + pub start_ptr: u32, + /// The block size of the memory access + pub block_size: u32, + /// The timestamp of the last access. + /// We don't _have_ to store it, but this is probably faster + /// in terms of cache locality + pub timestamp: u32, } -/// A simple data structure to read to/write from memory. -/// -/// Stores a log of memory accesses to reconstruct aspects of memory state for trace generation. -#[derive(Debug)] -pub struct Memory { - pub(super) data: AddressMap, - pub(super) log: Vec>, - timestamp: u32, +/// Online memory that stores additional information for trace generation purposes. +/// In particular, keeps track of timestamp. +#[derive(Getters)] +pub struct TracingMemory { + pub timestamp: u32, + /// The initial block size -- this depends on the type of boundary chip. + initial_block_size: usize, + /// The underlying data memory, with memory cells typed by address space: see [AddressMap]. + // TODO: make generic in GuestMemory + #[getset(get = "pub")] + pub data: GuestMemory, + /// A map of `addr_space -> (ptr / min_block_size[addr_space] -> (timestamp: u32, block_size: + /// u32))` for the timestamp and block size of the latest access. Each + /// `PagedVec` stores metadata in a paged manner for memory efficiency. + pub(super) meta: Vec>, + /// For each `addr_space`, the minimum block size allowed for memory accesses. In other words, + /// all memory accesses in `addr_space` must be aligned to this block size. + pub min_block_size: Vec, + pub access_adapter_inventory: AccessAdapterInventory, } -impl Memory { - pub fn new(mem_config: &MemoryConfig) -> Self { +impl TracingMemory { + // TODO: per-address space memory capacity specification + pub fn new( + mem_config: &MemoryConfig, + range_checker: SharedVariableRangeCheckerChip, + memory_bus: MemoryBus, + initial_block_size: usize, + ) -> Self { + let num_cells = mem_config.addr_space_sizes.clone(); + let num_addr_sp = 1 + (1 << mem_config.addr_space_height); + let mut min_block_size = vec![1; num_addr_sp]; + // TMP: hardcoding for now + min_block_size[1] = 4; + min_block_size[2] = 4; + min_block_size[3] = 4; + let meta = zip_eq(&min_block_size, &num_cells) + .map(|(min_block_size, num_cells)| { + let total_metadata_len = num_cells.div_ceil(*min_block_size as usize); + PagedVec::new(total_metadata_len, PAGE_SIZE) + }) + .collect(); Self { - data: AddressMap::from_mem_config(mem_config), + data: GuestMemory::new(AddressMap::from_mem_config(mem_config)), + meta, + min_block_size, timestamp: INITIAL_TIMESTAMP + 1, - log: Vec::with_capacity(mem_config.access_capacity), + initial_block_size, + access_adapter_inventory: AccessAdapterInventory::new( + range_checker, + memory_bus, + mem_config.clk_max_bits, + mem_config.max_access_adapter_n, + ), } } /// Instantiates a new `Memory` data structure from an image. - pub fn from_image(image: MemoryImage, access_capacity: usize) -> Self { - Self { - data: image, - timestamp: INITIAL_TIMESTAMP + 1, - log: Vec::with_capacity(access_capacity), + pub fn with_image(mut self, image: MemoryImage) -> Self { + for (i, (mem, cell_size)) in izip!(image.get_memory(), &image.cell_size).enumerate() { + let num_cells = mem.size() / cell_size; + + let total_metadata_len = num_cells.div_ceil(self.min_block_size[i] as usize); + self.meta[i] = PagedVec::new(total_metadata_len, PAGE_SIZE); } + self.data = GuestMemory::new(image); + self } - fn last_record_id(&self) -> RecordId { - RecordId(self.log.len() - 1) + #[inline(always)] + fn assert_alignment(&self, block_size: usize, align: usize, addr_space: u32, ptr: u32) { + debug_assert!(block_size.is_power_of_two()); + debug_assert_eq!(block_size % align, 0); + debug_assert_ne!(addr_space, 0); + debug_assert_eq!(align as u32, self.min_block_size[addr_space as usize]); + assert_eq!( + ptr % (align as u32), + 0, + "pointer={ptr} not aligned to {align}" + ); } - /// Writes an array of values to the memory at the specified address space and start index. - /// - /// Returns the `RecordId` for the memory record and the previous data. - pub fn write( + /// Updates the metadata with the given block. + #[inline] + fn set_meta_block( &mut self, - address_space: u32, - pointer: u32, - values: [F; N], - ) -> (RecordId, [F; N]) { - assert!(N.is_power_of_two()); - - let prev_data = self.data.set_range(&(address_space, pointer), &values); + address_space: usize, + pointer: usize, + align: usize, + block_size: usize, + timestamp: u32, + ) { + let ptr = pointer / align; + // SAFETY: address_space is assumed to be valid and within bounds + let meta = unsafe { self.meta.get_unchecked_mut(address_space) }; + for i in 0..(block_size / align) { + meta.set( + ptr + i, + AccessMetadata { + start_ptr: pointer as u32, + block_size: block_size as u32, + timestamp, + }, + ); + } + } - self.log.push(MemoryLogEntry::Write { - address_space, - pointer, - data: values.to_vec(), - }); - self.timestamp += 1; + pub(crate) fn add_split_record(&mut self, header: AccessRecordHeader) { + if header.block_size == header.lowest_block_size { + return; + } + let data_slice = unsafe { + self.data.memory.get_u8_slice( + header.address_space as usize, + (header.pointer * header.type_size) as usize, + (header.block_size * header.type_size) as usize, + ) + }; - (self.last_record_id(), prev_data) + let record_mut = self + .access_adapter_inventory + .alloc_record(AccessLayout::from_record_header(&header)); + *record_mut.header = header; + record_mut.data.copy_from_slice(data_slice); + // we don't mind garbage values in prev_* } - /// Reads an array of values from the memory at the specified address space and start index. - pub fn read(&mut self, address_space: u32, pointer: u32) -> (RecordId, [F; N]) { - assert!(N.is_power_of_two()); + pub(crate) fn add_merge_record( + &mut self, + header: AccessRecordHeader, + data: &[T], + prev_ts: &[u32], + ) { + if header.block_size == header.lowest_block_size { + return; + } + + let data_slice = + unsafe { from_raw_parts(data.as_ptr() as *const u8, std::mem::size_of_val(data)) }; + + let record_mut = self + .access_adapter_inventory + .alloc_record(AccessLayout::from_record_header(&header)); + *record_mut.header = header; + record_mut.header.timestamp_and_mask |= MERGE_AND_NOT_SPLIT_FLAG; + record_mut.data.copy_from_slice(data_slice); + record_mut.timestamps.copy_from_slice(prev_ts); + } - self.log.push(MemoryLogEntry::Read { - address_space, - pointer, - len: N, + fn split_by_meta( + &mut self, + meta: &AccessMetadata, + address_space: usize, + lowest_block_size: usize, + ) { + if meta.block_size == lowest_block_size as u32 { + return; + } + let begin = meta.start_ptr as usize / lowest_block_size; + for i in 0..(meta.block_size as usize / lowest_block_size) { + self.meta[address_space].set( + begin + i, + AccessMetadata { + start_ptr: (meta.start_ptr + (i * lowest_block_size) as u32), + block_size: lowest_block_size as u32, + timestamp: meta.timestamp, + }, + ); + } + self.add_split_record(AccessRecordHeader { + timestamp_and_mask: meta.timestamp, + address_space: address_space as u32, + pointer: meta.start_ptr, + block_size: meta.block_size, + lowest_block_size: lowest_block_size as u32, + type_size: size_of::() as u32, }); + } + + /// Returns the timestamp of the previous access to `[pointer:BLOCK_SIZE]_{address_space}` + /// and the offset of the record in bytes. + /// + /// Caller must ensure alignment (e.g. via `assert_alignment`) prior to calling this function. + fn prev_access_time( + &mut self, + address_space: usize, + pointer: usize, + align: usize, + prev_values: &[T; BLOCK_SIZE], + ) -> u32 { + let num_segs = BLOCK_SIZE / align; - let values = if address_space == 0 { - assert_eq!(N, 1, "cannot batch read from address space 0"); - [F::from_canonical_u32(pointer); N] + let begin = pointer / align; + + let first_meta = self.meta[address_space].get(begin); + let need_to_merge = + first_meta.block_size != BLOCK_SIZE as u32 || first_meta.start_ptr != pointer as u32; + let result = if need_to_merge { + // Then we need to split everything we touched there + // And add a merge record in the end + let mut i = 0; + while i < num_segs { + let meta = self.meta[address_space].get(begin + i); + if meta.block_size == 0 { + i += 1; + continue; + } + let meta = *meta; + self.split_by_meta::(&meta, address_space, align); + i = (meta.start_ptr + meta.block_size) as usize / align - begin; + } + + let prev_ts = (0..num_segs) + .map(|i| { + let meta = self.meta[address_space].get(begin + i); + if meta.block_size > 0 { + meta.timestamp + } else { + // Initialize + if self.initial_block_size >= align { + // We need to split the initial block into chunks + let block_start = (begin + i) & !(self.initial_block_size / align - 1); + self.split_by_meta::( + &AccessMetadata { + start_ptr: (block_start * align) as u32, + block_size: self.initial_block_size as u32, + timestamp: INITIAL_TIMESTAMP, + }, + address_space, + align, + ); + } else { + debug_assert_eq!(self.initial_block_size, 1); + debug_assert!((address_space as u32) < NATIVE_AS); // TODO: normal way + self.add_merge_record::( + AccessRecordHeader { + timestamp_and_mask: INITIAL_TIMESTAMP, + address_space: address_space as u32, + pointer: (pointer + i * align) as u32, + block_size: align as u32, + lowest_block_size: self.initial_block_size as u32, + type_size: 1, + }, + &vec![0; align], // TODO: not vec maybe + &vec![INITIAL_TIMESTAMP; align], // TODO: not vec maybe + ); + } + INITIAL_TIMESTAMP + } + }) + .collect::>(); // TODO(AG): small buffer or small vec or something + + let timestamp = *prev_ts.iter().max().unwrap(); + self.add_merge_record( + AccessRecordHeader { + timestamp_and_mask: timestamp, + address_space: address_space as u32, + pointer: pointer as u32, + block_size: BLOCK_SIZE as u32, + lowest_block_size: align as u32, + type_size: size_of::() as u32, + }, + prev_values, + &prev_ts, + ); + timestamp } else { - self.range_array::(address_space, pointer) + first_meta.timestamp }; + self.set_meta_block(address_space, pointer, align, BLOCK_SIZE, self.timestamp); + result + } + + /// Atomic read operation which increments the timestamp by 1. + /// Returns `(t_prev, [pointer:BLOCK_SIZE]_{address_space})` where `t_prev` is the + /// timestamp of the last memory access. + /// + /// The previous memory access is treated as atomic even if previous accesses were for + /// a smaller block size. This is made possible by internal memory access adapters + /// that split/merge memory blocks. More specifically, the last memory access corresponding + /// to `t_prev` may refer to an atomic access inserted by the memory access adapters. + /// + /// # Assumptions + /// The `BLOCK_SIZE` is a multiple of `ALIGN`, which must equal the minimum block size + /// of `address_space`. + /// + /// # Safety + /// The type `T` must be stack-allocated `repr(C)` or `repr(transparent)`, + /// and it must be the exact type used to represent a single memory cell in + /// address space `address_space`. For standard usage, + /// `T` is either `u8` or `F` where `F` is the base field of the ZK backend. + /// + /// In addition: + /// - `address_space` must be valid. + #[inline(always)] + pub unsafe fn read( + &mut self, + address_space: u32, + pointer: u32, + ) -> (u32, [T; BLOCK_SIZE]) + where + T: Copy + Debug, + { + self.assert_alignment(BLOCK_SIZE, ALIGN, address_space, pointer); + let values = self.data.read(address_space, pointer); + let t_prev = self.prev_access_time::( + address_space as usize, + pointer as usize, + ALIGN, + &values, + ); + self.timestamp += 1; + + (t_prev, values) + } + + /// Atomic write operation that writes `values` into `[pointer:BLOCK_SIZE]_{address_space}` and + /// then increments the timestamp by 1. Returns `(t_prev, values_prev)` which equal the + /// timestamp and value `[pointer:BLOCK_SIZE]_{address_space}` of the last memory access. + /// + /// The previous memory access is treated as atomic even if previous accesses were for + /// a smaller block size. This is made possible by internal memory access adapters + /// that split/merge memory blocks. More specifically, the last memory access corresponding + /// to `t_prev` may refer to an atomic access inserted by the memory access adapters. + /// + /// # Assumptions + /// The `BLOCK_SIZE` is a multiple of `ALIGN`, which must equal the minimum block size + /// of `address_space`. + /// + /// # Safety + /// The type `T` must be stack-allocated `repr(C)` or `repr(transparent)`, + /// and it must be the exact type used to represent a single memory cell in + /// address space `address_space`. For standard usage, + /// `T` is either `u8` or `F` where `F` is the base field of the ZK backend. + /// + /// In addition: + /// - `address_space` must be valid. + #[inline(always)] + pub unsafe fn write( + &mut self, + address_space: u32, + pointer: u32, + values: [T; BLOCK_SIZE], + ) -> (u32, [T; BLOCK_SIZE]) + where + T: Copy + Debug, + { + self.assert_alignment(BLOCK_SIZE, ALIGN, address_space, pointer); + let values_prev = self.data.read(address_space, pointer); + let t_prev = self.prev_access_time::( + address_space as usize, + pointer as usize, + ALIGN, + &values_prev, + ); + self.data.write(address_space, pointer, values); + self.timestamp += 1; + + (t_prev, values_prev) + } + + pub fn increment_timestamp(&mut self) { self.timestamp += 1; - (self.last_record_id(), values) } pub fn increment_timestamp_by(&mut self, amount: u32) { self.timestamp += amount; - self.log.push(MemoryLogEntry::IncrementTimestampBy(amount)) } pub fn timestamp(&self) -> u32 { self.timestamp } - #[inline(always)] - pub fn get(&self, address_space: u32, pointer: u32) -> F { - *self.data.get(&(address_space, pointer)).unwrap_or(&F::ZERO) + /// Returns the list of all touched blocks. The list is sorted by address. + pub fn touched_blocks(&self) -> Vec<(Address, AccessMetadata)> { + assert_eq!(self.meta.len(), self.min_block_size.len()); + self.meta + .par_iter() + .zip(self.min_block_size.par_iter()) + .enumerate() + .flat_map(|(addr_space, (page, &align))| { + page.par_iter() + .filter_map(move |(idx, metadata)| { + let ptr = idx as u32 * align; + if ptr == metadata.start_ptr && metadata.block_size != 0 { + Some(((addr_space as u32, ptr), metadata)) + } else { + None + } + }) + .collect::>() + }) + .collect() } - - #[inline(always)] - fn range_array(&self, address_space: u32, pointer: u32) -> [F; N] { - self.data.get_range(&(address_space, pointer)) + pub fn address_space_alignment(&self) -> Vec { + self.min_block_size + .iter() + .map(|&x| log2_strict_usize(x as usize) as u8) + .collect() } } #[cfg(test)] mod tests { + use std::array; + use openvm_stark_backend::p3_field::FieldAlgebra; - use openvm_stark_sdk::p3_baby_bear::BabyBear; + use openvm_stark_sdk::utils::create_seeded_rng; + use p3_baby_bear::BabyBear; + use rand::Rng; - use super::Memory; - use crate::arch::MemoryConfig; + use crate::arch::{testing::VmChipTestBuilder, MemoryConfig}; - macro_rules! bba { - [$($x:expr),*] => { - [$(BabyBear::from_canonical_u32($x)),*] - } - } + type F = BabyBear; - #[test] - fn test_write_read() { - let mut memory = Memory::new(&MemoryConfig::default()); - let address_space = 1; + fn test_memory_write_by_tester(mut tester: VmChipTestBuilder) { + let mut rng = create_seeded_rng(); - memory.write(address_space, 0, bba![1, 2, 3, 4]); + // The point here is to have a lot of equal + // and intersecting/overlapping blocks, + // by limiting the space of valid pointers. + let max_ptr = 20; + let aligns = [4, 4, 4, 1]; + let value_bounds = [256, 256, 256, (1 << 30)]; + let max_log_block_size = 4; + let its = 1000; + for _ in 0..its { + let addr_sp = rng.gen_range(1..=aligns.len()); + let align: usize = aligns[addr_sp - 1]; + let value_bound: u32 = value_bounds[addr_sp - 1]; + let ptr = rng.gen_range(0..max_ptr / align) * align; + let log_len = rng.gen_range(align.trailing_zeros()..=max_log_block_size); + match log_len { + 0 => tester.write::<1>( + addr_sp, + ptr, + array::from_fn(|_| F::from_canonical_u32(rng.gen_range(0..value_bound))), + ), + 1 => tester.write::<2>( + addr_sp, + ptr, + array::from_fn(|_| F::from_canonical_u32(rng.gen_range(0..value_bound))), + ), + 2 => tester.write::<4>( + addr_sp, + ptr, + array::from_fn(|_| F::from_canonical_u32(rng.gen_range(0..value_bound))), + ), + 3 => tester.write::<8>( + addr_sp, + ptr, + array::from_fn(|_| F::from_canonical_u32(rng.gen_range(0..value_bound))), + ), + 4 => tester.write::<16>( + addr_sp, + ptr, + array::from_fn(|_| F::from_canonical_u32(rng.gen_range(0..value_bound))), + ), + _ => unreachable!(), + } + } - let (_, data) = memory.read::<2>(address_space, 0); - assert_eq!(data, bba![1, 2]); + let tester = tester.build().finalize(); + tester.simple_test().expect("Verification failed"); + } - memory.write(address_space, 2, bba![100]); + #[test] + fn test_memory_write_volatile() { + test_memory_write_by_tester(VmChipTestBuilder::::volatile(MemoryConfig::default())); + } - let (_, data) = memory.read::<4>(address_space, 0); - assert_eq!(data, bba![1, 2, 100, 4]); + #[test] + fn test_memory_write_persistent() { + test_memory_write_by_tester(VmChipTestBuilder::::persistent(MemoryConfig::default())); } } diff --git a/crates/vm/src/system/memory/online/basic.rs b/crates/vm/src/system/memory/online/basic.rs new file mode 100644 index 0000000000..b5cddeb775 --- /dev/null +++ b/crates/vm/src/system/memory/online/basic.rs @@ -0,0 +1,243 @@ +use std::{ + alloc::{alloc_zeroed, dealloc, Layout}, + ptr::NonNull, +}; + +use crate::system::memory::online::{LinearMemory, PAGE_SIZE}; + +pub struct BasicMemory { + ptr: NonNull, + size: usize, + layout: Layout, +} + +impl BasicMemory { + #[inline(always)] + pub fn as_ptr(&self) -> *const u8 { + self.ptr.as_ptr() + } + + #[inline(always)] + pub fn as_mut_ptr(&mut self) -> *mut u8 { + self.ptr.as_ptr() + } +} + +impl Drop for BasicMemory { + fn drop(&mut self) { + if self.size > 0 { + unsafe { + dealloc(self.ptr.as_ptr(), self.layout); + } + } + } +} + +impl Clone for BasicMemory { + fn clone(&self) -> Self { + if self.size == 0 { + // Ensure we maintain the same aligned pointer for zero-size + let aligned_ptr = PAGE_SIZE as *mut u8; + let ptr = unsafe { NonNull::new_unchecked(aligned_ptr) }; + return Self { + ptr, + size: 0, + layout: self.layout, + }; + } + + let layout = self.layout; + let ptr = unsafe { + let new_ptr = alloc_zeroed(layout); + if new_ptr.is_null() { + std::alloc::handle_alloc_error(layout); + } + std::ptr::copy_nonoverlapping(self.ptr.as_ptr(), new_ptr, self.size); + NonNull::new_unchecked(new_ptr) + }; + Self { + ptr, + size: self.size, + layout, + } + } +} + +impl std::fmt::Debug for BasicMemory { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("BasicMemory") + .field("size", &self.size) + .field("alignment", &self.layout.align()) + .finish() + } +} + +impl LinearMemory for BasicMemory { + fn new(size: usize) -> Self { + if size == 0 { + // For zero-size allocation, use a dangling pointer with proper alignment + // We need to ensure the pointer is aligned to PAGE_SIZE + let aligned_ptr = PAGE_SIZE as *mut u8; + let ptr = unsafe { NonNull::new_unchecked(aligned_ptr) }; + let layout = Layout::from_size_align(0, PAGE_SIZE) + .expect("Failed to create layout with PAGE_SIZE alignment"); + return Self { + ptr, + size: 0, + layout, + }; + } + + // Use PAGE_SIZE alignment for consistency with MmapMemory + // This also ensures good alignment for any type we might store + let layout = Layout::from_size_align(size, PAGE_SIZE) + .expect("Failed to create layout with PAGE_SIZE alignment"); + + let ptr = unsafe { + let raw_ptr = alloc_zeroed(layout); + if raw_ptr.is_null() { + std::alloc::handle_alloc_error(layout); + } + NonNull::new_unchecked(raw_ptr) + }; + + Self { ptr, size, layout } + } + + fn size(&self) -> usize { + self.size + } + + fn as_slice(&self) -> &[u8] { + unsafe { std::slice::from_raw_parts(self.ptr.as_ptr(), self.size) } + } + + fn as_mut_slice(&mut self) -> &mut [u8] { + unsafe { std::slice::from_raw_parts_mut(self.ptr.as_ptr(), self.size) } + } + + #[inline(always)] + unsafe fn read(&self, from: usize) -> BLOCK { + let size = std::mem::size_of::(); + assert!( + from + size <= self.size, + "read from={from} of size={size} out of bounds: memory size={}", + self.size + ); + + let src = self.as_ptr().add(from) as *const BLOCK; + // SAFETY: + // - Bounds check is done via assert above + // - We assume `src` is aligned to `BLOCK` + // - We assume `BLOCK` is "plain old data" so the underlying `src` bytes is valid to read as + // an initialized value of `BLOCK` + core::ptr::read(src) + } + + #[inline(always)] + unsafe fn read_unaligned(&self, from: usize) -> BLOCK { + let size = std::mem::size_of::(); + assert!( + from + size <= self.size, + "read_unaligned from={from} of size={size} out of bounds: memory size={}", + self.size + ); + + let src = self.as_ptr().add(from) as *const BLOCK; + // SAFETY: + // - Bounds check is done via assert above + // - We assume `BLOCK` is "plain old data" so the underlying `src` bytes is valid to read as + // an initialized value of `BLOCK` + core::ptr::read_unaligned(src) + } + + #[inline(always)] + unsafe fn write(&mut self, start: usize, values: BLOCK) { + let size = std::mem::size_of::(); + assert!( + start + size <= self.size, + "write start={start} of size={size} out of bounds: memory size={}", + self.size + ); + + let dst = self.as_mut_ptr().add(start) as *mut BLOCK; + // SAFETY: + // - Bounds check is done via assert above + // - We assume `dst` is aligned to `BLOCK` + core::ptr::write(dst, values); + } + + #[inline(always)] + unsafe fn write_unaligned(&mut self, start: usize, values: BLOCK) { + let size = std::mem::size_of::(); + assert!( + start + size <= self.size, + "write_unaligned start={start} of size={size} out of bounds: memory size={}", + self.size + ); + + // Use slice's copy_from_slice for safe byte-level copy + let src_bytes = std::slice::from_raw_parts(&values as *const BLOCK as *const u8, size); + self.as_mut_slice()[start..start + size].copy_from_slice(src_bytes); + } + + #[inline(always)] + unsafe fn swap(&mut self, start: usize, values: &mut BLOCK) { + let size = std::mem::size_of::(); + assert!( + start + size <= self.size, + "swap start={start} of size={size} out of bounds: memory size={}", + self.size + ); + + // SAFETY: + // - Bounds check is done via assert above + // - We assume `start` is aligned to `BLOCK` + core::ptr::swap( + self.as_mut_ptr().add(start) as *mut BLOCK, + values as *mut BLOCK, + ); + } + + #[inline(always)] + unsafe fn copy_nonoverlapping(&mut self, to: usize, data: &[T]) { + let byte_len = std::mem::size_of_val(data); + assert!( + to + byte_len <= self.size, + "copy_nonoverlapping to={to} of size={byte_len} out of bounds: memory size={}", + self.size + ); + + // Use slice's copy_from_slice for safe byte-level copy + let src_bytes = std::slice::from_raw_parts(data.as_ptr() as *const u8, byte_len); + self.as_mut_slice()[to..to + byte_len].copy_from_slice(src_bytes); + } + + #[inline(always)] + unsafe fn get_aligned_slice(&self, start: usize, len: usize) -> &[T] { + let byte_len = len * std::mem::size_of::(); + assert!( + start + byte_len <= self.size, + "get_aligned_slice start={start} of size={byte_len} out of bounds: memory size={}", + self.size + ); + assert!( + start % std::mem::align_of::() == 0, + "get_aligned_slice: misaligned start" + ); + + let data = self.as_ptr().add(start) as *const T; + // SAFETY: + // - Bounds check is done via assert above + // - Alignment check is done via assert above + // - `T` is "plain old data" (POD), so conversion from underlying bytes is properly + // initialized + // - `self` will not be mutated while borrowed + core::slice::from_raw_parts(data, len) + } +} + +// SAFETY: BasicMemory properly manages its allocation and can be sent between threads +unsafe impl Send for BasicMemory {} +// SAFETY: BasicMemory has no interior mutability and can be shared between threads +unsafe impl Sync for BasicMemory {} diff --git a/crates/vm/src/system/memory/online/memmap.rs b/crates/vm/src/system/memory/online/memmap.rs new file mode 100644 index 0000000000..3b2155906a --- /dev/null +++ b/crates/vm/src/system/memory/online/memmap.rs @@ -0,0 +1,173 @@ +use std::fmt::Debug; + +use memmap2::MmapMut; + +use super::{LinearMemory, PAGE_SIZE}; + +pub const CELL_STRIDE: usize = 1; + +/// Mmap-backed linear memory. OS-memory pages are paged in on-demand and zero-initialized. +#[derive(Debug)] +pub struct MmapMemory { + mmap: MmapMut, +} + +impl Clone for MmapMemory { + fn clone(&self) -> Self { + let mut new_mmap = MmapMut::map_anon(self.mmap.len()).unwrap(); + new_mmap.copy_from_slice(&self.mmap); + Self { mmap: new_mmap } + } +} + +impl MmapMemory { + #[inline(always)] + pub fn as_ptr(&self) -> *const u8 { + self.mmap.as_ptr() + } + + #[inline(always)] + pub fn as_mut_ptr(&mut self) -> *mut u8 { + self.mmap.as_mut_ptr() + } +} + +impl LinearMemory for MmapMemory { + /// Create a new MmapMemory with the given `size` in bytes. + /// We round `size` up to be a multiple of the mmap page size (4kb by default) so that OS-level + /// MMU protection corresponds to out of bounds protection. + fn new(mut size: usize) -> Self { + size = size.div_ceil(PAGE_SIZE) * PAGE_SIZE; + // anonymous mapping means pages are zero-initialized on first use + Self { + mmap: MmapMut::map_anon(size).unwrap(), + } + } + + fn size(&self) -> usize { + self.mmap.len() + } + + fn as_slice(&self) -> &[u8] { + &self.mmap + } + + fn as_mut_slice(&mut self) -> &mut [u8] { + &mut self.mmap + } + + #[inline(always)] + unsafe fn read(&self, from: usize) -> BLOCK { + debug_assert!( + from + size_of::() <= self.size(), + "read from={from} of size={} out of bounds: memory size={}", + size_of::(), + self.size() + ); + let src = self.as_ptr().add(from) as *const BLOCK; + // SAFETY: + // - MMU will segfault if `src` access is out of bounds. + // - We assume `src` is aligned to `BLOCK` + // - We assume `BLOCK` is "plain old data" so the underlying `src` bytes is valid to read as + // an initialized value of `BLOCK` + core::ptr::read(src) + } + + #[inline(always)] + unsafe fn read_unaligned(&self, from: usize) -> BLOCK { + debug_assert!( + from + size_of::() <= self.size(), + "read_unaligned from={from} of size={} out of bounds: memory size={}", + size_of::(), + self.size() + ); + let src = self.as_ptr().add(from) as *const BLOCK; + // SAFETY: + // - MMU will segfault if `src` access is out of bounds. + // - We assume `BLOCK` is "plain old data" so the underlying `src` bytes is valid to read as + // an initialized value of `BLOCK` + core::ptr::read_unaligned(src) + } + + #[inline(always)] + unsafe fn write(&mut self, start: usize, values: BLOCK) { + debug_assert!( + start + size_of::() <= self.size(), + "write start={start} of size={} out of bounds: memory size={}", + size_of::(), + self.size() + ); + let dst = self.as_mut_ptr().add(start) as *mut BLOCK; + // SAFETY: + // - MMU will segfault if `dst` access is out of bounds. + // - We assume `dst` is aligned to `BLOCK` + core::ptr::write(dst, values); + } + + #[inline(always)] + unsafe fn write_unaligned(&mut self, start: usize, values: BLOCK) { + debug_assert!( + start + size_of::() <= self.size(), + "write_unaligned start={start} of size={} out of bounds: memory size={}", + size_of::(), + self.size() + ); + let dst = self.as_mut_ptr().add(start) as *mut BLOCK; + // SAFETY: + // - MMU will segfault if `dst` access is out of bounds. + core::ptr::write_unaligned(dst, values); + } + + #[inline(always)] + unsafe fn swap(&mut self, start: usize, values: &mut BLOCK) { + debug_assert!( + start + size_of::() <= self.size(), + "swap start={start} of size={} out of bounds: memory size={}", + size_of::(), + self.size() + ); + // SAFETY: + // - MMU will segfault if `start` access is out of bounds. + // - We assume `start` is aligned to `BLOCK` + core::ptr::swap( + self.as_mut_ptr().add(start) as *mut BLOCK, + values as *mut BLOCK, + ); + } + + #[inline(always)] + unsafe fn copy_nonoverlapping(&mut self, to: usize, data: &[T]) { + debug_assert!( + to + size_of_val(data) <= self.size(), + "copy_nonoverlapping to={to} of size={} out of bounds: memory size={}", + size_of_val(data), + self.size() + ); + debug_assert_eq!(PAGE_SIZE % align_of::(), 0); + let src = data.as_ptr(); + let dst = self.as_mut_ptr().add(to) as *mut T; + // SAFETY: + // - MMU will segfault if `dst..dst + size_of_val(data)` is out of bounds. + // - Assumes `to` is aligned to `T` and `self.as_mut_ptr()` is aligned to `T`, which implies + // the same for `dst`. + core::ptr::copy_nonoverlapping::(src, dst, data.len()); + } + + #[inline(always)] + unsafe fn get_aligned_slice(&self, start: usize, len: usize) -> &[T] { + debug_assert!( + start + len * size_of::() <= self.size(), + "get_aligned_slice start={start} of size={} out of bounds: memory size={}", + len * size_of::(), + self.size() + ); + let data = self.as_ptr().add(start) as *const T; + // SAFETY: + // - MMU will segfault if `data..data + len * size_of::()` is out of bounds. + // - Assumes `data` is aligned to `T` + // - `T` is "plain old data" (POD), so conversion from underlying bytes is properly + // initialized + // - `self` will not be mutated while borrowed + core::slice::from_raw_parts(data, len) + } +} diff --git a/crates/vm/src/system/memory/online/paged_vec.rs b/crates/vm/src/system/memory/online/paged_vec.rs new file mode 100644 index 0000000000..a60bb8eae4 --- /dev/null +++ b/crates/vm/src/system/memory/online/paged_vec.rs @@ -0,0 +1,115 @@ +use std::fmt::Debug; + +use openvm_stark_backend::p3_maybe_rayon::prelude::*; + +#[derive(Debug, Clone)] +pub struct PagedVec { + pages: Vec>>, + page_size: usize, +} + +unsafe impl Send for PagedVec {} +unsafe impl Sync for PagedVec {} + +impl PagedVec { + #[inline] + /// `total_size` is the capacity of elements of type `T`. + pub fn new(total_size: usize, page_size: usize) -> Self { + let num_pages = total_size.div_ceil(page_size); + Self { + pages: vec![None; num_pages], + page_size, + } + } + + /// Panics if the index is out of bounds. Creates a new page with default values if no page + /// exists. + #[inline] + pub fn get(&mut self, index: usize) -> &T { + let page_idx = index / self.page_size; + let offset = index % self.page_size; + + assert!( + page_idx < self.pages.len(), + "PagedVec::get index out of bounds: {} >= {}", + index, + self.pages.len() * self.page_size + ); + + if self.pages[page_idx].is_none() { + let page = vec![T::default(); self.page_size]; + self.pages[page_idx] = Some(page.into_boxed_slice()); + } + + unsafe { + // SAFETY: + // - We just ensured the page exists and has size `page_size` + // - offset < page_size by construction + self.pages + .get_unchecked(page_idx) + .as_ref() + .unwrap() + .get_unchecked(offset) + } + } + + /// Panics if the index is out of bounds. Creates new page before write when necessary. + #[inline] + pub fn set(&mut self, index: usize, value: T) { + let page_idx = index / self.page_size; + let offset = index % self.page_size; + + assert!( + page_idx < self.pages.len(), + "PagedVec::set index out of bounds: {} >= {}", + index, + self.pages.len() * self.page_size + ); + + if let Some(page) = &mut self.pages[page_idx] { + // SAFETY: + // - If page exists, then it has size `page_size` + unsafe { + *page.get_unchecked_mut(offset) = value; + } + } else { + let mut page = vec![T::default(); self.page_size]; + page[offset] = value; + self.pages[page_idx] = Some(page.into_boxed_slice()); + } + } + + pub fn par_iter(&self) -> impl ParallelIterator + '_ + where + T: Send + Sync, + { + self.pages + .par_iter() + .enumerate() + .filter_map(move |(page_idx, page)| { + page.as_ref().map(move |p| { + p.par_iter() + .enumerate() + .map(move |(offset, &value)| (page_idx * self.page_size + offset, value)) + }) + }) + .flatten() + } + + pub fn iter(&self) -> impl Iterator + '_ + where + T: Send + Sync, + { + self.pages + .iter() + .enumerate() + .filter_map(move |(page_idx, page)| { + page.as_ref().map(move |p| { + p.iter() + .enumerate() + .map(move |(offset, &value)| (page_idx * self.page_size + offset, value)) + }) + }) + .flatten() + } +} diff --git a/crates/vm/src/system/memory/paged_vec.rs b/crates/vm/src/system/memory/paged_vec.rs deleted file mode 100644 index 8a8b030970..0000000000 --- a/crates/vm/src/system/memory/paged_vec.rs +++ /dev/null @@ -1,447 +0,0 @@ -use std::{mem::MaybeUninit, ops::Range, ptr}; - -use serde::{Deserialize, Serialize}; - -use crate::arch::MemoryConfig; - -/// (address_space, pointer) -pub type Address = (u32, u32); -pub const PAGE_SIZE: usize = 1 << 12; - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct PagedVec { - pub pages: Vec>>, -} - -// ------------------------------------------------------------------ -// Common Helper Functions -// These functions encapsulate the common logic for copying ranges -// across pages, both for read-only and read-write (set) cases. -impl PagedVec { - // Copies a range of length `len` starting at index `start` - // into the memory pointed to by `dst`. If the relevant page is not - // initialized, fills that portion with T::default(). - fn read_range_generic(&self, start: usize, len: usize, dst: *mut T) { - let start_page = start / PAGE_SIZE; - let end_page = (start + len - 1) / PAGE_SIZE; - unsafe { - if start_page == end_page { - let offset = start % PAGE_SIZE; - if let Some(page) = self.pages[start_page].as_ref() { - ptr::copy_nonoverlapping(page.as_ptr().add(offset), dst, len); - } else { - std::slice::from_raw_parts_mut(dst, len).fill(T::default()); - } - } else { - let offset = start % PAGE_SIZE; - let first_part = PAGE_SIZE - offset; - if let Some(page) = self.pages[start_page].as_ref() { - ptr::copy_nonoverlapping(page.as_ptr().add(offset), dst, first_part); - } else { - std::slice::from_raw_parts_mut(dst, first_part).fill(T::default()); - } - let second_part = len - first_part; - if let Some(page) = self.pages[end_page].as_ref() { - ptr::copy_nonoverlapping(page.as_ptr(), dst.add(first_part), second_part); - } else { - std::slice::from_raw_parts_mut(dst.add(first_part), second_part) - .fill(T::default()); - } - } - } - } - - // Updates a range of length `len` starting at index `start` with new values. - // It copies the current values into the memory pointed to by `dst` - // and then writes the new values into the underlying pages, - // allocating pages (with defaults) if necessary. - fn set_range_generic(&mut self, start: usize, len: usize, new: *const T, dst: *mut T) { - let start_page = start / PAGE_SIZE; - let end_page = (start + len - 1) / PAGE_SIZE; - unsafe { - if start_page == end_page { - let offset = start % PAGE_SIZE; - let page = - self.pages[start_page].get_or_insert_with(|| vec![T::default(); PAGE_SIZE]); - ptr::copy_nonoverlapping(page.as_ptr().add(offset), dst, len); - ptr::copy_nonoverlapping(new, page.as_mut_ptr().add(offset), len); - } else { - let offset = start % PAGE_SIZE; - let first_part = PAGE_SIZE - offset; - { - let page = - self.pages[start_page].get_or_insert_with(|| vec![T::default(); PAGE_SIZE]); - ptr::copy_nonoverlapping(page.as_ptr().add(offset), dst, first_part); - ptr::copy_nonoverlapping(new, page.as_mut_ptr().add(offset), first_part); - } - let second_part = len - first_part; - { - let page = - self.pages[end_page].get_or_insert_with(|| vec![T::default(); PAGE_SIZE]); - ptr::copy_nonoverlapping(page.as_ptr(), dst.add(first_part), second_part); - ptr::copy_nonoverlapping(new.add(first_part), page.as_mut_ptr(), second_part); - } - } - } - } -} - -// ------------------------------------------------------------------ -// Implementation for types requiring Default + Clone -impl PagedVec { - pub fn new(num_pages: usize) -> Self { - Self { - pages: vec![None; num_pages], - } - } - - pub fn get(&self, index: usize) -> Option<&T> { - let page_idx = index / PAGE_SIZE; - self.pages[page_idx] - .as_ref() - .map(|page| &page[index % PAGE_SIZE]) - } - - pub fn get_mut(&mut self, index: usize) -> Option<&mut T> { - let page_idx = index / PAGE_SIZE; - self.pages[page_idx] - .as_mut() - .map(|page| &mut page[index % PAGE_SIZE]) - } - - pub fn set(&mut self, index: usize, value: T) -> Option { - let page_idx = index / PAGE_SIZE; - if let Some(page) = self.pages[page_idx].as_mut() { - Some(std::mem::replace(&mut page[index % PAGE_SIZE], value)) - } else { - let page = self.pages[page_idx].get_or_insert_with(|| vec![T::default(); PAGE_SIZE]); - page[index % PAGE_SIZE] = value; - None - } - } - - #[inline(always)] - pub fn range_vec(&self, range: Range) -> Vec { - let len = range.end - range.start; - // Create a vector for uninitialized values. - let mut result: Vec> = Vec::with_capacity(len); - // SAFETY: We set the length and then initialize every element via read_range_generic. - unsafe { - result.set_len(len); - self.read_range_generic(range.start, len, result.as_mut_ptr() as *mut T); - std::mem::transmute::>, Vec>(result) - } - } - - pub fn set_range(&mut self, range: Range, values: &[T]) -> Vec { - let len = range.end - range.start; - assert_eq!(values.len(), len); - let mut result: Vec> = Vec::with_capacity(len); - // SAFETY: We will write to every element in result via set_range_generic. - unsafe { - result.set_len(len); - self.set_range_generic( - range.start, - len, - values.as_ptr(), - result.as_mut_ptr() as *mut T, - ); - std::mem::transmute::>, Vec>(result) - } - } - - pub fn memory_size(&self) -> usize { - self.pages.len() * PAGE_SIZE - } - - pub fn is_empty(&self) -> bool { - self.pages.iter().all(|page| page.is_none()) - } -} - -// ------------------------------------------------------------------ -// Implementation for types requiring Default + Copy -impl PagedVec { - #[inline(always)] - pub fn range_array(&self, from: usize) -> [T; N] { - // Create an uninitialized array of MaybeUninit - let mut result: [MaybeUninit; N] = unsafe { - // SAFETY: An uninitialized `[MaybeUninit; N]` is valid. - MaybeUninit::uninit().assume_init() - }; - self.read_range_generic(from, N, result.as_mut_ptr() as *mut T); - // SAFETY: All elements have been initialized. - unsafe { ptr::read(&result as *const _ as *const [T; N]) } - } - - #[inline(always)] - pub fn set_range_array(&mut self, from: usize, values: &[T; N]) -> [T; N] { - // Create an uninitialized array for old values. - let mut result: [MaybeUninit; N] = unsafe { MaybeUninit::uninit().assume_init() }; - self.set_range_generic(from, N, values.as_ptr(), result.as_mut_ptr() as *mut T); - unsafe { ptr::read(&result as *const _ as *const [T; N]) } - } -} - -impl PagedVec { - pub fn iter(&self) -> PagedVecIter<'_, T, PAGE_SIZE> { - PagedVecIter { - vec: self, - current_page: 0, - current_index_in_page: 0, - } - } -} - -pub struct PagedVecIter<'a, T, const PAGE_SIZE: usize> { - vec: &'a PagedVec, - current_page: usize, - current_index_in_page: usize, -} - -impl Iterator for PagedVecIter<'_, T, PAGE_SIZE> { - type Item = (usize, T); - - fn next(&mut self) -> Option { - while self.current_page < self.vec.pages.len() - && self.vec.pages[self.current_page].is_none() - { - self.current_page += 1; - debug_assert_eq!(self.current_index_in_page, 0); - self.current_index_in_page = 0; - } - if self.current_page >= self.vec.pages.len() { - return None; - } - let global_index = self.current_page * PAGE_SIZE + self.current_index_in_page; - - let page = self.vec.pages[self.current_page].as_ref()?; - let value = page[self.current_index_in_page].clone(); - - self.current_index_in_page += 1; - if self.current_index_in_page == PAGE_SIZE { - self.current_page += 1; - self.current_index_in_page = 0; - } - Some((global_index, value)) - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct AddressMap { - pub paged_vecs: Vec>, - pub as_offset: u32, -} - -impl Default for AddressMap { - fn default() -> Self { - Self::from_mem_config(&MemoryConfig::default()) - } -} - -impl AddressMap { - pub fn new(as_offset: u32, as_cnt: usize, mem_size: usize) -> Self { - Self { - paged_vecs: vec![PagedVec::new(mem_size.div_ceil(PAGE_SIZE)); as_cnt], - as_offset, - } - } - pub fn from_mem_config(mem_config: &MemoryConfig) -> Self { - Self::new( - mem_config.as_offset, - 1 << mem_config.as_height, - 1 << mem_config.pointer_max_bits, - ) - } - pub fn items(&self) -> impl Iterator + '_ { - self.paged_vecs - .iter() - .enumerate() - .flat_map(move |(as_idx, page)| { - page.iter() - .map(move |(ptr_idx, x)| ((as_idx as u32 + self.as_offset, ptr_idx as u32), x)) - }) - } - pub fn get(&self, address: &Address) -> Option<&T> { - self.paged_vecs[(address.0 - self.as_offset) as usize].get(address.1 as usize) - } - pub fn get_mut(&mut self, address: &Address) -> Option<&mut T> { - self.paged_vecs[(address.0 - self.as_offset) as usize].get_mut(address.1 as usize) - } - pub fn insert(&mut self, address: &Address, data: T) -> Option { - self.paged_vecs[(address.0 - self.as_offset) as usize].set(address.1 as usize, data) - } - pub fn is_empty(&self) -> bool { - self.paged_vecs.iter().all(|page| page.is_empty()) - } - - pub fn from_iter( - as_offset: u32, - as_cnt: usize, - mem_size: usize, - iter: impl IntoIterator, - ) -> Self { - let mut vec = Self::new(as_offset, as_cnt, mem_size); - for (address, data) in iter { - vec.insert(&address, data); - } - vec - } -} - -impl AddressMap { - pub fn get_range(&self, address: &Address) -> [T; N] { - self.paged_vecs[(address.0 - self.as_offset) as usize].range_array(address.1 as usize) - } - pub fn set_range(&mut self, address: &Address, values: &[T; N]) -> [T; N] { - self.paged_vecs[(address.0 - self.as_offset) as usize] - .set_range_array(address.1 as usize, values) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_basic_get_set() { - let mut v = PagedVec::<_, 4>::new(3); - assert_eq!(v.get(0), None); - v.set(0, 42); - assert_eq!(v.get(0), Some(&42)); - } - - #[test] - fn test_cross_page_operations() { - let mut v = PagedVec::<_, 4>::new(3); - v.set(3, 10); // Last element of first page - v.set(4, 20); // First element of second page - assert_eq!(v.get(3), Some(&10)); - assert_eq!(v.get(4), Some(&20)); - } - - #[test] - fn test_page_boundaries() { - let mut v = PagedVec::<_, 4>::new(2); - // Fill first page - v.set(0, 1); - v.set(1, 2); - v.set(2, 3); - v.set(3, 4); - // Fill second page - v.set(4, 5); - v.set(5, 6); - v.set(6, 7); - v.set(7, 8); - - // Verify all values - assert_eq!(v.range_vec(0..8), [1, 2, 3, 4, 5, 6, 7, 8]); - } - - #[test] - fn test_range_cross_page_boundary() { - let mut v = PagedVec::<_, 4>::new(2); - v.set_range(2..8, &[10, 11, 12, 13, 14, 15]); - assert_eq!(v.range_vec(2..8), [10, 11, 12, 13, 14, 15]); - } - - #[test] - fn test_large_indices() { - let mut v = PagedVec::<_, 4>::new(100); - let large_index = 399; - v.set(large_index, 42); - assert_eq!(v.get(large_index), Some(&42)); - } - - #[test] - fn test_range_operations_with_defaults() { - let mut v = PagedVec::<_, 4>::new(3); - v.set(2, 5); - v.set(5, 10); - - // Should include both set values and defaults - assert_eq!(v.range_vec(1..7), [0, 5, 0, 0, 10, 0]); - } - - #[test] - fn test_non_zero_default_type() { - let mut v: PagedVec = PagedVec::new(2); - assert_eq!(v.get(0), None); // bool's default - v.set(0, true); - assert_eq!(v.get(0), Some(&true)); - assert_eq!(v.get(1), Some(&false)); // because we created the page - } - - #[test] - fn test_set_range_overlapping_pages() { - let mut v = PagedVec::<_, 4>::new(3); - let test_data = [1, 2, 3, 4, 5, 6]; - v.set_range(2..8, &test_data); - - // Verify first page - assert_eq!(v.get(2), Some(&1)); - assert_eq!(v.get(3), Some(&2)); - - // Verify second page - assert_eq!(v.get(4), Some(&3)); - assert_eq!(v.get(5), Some(&4)); - assert_eq!(v.get(6), Some(&5)); - assert_eq!(v.get(7), Some(&6)); - } - - #[test] - fn test_overlapping_set_ranges() { - let mut v = PagedVec::<_, 4>::new(3); - - // Initial set_range - v.set_range(0..5, &[1, 2, 3, 4, 5]); - assert_eq!(v.range_vec(0..5), [1, 2, 3, 4, 5]); - - // Overlap from beginning - v.set_range(0..3, &[10, 20, 30]); - assert_eq!(v.range_vec(0..5), [10, 20, 30, 4, 5]); - - // Overlap in middle - v.set_range(2..4, &[42, 43]); - assert_eq!(v.range_vec(0..5), [10, 20, 42, 43, 5]); - - // Overlap at end - v.set_range(4..6, &[91, 92]); - assert_eq!(v.range_vec(0..6), [10, 20, 42, 43, 91, 92]); - } - - #[test] - fn test_overlapping_set_ranges_cross_pages() { - let mut v = PagedVec::<_, 4>::new(3); - - // Fill across first two pages - v.set_range(0..8, &[1, 2, 3, 4, 5, 6, 7, 8]); - - // Overlap end of first page and start of second - v.set_range(2..6, &[21, 22, 23, 24]); - assert_eq!(v.range_vec(0..8), [1, 2, 21, 22, 23, 24, 7, 8]); - - // Overlap multiple pages - v.set_range(1..7, &[31, 32, 33, 34, 35, 36]); - assert_eq!(v.range_vec(0..8), [1, 31, 32, 33, 34, 35, 36, 8]); - } - - #[test] - fn test_iterator() { - let mut v = PagedVec::<_, 4>::new(3); - - v.set_range(4..10, &[1, 2, 3, 4, 5, 6]); - let contents: Vec<_> = v.iter().collect(); - assert_eq!(contents.len(), 8); // two pages - - contents - .iter() - .take(6) - .enumerate() - .for_each(|(i, &(idx, val))| { - assert_eq!((idx, val), (4 + i, 1 + i)); - }); - assert_eq!(contents[6], (10, 0)); - assert_eq!(contents[7], (11, 0)); - } -} diff --git a/crates/vm/src/system/memory/persistent.rs b/crates/vm/src/system/memory/persistent.rs index 55a178be4d..aa68a0c8bb 100644 --- a/crates/vm/src/system/memory/persistent.rs +++ b/crates/vm/src/system/memory/persistent.rs @@ -18,13 +18,14 @@ use openvm_stark_backend::{ AirRef, Chip, ChipUsageGetter, }; use rustc_hash::FxHashSet; +use tracing::instrument; -use super::merkle::SerialReceiver; +use super::{merkle::SerialReceiver, online::INITIAL_TIMESTAMP, TimestampedValues}; use crate::{ - arch::hasher::Hasher, + arch::{hasher::Hasher, ADDR_SPACE_OFFSET}, system::memory::{ dimensions::MemoryDimensions, offline_checker::MemoryBus, MemoryAddress, MemoryImage, - TimestampedEquipartition, INITIAL_TIMESTAMP, + TimestampedEquipartition, }, }; @@ -92,7 +93,7 @@ impl Air for PersistentBoundaryA // direction = -1 => is_final = 1 local.expand_direction.into(), AB::Expr::ZERO, - local.address_space - AB::F::from_canonical_u32(self.memory_dims.as_offset), + local.address_space - AB::F::from_canonical_u32(ADDR_SPACE_OFFSET), local.leaf_label.into(), ]; expand_fields.extend(local.hash.map(Into::into)); @@ -123,18 +124,18 @@ impl Air for PersistentBoundaryA pub struct PersistentBoundaryChip { pub air: PersistentBoundaryAir, - touched_labels: TouchedLabels, + pub touched_labels: TouchedLabels, overridden_height: Option, } #[derive(Debug)] -enum TouchedLabels { +pub enum TouchedLabels { Running(FxHashSet<(u32, u32)>), Final(Vec>), } #[derive(Debug)] -struct FinalTouchedLabel { +pub struct FinalTouchedLabel { address_space: u32, label: u32, init_values: [F; CHUNK], @@ -159,7 +160,15 @@ impl TouchedLabels { _ => panic!("Cannot touch after finalization"), } } - fn len(&self) -> usize { + + pub fn is_empty(&self) -> bool { + match self { + TouchedLabels::Running(touched_labels) => touched_labels.is_empty(), + TouchedLabels::Final(touched_labels) => touched_labels.is_empty(), + } + } + + pub fn len(&self) -> usize { match self { TouchedLabels::Running(touched_labels) => touched_labels.len(), TouchedLabels::Final(touched_labels) => touched_labels.len(), @@ -198,47 +207,40 @@ impl PersistentBoundaryChip { } } + #[instrument(name = "boundary_finalize", skip_all)] pub fn finalize( &mut self, - initial_memory: &MemoryImage, + initial_memory: &MemoryImage, + // Only touched stuff final_memory: &TimestampedEquipartition, hasher: &mut H, ) where H: Hasher + Sync + for<'a> SerialReceiver<&'a [F]>, { - match &mut self.touched_labels { - TouchedLabels::Running(touched_labels) => { - let final_touched_labels: Vec<_> = touched_labels - .par_iter() - .map(|&(address_space, label)| { - let pointer = label * CHUNK as u32; - let init_values = array::from_fn(|i| { - *initial_memory - .get(&(address_space, pointer + i as u32)) - .unwrap_or(&F::ZERO) - }); - let initial_hash = hasher.hash(&init_values); - let timestamped_values = final_memory.get(&(address_space, label)).unwrap(); - let final_hash = hasher.hash(×tamped_values.values); - FinalTouchedLabel { - address_space, - label, - init_values, - final_values: timestamped_values.values, - init_hash: initial_hash, - final_hash, - final_timestamp: timestamped_values.timestamp, - } - }) - .collect(); - for l in &final_touched_labels { - hasher.receive(&l.init_values); - hasher.receive(&l.final_values); + let final_touched_labels: Vec<_> = final_memory + .par_iter() + .map(|&((addr_space, ptr), ts_values)| { + let TimestampedValues { timestamp, values } = ts_values; + let init_values = + array::from_fn(|i| initial_memory.get_f::(addr_space, ptr + i as u32)); + let initial_hash = hasher.hash(&init_values); + let final_hash = hasher.hash(&values); + FinalTouchedLabel { + address_space: addr_space, + label: ptr / CHUNK as u32, + init_values, + final_values: values, + init_hash: initial_hash, + final_hash, + final_timestamp: timestamp, } - self.touched_labels = TouchedLabels::Final(final_touched_labels); - } - _ => panic!("Cannot finalize after finalization"), + }) + .collect(); + for l in &final_touched_labels { + hasher.receive(&l.init_values); + hasher.receive(&l.final_values); } + self.touched_labels = TouchedLabels::Final(final_touched_labels); } } diff --git a/crates/vm/src/system/memory/tests.rs b/crates/vm/src/system/memory/tests.rs index 9ebb9306aa..b950e7862e 100644 --- a/crates/vm/src/system/memory/tests.rs +++ b/crates/vm/src/system/memory/tests.rs @@ -292,34 +292,35 @@ fn make_random_accesses( ) -> Vec { (0..1024) .map(|_| { - let address_space = F::from_canonical_u32(*[1, 2].choose(&mut rng).unwrap()); + let address_space = F::from_canonical_u32(*[4, 5].choose(&mut rng).unwrap()); match rng.gen_range(0..5) { 0 => { let pointer = F::from_canonical_usize(gen_pointer(rng, 1)); let data = F::from_canonical_u32(rng.gen_range(0..1 << 30)); - let (record_id, _) = memory_controller.write(address_space, pointer, [data]); + let (record_id, _) = memory_controller.write(address_space, pointer, &[data]); record_id } 1 => { let pointer = F::from_canonical_usize(gen_pointer(rng, 1)); - let (record_id, _) = memory_controller.read::<1>(address_space, pointer); + let (record_id, _) = memory_controller.read::(address_space, pointer); record_id } 2 => { let pointer = F::from_canonical_usize(gen_pointer(rng, 4)); - let (record_id, _) = memory_controller.read::<4>(address_space, pointer); + let (record_id, _) = memory_controller.read::(address_space, pointer); record_id } 3 => { let pointer = F::from_canonical_usize(gen_pointer(rng, 4)); let data = array::from_fn(|_| F::from_canonical_u32(rng.gen_range(0..1 << 30))); - let (record_id, _) = memory_controller.write::<4>(address_space, pointer, data); + let (record_id, _) = + memory_controller.write::(address_space, pointer, &data); record_id } 4 => { let pointer = F::from_canonical_usize(gen_pointer(rng, MAX)); - let (record_id, _) = memory_controller.read::(address_space, pointer); + let (record_id, _) = memory_controller.read::(address_space, pointer); record_id } _ => unreachable!(), diff --git a/crates/vm/src/system/memory/tree/mod.rs b/crates/vm/src/system/memory/tree/mod.rs deleted file mode 100644 index fcdb86d8ee..0000000000 --- a/crates/vm/src/system/memory/tree/mod.rs +++ /dev/null @@ -1,177 +0,0 @@ -pub mod public_values; - -use std::{ops::Range, sync::Arc}; - -use openvm_stark_backend::{p3_field::PrimeField32, p3_maybe_rayon::prelude::*}; -use MemoryNode::*; - -use super::controller::dimensions::MemoryDimensions; -use crate::{ - arch::hasher::{Hasher, HasherChip}, - system::memory::MemoryImage, -}; - -#[derive(Clone, Debug, PartialEq)] -pub enum MemoryNode { - Leaf { - values: [F; CHUNK], - }, - NonLeaf { - hash: [F; CHUNK], - left: Arc>, - right: Arc>, - }, -} - -impl MemoryNode { - pub fn hash(&self) -> [F; CHUNK] { - match self { - Leaf { values: hash } => *hash, - NonLeaf { hash, .. } => *hash, - } - } - - pub fn new_leaf(values: [F; CHUNK]) -> Self { - Leaf { values } - } - - pub fn new_nonleaf( - left: Arc>, - right: Arc>, - hasher: &mut impl HasherChip, - ) -> Self { - NonLeaf { - hash: hasher.compress_and_record(&left.hash(), &right.hash()), - left, - right, - } - } - - /// Returns a tree of height `height` with all leaves set to `leaf_value`. - pub fn construct_uniform( - height: usize, - leaf_value: [F; CHUNK], - hasher: &impl Hasher, - ) -> MemoryNode { - if height == 0 { - Self::new_leaf(leaf_value) - } else { - let child = Arc::new(Self::construct_uniform(height - 1, leaf_value, hasher)); - NonLeaf { - hash: hasher.compress(&child.hash(), &child.hash()), - left: child.clone(), - right: child, - } - } - } - - fn from_memory( - memory: &[(u64, F)], - lookup_range: Range, - length: u64, - from: u64, - hasher: &(impl Hasher + Sync), - zero_leaf: &MemoryNode, - ) -> MemoryNode { - if length == CHUNK as u64 { - if lookup_range.is_empty() { - zero_leaf.clone() - } else { - debug_assert_eq!(memory[lookup_range.start].0, from); - let mut values = [F::ZERO; CHUNK]; - for (index, value) in memory[lookup_range].iter() { - values[(index % CHUNK as u64) as usize] = *value; - } - MemoryNode::new_leaf(hasher.hash(&values)) - } - } else if lookup_range.is_empty() { - let leaf_value = hasher.hash(&[F::ZERO; CHUNK]); - MemoryNode::construct_uniform( - (length / CHUNK as u64).trailing_zeros() as usize, - leaf_value, - hasher, - ) - } else { - let midpoint = from + length / 2; - let mid = { - let mut left = lookup_range.start; - let mut right = lookup_range.end; - if memory[left].0 >= midpoint { - left - } else { - while left + 1 < right { - let mid = left + (right - left) / 2; - if memory[mid].0 < midpoint { - left = mid; - } else { - right = mid; - } - } - right - } - }; - let (left, right) = join( - || { - Self::from_memory( - memory, - lookup_range.start..mid, - length >> 1, - from, - hasher, - zero_leaf, - ) - }, - || { - Self::from_memory( - memory, - mid..lookup_range.end, - length >> 1, - midpoint, - hasher, - zero_leaf, - ) - }, - ); - NonLeaf { - hash: hasher.compress(&left.hash(), &right.hash()), - left: Arc::new(left), - right: Arc::new(right), - } - } - } - - pub fn tree_from_memory( - memory_dimensions: MemoryDimensions, - memory: &MemoryImage, - hasher: &(impl Hasher + Sync), - ) -> MemoryNode { - // Construct a Vec that includes the address space in the label calculation, - // representing the entire memory tree. - let memory_items = memory - .items() - .filter(|((_, ptr), _)| *ptr as usize / CHUNK < (1 << memory_dimensions.address_height)) - .map(|((address_space, pointer), value)| { - ( - memory_dimensions.label_to_index((address_space, pointer / CHUNK as u32)) - * CHUNK as u64 - + (pointer % CHUNK as u32) as u64, - value, - ) - }) - .collect::>(); - debug_assert!(memory_items.is_sorted_by_key(|(addr, _)| addr)); - debug_assert!( - memory_items.last().map_or(0, |(addr, _)| *addr) - < ((CHUNK as u64) << memory_dimensions.overall_height()) - ); - let zero_leaf = MemoryNode::new_leaf(hasher.hash(&[F::ZERO; CHUNK])); - Self::from_memory( - &memory_items, - 0..memory_items.len(), - (CHUNK as u64) << memory_dimensions.overall_height(), - 0, - hasher, - &zero_leaf, - ) - } -} diff --git a/crates/vm/src/system/memory/volatile/mod.rs b/crates/vm/src/system/memory/volatile/mod.rs index e01162c789..8c7d976599 100644 --- a/crates/vm/src/system/memory/volatile/mod.rs +++ b/crates/vm/src/system/memory/volatile/mod.rs @@ -26,6 +26,7 @@ use openvm_stark_backend::{ AirRef, Chip, ChipUsageGetter, }; use static_assertions::const_assert; +use tracing::instrument; use super::TimestampedEquipartition; use crate::system::memory::{ @@ -183,7 +184,7 @@ pub struct VolatileBoundaryChip { pub air: VolatileBoundaryAir, range_checker: SharedVariableRangeCheckerChip, overridden_height: Option, - final_memory: Option>, + pub final_memory: Option>, addr_space_max_bits: usize, pointer_max_bits: usize, } @@ -218,6 +219,7 @@ impl VolatileBoundaryChip { } /// Volatile memory requires the starting and final memory to be in equipartition with block /// size `1`. When block size is `1`, then the `label` is the same as the address pointer. + #[instrument(name = "boundary_finalize", skip_all)] pub fn finalize(&mut self, final_memory: TimestampedEquipartition) { self.final_memory = Some(final_memory); } diff --git a/crates/vm/src/system/memory/volatile/tests.rs b/crates/vm/src/system/memory/volatile/tests.rs index 29917d219d..8e941179f7 100644 --- a/crates/vm/src/system/memory/volatile/tests.rs +++ b/crates/vm/src/system/memory/volatile/tests.rs @@ -55,14 +55,15 @@ fn boundary_air_test() { let final_data = Val::from_canonical_u32(rng.gen_range(0..MAX_VAL)); let final_clk = rng.gen_range(1..MAX_VAL) as u32; - final_memory.insert( + final_memory.push(( (addr_space, pointer), TimestampedValues { values: [final_data], timestamp: final_clk, }, - ); + )); } + final_memory.sort_by_key(|(key, _)| *key); let diff_height = num_addresses.next_power_of_two() - num_addresses; @@ -90,7 +91,10 @@ fn boundary_air_test() { distinct_addresses .iter() .flat_map(|(addr_space, pointer)| { - let timestamped_value = final_memory.get(&(*addr_space, *pointer)).unwrap(); + let timestamped_value = final_memory[final_memory + .binary_search_by(|(key, _)| key.cmp(&(*addr_space, *pointer))) + .unwrap()] + .1; vec![ Val::ONE, diff --git a/crates/vm/src/system/mod.rs b/crates/vm/src/system/mod.rs index a1038ac86a..7164846d5b 100644 --- a/crates/vm/src/system/mod.rs +++ b/crates/vm/src/system/mod.rs @@ -1,5 +1,6 @@ pub mod connector; pub mod memory; +// Necessary for the PublicValuesChip pub mod native_adapter; /// Chip to handle phantom instructions. /// The Air will always constrain a NOP which advances pc by DEFAULT_PC_STEP. diff --git a/crates/vm/src/system/native_adapter/mod.rs b/crates/vm/src/system/native_adapter/mod.rs index 95c2c7c4a4..c2ec200934 100644 --- a/crates/vm/src/system/native_adapter/mod.rs +++ b/crates/vm/src/system/native_adapter/mod.rs @@ -1,3 +1,5 @@ +pub mod util; + use std::{ borrow::{Borrow, BorrowMut}, marker::PhantomData, @@ -5,86 +7,31 @@ use std::{ use openvm_circuit::{ arch::{ - AdapterAirContext, AdapterRuntimeContext, BasicAdapterInterface, ExecutionBridge, - ExecutionBus, ExecutionState, MinimalInstruction, Result, VmAdapterAir, VmAdapterChip, - VmAdapterInterface, + AdapterAirContext, BasicAdapterInterface, ExecutionBridge, ExecutionState, + MinimalInstruction, VmAdapterAir, }, - system::{ - memory::{ - offline_checker::{MemoryBridge, MemoryReadOrImmediateAuxCols, MemoryWriteAuxCols}, - MemoryAddress, MemoryController, - }, - program::ProgramBus, + system::memory::{ + offline_checker::{MemoryBridge, MemoryReadOrImmediateAuxCols, MemoryWriteAuxCols}, + MemoryAddress, }, }; +use openvm_circuit_primitives::AlignedBytesBorrow; use openvm_circuit_primitives_derive::AlignedBorrow; -use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP}; +use openvm_instructions::{ + instruction::Instruction, program::DEFAULT_PC_STEP, riscv::RV32_IMM_AS, NATIVE_AS, +}; use openvm_stark_backend::{ interaction::InteractionBuilder, p3_air::BaseAir, p3_field::{Field, FieldAlgebra, PrimeField32}, }; -use serde::{Deserialize, Serialize}; -use serde_big_array::BigArray; - -use crate::system::memory::{OfflineMemory, RecordId}; - -/// R reads(R<=2), W writes(W<=1). -/// Operands: b for the first read, c for the second read, a for the first write. -/// If an operand is not used, its address space and pointer should be all 0. -#[derive(Debug)] -pub struct NativeAdapterChip { - pub air: NativeAdapterAir, - _phantom: PhantomData, -} - -impl NativeAdapterChip { - pub fn new( - execution_bus: ExecutionBus, - program_bus: ProgramBus, - memory_bridge: MemoryBridge, - ) -> Self { - Self { - air: NativeAdapterAir { - execution_bridge: ExecutionBridge::new(execution_bus, program_bus), - memory_bridge, - }, - _phantom: PhantomData, - } - } -} +use util::{tracing_read_or_imm_native, tracing_write_native}; -#[repr(C)] -#[derive(Debug, Serialize, Deserialize)] -pub struct NativeReadRecord { - #[serde(with = "BigArray")] - pub reads: [(RecordId, [F; 1]); R], -} - -impl NativeReadRecord { - pub fn b(&self) -> &[F; 1] { - &self.reads[0].1 - } - - pub fn c(&self) -> &[F; 1] { - &self.reads[1].1 - } -} - -#[repr(C)] -#[derive(Debug, Serialize, Deserialize)] -#[serde(bound = "F: Field")] -pub struct NativeWriteRecord { - pub from_state: ExecutionState, - #[serde(with = "BigArray")] - pub writes: [(RecordId, [F; 1]); W], -} - -impl NativeWriteRecord { - pub fn a(&self) -> &[F; 1] { - &self.writes[0].1 - } -} +use super::memory::{online::TracingMemory, MemoryAuxColsFactory}; +use crate::{ + arch::{get_record_from_slice, AdapterTraceFiller, AdapterTraceStep}, + system::memory::offline_checker::{MemoryReadAuxRecord, MemoryWriteAuxRecord}, +}; #[repr(C)] #[derive(AlignedBorrow)] @@ -205,101 +152,150 @@ impl VmAdapterAir } } -impl VmAdapterChip - for NativeAdapterChip -{ - type ReadRecord = NativeReadRecord; - type WriteRecord = NativeWriteRecord; - type Air = NativeAdapterAir; - type Interface = BasicAdapterInterface, R, W, 1, 1>; +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug)] +pub struct NativeAdapterRecord { + pub from_pc: u32, + pub from_timestamp: u32, - fn preprocess( - &mut self, - memory: &mut MemoryController, - instruction: &Instruction, - ) -> Result<( - >::Reads, - Self::ReadRecord, - )> { - assert!(R <= 2); - let Instruction { b, c, e, f, .. } = *instruction; + // These are either a pointer to native memory or an immediate value + pub read_ptr_or_imm: [F; R], + // Will set prev_timestamp to `u32::MAX` if the read is from RV32_IMM_AS + pub reads_aux: [MemoryReadAuxRecord; R], + pub write_ptr: [F; W], + pub writes_aux: [MemoryWriteAuxRecord; W], +} - let mut reads = Vec::with_capacity(R); - if R >= 1 { - reads.push(memory.read::<1>(e, b)); - } - if R >= 2 { - reads.push(memory.read::<1>(f, c)); - } - let i_reads: [_; R] = std::array::from_fn(|i| reads[i].1); +/// R reads(R<=2), W writes(W<=1). +/// Operands: b for the first read, c for the second read, a for the first write. +/// If an operand is not used, its address space and pointer should be all 0. +#[derive(Debug, derive_new::new)] +pub struct NativeAdapterStep { + _phantom: PhantomData, +} + +impl AdapterTraceStep for NativeAdapterStep +where + F: PrimeField32, +{ + const WIDTH: usize = size_of::>(); + type ReadData = [[F; 1]; R]; + type WriteData = [[F; 1]; W]; + type RecordMut<'a> = &'a mut NativeAdapterRecord; - Ok(( - i_reads, - Self::ReadRecord { - reads: reads.try_into().unwrap(), - }, - )) + #[inline(always)] + fn start(pc: u32, memory: &TracingMemory, record: &mut Self::RecordMut<'_>) { + record.from_pc = pc; + record.from_timestamp = memory.timestamp; } - fn postprocess( - &mut self, - memory: &mut MemoryController, + #[inline(always)] + fn read( + &self, + memory: &mut TracingMemory, instruction: &Instruction, - from_state: ExecutionState, - output: AdapterRuntimeContext, - _read_record: &Self::ReadRecord, - ) -> Result<(ExecutionState, Self::WriteRecord)> { - assert!(W <= 1); - let Instruction { a, d, .. } = *instruction; - let mut writes = Vec::with_capacity(W); - if W >= 1 { - let (record_id, _) = memory.write(d, a, output.writes[0]); - writes.push((record_id, output.writes[0])); - } + record: &mut Self::RecordMut<'_>, + ) -> Self::ReadData { + debug_assert!(R <= 2); + let &Instruction { b, c, e, f, .. } = instruction; - Ok(( - ExecutionState { - pc: output.to_pc.unwrap_or(from_state.pc + DEFAULT_PC_STEP), - timestamp: memory.timestamp(), - }, - Self::WriteRecord { - from_state, - writes: writes.try_into().unwrap(), - }, - )) + let mut reads = [[F::ZERO; 1]; R]; + record + .read_ptr_or_imm + .iter_mut() + .enumerate() + .zip(record.reads_aux.iter_mut()) + .for_each(|((i, ptr_or_imm), read_aux)| { + *ptr_or_imm = if i == 0 { b } else { c }; + let addr_space = if i == 0 { e } else { f }; + reads[i][0] = tracing_read_or_imm_native( + memory, + addr_space.as_canonical_u32(), + *ptr_or_imm, + &mut read_aux.prev_timestamp, + ); + }); + reads } - fn generate_trace_row( + #[inline(always)] + fn write( &self, - row_slice: &mut [F], - read_record: Self::ReadRecord, - write_record: Self::WriteRecord, - memory: &OfflineMemory, + memory: &mut TracingMemory, + instruction: &Instruction, + data: Self::WriteData, + record: &mut Self::RecordMut<'_>, ) { - let row_slice: &mut NativeAdapterCols<_, R, W> = row_slice.borrow_mut(); - let aux_cols_factory = memory.aux_cols_factory(); - - row_slice.from_state = write_record.from_state.map(F::from_canonical_u32); + let &Instruction { a, d, .. } = instruction; + debug_assert!(W <= 1); + debug_assert_eq!(d.as_canonical_u32(), NATIVE_AS); - for (i, read) in read_record.reads.iter().enumerate() { - let (id, _) = read; - let record = memory.record_by_id(*id); - aux_cols_factory - .generate_read_or_immediate_aux(record, &mut row_slice.reads_aux[i].read_aux); - row_slice.reads_aux[i].address = - MemoryAddress::new(record.address_space, record.pointer); + if W >= 1 { + record.write_ptr[0] = a; + tracing_write_native( + memory, + a.as_canonical_u32(), + data[0], + &mut record.writes_aux[0].prev_timestamp, + &mut record.writes_aux[0].prev_data, + ); } + } +} - for (i, write) in write_record.writes.iter().enumerate() { - let (id, _) = write; - let record = memory.record_by_id(*id); - aux_cols_factory.generate_write_aux(record, &mut row_slice.writes_aux[i].write_aux); - row_slice.writes_aux[i].address = - MemoryAddress::new(record.address_space, record.pointer); +impl AdapterTraceFiller + for NativeAdapterStep +{ + #[inline(always)] + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, mut adapter_row: &mut [F]) { + let record: &NativeAdapterRecord = + unsafe { get_record_from_slice(&mut adapter_row, ()) }; + let adapter_row: &mut NativeAdapterCols<_, R, W> = adapter_row.borrow_mut(); + // Writing in reverse order to avoid overwriting the `record` + if W >= 1 { + adapter_row.writes_aux[0] + .write_aux + .set_prev_data(record.writes_aux[0].prev_data); + mem_helper.fill( + record.writes_aux[0].prev_timestamp, + record.from_timestamp + R as u32, + adapter_row.writes_aux[0].write_aux.as_mut(), + ); + adapter_row.writes_aux[0].address.pointer = record.write_ptr[0]; + adapter_row.writes_aux[0].address.address_space = F::from_canonical_u32(NATIVE_AS); } - } - fn air(&self) -> &Self::Air { - &self.air + adapter_row + .reads_aux + .iter_mut() + .enumerate() + .zip(record.reads_aux.iter().zip(record.read_ptr_or_imm.iter())) + .rev() + .for_each(|((i, read_cols), (read_record, ptr_or_imm))| { + if read_record.prev_timestamp == u32::MAX { + read_cols.read_aux.is_zero_aux = F::ZERO; + read_cols.read_aux.is_immediate = F::ONE; + mem_helper.fill( + 0, + record.from_timestamp + i as u32, + read_cols.read_aux.as_mut(), + ); + read_cols.address.pointer = *ptr_or_imm; + read_cols.address.address_space = F::from_canonical_u32(RV32_IMM_AS); + } else { + read_cols.read_aux.is_zero_aux = F::from_canonical_u32(NATIVE_AS).inverse(); + read_cols.read_aux.is_immediate = F::ZERO; + mem_helper.fill( + read_record.prev_timestamp, + record.from_timestamp + i as u32, + read_cols.read_aux.as_mut(), + ); + read_cols.address.pointer = *ptr_or_imm; + read_cols.address.address_space = F::from_canonical_u32(NATIVE_AS); + } + }); + + adapter_row.from_state.timestamp = F::from_canonical_u32(record.from_timestamp); + adapter_row.from_state.pc = F::from_canonical_u32(record.from_pc); } } diff --git a/crates/vm/src/system/native_adapter/util.rs b/crates/vm/src/system/native_adapter/util.rs new file mode 100644 index 0000000000..4d3b3d6562 --- /dev/null +++ b/crates/vm/src/system/native_adapter/util.rs @@ -0,0 +1,196 @@ +use openvm_circuit::system::memory::online::TracingMemory; +use openvm_instructions::{riscv::RV32_IMM_AS, NATIVE_AS}; +use openvm_stark_backend::p3_field::PrimeField32; + +use crate::{ + arch::{execution_mode::E1ExecutionCtx, VmStateMut}, + system::memory::{offline_checker::MemoryWriteAuxCols, online::GuestMemory}, +}; + +#[inline(always)] +pub fn memory_read_native(memory: &GuestMemory, ptr: u32) -> [F; N] +where + F: PrimeField32, +{ + // SAFETY: + // - address space `NATIVE_AS` will always have cell type `F` and minimum alignment of `1` + unsafe { memory.read::(NATIVE_AS, ptr) } +} + +#[inline(always)] +pub fn memory_read_or_imm_native(memory: &GuestMemory, addr_space: u32, ptr_or_imm: F) -> F +where + F: PrimeField32, +{ + debug_assert!(addr_space == RV32_IMM_AS || addr_space == NATIVE_AS); + + if addr_space == NATIVE_AS { + let [result]: [F; 1] = memory_read_native(memory, ptr_or_imm.as_canonical_u32()); + result + } else { + ptr_or_imm + } +} + +#[inline(always)] +pub fn memory_write_native(memory: &mut GuestMemory, ptr: u32, data: [F; N]) +where + F: PrimeField32, +{ + // SAFETY: + // - address space `NATIVE_AS` will always have cell type `F` and minimum alignment of `1` + unsafe { memory.write::(NATIVE_AS, ptr, data) } +} + +#[inline(always)] +pub fn memory_read_native_from_state( + state: &mut VmStateMut, + ptr: u32, +) -> [F; N] +where + F: PrimeField32, + Ctx: E1ExecutionCtx, +{ + state.ctx.on_memory_operation(NATIVE_AS, ptr, N as u32); + + memory_read_native(state.memory, ptr) +} + +#[inline(always)] +pub fn memory_read_or_imm_native_from_state( + state: &mut VmStateMut, + addr_space: u32, + ptr_or_imm: F, +) -> F +where + F: PrimeField32, + Ctx: E1ExecutionCtx, +{ + debug_assert!(addr_space == RV32_IMM_AS || addr_space == NATIVE_AS); + + if addr_space == NATIVE_AS { + let [result]: [F; 1] = memory_read_native_from_state(state, ptr_or_imm.as_canonical_u32()); + result + } else { + ptr_or_imm + } +} + +#[inline(always)] +pub fn memory_write_native_from_state( + state: &mut VmStateMut, + ptr: u32, + data: [F; N], +) where + F: PrimeField32, + Ctx: E1ExecutionCtx, +{ + state.ctx.on_memory_operation(NATIVE_AS, ptr, N as u32); + + memory_write_native(state.memory, ptr, data) +} + +/// Atomic read operation which increments the timestamp by 1. +/// Returns `(t_prev, [ptr:BLOCK_SIZE]_4)` where `t_prev` is the timestamp of the last memory +/// access. +#[inline(always)] +pub fn timed_read_native( + memory: &mut TracingMemory, + ptr: u32, +) -> (u32, [F; BLOCK_SIZE]) +where + F: PrimeField32, +{ + // SAFETY: + // - address space `Native` will always have cell type `F` and minimum alignment of `1` + unsafe { memory.read::(NATIVE_AS, ptr) } +} + +#[inline(always)] +pub fn timed_write_native( + memory: &mut TracingMemory, + ptr: u32, + vals: [F; BLOCK_SIZE], +) -> (u32, [F; BLOCK_SIZE]) +where + F: PrimeField32, +{ + // SAFETY: + // - address space `Native` will always have cell type `F` and minimum alignment of `1` + unsafe { memory.write::(NATIVE_AS, ptr, vals) } +} + +/// Reads register value at `ptr` from memory and records the previous timestamp. +#[inline(always)] +pub fn tracing_read_native( + memory: &mut TracingMemory, + ptr: u32, + prev_timestamp: &mut u32, +) -> [F; BLOCK_SIZE] +where + F: PrimeField32, +{ + let (t_prev, data) = timed_read_native(memory, ptr); + *prev_timestamp = t_prev; + data +} + +/// Writes `ptr, vals` into memory and records the previous timestamp and data. +#[inline(always)] +pub fn tracing_write_native( + memory: &mut TracingMemory, + ptr: u32, + vals: [F; BLOCK_SIZE], + prev_timestamp: &mut u32, + prev_data: &mut [F; BLOCK_SIZE], +) where + F: PrimeField32, +{ + let (t_prev, data_prev) = timed_write_native(memory, ptr, vals); + *prev_timestamp = t_prev; + *prev_data = data_prev; +} + +/// Writes `ptr, vals` into memory and records the previous timestamp and data. +#[inline(always)] +pub fn tracing_write_native_inplace( + memory: &mut TracingMemory, + ptr: u32, + vals: [F; BLOCK_SIZE], + cols: &mut MemoryWriteAuxCols, +) where + F: PrimeField32, +{ + let (t_prev, data_prev) = timed_write_native(memory, ptr, vals); + cols.base.set_prev(F::from_canonical_u32(t_prev)); + cols.prev_data = data_prev; +} + +/// Reads value at `_ptr` from memory and records the previous timestamp. +/// If the read is an immediate, the previous timestamp will be set to `u32::MAX`. +#[inline(always)] +pub fn tracing_read_or_imm_native( + memory: &mut TracingMemory, + addr_space: u32, + ptr_or_imm: F, + prev_timestamp: &mut u32, +) -> F +where + F: PrimeField32, +{ + debug_assert!( + addr_space == RV32_IMM_AS || addr_space == NATIVE_AS, + "addr_space={} is not valid", + addr_space + ); + + if addr_space == RV32_IMM_AS { + *prev_timestamp = u32::MAX; + memory.increment_timestamp(); + ptr_or_imm + } else { + let data: [F; 1] = + tracing_read_native(memory, ptr_or_imm.as_canonical_u32(), prev_timestamp); + data[0] + } +} diff --git a/crates/vm/src/system/phantom/execution.rs b/crates/vm/src/system/phantom/execution.rs new file mode 100644 index 0000000000..d8dd1aa067 --- /dev/null +++ b/crates/vm/src/system/phantom/execution.rs @@ -0,0 +1,193 @@ +use std::borrow::{Borrow, BorrowMut}; + +use openvm_circuit_primitives_derive::AlignedBytesBorrow; +use openvm_instructions::{ + instruction::Instruction, program::DEFAULT_PC_STEP, PhantomDiscriminant, SysPhantom, +}; +use openvm_stark_backend::p3_field::PrimeField32; +use rand::rngs::StdRng; + +use crate::{ + arch::{ + execution_mode::{E1ExecutionCtx, E2ExecutionCtx}, + E2PreCompute, ExecuteFunc, ExecutionError, InsExecutorE1, InsExecutorE2, + PhantomSubExecutor, Streams, VmSegmentState, + }, + system::{memory::online::GuestMemory, phantom::PhantomChip}, +}; + +#[derive(Clone, AlignedBytesBorrow)] +#[repr(C)] +pub(super) struct PhantomOperands { + pub(super) a: u32, + pub(super) b: u32, + pub(super) c: u32, +} + +#[derive(Clone, AlignedBytesBorrow)] +#[repr(C)] +struct PhantomPreCompute { + operands: PhantomOperands, + sub_executor: *const Box>, +} + +impl InsExecutorE1 for PhantomChip +where + F: PrimeField32, +{ + #[inline(always)] + fn pre_compute_size(&self) -> usize { + size_of::>() + } + #[inline(always)] + fn pre_compute_e1( + &self, + _pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> crate::arch::Result> + where + Ctx: E1ExecutionCtx, + { + let data: &mut PhantomPreCompute = data.borrow_mut(); + self.pre_compute_impl(inst, data); + Ok(execute_e1_impl) + } + + fn set_trace_height(&mut self, _height: usize) {} +} + +pub(super) struct PhantomStateMut<'a, F> { + pub(super) pc: &'a mut u32, + pub(super) memory: &'a mut GuestMemory, + pub(super) streams: &'a mut Streams, + pub(super) rng: &'a mut StdRng, +} + +#[inline(always)] +unsafe fn execute_e12_impl( + pre_compute: &PhantomPreCompute, + vm_state: &mut VmSegmentState, +) { + let sub_executor = &*pre_compute.sub_executor; + if let Err(e) = execute_impl( + PhantomStateMut { + pc: &mut vm_state.pc, + memory: &mut vm_state.memory, + streams: &mut vm_state.streams, + rng: &mut vm_state.rng, + }, + &pre_compute.operands, + sub_executor.as_ref(), + ) { + vm_state.exit_code = Err(e); + return; + } + vm_state.pc += DEFAULT_PC_STEP; + vm_state.instret += 1; +} + +#[inline(always)] +unsafe fn execute_e1_impl( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &PhantomPreCompute = pre_compute.borrow(); + execute_e12_impl(pre_compute, vm_state); +} + +#[inline(always)] +unsafe fn execute_e2_impl( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &E2PreCompute> = pre_compute.borrow(); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, 1); + execute_e12_impl(&pre_compute.data, vm_state); +} + +#[inline(always)] +pub(super) fn execute_impl( + state: PhantomStateMut, + operands: &PhantomOperands, + sub_executor: &dyn PhantomSubExecutor, +) -> Result<(), ExecutionError> +where + F: PrimeField32, +{ + let &PhantomOperands { a, b, c } = operands; + + let discriminant = PhantomDiscriminant(c as u16); + // If not a system phantom sub-instruction (which is handled in + // ExecutionSegment), look for a phantom sub-executor to handle it. + if let Some(discr) = SysPhantom::from_repr(discriminant.0) { + if discr == SysPhantom::DebugPanic { + return Err(ExecutionError::Fail { pc: *state.pc }); + } + } + sub_executor + .phantom_execute( + state.memory, + state.streams, + state.rng, + discriminant, + a, + b, + (c >> 16) as u16, + ) + .map_err(|e| ExecutionError::Phantom { + pc: *state.pc, + discriminant, + inner: e, + })?; + + Ok(()) +} + +impl PhantomChip +where + F: PrimeField32, +{ + #[inline(always)] + fn pre_compute_impl(&self, inst: &Instruction, data: &mut PhantomPreCompute) { + let c = inst.c.as_canonical_u32(); + *data = PhantomPreCompute { + operands: PhantomOperands { + a: inst.a.as_canonical_u32(), + b: inst.b.as_canonical_u32(), + c, + }, + sub_executor: self + .phantom_executors + .get(&PhantomDiscriminant(c as u16)) + .unwrap(), + }; + } +} + +impl InsExecutorE2 for PhantomChip +where + F: PrimeField32, +{ + fn e2_pre_compute_size(&self) -> usize { + size_of::>>() + } + + fn pre_compute_e2( + &self, + chip_idx: usize, + _pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> crate::arch::Result> + where + Ctx: E2ExecutionCtx, + { + let e2_data: &mut E2PreCompute> = data.borrow_mut(); + e2_data.chip_idx = chip_idx as u32; + self.pre_compute_impl(inst, &mut e2_data.data); + Ok(execute_e2_impl) + } +} diff --git a/crates/vm/src/system/phantom/mod.rs b/crates/vm/src/system/phantom/mod.rs index 28977fe2cd..32708cce0d 100644 --- a/crates/vm/src/system/phantom/mod.rs +++ b/crates/vm/src/system/phantom/mod.rs @@ -1,6 +1,6 @@ use std::{ borrow::{Borrow, BorrowMut}, - sync::{Arc, Mutex, OnceLock}, + sync::Arc, }; use openvm_circuit_primitives_derive::AlignedBorrow; @@ -19,6 +19,7 @@ use openvm_stark_backend::{ rap::{get_air_name, BaseAirWithPublicValues, PartitionedBaseAir}, AirRef, Chip, ChipUsageGetter, }; +use rand::rngs::StdRng; use rustc_hash::FxHashMap; use serde::{Deserialize, Serialize}; use serde_big_array::BigArray; @@ -29,9 +30,14 @@ use crate::{ ExecutionBridge, ExecutionBus, ExecutionError, ExecutionState, InstructionExecutor, PcIncOrSet, PhantomSubExecutor, Streams, }, - system::program::ProgramBus, + system::{ + memory::online::GuestMemory, + phantom::execution::{execute_impl, PhantomOperands, PhantomStateMut}, + program::ProgramBus, + }, }; +mod execution; #[cfg(test)] mod tests; @@ -91,7 +97,6 @@ impl Air for PhantomAir { pub struct PhantomChip { pub air: PhantomAir, pub rows: Vec>, - streams: OnceLock>>>, phantom_executors: FxHashMap>>, } @@ -103,17 +108,10 @@ impl PhantomChip { phantom_opcode: VmOpcode::from_usize(offset + SystemOpcode::PHANTOM.local_usize()), }, rows: vec![], - streams: OnceLock::new(), phantom_executors: FxHashMap::default(), } } - pub fn set_streams(&mut self, streams: Arc>>) { - if self.streams.set(streams).is_err() { - panic!("Streams should only be set once"); - } - } - pub(crate) fn add_sub_executor + 'static>( &mut self, sub_executor: P, @@ -128,55 +126,47 @@ impl InstructionExecutor for PhantomChip { fn execute( &mut self, memory: &mut MemoryController, + streams: &mut Streams, + rng: &mut StdRng, instruction: &Instruction, from_state: ExecutionState, ) -> Result, ExecutionError> { - let &Instruction { - opcode, a, b, c, .. - } = instruction; - assert_eq!(opcode, self.air.phantom_opcode); - - let c_u32 = c.as_canonical_u32(); - let discriminant = PhantomDiscriminant(c_u32 as u16); - // If not a system phantom sub-instruction (which is handled in - // ExecutionSegment), look for a phantom sub-executor to handle it. - if SysPhantom::from_repr(discriminant.0).is_none() { - let sub_executor = self - .phantom_executors - .get_mut(&discriminant) - .ok_or_else(|| ExecutionError::PhantomNotFound { - pc: from_state.pc, - discriminant, - })?; - let mut streams = self.streams.get().unwrap().lock().unwrap(); - sub_executor - .as_mut() - .phantom_execute( - memory, - &mut streams, - discriminant, - a, - b, - (c_u32 >> 16) as u16, - ) - .map_err(|e| ExecutionError::Phantom { - pc: from_state.pc, - discriminant, - inner: e, - })?; - } - + let mut pc = from_state.pc; self.rows.push(PhantomCols { - pc: F::from_canonical_u32(from_state.pc), - operands: [a, b, c], - timestamp: F::from_canonical_u32(from_state.timestamp), + pc: F::from_canonical_u32(pc), + operands: [instruction.a, instruction.b, instruction.c], + timestamp: F::from_canonical_u32(memory.memory.timestamp), is_valid: F::ONE, }); + + let c_u32 = instruction.c.as_canonical_u32() as u16; + if SysPhantom::from_repr(c_u32).is_none() { + let sub_executor = self + .phantom_executors + .get(&PhantomDiscriminant(c_u32)) + .unwrap(); + execute_impl( + PhantomStateMut { + pc: &mut pc, + memory: &mut memory.memory.data, + streams, + rng, + }, + &PhantomOperands { + a: instruction.a.as_canonical_u32(), + b: instruction.b.as_canonical_u32(), + c: instruction.c.as_canonical_u32(), + }, + sub_executor.as_ref(), + )?; + } + pc += DEFAULT_PC_STEP; memory.increment_timestamp(); - Ok(ExecutionState::new( - from_state.pc + DEFAULT_PC_STEP, - from_state.timestamp + 1, - )) + + Ok(ExecutionState { + pc, + timestamp: memory.memory.timestamp, + }) } fn get_opcode_name(&self, _: usize) -> String { @@ -222,3 +212,57 @@ where AirProofInput::simple(trace, vec![]) } } + +pub struct NopPhantomExecutor; +pub struct CycleStartPhantomExecutor; +pub struct CycleEndPhantomExecutor; + +impl PhantomSubExecutor for NopPhantomExecutor { + #[inline(always)] + fn phantom_execute( + &self, + _memory: &GuestMemory, + _streams: &mut Streams, + _rng: &mut StdRng, + _discriminant: PhantomDiscriminant, + _a: u32, + _b: u32, + _c_upper: u16, + ) -> eyre::Result<()> { + Ok(()) + } +} + +impl PhantomSubExecutor for CycleStartPhantomExecutor { + #[inline(always)] + fn phantom_execute( + &self, + _memory: &GuestMemory, + _streams: &mut Streams, + _rng: &mut StdRng, + _discriminant: PhantomDiscriminant, + _a: u32, + _b: u32, + _c_upper: u16, + ) -> eyre::Result<()> { + // TODO: implement cycle tracker for E1/E2 + Ok(()) + } +} + +impl PhantomSubExecutor for CycleEndPhantomExecutor { + #[inline(always)] + fn phantom_execute( + &self, + _memory: &GuestMemory, + _streams: &mut Streams, + _rng: &mut StdRng, + _discriminant: PhantomDiscriminant, + _a: u32, + _b: u32, + _c_upper: u16, + ) -> eyre::Result<()> { + // TODO: implement cycle tracker for E1/E2 + Ok(()) + } +} diff --git a/crates/vm/src/system/phantom/tests.rs b/crates/vm/src/system/phantom/tests.rs index 7a0b068d36..e2aef1119e 100644 --- a/crates/vm/src/system/phantom/tests.rs +++ b/crates/vm/src/system/phantom/tests.rs @@ -1,5 +1,3 @@ -use std::sync::{Arc, Mutex}; - use openvm_instructions::{instruction::Instruction, SystemOpcode}; use openvm_stark_backend::p3_field::{FieldAlgebra, PrimeField32}; use openvm_stark_sdk::p3_baby_bear::BabyBear; @@ -16,7 +14,6 @@ fn test_nops_and_terminate() { tester.program_bus(), SystemOpcode::CLASS_OFFSET, ); - chip.set_streams(Arc::new(Mutex::new(Default::default()))); let nop = Instruction::from_isize(SystemOpcode::PHANTOM.global_opcode(), 0, 0, 0, 0, 0); let mut state: ExecutionState = ExecutionState::new(F::ZERO, F::ONE); diff --git a/crates/vm/src/system/poseidon2/chip.rs b/crates/vm/src/system/poseidon2/chip.rs index e0059f1ce1..a4fac35cd1 100644 --- a/crates/vm/src/system/poseidon2/chip.rs +++ b/crates/vm/src/system/poseidon2/chip.rs @@ -1,14 +1,18 @@ use std::{ array, - sync::{atomic::AtomicU32, Arc}, + sync::{ + atomic::{AtomicBool, AtomicU32}, + Arc, + }, }; +use dashmap::DashMap; use openvm_poseidon2_air::{Poseidon2Config, Poseidon2SubChip}; use openvm_stark_backend::{ interaction::{BusIndex, LookupBus}, p3_field::PrimeField32, }; -use rustc_hash::FxHashMap; +use rustc_hash::FxBuildHasher; use super::{ air::Poseidon2PeripheryAir, PERIPHERY_POSEIDON2_CHUNK_SIZE, PERIPHERY_POSEIDON2_WIDTH, @@ -19,7 +23,8 @@ use crate::arch::hasher::{Hasher, HasherChip}; pub struct Poseidon2PeripheryBaseChip { pub air: Arc>, pub subchip: Poseidon2SubChip, - pub records: FxHashMap<[F; PERIPHERY_POSEIDON2_WIDTH], AtomicU32>, + pub records: DashMap<[F; PERIPHERY_POSEIDON2_WIDTH], AtomicU32, FxBuildHasher>, + pub nonempty: AtomicBool, } impl Poseidon2PeripheryBaseChip { @@ -31,7 +36,8 @@ impl Poseidon2PeripheryBaseChip HasherChip [F; PERIPHERY_POSEIDON2_CHUNK_SIZE] { @@ -73,6 +79,8 @@ impl HasherChip Hasher for Poseidon2Per impl HasherChip for Poseidon2PeripheryChip { fn compress_and_record( - &mut self, + &self, lhs: &[F; PERIPHERY_POSEIDON2_CHUNK_SIZE], rhs: &[F; PERIPHERY_POSEIDON2_CHUNK_SIZE], ) -> [F; PERIPHERY_POSEIDON2_CHUNK_SIZE] { diff --git a/crates/vm/src/system/poseidon2/tests.rs b/crates/vm/src/system/poseidon2/tests.rs index 095c8acba4..2f620847e4 100644 --- a/crates/vm/src/system/poseidon2/tests.rs +++ b/crates/vm/src/system/poseidon2/tests.rs @@ -32,7 +32,7 @@ fn poseidon2_periphery_direct_test() { ) }); - let mut chip = Poseidon2PeripheryChip::::new( + let chip = Poseidon2PeripheryChip::::new( Poseidon2Config::default(), POSEIDON2_DIRECT_BUS, 3, @@ -86,7 +86,7 @@ fn poseidon2_periphery_duplicate_hashes_test() { }); let counts: [u32; NUM_OPS] = std::array::from_fn(|_| rng.next_u32() % 20); - let mut chip = Poseidon2PeripheryChip::::new( + let chip = Poseidon2PeripheryChip::::new( Poseidon2Config::default(), POSEIDON2_DIRECT_BUS, 3, diff --git a/crates/vm/src/system/poseidon2/trace.rs b/crates/vm/src/system/poseidon2/trace.rs index 2b6f3e6b0b..4a929b8d06 100644 --- a/crates/vm/src/system/poseidon2/trace.rs +++ b/crates/vm/src/system/poseidon2/trace.rs @@ -8,7 +8,6 @@ use openvm_stark_backend::{ p3_matrix::dense::RowMajorMatrix, p3_maybe_rayon::prelude::*, prover::types::AirProofInput, - rap::get_air_name, AirRef, Chip, ChipUsageGetter, }; @@ -29,9 +28,11 @@ where let mut inputs = Vec::with_capacity(height); let mut multiplicities = Vec::with_capacity(height); - let (actual_inputs, actual_multiplicities): (Vec<_>, Vec<_>) = self - .records - .into_par_iter() + #[cfg(feature = "parallel")] + let records_iter = self.records.into_par_iter(); + #[cfg(not(feature = "parallel"))] + let records_iter = self.records.into_iter(); + let (actual_inputs, actual_multiplicities): (Vec<_>, Vec<_>) = records_iter .map(|(input, mult)| (input, mult.load(std::sync::atomic::Ordering::Relaxed))) .unzip(); inputs.extend(actual_inputs); @@ -63,11 +64,16 @@ impl ChipUsageGetter for Poseidon2PeripheryBaseChip { fn air_name(&self) -> String { - get_air_name(&self.air) + format!("Poseidon2PeripheryAir", SBOX_REGISTERS) } fn current_trace_height(&self) -> usize { - self.records.len() + if self.nonempty.load(std::sync::atomic::Ordering::Relaxed) { + // Not to call `DashMap::len` too often + self.records.len() + } else { + 0 + } } fn trace_width(&self) -> usize { diff --git a/crates/vm/src/system/program/trace.rs b/crates/vm/src/system/program/trace.rs index d9e2abd956..c168a478f0 100644 --- a/crates/vm/src/system/program/trace.rs +++ b/crates/vm/src/system/program/trace.rs @@ -23,7 +23,7 @@ use crate::{ hasher::{poseidon2::vm_poseidon2_hasher, Hasher}, MemoryConfig, }, - system::memory::{tree::MemoryNode, AddressMap, CHUNK}, + system::memory::{merkle::MerkleTree, AddressMap, CHUNK}, }; #[derive(Serialize, Deserialize, Derivative)] @@ -82,17 +82,12 @@ where let memory_dimensions = memory_config.memory_dimensions(); let app_program_commit: &[Val; CHUNK] = self.committed_program.commitment.as_ref(); let mem_config = memory_config; - let init_memory_commit = MemoryNode::tree_from_memory( - memory_dimensions, - &AddressMap::from_iter( - mem_config.as_offset, - 1 << mem_config.as_height, - 1 << mem_config.pointer_max_bits, - self.exe.init_memory.clone(), - ), - &hasher, - ) - .hash(); + let memory_image = AddressMap::from_sparse( + mem_config.addr_space_sizes.clone(), + self.exe.init_memory.clone(), + ); + let init_memory_commit = + MerkleTree::from_memory(&memory_image, &memory_dimensions, &hasher).root(); Com::::from(compute_exe_commit( &hasher, app_program_commit, diff --git a/crates/vm/src/system/public_values/core.rs b/crates/vm/src/system/public_values/core.rs index de189f101b..6a86344e17 100644 --- a/crates/vm/src/system/public_values/core.rs +++ b/crates/vm/src/system/public_values/core.rs @@ -1,8 +1,16 @@ -use std::sync::Mutex; +use std::{ + borrow::{Borrow, BorrowMut}, + sync::Mutex, +}; -use openvm_circuit_primitives::{encoder::Encoder, SubAir}; +use openvm_circuit_primitives::{encoder::Encoder, AlignedBytesBorrow, SubAir}; use openvm_instructions::{ - instruction::Instruction, LocalOpcode, PublishOpcode, PublishOpcode::PUBLISH, + instruction::Instruction, + program::DEFAULT_PC_STEP, + riscv::RV32_IMM_AS, + LocalOpcode, + PublishOpcode::{self, PUBLISH}, + NATIVE_AS, }; use openvm_stark_backend::{ interaction::InteractionBuilder, @@ -10,17 +18,23 @@ use openvm_stark_backend::{ p3_field::{Field, FieldAlgebra, PrimeField32}, rap::BaseAirWithPublicValues, }; -use serde::{Deserialize, Serialize}; use crate::{ arch::{ - AdapterAirContext, AdapterRuntimeContext, BasicAdapterInterface, MinimalInstruction, - Result, VmAdapterInterface, VmCoreAir, VmCoreChip, + execution_mode::{E1ExecutionCtx, E2ExecutionCtx}, + get_record_from_slice, AdapterAirContext, AdapterTraceFiller, AdapterTraceStep, + BasicAdapterInterface, E2PreCompute, EmptyAdapterCoreLayout, ExecuteFunc, + MinimalInstruction, RecordArena, Result, StepExecutorE1, StepExecutorE2, TraceFiller, + TraceStep, VmCoreAir, VmSegmentState, VmStateMut, + }, + system::{ + memory::{online::TracingMemory, MemoryAuxColsFactory}, + public_values::columns::PublicValuesCoreColsView, }, - system::public_values::columns::PublicValuesCoreColsView, + utils::{transmute_field_to_u32, transmute_u32_to_field}, }; + pub(crate) type AdapterInterface = BasicAdapterInterface, 2, 0, 1, 1>; -pub(crate) type AdapterInterfaceReads = as VmAdapterInterface>::Reads; #[derive(Clone, Debug)] pub struct PublicValuesCoreAir { @@ -99,27 +113,32 @@ impl VmCoreAir { - value: F, - index: F, + pub value: F, + pub index: F, } /// ATTENTION: If a specific public value is not provided, a default 0 will be used when generating /// the proof but in the perspective of constraints, it could be any value. -pub struct PublicValuesCoreChip { - air: PublicValuesCoreAir, +pub struct PublicValuesCoreStep { + adapter: A, + encoder: Encoder, // Mutex is to make the struct Sync. But it actually won't be accessed by multiple threads. - custom_pvs: Mutex>>, + pub(crate) custom_pvs: Mutex>>, } -impl PublicValuesCoreChip { +impl PublicValuesCoreStep +where + F: PrimeField32, +{ /// **Note:** `max_degree` is the maximum degree of the constraint polynomials to represent the /// flags. If you want the overall AIR's constraint degree to be `<= max_constraint_degree`, /// then typically you should set `max_degree` to `max_constraint_degree - 1`. - pub fn new(num_custom_pvs: usize, max_degree: u32) -> Self { + pub fn new(adapter: A, num_custom_pvs: usize, max_degree: u32) -> Self { Self { - air: PublicValuesCoreAir::new(num_custom_pvs, max_degree), + adapter, + encoder: Encoder::new(num_custom_pvs, max_degree, true), custom_pvs: Mutex::new(vec![None; num_custom_pvs]), } } @@ -128,56 +147,53 @@ impl PublicValuesCoreChip { } } -impl VmCoreChip> for PublicValuesCoreChip { - type Record = PublicValuesRecord; - type Air = PublicValuesCoreAir; +impl TraceStep for PublicValuesCoreStep +where + F: PrimeField32, + A: 'static + AdapterTraceStep, +{ + type RecordLayout = EmptyAdapterCoreLayout; + type RecordMut<'a> = (A::RecordMut<'a>, &'a mut PublicValuesRecord); - #[allow(clippy::type_complexity)] - fn execute_instruction( - &self, - _instruction: &Instruction, - _from_pc: u32, - reads: AdapterInterfaceReads, - ) -> Result<(AdapterRuntimeContext>, Self::Record)> { - let [[value], [index]] = reads; + fn get_opcode_name(&self, opcode: usize) -> String { + format!( + "{:?}", + PublishOpcode::from_usize(opcode - PublishOpcode::CLASS_OFFSET) + ) + } + + fn execute<'buf, RA>( + &mut self, + state: VmStateMut, CTX>, + instruction: &Instruction, + arena: &'buf mut RA, + ) -> Result<()> + where + RA: RecordArena<'buf, Self::RecordLayout, Self::RecordMut<'buf>>, + { + let (mut adapter_record, core_record) = arena.alloc(EmptyAdapterCoreLayout::new()); + + A::start(*state.pc, state.memory, &mut adapter_record); + + [[core_record.value], [core_record.index]] = + self.adapter + .read(state.memory, instruction, &mut adapter_record); { - let idx: usize = index.as_canonical_u32() as usize; + let idx: usize = core_record.index.as_canonical_u32() as usize; let mut custom_pvs = self.custom_pvs.lock().unwrap(); if custom_pvs[idx].is_none() { - custom_pvs[idx] = Some(value); + custom_pvs[idx] = Some(core_record.value); } else { // Not a hard constraint violation when publishing the same value twice but the // program should avoid that. panic!("Custom public value {} already set", idx); } } - let output = AdapterRuntimeContext { - to_pc: None, - writes: [], - }; - let record = Self::Record { value, index }; - Ok((output, record)) - } - fn get_opcode_name(&self, opcode: usize) -> String { - format!( - "{:?}", - PublishOpcode::from_usize(opcode - PublishOpcode::CLASS_OFFSET) - ) - } + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); - fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) { - let mut cols = PublicValuesCoreColsView::<_, &mut F>::borrow_mut(row_slice); - debug_assert_eq!(cols.width(), BaseAir::::width(&self.air)); - *cols.is_valid = F::ONE; - *cols.value = record.value; - *cols.index = record.index; - let idx: usize = record.index.as_canonical_u32() as usize; - let pt = self.air.encoder.get_flag_pt(idx); - for (i, var) in cols.custom_pv_vars.iter_mut().enumerate() { - **var = F::from_canonical_u32(pt[i]); - } + Ok(()) } fn generate_public_values(&self) -> Vec { @@ -186,8 +202,199 @@ impl VmCoreChip> for PublicValuesCoreChi .map(|x| x.unwrap_or(F::ZERO)) .collect() } +} + +impl TraceFiller for PublicValuesCoreStep +where + F: PrimeField32, + A: 'static + AdapterTraceFiller, +{ + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + self.adapter.fill_trace_row(mem_helper, adapter_row); + let record: &PublicValuesRecord = unsafe { get_record_from_slice(&mut core_row, ()) }; + let cols = PublicValuesCoreColsView::<_, &mut F>::borrow_mut(core_row); + + let idx: usize = record.index.as_canonical_u32() as usize; + let pt = self.encoder.get_flag_pt(idx); + + cols.custom_pv_vars + .into_iter() + .zip(pt.iter()) + .for_each(|(var, &val)| { + *var = F::from_canonical_u32(val); + }); + + *cols.index = record.index; + *cols.value = record.value; + *cols.is_valid = F::ONE; + } +} + +#[derive(AlignedBytesBorrow)] +#[repr(C)] +struct PublicValuesPreCompute { + b_or_imm: u32, + c_or_imm: u32, + pvs: *const Mutex>>, +} + +impl StepExecutorE1 for PublicValuesCoreStep +where + F: PrimeField32, +{ + #[inline(always)] + fn pre_compute_size(&self) -> usize { + size_of::>() + } + + #[inline(always)] + fn pre_compute_e1( + &self, + _pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result> + where + Ctx: E1ExecutionCtx, + { + let data: &mut PublicValuesPreCompute = data.borrow_mut(); + let (b_is_imm, c_is_imm) = self.pre_compute_impl(inst, data); + + let fn_ptr = match (b_is_imm, c_is_imm) { + (true, true) => execute_e1_impl::<_, _, true, true>, + (true, false) => execute_e1_impl::<_, _, true, false>, + (false, true) => execute_e1_impl::<_, _, false, true>, + (false, false) => execute_e1_impl::<_, _, false, false>, + }; + Ok(fn_ptr) + } +} + +impl StepExecutorE2 for PublicValuesCoreStep +where + F: PrimeField32, +{ + fn e2_pre_compute_size(&self) -> usize { + size_of::>>() + } + + fn pre_compute_e2( + &self, + chip_idx: usize, + _pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result> + where + Ctx: E2ExecutionCtx, + { + let data: &mut E2PreCompute> = data.borrow_mut(); + data.chip_idx = chip_idx as u32; + let (b_is_imm, c_is_imm) = self.pre_compute_impl(inst, &mut data.data); + + let fn_ptr = match (b_is_imm, c_is_imm) { + (true, true) => execute_e2_impl::<_, _, true, true>, + (true, false) => execute_e2_impl::<_, _, true, false>, + (false, true) => execute_e2_impl::<_, _, false, true>, + (false, false) => execute_e2_impl::<_, _, false, false>, + }; + Ok(fn_ptr) + } +} + +#[inline(always)] +unsafe fn execute_e1_impl( + pre_compute: &[u8], + state: &mut VmSegmentState, +) where + CTX: E1ExecutionCtx, +{ + let pre_compute: &PublicValuesPreCompute = pre_compute.borrow(); + execute_e12_impl::<_, _, B_IS_IMM, C_IS_IMM>(pre_compute, state); +} + +#[inline(always)] +unsafe fn execute_e2_impl( + pre_compute: &[u8], + state: &mut VmSegmentState, +) where + CTX: E2ExecutionCtx, +{ + let pre_compute: &E2PreCompute> = pre_compute.borrow(); + state.ctx.on_height_change(pre_compute.chip_idx as usize, 1); + execute_e12_impl::<_, _, B_IS_IMM, C_IS_IMM>(&pre_compute.data, state); +} + +#[inline(always)] +unsafe fn execute_e12_impl( + pre_compute: &PublicValuesPreCompute, + state: &mut VmSegmentState, +) where + CTX: E1ExecutionCtx, +{ + let value = if B_IS_IMM { + transmute_u32_to_field(&pre_compute.b_or_imm) + } else { + state.vm_read::(NATIVE_AS, pre_compute.b_or_imm)[0] + }; + let index = if C_IS_IMM { + transmute_u32_to_field(&pre_compute.c_or_imm) + } else { + state.vm_read::(NATIVE_AS, pre_compute.c_or_imm)[0] + }; + + let idx: usize = index.as_canonical_u32() as usize; + { + let custom_pvs = unsafe { &*pre_compute.pvs }; + let mut custom_pvs = custom_pvs.lock().unwrap(); + + if custom_pvs[idx].is_none() { + custom_pvs[idx] = Some(value); + } else { + // Not a hard constraint violation when publishing the same value twice but the + // program should avoid that. + panic!("Custom public value {} already set", idx); + } + } + state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + state.instret += 1; +} + +impl PublicValuesCoreStep +where + F: PrimeField32, +{ + fn pre_compute_impl( + &self, + inst: &Instruction, + data: &mut PublicValuesPreCompute, + ) -> (bool, bool) { + let &Instruction { b, c, e, f, .. } = inst; + + let e = e.as_canonical_u32(); + let f = f.as_canonical_u32(); + + let b_is_imm = e == RV32_IMM_AS; + let c_is_imm = f == RV32_IMM_AS; + + let b_or_imm = if b_is_imm { + transmute_field_to_u32(&b) + } else { + b.as_canonical_u32() + }; + let c_or_imm = if c_is_imm { + transmute_field_to_u32(&c) + } else { + c.as_canonical_u32() + }; + + *data = PublicValuesPreCompute { + b_or_imm, + c_or_imm, + pvs: &self.custom_pvs, + }; - fn air(&self) -> &Self::Air { - &self.air + (b_is_imm, c_is_imm) } } diff --git a/crates/vm/src/system/public_values/mod.rs b/crates/vm/src/system/public_values/mod.rs index 918606497b..1712fd9c62 100644 --- a/crates/vm/src/system/public_values/mod.rs +++ b/crates/vm/src/system/public_values/mod.rs @@ -1,8 +1,10 @@ +use core::PublicValuesCoreStep; + use crate::{ - arch::{VmAirWrapper, VmChipWrapper}, + arch::{MatrixRecordArena, NewVmChipWrapper, VmAirWrapper}, system::{ - native_adapter::{NativeAdapterAir, NativeAdapterChip}, - public_values::core::{PublicValuesCoreAir, PublicValuesCoreChip}, + native_adapter::{NativeAdapterAir, NativeAdapterStep}, + public_values::core::PublicValuesCoreAir, }, }; @@ -14,5 +16,6 @@ pub mod core; mod tests; pub type PublicValuesAir = VmAirWrapper, PublicValuesCoreAir>; +pub type PublicValuesStepWithAdapter = PublicValuesCoreStep, F>; pub type PublicValuesChip = - VmChipWrapper, PublicValuesCoreChip>; + NewVmChipWrapper, MatrixRecordArena>; diff --git a/crates/vm/src/utils/mod.rs b/crates/vm/src/utils/mod.rs index 7b4823c53a..a9c8486852 100644 --- a/crates/vm/src/utils/mod.rs +++ b/crates/vm/src/utils/mod.rs @@ -1,10 +1,47 @@ #[cfg(any(test, feature = "test-utils"))] mod stark_utils; #[cfg(any(test, feature = "test-utils"))] -mod test_utils; +pub mod test_utils; pub use openvm_circuit_primitives::utils::next_power_of_two_or_zero; +use openvm_stark_backend::p3_field::PrimeField32; #[cfg(any(test, feature = "test-utils"))] pub use stark_utils::*; #[cfg(any(test, feature = "test-utils"))] pub use test_utils::*; + +#[inline(always)] +pub fn transmute_field_to_u32(field: &F) -> u32 { + debug_assert_eq!( + std::mem::size_of::(), + std::mem::size_of::(), + "Field type F must have the same size as u32" + ); + debug_assert_eq!( + std::mem::align_of::(), + std::mem::align_of::(), + "Field type F must have the same alignment as u32" + ); + // SAFETY: This assumes that F has the same memory layout as u32. + // This is only safe for field types that are guaranteed to be represented + // as a single u32 internally + unsafe { *(field as *const F as *const u32) } +} + +#[inline(always)] +pub fn transmute_u32_to_field(value: &u32) -> F { + debug_assert_eq!( + std::mem::size_of::(), + std::mem::size_of::(), + "Field type F must have the same size as u32" + ); + debug_assert_eq!( + std::mem::align_of::(), + std::mem::align_of::(), + "Field type F must have the same alignment as u32" + ); + // SAFETY: This assumes that F has the same memory layout as u32. + // This is only safe for field types that are guaranteed to be represented + // as a single u32 internally + unsafe { *(value as *const u32 as *const F) } +} diff --git a/crates/vm/src/utils/stark_utils.rs b/crates/vm/src/utils/stark_utils.rs index d940be5c75..ef6410f9d6 100644 --- a/crates/vm/src/utils/stark_utils.rs +++ b/crates/vm/src/utils/stark_utils.rs @@ -16,15 +16,25 @@ use openvm_stark_sdk::{ utils::ProofInputForTest, }; -use crate::arch::{ - vm::{VirtualMachine, VmExecutor}, - Streams, VmConfig, VmMemoryState, +#[cfg(feature = "bench-metrics")] +use crate::arch::vm::VmExecutor; +use crate::{ + arch::{ + execution_mode::{ + e1::E1Ctx, + metered::{ctx::DEFAULT_PAGE_BITS, MeteredCtx}, + }, + interpreter::InterpretedInstance, + vm::VirtualMachine, + InsExecutorE1, Streams, VmConfig, + }, + system::memory::MemoryImage, }; pub fn air_test(config: VC, exe: impl Into>) where VC: VmConfig, - VC::Executor: Chip, + VC::Executor: Chip + InsExecutorE1, VC::Periphery: Chip, { air_test_with_min_segments(config, exe, Streams::default(), 1); @@ -36,10 +46,10 @@ pub fn air_test_with_min_segments( exe: impl Into>, input: impl Into>, min_segments: usize, -) -> Option> +) -> Option where VC: VmConfig, - VC::Executor: Chip, + VC::Executor: Chip + InsExecutorE1, VC::Periphery: Chip, { air_test_impl(config, exe, input, min_segments, true) @@ -53,23 +63,48 @@ pub fn air_test_impl( input: impl Into>, min_segments: usize, debug: bool, -) -> Option> +) -> Option where VC: VmConfig, - VC::Executor: Chip, + VC::Executor: Chip + InsExecutorE1, VC::Periphery: Chip, { setup_tracing(); + let exe = exe.into(); + let input = input.into(); + { + let executor = InterpretedInstance::::new(config.clone(), exe.clone()); + executor + .execute(E1Ctx::new(None), input.clone()) + .expect("Failed to execute"); + } let mut log_blowup = 1; while config.system().max_constraint_degree > (1 << log_blowup) + 1 { log_blowup += 1; } let engine = BabyBearPoseidon2Engine::new(FriParameters::new_for_testing(log_blowup)); - let vm = VirtualMachine::new(engine, config); + let vm = VirtualMachine::new(engine, config.clone()); let pk = vm.keygen(); - let mut result = vm.execute_and_generate(exe, input).unwrap(); + let vk = pk.get_vk(); + let chip_complex = vm.config().create_chip_complex().unwrap(); + { + let executor = InterpretedInstance::::new(config.clone(), exe.clone()); + let ctx = MeteredCtx::::new(&chip_complex, vk.num_interactions()) + .with_max_trace_height(config.system().segmentation_strategy.max_trace_height() as u32) + .with_max_cells(config.system().segmentation_strategy.max_cells()); + let final_state = executor + .execute_e2(ctx, input.clone()) + .expect("Failed to execute"); + assert!(final_state.ctx.segments().len() >= min_segments); + } + + let segments = vm + .executor + .execute_metered(exe.clone(), input.clone(), &vk.num_interactions()) + .unwrap(); + let mut result = vm.execute_and_generate(exe, input, &segments).unwrap(); let final_memory = Option::take(&mut result.final_memory); - let global_airs = vm.config().create_chip_complex().unwrap().airs(); + let global_airs = chip_complex.airs(); if debug { for proof_input in &result.per_segment { let (airs, pks, air_proof_inputs): (Vec<_>, Vec<_>, Vec<_>) = @@ -95,37 +130,51 @@ where /// do any proving. Output is the payload of everything the prover needs. /// /// The output AIRs and traces are sorted by height in descending order. -pub fn gen_vm_program_test_proof_input( +pub fn gen_vm_program_test_proof_input( program: Program>, input_stream: impl Into>> + Clone, #[allow(unused_mut)] mut config: VC, ) -> ProofInputForTest where + E: StarkFriEngine, + SC: StarkGenericConfig, Val: PrimeField32, VC: VmConfig> + Clone, - VC::Executor: Chip, + VC::Executor: Chip + InsExecutorE1>, VC::Periphery: Chip, { + let program_exe = VmExe::new(program); + let input = input_stream.into(); + + let airs = config.create_chip_complex().unwrap().airs(); + let engine = E::new(FriParameters::new_for_testing(1)); + let vm = VirtualMachine::new(engine, config.clone()); + + let pk = vm.keygen(); + let vk = pk.get_vk(); + let segments = vm + .executor + .execute_metered(program_exe.clone(), input.clone(), &vk.num_interactions()) + .unwrap(); + cfg_if::cfg_if! { if #[cfg(feature = "bench-metrics")] { // Run once with metrics collection enabled, which can improve runtime performance config.system_mut().profiling = true; { let executor = VmExecutor::, VC>::new(config.clone()); - executor.execute(program.clone(), input_stream.clone()).unwrap(); + executor.execute(program_exe.clone(), input.clone(), &segments).unwrap(); } // Run again with metrics collection disabled and measure trace generation time config.system_mut().profiling = false; let start = std::time::Instant::now(); } } - - let airs = config.create_chip_complex().unwrap().airs(); - let executor = VmExecutor::, VC>::new(config); - - let mut result = executor - .execute_and_generate(program, input_stream) + let mut result = vm + .executor + .execute_and_generate(program_exe, input, &segments) .unwrap(); + assert_eq!( result.per_segment.len(), 1, @@ -159,11 +208,12 @@ pub fn execute_and_prove_program, where Val: PrimeField32, VC: VmConfig> + Clone, - VC::Executor: Chip, + VC::Executor: Chip + InsExecutorE1>, VC::Periphery: Chip, { let span = tracing::info_span!("execute_and_prove_program").entered(); - let test_proof_input = gen_vm_program_test_proof_input(program, input_stream, config); + let test_proof_input = + gen_vm_program_test_proof_input::<_, _, E>(program, input_stream, config); let vparams = test_proof_input.run_test(engine)?; span.exit(); Ok(vparams) diff --git a/crates/vm/src/utils/test_utils.rs b/crates/vm/src/utils/test_utils.rs index 9449aff5b8..f924536888 100644 --- a/crates/vm/src/utils/test_utils.rs +++ b/crates/vm/src/utils/test_utils.rs @@ -1,5 +1,7 @@ use std::array; +use openvm_circuit::arch::{MemoryConfig, SystemConfig}; +use openvm_instructions::NATIVE_AS; use openvm_stark_backend::p3_field::PrimeField32; use rand::{rngs::StdRng, Rng}; @@ -31,3 +33,25 @@ pub fn u32_sign_extend(num: u32) -> u32 { num } } + +pub fn test_system_config() -> SystemConfig { + SystemConfig::new( + 3, + MemoryConfig::new(2, vec![0, 4096, 1 << 22, 4096, 1 << 25], 29, 29, 17, 32), + 32, + ) +} + +// Testing config when native address space is not needed +pub fn test_system_config_with_continuations() -> SystemConfig { + let mut config = test_system_config(); + config.memory_config.addr_space_sizes[NATIVE_AS as usize] = 0; + config.with_continuations() +} + +/// Generate a random message of a given length in bytes +pub fn get_random_message(rng: &mut StdRng, len: usize) -> Vec { + let mut random_message: Vec = vec![0u8; len]; + rng.fill(&mut random_message[..]); + random_message +} diff --git a/crates/vm/tests/integration_test.rs b/crates/vm/tests/integration_test.rs index 168d756111..03ca6388ac 100644 --- a/crates/vm/tests/integration_test.rs +++ b/crates/vm/tests/integration_test.rs @@ -1,21 +1,28 @@ use std::{ collections::{BTreeMap, VecDeque}, iter::zip, + mem::transmute, sync::Arc, }; use openvm_circuit::{ arch::{ + create_and_initialize_chip_complex, + execution_mode::{ + e1::E1Ctx, + tracegen::{TracegenCtx, TracegenExecutionControl}, + }, hasher::{poseidon2::vm_poseidon2_hasher, Hasher}, - ChipId, ExecutionSegment, MemoryConfig, SingleSegmentVmExecutor, SystemConfig, - SystemTraceHeights, VirtualMachine, VmComplexTraceHeights, VmConfig, - VmInventoryTraceHeights, + interpreter::InterpretedInstance, + ChipId, DefaultSegmentationStrategy, SingleSegmentVmExecutor, SystemTraceHeights, + VirtualMachine, VmComplexTraceHeights, VmConfig, VmInventoryTraceHeights, + VmSegmentExecutor, VmSegmentState, }, system::{ memory::{MemoryTraceHeights, VolatileMemoryTraceHeights, CHUNK}, program::trace::VmCommittedExe, }, - utils::{air_test, air_test_with_min_segments}, + utils::{air_test, air_test_with_min_segments, test_system_config}, }; use openvm_instructions::{ exe::VmExe, @@ -26,10 +33,18 @@ use openvm_instructions::{ SysPhantom, SystemOpcode::*, }; -use openvm_native_circuit::NativeConfig; +use openvm_native_circuit::{ + test_native_config, test_native_continuations_config, test_rv32_with_kernels_config, + NativeConfig, +}; use openvm_native_compiler::{ - FieldArithmeticOpcode::*, FieldExtensionOpcode::*, NativeBranchEqualOpcode, NativeJalOpcode::*, - NativeLoadStoreOpcode::*, NativePhantom, + CastfOpcode, + FieldArithmeticOpcode::*, + FieldExtensionOpcode::*, + FriOpcode, NativeBranchEqualOpcode, + NativeJalOpcode::{self, *}, + NativeLoadStoreOpcode::*, + NativePhantom, NativeRangeCheckOpcode, Poseidon2Opcode, }; use openvm_rv32im_transpiler::BranchEqualOpcode::*; use openvm_stark_backend::{ @@ -44,7 +59,7 @@ use openvm_stark_sdk::{ engine::StarkFriEngine, p3_baby_bear::BabyBear, }; -use rand::Rng; +use rand::{rngs::StdRng, Rng, SeedableRng}; use test_log::test; pub fn gen_pointer(rng: &mut R, len: usize) -> usize @@ -55,19 +70,6 @@ where rng.gen_range(0..MAX_MEMORY - len) / len * len } -fn test_native_config() -> NativeConfig { - NativeConfig { - system: SystemConfig::new(3, MemoryConfig::new(2, 1, 16, 29, 15, 32, 1024), 0), - native: Default::default(), - } -} - -fn test_native_continuations_config() -> NativeConfig { - let mut config = test_native_config(); - config.system = config.system.with_continuations(); - config -} - #[test] fn test_vm_1() { let n = 6; @@ -126,9 +128,18 @@ fn test_vm_override_executor_height() { // Test getting heights. let vm_config = NativeConfig::aggregation(8, 3); + let vm = VirtualMachine::new(e, vm_config.clone()); + let pk = vm.keygen(); + let vk = pk.get_vk(); + let executor = SingleSegmentVmExecutor::new(vm_config.clone()); + + let max_trace_heights = executor + .execute_metered(committed_exe.exe.clone(), vec![], &vk.num_interactions()) + .unwrap(); + let res = executor - .execute_and_compute_heights(committed_exe.exe.clone(), vec![]) + .execute_and_compute_heights(committed_exe.exe.clone(), vec![], &max_trace_heights) .unwrap(); // Memory trace heights are not computed during execution. assert_eq!( @@ -192,7 +203,7 @@ fn test_vm_override_executor_height() { Some(overridden_heights), ); let proof_input = executor - .execute_and_generate(committed_exe, vec![]) + .execute_and_generate(committed_exe, vec![], &max_trace_heights) .unwrap(); let air_heights: Vec<_> = proof_input .per_air @@ -235,8 +246,16 @@ fn test_vm_1_optional_air() { ]; let program = Program::from_instructions(&instructions); + + let pk = vm.keygen(); + let vk = pk.get_vk(); + let segments = vm + .executor + .execute_metered(program.clone(), vec![], &vk.num_interactions()) + .unwrap(); + let result = vm - .execute_and_generate(program, vec![]) + .execute_and_generate(program, vec![], &segments) .expect("Failed to execute VM"); assert_eq!(result.per_segment.len(), 1); let proof_input = result.per_segment.last().unwrap(); @@ -255,11 +274,12 @@ fn test_vm_1_optional_air() { fn test_vm_public_values() { setup_tracing(); let num_public_values = 100; - let config = SystemConfig::default().with_public_values(num_public_values); + let config = test_system_config().with_public_values(num_public_values); let engine = BabyBearPoseidon2Engine::new(standard_fri_params_with_100_bits_conjectured_security(3)); let vm = VirtualMachine::new(engine, config.clone()); let pk = vm.keygen(); + let vk = pk.get_vk(); { let instructions = vec![ @@ -273,8 +293,13 @@ fn test_vm_public_values() { vm.engine.config.pcs(), )); let single_vm = SingleSegmentVmExecutor::new(config); + + let max_trace_heights = single_vm + .execute_metered(program.clone().into(), vec![], &vk.num_interactions()) + .unwrap(); + let exe_result = single_vm - .execute_and_compute_heights(program, vec![]) + .execute_and_compute_heights(program, vec![], &max_trace_heights) .unwrap(); assert_eq!( exe_result.public_values, @@ -285,7 +310,7 @@ fn test_vm_public_values() { .concat(), ); let proof_input = single_vm - .execute_and_generate(committed_exe, vec![]) + .execute_and_generate(committed_exe, vec![], &max_trace_heights) .unwrap(); vm.engine .prove_then_verify(&pk, proof_input) @@ -316,9 +341,8 @@ fn test_vm_initial_memory() { Instruction::::from_isize(TERMINATE.global_opcode(), 0, 0, 0, 0, 0), ]); - let init_memory: BTreeMap<_, _> = [((4, 7), BabyBear::from_canonical_u32(101))] - .into_iter() - .collect(); + let raw = unsafe { transmute::(BabyBear::from_canonical_u32(101)) }; + let init_memory = BTreeMap::from_iter((0..4).map(|i| ((4u32, 7u32 * 4 + i), raw[i as usize]))); let config = test_native_continuations_config(); let exe = VmExe { @@ -335,13 +359,14 @@ fn test_vm_1_persistent() { let engine = BabyBearPoseidon2Engine::new(FriParameters::standard_fast()); let config = test_native_continuations_config(); let ptr_max_bits = config.system.memory_config.pointer_max_bits; - let as_height = config.system.memory_config.as_height; + let addr_space_height = config.system.memory_config.addr_space_height; let airs = VmConfig::::create_chip_complex(&config) .unwrap() .airs::(); let vm = VirtualMachine::new(engine, config); let pk = vm.keygen(); + let vk = pk.get_vk(); let n = 6; let instructions = vec![ @@ -360,7 +385,14 @@ fn test_vm_1_persistent() { let program = Program::from_instructions(&instructions); - let result = vm.execute_and_generate(program.clone(), vec![]).unwrap(); + let segments = vm + .executor + .execute_metered(program.clone(), vec![], &vk.num_interactions()) + .unwrap(); + + let result = vm + .execute_and_generate(program.clone(), vec![], &segments) + .unwrap(); { let proof_input = result.per_segment.into_iter().next().unwrap(); @@ -374,20 +406,20 @@ fn test_vm_1_persistent() { ); let mut digest = [BabyBear::ZERO; CHUNK]; let compression = vm_poseidon2_hasher(); - for _ in 0..ptr_max_bits + as_height - 2 { + for _ in 0..ptr_max_bits + addr_space_height - 2 { digest = compression.compress(&digest, &digest); } assert_eq!( merkle_air_proof_input.raw.public_values[..8], // The value when you start with zeros and repeatedly hash the value with itself - // ptr_max_bits + as_height - 2 times. - // The height of the tree is ptr_max_bits + as_height - log2(8). The leaf also must be - // hashed once with padding for security. + // ptr_max_bits + addr_space_height - 2 times. + // The height of the tree is ptr_max_bits + addr_space_height - log2(8). The leaf also + // must be hashed once with padding for security. digest ); } - let result_for_proof = vm.execute_and_generate(program, vec![]).unwrap(); + let result_for_proof = vm.execute_and_generate(program, vec![], &segments).unwrap(); let proofs = vm.prove(&pk, result_for_proof); vm.verify(&pk.get_vk(), proofs) .expect("Verification failed"); @@ -656,7 +688,7 @@ fn test_vm_hint() { Instruction::from_isize(LOADW.global_opcode(), 38, 0, 32, 4, 4), Instruction::large_from_isize(ADD.global_opcode(), 44, 20, 0, 4, 4, 0, 0), Instruction::from_isize(MUL.global_opcode(), 24, 38, 1, 4, 4), - Instruction::large_from_isize(ADD.global_opcode(), 20, 20, 24, 4, 4, 1, 0), + Instruction::large_from_isize(ADD.global_opcode(), 20, 20, 24, 4, 4, 4, 0), Instruction::large_from_isize(ADD.global_opcode(), 50, 16, 0, 4, 4, 0, 0), Instruction::from_isize( JAL.global_opcode(), @@ -694,7 +726,7 @@ fn test_vm_hint() { type F = BabyBear; let input_stream: Vec> = vec![vec![F::TWO]]; - let config = NativeConfig::new(SystemConfig::default(), Default::default()); + let config = test_native_config(); air_test_with_min_segments(config, program, input_stream, 1); } @@ -712,17 +744,40 @@ fn test_hint_load_1() { ]; let program = Program::from_instructions(&instructions); + let input = vec![vec![F::ONE, F::TWO]]; + let rng = StdRng::seed_from_u64(0); - let mut segment = ExecutionSegment::new( + let engine = BabyBearPoseidon2Engine::new(FriParameters::standard_fast()); + let vm = VirtualMachine::new(engine, test_native_config()); + let pk = vm.keygen(); + let vk = pk.get_vk(); + let mut segments = vm + .executor + .execute_metered(program.clone(), input.clone(), &vk.num_interactions()) + .unwrap(); + assert_eq!(segments.len(), 1); + let segment = segments.pop().unwrap(); + + let chip_complex = create_and_initialize_chip_complex( &test_native_config(), program, - vec![vec![F::ONE, F::TWO]].into(), None, + Some(&segment.trace_heights), + ) + .unwrap(); + + let mut executor = VmSegmentExecutor::::new( + chip_complex, vec![], Default::default(), + TracegenExecutionControl, ); - segment.execute_from_pc(0).unwrap(); - let streams = segment.chip_complex.take_streams(); + + let ctx = TracegenCtx::new(Some(segment.num_insns)); + let mut exec_state = VmSegmentState::new(0, 0, None, input.into(), rng, ctx); + executor.execute_from_state(&mut exec_state).unwrap(); + + let streams = exec_state.streams; assert!(streams.input_stream.is_empty()); assert_eq!(streams.hint_stream, VecDeque::from(vec![F::ZERO])); assert_eq!(streams.hint_space, vec![vec![F::ONE, F::TWO]]); @@ -749,24 +804,49 @@ fn test_hint_load_2() { ]; let program = Program::from_instructions(&instructions); + let input = vec![vec![F::ONE, F::TWO], vec![F::TWO, F::ONE]]; + let rng = StdRng::seed_from_u64(0); + + let engine = BabyBearPoseidon2Engine::new(FriParameters::standard_fast()); + let vm = VirtualMachine::new(engine, test_native_config()); + let pk = vm.keygen(); + let vk = pk.get_vk(); + let mut segments = vm + .executor + .execute_metered(program.clone(), input.clone(), &vk.num_interactions()) + .unwrap(); + assert_eq!(segments.len(), 1); + let segment = segments.pop().unwrap(); - let mut segment = ExecutionSegment::new( + let chip_complex = create_and_initialize_chip_complex( &test_native_config(), program, - vec![vec![F::ONE, F::TWO], vec![F::TWO, F::ONE]].into(), None, + Some(&segment.trace_heights), + ) + .unwrap(); + + let mut executor = VmSegmentExecutor::::new( + chip_complex, vec![], Default::default(), + TracegenExecutionControl, ); - segment.execute_from_pc(0).unwrap(); - assert_eq!( - segment + + let ctx = TracegenCtx::new(Some(segment.num_insns)); + let mut exec_state = VmSegmentState::new(0, 0, None, input.into(), rng, ctx); + executor.execute_from_state(&mut exec_state).unwrap(); + + let [read] = unsafe { + executor .chip_complex .memory_controller() - .unsafe_read_cell(F::from_canonical_usize(4), F::from_canonical_usize(32)), - F::ZERO - ); - let streams = segment.chip_complex.take_streams(); + .memory + .data + .read::(4, 32) + }; + assert_eq!(read, F::ZERO); + let streams = exec_state.streams; assert!(streams.input_stream.is_empty()); assert_eq!(streams.hint_stream, VecDeque::from(vec![F::ONE])); assert_eq!( @@ -774,3 +854,222 @@ fn test_hint_load_2() { vec![vec![F::ONE, F::TWO], vec![F::TWO, F::ONE]] ); } + +#[test] +fn test_vm_pure_execution_non_continuation() { + type F = BabyBear; + let n = 6; + /* + Instruction 0 assigns word[0]_4 to n. + Instruction 4 terminates + The remainder is a loop that decrements word[0]_4 until it reaches 0, then terminates. + Instruction 1 checks if word[0]_4 is 0 yet, and if so sets pc to 5 in order to terminate + Instruction 2 decrements word[0]_4 (using word[1]_4) + Instruction 3 uses JAL as a simple jump to go back to instruction 1 (repeating the loop). + */ + let instructions = vec![ + // word[0]_4 <- word[n]_0 + Instruction::large_from_isize(ADD.global_opcode(), 0, n, 0, 4, 0, 0, 0), + // if word[0]_4 == 0 then pc += 3 * DEFAULT_PC_STEP + Instruction::from_isize( + NativeBranchEqualOpcode(BEQ).global_opcode(), + 0, + 0, + 3 * DEFAULT_PC_STEP as isize, + 4, + 0, + ), + // word[0]_4 <- word[0]_4 - word[1]_4 + Instruction::large_from_isize(SUB.global_opcode(), 0, 0, 1, 4, 4, 0, 0), + // word[2]_4 <- pc + DEFAULT_PC_STEP, pc -= 2 * DEFAULT_PC_STEP + Instruction::from_isize( + JAL.global_opcode(), + 2, + -2 * DEFAULT_PC_STEP as isize, + 0, + 4, + 0, + ), + // terminate + Instruction::from_isize(TERMINATE.global_opcode(), 0, 0, 0, 0, 0), + ]; + + let program = Program::from_instructions(&instructions); + + let executor = InterpretedInstance::::new(test_native_config(), program); + executor + .execute(E1Ctx::new(None), vec![]) + .expect("Failed to execute"); +} + +#[test] +fn test_vm_pure_execution_continuation() { + type F = BabyBear; + let instructions = vec![ + Instruction::large_from_isize(ADD.global_opcode(), 0, 0, 1, 4, 0, 0, 0), + Instruction::large_from_isize(ADD.global_opcode(), 1, 0, 2, 4, 0, 0, 0), + Instruction::large_from_isize(ADD.global_opcode(), 2, 0, 1, 4, 0, 0, 0), + Instruction::large_from_isize(ADD.global_opcode(), 3, 0, 2, 4, 0, 0, 0), + Instruction::large_from_isize(ADD.global_opcode(), 4, 0, 2, 4, 0, 0, 0), + Instruction::large_from_isize(ADD.global_opcode(), 5, 0, 1, 4, 0, 0, 0), + Instruction::large_from_isize(ADD.global_opcode(), 6, 0, 1, 4, 0, 0, 0), + Instruction::large_from_isize(ADD.global_opcode(), 7, 0, 2, 4, 0, 0, 0), + Instruction::from_isize(FE4ADD.global_opcode(), 8, 0, 4, 4, 4), + Instruction::from_isize(FE4ADD.global_opcode(), 8, 0, 4, 4, 4), + Instruction::from_isize(FE4SUB.global_opcode(), 12, 0, 4, 4, 4), + Instruction::from_isize(BBE4MUL.global_opcode(), 12, 0, 4, 4, 4), + Instruction::from_isize(BBE4DIV.global_opcode(), 12, 0, 4, 4, 4), + Instruction::from_isize(TERMINATE.global_opcode(), 0, 0, 0, 0, 0), + ]; + + let program = Program::from_instructions(&instructions); + let executor = InterpretedInstance::::new(test_native_continuations_config(), program); + executor + .execute(E1Ctx::new(None), vec![]) + .expect("Failed to execute"); +} + +#[test] +fn test_vm_e1_native_chips() { + type F = BabyBear; + + let instructions = vec![ + // Field Arithmetic operations (FieldArithmeticChip) + Instruction::large_from_isize(ADD.global_opcode(), 0, 0, 1, 4, 0, 0, 0), + Instruction::large_from_isize(SUB.global_opcode(), 1, 10, 2, 4, 0, 0, 0), + Instruction::large_from_isize(MUL.global_opcode(), 2, 3, 4, 4, 0, 0, 0), + Instruction::large_from_isize(DIV.global_opcode(), 3, 20, 5, 4, 0, 0, 0), + // Field Extension operations (FieldExtensionChip) + Instruction::from_isize(FE4ADD.global_opcode(), 8, 0, 4, 4, 4), + Instruction::from_isize(FE4SUB.global_opcode(), 12, 8, 4, 4, 4), + Instruction::from_isize(BBE4MUL.global_opcode(), 16, 12, 8, 4, 4), + Instruction::from_isize(BBE4DIV.global_opcode(), 20, 16, 12, 4, 4), + // Branch operations (NativeBranchEqChip) + Instruction::from_isize( + NativeBranchEqualOpcode(BEQ).global_opcode(), + 0, + 0, + DEFAULT_PC_STEP as isize, + 4, + 4, + ), + Instruction::from_isize( + NativeBranchEqualOpcode(BNE).global_opcode(), + 1, + 2, + DEFAULT_PC_STEP as isize, + 4, + 4, + ), + // JAL operation (JalRangeCheckChip) + Instruction::from_isize( + NativeJalOpcode::JAL.global_opcode(), + 24, + DEFAULT_PC_STEP as isize, + 0, + 4, + 0, + ), + // Range check operation (JalRangeCheckChip) + Instruction::from_isize( + NativeRangeCheckOpcode::RANGE_CHECK.global_opcode(), + 0, + 10, + 8, + 4, + 0, + ), + // Load/Store operations (NativeLoadStoreChip) + Instruction::from_isize(STOREW.global_opcode(), 0, 0, 28, 4, 4), + Instruction::from_isize(LOADW.global_opcode(), 32, 0, 28, 4, 4), + Instruction::from_isize( + PHANTOM.global_opcode(), + 0, + 0, + NativePhantom::HintInput as isize, + 0, + 0, + ), + Instruction::from_isize(HINT_STOREW.global_opcode(), 32, 0, 0, 4, 4), + // Cast to field operation (CastFChip) + Instruction::from_usize(CastfOpcode::CASTF.global_opcode(), [36, 40, 0, 2, 4]), + // Poseidon2 operations (Poseidon2Chip) + Instruction::new( + Poseidon2Opcode::PERM_POS2.global_opcode(), + F::from_canonical_usize(44), + F::from_canonical_usize(48), + F::ZERO, + F::from_canonical_usize(4), + F::from_canonical_usize(4), + F::ZERO, + F::ZERO, + ), + Instruction::new( + Poseidon2Opcode::COMP_POS2.global_opcode(), + F::from_canonical_usize(52), + F::from_canonical_usize(44), + F::from_canonical_usize(48), + F::from_canonical_usize(4), + F::from_canonical_usize(4), + F::ZERO, + F::ZERO, + ), + // FRI operation (FriReducedOpeningChip) + Instruction::large_from_isize(ADD.global_opcode(), 60, 64, 0, 4, 4, 0, 0), /* a_pointer_pointer, */ + Instruction::large_from_isize(ADD.global_opcode(), 64, 68, 0, 4, 4, 0, 0), /* b_pointer_pointer, */ + Instruction::large_from_isize(ADD.global_opcode(), 68, 2, 0, 4, 0, 0, 0), /* length_pointer (value 2), */ + Instruction::large_from_isize(ADD.global_opcode(), 72, 1, 0, 4, 0, 0, 0), //alpha_pointer + Instruction::large_from_isize(ADD.global_opcode(), 76, 80, 0, 4, 4, 0, 0), /* result_pointer, */ + Instruction::large_from_isize(ADD.global_opcode(), 80, 1, 0, 4, 0, 0, 0), /* is_init (value 1) , */ + Instruction::from_usize( + FriOpcode::FRI_REDUCED_OPENING.global_opcode(), + [60, 64, 68, 72, 76, 0, 80], + ), + // Terminate + Instruction::from_isize(TERMINATE.global_opcode(), 0, 0, 0, 0, 0), + ]; + + let program = Program::from_instructions(&instructions); + let input_stream: Vec> = vec![vec![]]; + + let executor = InterpretedInstance::::new(test_rv32_with_kernels_config(), program); + executor + .execute(E1Ctx::new(None), input_stream) + .expect("Failed to execute"); +} + +#[test] +fn test_single_segment_executor_no_segmentation() { + setup_tracing(); + type F = BabyBear; + + let mut config = test_native_config(); + + config.system.set_segmentation_strategy(Arc::new( + DefaultSegmentationStrategy::new_with_max_segment_len(1), + )); + + let engine = + BabyBearPoseidon2Engine::new(standard_fri_params_with_100_bits_conjectured_security(3)); + let vm = VirtualMachine::new(engine, config.clone()); + let pk = vm.keygen(); + let vk = pk.get_vk(); + let instructions: Vec<_> = (0..1000) + .map(|_| Instruction::large_from_isize(ADD.global_opcode(), 0, 0, 1, 4, 0, 0, 0)) + .chain(std::iter::once(Instruction::from_isize( + TERMINATE.global_opcode(), + 0, + 0, + 0, + 0, + 0, + ))) + .collect(); + + let program = Program::from_instructions(&instructions); + let single_vm = SingleSegmentVmExecutor::::new(config); + + let _ = single_vm + .execute_metered(program.clone().into(), vec![], &vk.num_interactions()) + .unwrap(); +} diff --git a/docs/crates/metrics.md b/docs/crates/metrics.md index 362bce47e0..8874f8d7be 100644 --- a/docs/crates/metrics.md +++ b/docs/crates/metrics.md @@ -7,14 +7,20 @@ We describe the metrics that are collected for a single VM circuit proof, which To scope metrics from different proofs, we use the [`metrics_tracing_context`](https://docs.rs/metrics-tracing-context/latest/metrics_tracing_context/) crate to provide context-dependent labels. With the exception of the `segment` label, all other labels must be set by the caller. -For a single segment proof, the following metrics are collected: +For a segment proof, the following metrics are collected: -- `execute_time_ms` (gauge): The runtime execution time of the segment in milliseconds. +- `execute_metered_time_ms` (gauge): The metered execution time of the segment in milliseconds. This is timed across **all** segments in the group. +- `execute_e3_time_ms` (gauge): The preflight execution time of the segment in milliseconds. - If this is a segment in a VM with continuations enabled, a `segment: segment_idx` label is added to the metric. - `trace_gen_time_ms` (gauge): The time to generate non-cached trace matrices from execution records. - If this is a segment in a VM with continuations enabled, a `segment: segment_idx` label is added to the metric. + - `memory_finalize_time_ms` (gauge): The time in trace generation spent on memory finalization. + - `boundary_finalize_time_ms` (gauge): The time in memory finalization spent on boundary finalization. + - `merkle_finalize_time_ms` (gauge): The time in memory finalization spent on merkle tree finalization. - All metrics collected by [`openvm-stark-backend`](https://github.com/openvm-org/stark-backend/blob/main/docs/metrics.md), in particular `stark_prove_excluding_trace_time_ms` (gauge). - - The total proving time of the proof is the sum of `execute_time_ms + trace_gen_time_ms + stark_prove_excluding_trace_time_ms`. +- The `total_proof_time_ms` of the proof is: + - The sum `execute_e3_time_ms + trace_gen_time_ms + stark_prove_excluding_trace_time_ms` for app proofs. The `execute_metered_time_ms` is excluded for app proofs because it is not run on a per-segment basis. + - The sum `execute_metered_time_ms + execute_e3_time_ms + trace_gen_time_ms + stark_prove_excluding_trace_time_ms` for non-app proofs. - `total_cycles` (counter): The total number of cycles in the segment. - `main_cells_used` (counter): The total number of main trace cells used by all chips in the segment. This does not include cells needed to pad rows to power-of-two matrix heights. Only main trace cells, not preprocessed or permutation trace cells, are counted. diff --git a/docs/crates/vm.md b/docs/crates/vm.md index 989e9c8a88..3e988148f9 100644 --- a/docs/crates/vm.md +++ b/docs/crates/vm.md @@ -59,6 +59,7 @@ pub trait PhantomSubExecutor { &mut self, memory: &MemoryController, streams: &mut Streams, + rng: &mut StdRng, discriminant: PhantomDiscriminant, a: F, b: F, diff --git a/docs/specs/ISA.md b/docs/specs/ISA.md index 2755243cdf..1ca0091f98 100644 --- a/docs/specs/ISA.md +++ b/docs/specs/ISA.md @@ -6,7 +6,7 @@ This specification describes the overall architecture and default VM extensions - [RV32IM](#rv32im-extension): An extension supporting the 32-bit RISC-V ISA with multiplication. - [Native](#native-extension): An extension supporting native field arithmetic for proof recursion and aggregation. - [Keccak-256](#keccak-extension): An extension implementing the Keccak-256 hash function compatibly with RISC-V memory. -- [SHA2-256](#sha2-256-extension): An extension implementing the SHA2-256 hash function compatibly with RISC-V memory. +- [SHA2](#sha-2-extension): An extension implementing the SHA-256, SHA-512, and SHA-384 hash functions compatibly with RISC-V memory. - [BigInt](#bigint-extension): An extension supporting 256-bit signed and unsigned integer arithmetic, including multiplication. This extension respects the RISC-V memory format. - [Algebra](#algebra-extension): An extension supporting modular arithmetic over arbitrary fields and their complex @@ -31,8 +31,8 @@ OpenVM depends on the following parameters, some of which are fixed and some of | `PC_BITS` | The number of bits in the program counter. | Fixed to 30. | | `DEFAULT_PC_STEP` | The default program counter step size. | Fixed to 4. | | `LIMB_BITS` | The number of bits in a limb for RISC-V memory emulation. | Fixed to 8. | -| `as_offset` | The index of the first writable address space. | Fixed to 1. | -| `as_height` | The base 2 log of the number of writable address spaces supported. | Configurable, must satisfy `as_height <= F::bits() - 2` | +| `ADDR_SPACE_OFFSET` | The index of the first writable address space. | Fixed to 1. | +| `addr_space_height` | The base 2 log of the number of writable address spaces supported. | Configurable, must satisfy `addr_space_height <= F::bits() - 2` | | `pointer_max_bits` | The maximum number of bits in a pointer. | Configurable, must satisfy `pointer_max_bits <= F::bits() - 2` | | `num_public_values` | The number of user public values. | Configurable. If continuation is enabled, it must equal `8` times a power of two(which is nonzero). | @@ -113,12 +113,12 @@ Data memory is a random access memory (RAM) which supports read and write operat cells which represent a single field element indexed by **address space** and **pointer**. The number of supported address spaces and the size of each address space are configurable constants. -- Valid address spaces not used for immediates lie in `[1, 1 + 2^as_height)` for configuration constant `as_height`. +- Valid address spaces not used for immediates lie in `[1, 1 + 2^addr_space_height)` for configuration constant `addr_space_height`. - Valid pointers are field elements that lie in `[0, 2^pointer_max_bits)`, for configuration constant `pointer_max_bits`. When accessing an address out of `[0, 2^pointer_max_bits)`, the VM should panic. - -These configuration constants must satisfy `as_height, pointer_max_bits <= F::bits() - 2`. We use the following notation +- For the register address space (address space `1`), valid pointers lie in `[0, 128)`, corresponding to 32 registers with 4 byte limbs each. +These configuration constants must satisfy `addr_space_height, pointer_max_bits <= F::bits() - 2`. We use the following notation to denote cells in memory: - `[a]_d` denotes the single-cell value at pointer location `a` in address space `d`. This is a single @@ -171,7 +171,7 @@ structures during runtime execution: - `hint_space`: a vector of vectors of field elements used to store hints during runtime execution via [phantom sub-instructions](#phantom-sub-instructions) such as `NativeHintLoad`. The outer `hint_space` vector is append-only, but each internal `hint_space[hint_id]` vector may be mutated, including deletions, by the host. -- `kv_store`: a read-only key-value store for hints. Executors(e.g. `Rv32HintLoadByKey`) can read data from `kv_store` +- `kv_store`: a read-only key-value store for hints. Executors(e.g. `Rv32HintLoadByKey`) can read data from `kv_store` at runtime. `kv_store` is designed for general purposes so both key and value are byte arrays. Encoding of key/value are decided by each executor. Users need to use the corresponding encoding when adding data to `kv_store`. @@ -327,7 +327,7 @@ unsigned integer, and convert to field element. In the instructions below, `[c:4 #### Load/Store For all load/store instructions, we assume the operand `c` is in `[0, 2^16)`, and we fix address spaces `d = 1`. -The address space `e` can be `0`, `1`, or `2` for load instructions, and `2`, `3`, or `4` for store instructions. +The address space `e` is `2` for load instructions, and can be `2`, `3`, or `4` for store instructions. The operand `g` must be a boolean. We let `sign_extend(decompose(c)[0:2], g)` denote the `i32` defined by first taking the unsigned integer encoding of `c` as 16 bits, then sign extending it to 32 bits using the sign bit `g`, and considering the 32 bits as the 2's complement of an `i32`. @@ -454,7 +454,7 @@ reads but not allowed for writes. When using immediates, we interpret `[a]_0` as | STOREW | `a,b,c,4,4` | Set `[[c]_4 + b]_4 = [a]_4`. | | LOADW4 | `a,b,c,4,4` | Set `[a:4]_4 = [[c]_4 + b:4]_4`. | | STOREW4 | `a,b,c,4,4` | Set `[[c]_4 + b:4]_4 = [a:4]_4`. | -| JAL | `a,b,_,4` | Jump to address and link: set `[a]_4 = (pc + DEFAULT_PC_STEP)` and `pc = pc + b`. | +| JAL | `a,b,_,4` | Jump to address and link: set `[a]_4 = (pc + DEFAU````LT````_PC_STEP)` and `pc = pc + b`. | | RANGE_CHECK | `a,b,c,4` | Assert that `[a]_4 = x + y * 2^16` for some `x < 2^b` and `y < 2^c`. `b` must be in [0,16] and `c` must be in [0, 14]. | | BEQ | `a,b,c,d,e` | If `[a]_d == [b]_e`, then set `pc = pc + c`. | | BNE | `a,b,c,d,e` | If `[a]_d != [b]_e`, then set `pc = pc + c`. | @@ -464,7 +464,7 @@ reads but not allowed for writes. When using immediates, we interpret `[a]_0` as #### Field Arithmetic -This instruction set does native field operations. Below, `e,f` may be any address space. +This instruction set does native field operations. Below, `e,f` must be either `0` or `4`. When either `e` or `f` is zero, `[b]_0` and `[c]_0` should be interpreted as the immediates `b` and `c`, respectively. @@ -541,14 +541,16 @@ all memory cells are constrained to be bytes. | -------------- | ----------- | ----------------------------------------------------------------------------------------------------------------- | | KECCAK256_RV32 | `a,b,c,1,2` | `[r32{0}(a):32]_2 = keccak256([r32{0}(b)..r32{0}(b)+r32{0}(c)]_2)`. Performs memory accesses with block size `4`. | -### SHA2-256 Extension +### SHA-2 Extension -The SHA2-256 extension supports the SHA2-256 hash function. The extension operates on address spaces `1` and `2`, +The SHA-2 extension supports the SHA-256 and SHA-512 hash functions. The extension operates on address spaces `1` and `2`, meaning all memory cells are constrained to be bytes. | Name | Operands | Description | | ----------- | ----------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | | SHA256_RV32 | `a,b,c,1,2` | `[r32{0}(a):32]_2 = sha256([r32{0}(b)..r32{0}(b)+r32{0}(c)]_2)`. Does the necessary padding. Performs memory reads with block size `16` and writes with block size `32`. | +| SHA512_RV32 | `a,b,c,1,2` | `[r32{0}(a):64]_2 = sha512([r32{0}(b)..r32{0}(b)+r32{0}(c)]_2)`. Does the necessary padding. Performs memory reads with block size `32` and writes with block size `32`. | +| SHA384_RV32 | `a,b,c,1,2` | `[r32{0}(a):64]_2 = sha384([r32{0}(b)..r32{0}(b)+r32{0}(c)]_2)`. Does the necessary padding. Performs memory reads with block size `32` and writes with block size `32`. Writes 64 bytes to memory: the first 48 are the SHA-384 digest and the last 16 are zeros. | ### BigInt Extension diff --git a/docs/specs/RISCV.md b/docs/specs/RISCV.md index 32d0cc63fa..db1dbfb8f2 100644 --- a/docs/specs/RISCV.md +++ b/docs/specs/RISCV.md @@ -5,7 +5,7 @@ The default VM extensions that support transpilation are: - [RV32IM](#rv32im-extension): An extension supporting the 32-bit RISC-V ISA with multiplication. - [Keccak-256](#keccak-extension): An extension implementing the Keccak-256 hash function compatibly with RISC-V memory. -- [SHA2-256](#sha2-256-extension): An extension implementing the SHA2-256 hash function compatibly with RISC-V memory. +- [SHA2](#sha-2-extension): An extension implementing the SHA-256, SHA-512, and SHA-384 hash functions compatibly with RISC-V memory. - [BigInt](#bigint-extension): An extension supporting 256-bit signed and unsigned integer arithmetic, including multiplication. This extension respects the RISC-V memory format. - [Algebra](#algebra-extension): An extension supporting modular arithmetic over arbitrary fields and their complex field extensions. This extension respects the RISC-V memory format. - [Elliptic curve](#elliptic-curve-extension): An extension for elliptic curve operations over Weierstrass curves, including addition and doubling. This can be used to implement multi-scalar multiplication and ECDSA scalar multiplication. This extension respects the RISC-V memory format. @@ -85,11 +85,13 @@ implementation is here. But we use `funct3 = 111` because the native extension h | ----------- | --- | ----------- | ------ | ------ | ------------------------------------------- | | keccak256 | R | 0001011 | 100 | 0x0 | `[rd:32]_2 = keccak256([rs1..rs1 + rs2]_2)` | -## SHA2-256 Extension +## SHA-2 Extension | RISC-V Inst | FMT | opcode[6:0] | funct3 | funct7 | RISC-V description and notes | | ----------- | --- | ----------- | ------ | ------ | ---------------------------------------- | | sha256 | R | 0001011 | 100 | 0x1 | `[rd:32]_2 = sha256([rs1..rs1 + rs2]_2)` | +| sha512 | R | 0001011 | 100 | 0x2 | `[rd:64]_2 = sha512([rs1..rs1 + rs2]_2)` | +| sha384 | R | 0001011 | 100 | 0x3 | `[rd:64]_2 = sha384([rs1..rs1 + rs2]_2)`. Last 16 bytes will be set to zeros. | ## BigInt Extension diff --git a/docs/specs/circuit.md b/docs/specs/circuit.md index 4238c7c27b..bd34344674 100644 --- a/docs/specs/circuit.md +++ b/docs/specs/circuit.md @@ -104,7 +104,7 @@ The chips that fall into these categories are: | FriReducedOpeningChip | – | – | Case 1. | | NativePoseidon2Chip | – | – | Case 1. | | Rv32HintStoreChip | – | – | Case 1. | -| Sha256VmChip | – | – | Case 1. | +| Sha2VmChip | – | – | Case 1. | The PhantomChip satisfies the condition because `1 < 3`. diff --git a/docs/specs/continuations.md b/docs/specs/continuations.md index 87d83e5f19..233d2f3506 100644 --- a/docs/specs/continuations.md +++ b/docs/specs/continuations.md @@ -270,9 +270,9 @@ multiple accesses. Persistent memory requires three chips: the `PersistentBoundaryChip`, the `MemoryMerkleChip`, and a chip to assist in hashing, which is by default the `Poseidon2Chip`. To simplify the discussion, define constants `C` equal to the number -of field elements in a hash value, `L` where the addresses in an address space are $0..2^L$, `M` and `AS_OFFSET` where -the address spaces are `AS_OFFSET..AS_OFFSET + 2^M`, and `H = M + L - log2(C)`. `H` is the height of the Merkle tree in -the sense that the leaves are at distance `H` from the root. We define the following interactions: +of field elements in a hash value, `L` where the addresses in an address space are $0..2^L$, `M` and `ADDR_SPACE_OFFSET` +where the address spaces are `ADDR_SPACE_OFFSET..ADDR_SPACE_OFFSET + 2^M`, and `H = M + L - log2(C)`. `H` is the height +of the Merkle tree in the sense that the leaves are at distance `H` from the root. We define the following interactions: On the MERKLE_BUS, we have interactions of the form **(expand_direction: {-1, 0, 1}, height: F, labels: (F, F), hash: [F; C])**, where @@ -309,8 +309,8 @@ The `PersistentBoundaryChip` has rows of the form `(expand_direction, address_space, leaf_label, values, hash, timestamp)` and has the following interactions on the MERKLE_BUS: -- Send **(1, 0, (as - AS_OFFSET) \* 2^L, node\*label, hash_initial)** -- Receive **(-1, 0, (as - AS_OFFSET) \* 2^L, node_label, hash_final)** +- Send **(1, 0, (as - ADDR_SPACE_OFFSET) \* 2^L, node\*label, hash_initial)** +- Receive **(-1, 0, (as - ADDR_SPACE_OFFSET) \* 2^L, node_label, hash_final)** It receives `values` from the `MEMORY_BUS` and constrains `hash = compress(values, 0)` via the `POSEIDON2_DIRECT_BUS`. The aggregation program takes a variable number of consecutive segment proofs and consolidates them into a single proof diff --git a/docs/specs/isa-table.md b/docs/specs/isa-table.md index 7b7f374065..fc76462a00 100644 --- a/docs/specs/isa-table.md +++ b/docs/specs/isa-table.md @@ -130,13 +130,15 @@ In the tables below, we provide the mapping between the `LocalOpcode` and `Phant | ------------- | ---------- | ------------- | | Keccak | `Rv32KeccakOpcode::KECCAK256` | KECCAK256_RV32 | -## SHA2-256 Extension +## SHA-2 Extension #### Instructions | VM Extension | `LocalOpcode` | ISA Instruction | | ------------- | ---------- | ------------- | -| SHA2-256 | `Rv32Sha256Opcode::SHA256` | SHA256_RV32 | +| SHA-2 | `Rv32Sha2Opcode::SHA256` | SHA256_RV32 | +| SHA-2 | `Rv32Sha2Opcode::SHA512` | SHA512_RV32 | +| SHA-2 | `Rv32Sha2Opcode::SHA384` | SHA384_RV32 | ## BigInt Extension diff --git a/docs/specs/memory.md b/docs/specs/memory.md index 3d8ea1dc1d..87a7ed1316 100644 --- a/docs/specs/memory.md +++ b/docs/specs/memory.md @@ -163,7 +163,7 @@ Both boundary chips perform, for every subsegment ever existed in our nice set, The following invariants **must** be maintained by the memory architecture: 1. In the MEMORY_BUS, the `timestamp` is always in range `[0, 2^timestamp_max_bits)` where `timestamp_max_bits <= F::bits() - 2` is a configuration constant. -2. In the MEMORY_BUS, the `address_space` is always in range `[0, 1 + 2^as_height)` where `as_height` is a configuration constant satisfying `as_height < F::bits() - 2`. (Our current implementation only supports `as_height` less than the max bits supported by the VariableRangeCheckerBus). +2. In the MEMORY_BUS, the `address_space` is always in range `[0, 1 + 2^addr_space_height)` where `addr_space_height` is a configuration constant satisfying `addr_space_height < F::bits() - 2`. (Our current implementation only supports `addr_space_height` less than the max bits supported by the VariableRangeCheckerBus). 3. In the MEMORY_BUS, the `pointer` is always in range `[0, 2^pointer_max_bits)` where `pointer_max_bits <= F::bits() - 2` is a configuration constant. Invariant 1 is guaranteed by [time goes forward](#time-goes-forward) under the [assumption](./circuit.md#instruction-executors) that the timestamp increase during instruction execution is bounded by the number of AIR interactions. diff --git a/docs/specs/transpiler.md b/docs/specs/transpiler.md index fded65b6d8..1cdefb8b77 100644 --- a/docs/specs/transpiler.md +++ b/docs/specs/transpiler.md @@ -151,11 +151,13 @@ Each VM extension's behavior is specified below. | ----------- | -------------------------------------------------- | | keccak256 | KECCAK256_RV32 `ind(rd), ind(rs1), ind(rs2), 1, 2` | -### SHA2-256 Extension +### SHA-2 Extension | RISC-V Inst | OpenVM Instruction | | ----------- | ----------------------------------------------- | | sha256 | SHA256_RV32 `ind(rd), ind(rs1), ind(rs2), 1, 2` | +| sha512 | SHA512_RV32 `ind(rd), ind(rs1), ind(rs2), 1, 2` | +| sha384 | SHA384_RV32 `ind(rd), ind(rs1), ind(rs2), 1, 2` | ### BigInt Extension diff --git a/examples/i256/src/main.rs b/examples/i256/src/main.rs index 8f008f40a0..ec911bc1cd 100644 --- a/examples/i256/src/main.rs +++ b/examples/i256/src/main.rs @@ -1,4 +1,6 @@ #![allow(clippy::needless_range_loop)] +openvm::entry!(main); + use core::array; use alloy_primitives::I256; diff --git a/examples/keccak/src/main.rs b/examples/keccak/src/main.rs index 7b98d36ed1..0d138d5694 100644 --- a/examples/keccak/src/main.rs +++ b/examples/keccak/src/main.rs @@ -1,3 +1,5 @@ +openvm::entry!(main); + // ANCHOR: imports use core::hint::black_box; diff --git a/examples/sha256/Cargo.toml b/examples/sha2/Cargo.toml similarity index 89% rename from examples/sha256/Cargo.toml rename to examples/sha2/Cargo.toml index 0b5a44bc3e..adfc269750 100644 --- a/examples/sha256/Cargo.toml +++ b/examples/sha2/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "sha256-example" +name = "sha2-example" version = "0.0.0" edition = "2021" @@ -7,6 +7,7 @@ edition = "2021" members = [] [dependencies] +# TODO: update rev after PR is merged openvm = { git = "https://github.com/openvm-org/openvm.git", features = [ "std", ] } diff --git a/examples/sha2/openvm.toml b/examples/sha2/openvm.toml new file mode 100644 index 0000000000..35f92b7195 --- /dev/null +++ b/examples/sha2/openvm.toml @@ -0,0 +1,4 @@ +[app_vm_config.rv32i] +[app_vm_config.rv32m] +[app_vm_config.io] +[app_vm_config.sha2] diff --git a/examples/sha2/src/main.rs b/examples/sha2/src/main.rs new file mode 100644 index 0000000000..4fa1539ab6 --- /dev/null +++ b/examples/sha2/src/main.rs @@ -0,0 +1,39 @@ +// ANCHOR: imports +use core::hint::black_box; + +use hex::FromHex; +use openvm_sha2::{sha256, sha384, sha512}; +// ANCHOR_END: imports + +// ANCHOR: main +openvm::entry!(main); + +pub fn main() { + let test_vectors = [( + "", + "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", + "cf83e1357eefb8bdf1542850d66d8007d620e4050b5715dc83f4a921d36ce9ce47d0d13c5d85f2b0ff8318d2877eec2f63b931bd47417a81a538327af927da3e", + "38b060a751ac96384cd9327eb1b1e36a21fdb71114be07434c0cc7bf63f6e1da274edebfe76f65fbd51ad2f14898b95b", + )]; + for (input, expected_output_sha256, expected_output_sha512, expected_output_sha384) in + test_vectors.iter() + { + let input = Vec::from_hex(input).unwrap(); + let expected_output_sha256 = Vec::from_hex(expected_output_sha256).unwrap(); + let expected_output_sha512 = Vec::from_hex(expected_output_sha512).unwrap(); + let expected_output_sha384 = Vec::from_hex(expected_output_sha384).unwrap(); + let output = sha256(black_box(&input)); + if output != *expected_output_sha256 { + panic!(); + } + let output = sha512(black_box(&input)); + if output != *expected_output_sha512 { + panic!(); + } + let output = sha384(black_box(&input)); + if output != *expected_output_sha384 { + panic!(); + } + } +} +// ANCHOR_END: main diff --git a/examples/sha256/src/main.rs b/examples/sha256/src/main.rs deleted file mode 100644 index a6195390a4..0000000000 --- a/examples/sha256/src/main.rs +++ /dev/null @@ -1,25 +0,0 @@ -// ANCHOR: imports -use core::hint::black_box; - -use hex::FromHex; -use openvm_sha2::sha256; -// ANCHOR_END: imports - -// ANCHOR: main -openvm::entry!(main); - -pub fn main() { - let test_vectors = [( - "", - "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", - )]; - for (input, expected_output) in test_vectors.iter() { - let input = Vec::from_hex(input).unwrap(); - let expected_output = Vec::from_hex(expected_output).unwrap(); - let output = sha256(&black_box(input)); - if output != *expected_output { - panic!(); - } - } -} -// ANCHOR_END: main diff --git a/examples/u256/src/main.rs b/examples/u256/src/main.rs index 75b80afd3d..05319a2a17 100644 --- a/examples/u256/src/main.rs +++ b/examples/u256/src/main.rs @@ -1,4 +1,6 @@ #![allow(clippy::needless_range_loop)] +openvm::entry!(main); + use core::array; use openvm_ruint::aliases::U256; diff --git a/extensions/algebra/circuit/Cargo.toml b/extensions/algebra/circuit/Cargo.toml index 258bff450b..095128ee7b 100644 --- a/extensions/algebra/circuit/Cargo.toml +++ b/extensions/algebra/circuit/Cargo.toml @@ -29,7 +29,6 @@ strum = { workspace = true } derive-new = { workspace = true } serde.workspace = true serde_with = { workspace = true } -serde-big-array = { workspace = true } eyre = { workspace = true } [dev-dependencies] @@ -38,3 +37,4 @@ openvm-mod-circuit-builder = { workspace = true, features = ["test-utils"] } openvm-circuit = { workspace = true, features = ["test-utils"] } openvm-rv32-adapters = { workspace = true, features = ["test-utils"] } openvm-pairing-guest = { workspace = true, features = ["halo2curves"] } +test-case = {workspace = true} diff --git a/extensions/algebra/circuit/src/config.rs b/extensions/algebra/circuit/src/config.rs index 5b43163b77..5a641c1951 100644 --- a/extensions/algebra/circuit/src/config.rs +++ b/extensions/algebra/circuit/src/config.rs @@ -5,7 +5,10 @@ use openvm_rv32im_circuit::*; use openvm_stark_backend::p3_field::PrimeField32; use serde::{Deserialize, Serialize}; -use super::*; +use crate::{ + Fp2Extension, Fp2ExtensionExecutor, Fp2ExtensionPeriphery, ModularExtension, + ModularExtensionExecutor, ModularExtensionPeriphery, +}; #[derive(Clone, Debug, VmConfig, Serialize, Deserialize)] pub struct Rv32ModularConfig { diff --git a/extensions/algebra/circuit/src/fp2_chip/addsub.rs b/extensions/algebra/circuit/src/fp2_chip/addsub.rs index 4eca1ad102..905acdc580 100644 --- a/extensions/algebra/circuit/src/fp2_chip/addsub.rs +++ b/extensions/algebra/circuit/src/fp2_chip/addsub.rs @@ -1,63 +1,26 @@ -use std::{ - cell::RefCell, - rc::Rc, - sync::{Arc, Mutex}, -}; +use std::{cell::RefCell, rc::Rc}; use openvm_algebra_transpiler::Fp2Opcode; -use openvm_circuit::{arch::VmChipWrapper, system::memory::OfflineMemory}; -use openvm_circuit_derive::InstructionExecutor; -use openvm_circuit_primitives::var_range::{ - SharedVariableRangeCheckerChip, VariableRangeCheckerBus, +use openvm_circuit::{ + arch::ExecutionBridge, + system::memory::{offline_checker::MemoryBridge, SharedMemoryHelper}, +}; +use openvm_circuit_derive::{InsExecutorE1, InsExecutorE2, InstructionExecutor}; +use openvm_circuit_primitives::{ + bitwise_op_lookup::SharedBitwiseOperationLookupChip, + var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus}, + Chip, ChipUsageGetter, }; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; +use openvm_instructions::riscv::RV32_CELL_BITS; use openvm_mod_circuit_builder::{ - ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreChip, + ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreAir, }; -use openvm_rv32_adapters::Rv32VecHeapAdapterChip; +use openvm_rv32_adapters::{Rv32VecHeapAdapterAir, Rv32VecHeapAdapterStep}; use openvm_stark_backend::p3_field::PrimeField32; +use super::{Fp2Air, Fp2Chip, Fp2Step}; use crate::Fp2; -// Input: Fp2 * 2 -// Output: Fp2 -#[derive(Chip, ChipUsageGetter, InstructionExecutor)] -pub struct Fp2AddSubChip( - pub VmChipWrapper< - F, - Rv32VecHeapAdapterChip, - FieldExpressionCoreChip, - >, -); - -impl - Fp2AddSubChip -{ - pub fn new( - adapter: Rv32VecHeapAdapterChip, - config: ExprBuilderConfig, - offset: usize, - range_checker: SharedVariableRangeCheckerChip, - offline_memory: Arc>>, - ) -> Self { - let (expr, is_add_flag, is_sub_flag) = fp2_addsub_expr(config, range_checker.bus()); - let core = FieldExpressionCoreChip::new( - expr, - offset, - vec![ - Fp2Opcode::ADD as usize, - Fp2Opcode::SUB as usize, - Fp2Opcode::SETUP_ADDSUB as usize, - ], - vec![is_add_flag, is_sub_flag], - range_checker, - "Fp2AddSub", - false, - ); - Self(VmChipWrapper::new(adapter, core, offline_memory)) - } -} - pub fn fp2_addsub_expr( config: ExprBuilderConfig, range_bus: VariableRangeCheckerBus, @@ -85,13 +48,79 @@ pub fn fp2_addsub_expr( ) } +// Input: Fp2 * 2 +// Output: Fp2 +#[derive(Chip, ChipUsageGetter, InstructionExecutor, InsExecutorE1, InsExecutorE2)] +pub struct Fp2AddSubChip( + pub Fp2Chip, +); + +impl + Fp2AddSubChip +{ + #[allow(clippy::too_many_arguments)] + pub fn new( + execution_bridge: ExecutionBridge, + memory_bridge: MemoryBridge, + mem_helper: SharedMemoryHelper, + pointer_max_bits: usize, + config: ExprBuilderConfig, + offset: usize, + bitwise_lookup_chip: SharedBitwiseOperationLookupChip, + range_checker: SharedVariableRangeCheckerChip, + ) -> Self { + let (expr, is_add_flag, is_sub_flag) = fp2_addsub_expr(config, range_checker.bus()); + + let local_opcode_idx = vec![ + Fp2Opcode::ADD as usize, + Fp2Opcode::SUB as usize, + Fp2Opcode::SETUP_ADDSUB as usize, + ]; + let opcode_flag_idx = vec![is_add_flag, is_sub_flag]; + let air = Fp2Air::new( + Rv32VecHeapAdapterAir::new( + execution_bridge, + memory_bridge, + bitwise_lookup_chip.bus(), + pointer_max_bits, + ), + FieldExpressionCoreAir::new( + expr.clone(), + offset, + local_opcode_idx.clone(), + opcode_flag_idx.clone(), + ), + ); + + let step = Fp2Step::new( + Rv32VecHeapAdapterStep::new(pointer_max_bits, bitwise_lookup_chip), + expr, + offset, + local_opcode_idx, + opcode_flag_idx, + range_checker, + "Fp2AddSub", + false, + ); + Self(Fp2Chip::new(air, step, mem_helper)) + } + + pub fn expr(&self) -> &FieldExpr { + &self.0.step.0.expr + } +} + #[cfg(test)] mod tests { use halo2curves_axiom::{bn256::Fq2, ff::Field}; use itertools::Itertools; + use num_bigint::BigUint; use openvm_algebra_transpiler::Fp2Opcode; - use openvm_circuit::arch::testing::{VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}; + use openvm_circuit::arch::{ + testing::{VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}, + InsExecutorE1, + }; use openvm_circuit_primitives::bitwise_op_lookup::{ BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, }; @@ -101,52 +130,30 @@ mod tests { ExprBuilderConfig, }; use openvm_pairing_guest::bn254::BN254_MODULUS; - use openvm_rv32_adapters::{rv32_write_heap_default, Rv32VecHeapAdapterChip}; + use openvm_rv32_adapters::rv32_write_heap_default; use openvm_stark_backend::p3_field::FieldAlgebra; - use openvm_stark_sdk::p3_baby_bear::BabyBear; - use rand::{rngs::StdRng, SeedableRng}; + use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; use super::Fp2AddSubChip; const NUM_LIMBS: usize = 32; const LIMB_BITS: usize = 8; + const MAX_INS_CAPACITY: usize = 128; + const OFFSET: usize = Fp2Opcode::CLASS_OFFSET; type F = BabyBear; - #[test] - fn test_fp2_addsub() { - let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let modulus = BN254_MODULUS.clone(); - let config = ExprBuilderConfig { - modulus: modulus.clone(), - num_limbs: NUM_LIMBS, - limb_bits: LIMB_BITS, - }; - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let adapter = Rv32VecHeapAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - bitwise_chip.clone(), - ); - let mut chip = Fp2AddSubChip::new( - adapter, - config, - Fp2Opcode::CLASS_OFFSET, - tester.range_checker(), - tester.offline_memory_mutex_arc(), - ); - - let mut rng = StdRng::seed_from_u64(42); + fn set_and_execute_rand( + tester: &mut VmChipTestBuilder, + chip: &mut Fp2AddSubChip, + modulus: &BigUint, + ) { + let mut rng = create_seeded_rng(); let x = Fq2::random(&mut rng); let y = Fq2::random(&mut rng); let inputs = [x.c0, x.c1, y.c0, y.c1].map(bn254_fq_to_biguint); let expected_sum = bn254_fq2_to_biguint_vec(x + y); let r_sum = chip - .0 - .core .expr() .execute_with_output(inputs.to_vec(), vec![true, false]); assert_eq!(r_sum.len(), 2); @@ -155,8 +162,6 @@ mod tests { let expected_sub = bn254_fq2_to_biguint_vec(x - y); let r_sub = chip - .0 - .core .expr() .execute_with_output(inputs.to_vec(), vec![false, true]); assert_eq!(r_sub.len(), 2); @@ -177,30 +182,57 @@ mod tests { .map(BabyBear::from_canonical_u32) }) .collect_vec(); - let modulus = - biguint_to_limbs::(modulus, LIMB_BITS).map(BabyBear::from_canonical_u32); + let modulus = biguint_to_limbs::(modulus.clone(), LIMB_BITS) + .map(BabyBear::from_canonical_u32); let zero = [BabyBear::ZERO; NUM_LIMBS]; let setup_instruction = rv32_write_heap_default( - &mut tester, + tester, vec![modulus, zero], vec![zero; 2], - chip.0.core.air.offset + Fp2Opcode::SETUP_ADDSUB as usize, + OFFSET + Fp2Opcode::SETUP_ADDSUB as usize, ); let instruction1 = rv32_write_heap_default( - &mut tester, + tester, x_limbs.clone(), y_limbs.clone(), - chip.0.core.air.offset + Fp2Opcode::ADD as usize, + OFFSET + Fp2Opcode::ADD as usize, ); - let instruction2 = rv32_write_heap_default( - &mut tester, - x_limbs, - y_limbs, - chip.0.core.air.offset + Fp2Opcode::SUB as usize, + let instruction2 = + rv32_write_heap_default(tester, x_limbs, y_limbs, OFFSET + Fp2Opcode::SUB as usize); + + tester.execute(chip, &setup_instruction); + tester.execute(chip, &instruction1); + tester.execute(chip, &instruction2); + } + + #[test] + fn test_fp2_addsub() { + let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); + let modulus = BN254_MODULUS.clone(); + let config = ExprBuilderConfig { + modulus: modulus.clone(), + num_limbs: NUM_LIMBS, + limb_bits: LIMB_BITS, + }; + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); + + let mut chip = Fp2AddSubChip::new( + tester.execution_bridge(), + tester.memory_bridge(), + tester.memory_helper(), + tester.address_bits(), + config, + OFFSET, + bitwise_chip.clone(), + tester.range_checker(), ); - tester.execute(&mut chip, &setup_instruction); - tester.execute(&mut chip, &instruction1); - tester.execute(&mut chip, &instruction2); + chip.set_trace_height(MAX_INS_CAPACITY); + + let num_ops = 10; + for _ in 0..num_ops { + set_and_execute_rand(&mut tester, &mut chip, &modulus); + } let tester = tester.build().load(chip).load(bitwise_chip).finalize(); tester.simple_test().expect("Verification failed"); } diff --git a/extensions/algebra/circuit/src/fp2_chip/mod.rs b/extensions/algebra/circuit/src/fp2_chip/mod.rs index cd316fd70c..fff47acdb8 100644 --- a/extensions/algebra/circuit/src/fp2_chip/mod.rs +++ b/extensions/algebra/circuit/src/fp2_chip/mod.rs @@ -1,5 +1,26 @@ +use openvm_circuit::arch::{MatrixRecordArena, NewVmChipWrapper, VmAirWrapper}; +use openvm_mod_circuit_builder::FieldExpressionCoreAir; +use openvm_rv32_adapters::Rv32VecHeapAdapterAir; + +use crate::FieldExprVecHeapStep; + mod addsub; pub use addsub::*; mod muldiv; pub use muldiv::*; + +pub(crate) type Fp2Air = VmAirWrapper< + Rv32VecHeapAdapterAir<2, BLOCKS, BLOCKS, BLOCK_SIZE, BLOCK_SIZE>, + FieldExpressionCoreAir, +>; + +pub(crate) type Fp2Step = + FieldExprVecHeapStep<2, BLOCKS, BLOCK_SIZE>; + +pub(crate) type Fp2Chip = NewVmChipWrapper< + F, + Fp2Air, + Fp2Step, + MatrixRecordArena, +>; diff --git a/extensions/algebra/circuit/src/fp2_chip/muldiv.rs b/extensions/algebra/circuit/src/fp2_chip/muldiv.rs index 83ef9565f3..c2c4c66291 100644 --- a/extensions/algebra/circuit/src/fp2_chip/muldiv.rs +++ b/extensions/algebra/circuit/src/fp2_chip/muldiv.rs @@ -1,63 +1,26 @@ -use std::{ - cell::RefCell, - rc::Rc, - sync::{Arc, Mutex}, -}; +use std::{cell::RefCell, rc::Rc}; use openvm_algebra_transpiler::Fp2Opcode; -use openvm_circuit::{arch::VmChipWrapper, system::memory::OfflineMemory}; -use openvm_circuit_derive::InstructionExecutor; -use openvm_circuit_primitives::var_range::{ - SharedVariableRangeCheckerChip, VariableRangeCheckerBus, +use openvm_circuit::{ + arch::ExecutionBridge, + system::memory::{offline_checker::MemoryBridge, SharedMemoryHelper}, +}; +use openvm_circuit_derive::{InsExecutorE1, InsExecutorE2, InstructionExecutor}; +use openvm_circuit_primitives::{ + bitwise_op_lookup::SharedBitwiseOperationLookupChip, + var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus}, + Chip, ChipUsageGetter, }; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; +use openvm_instructions::riscv::RV32_CELL_BITS; use openvm_mod_circuit_builder::{ - ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreChip, SymbolicExpr, + ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreAir, SymbolicExpr, }; -use openvm_rv32_adapters::Rv32VecHeapAdapterChip; +use openvm_rv32_adapters::{Rv32VecHeapAdapterAir, Rv32VecHeapAdapterStep}; use openvm_stark_backend::p3_field::PrimeField32; +use super::{Fp2Air, Fp2Chip, Fp2Step}; use crate::Fp2; -// Input: Fp2 * 2 -// Output: Fp2 -#[derive(Chip, ChipUsageGetter, InstructionExecutor)] -pub struct Fp2MulDivChip( - pub VmChipWrapper< - F, - Rv32VecHeapAdapterChip, - FieldExpressionCoreChip, - >, -); - -impl - Fp2MulDivChip -{ - pub fn new( - adapter: Rv32VecHeapAdapterChip, - config: ExprBuilderConfig, - offset: usize, - range_checker: SharedVariableRangeCheckerChip, - offline_memory: Arc>>, - ) -> Self { - let (expr, is_mul_flag, is_div_flag) = fp2_muldiv_expr(config, range_checker.bus()); - let core = FieldExpressionCoreChip::new( - expr, - offset, - vec![ - Fp2Opcode::MUL as usize, - Fp2Opcode::DIV as usize, - Fp2Opcode::SETUP_MULDIV as usize, - ], - vec![is_mul_flag, is_div_flag], - range_checker, - "Fp2MulDiv", - false, - ); - Self(VmChipWrapper::new(adapter, core, offline_memory)) - } -} - pub fn fp2_muldiv_expr( config: ExprBuilderConfig, range_bus: VariableRangeCheckerBus, @@ -124,13 +87,79 @@ pub fn fp2_muldiv_expr( ) } +// Input: Fp2 * 2 +// Output: Fp2 +#[derive(Chip, ChipUsageGetter, InstructionExecutor, InsExecutorE1, InsExecutorE2)] +pub struct Fp2MulDivChip( + pub Fp2Chip, +); + +impl + Fp2MulDivChip +{ + #[allow(clippy::too_many_arguments)] + pub fn new( + execution_bridge: ExecutionBridge, + memory_bridge: MemoryBridge, + mem_helper: SharedMemoryHelper, + pointer_max_bits: usize, + config: ExprBuilderConfig, + offset: usize, + bitwise_lookup_chip: SharedBitwiseOperationLookupChip, + range_checker: SharedVariableRangeCheckerChip, + ) -> Self { + let (expr, is_mul_flag, is_div_flag) = fp2_muldiv_expr(config, range_checker.bus()); + + let local_opcode_idx = vec![ + Fp2Opcode::MUL as usize, + Fp2Opcode::DIV as usize, + Fp2Opcode::SETUP_MULDIV as usize, + ]; + let opcode_flag_idx = vec![is_mul_flag, is_div_flag]; + let air = Fp2Air::new( + Rv32VecHeapAdapterAir::new( + execution_bridge, + memory_bridge, + bitwise_lookup_chip.bus(), + pointer_max_bits, + ), + FieldExpressionCoreAir::new( + expr.clone(), + offset, + local_opcode_idx.clone(), + opcode_flag_idx.clone(), + ), + ); + + let step = Fp2Step::new( + Rv32VecHeapAdapterStep::new(pointer_max_bits, bitwise_lookup_chip), + expr, + offset, + local_opcode_idx, + opcode_flag_idx, + range_checker, + "Fp2MulDiv", + false, + ); + Self(Fp2Chip::new(air, step, mem_helper)) + } + + pub fn expr(&self) -> &FieldExpr { + &self.0.step.0.expr + } +} + #[cfg(test)] mod tests { use halo2curves_axiom::{bn256::Fq2, ff::Field}; use itertools::Itertools; + use num_bigint::BigUint; use openvm_algebra_transpiler::Fp2Opcode; - use openvm_circuit::arch::testing::{VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}; + use openvm_circuit::arch::{ + testing::{VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}, + InsExecutorE1, + }; use openvm_circuit_primitives::bitwise_op_lookup::{ BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, }; @@ -140,57 +169,30 @@ mod tests { ExprBuilderConfig, }; use openvm_pairing_guest::bn254::BN254_MODULUS; - use openvm_rv32_adapters::{rv32_write_heap_default, Rv32VecHeapAdapterChip}; + use openvm_rv32_adapters::rv32_write_heap_default; use openvm_stark_backend::p3_field::FieldAlgebra; - use openvm_stark_sdk::p3_baby_bear::BabyBear; - use rand::{rngs::StdRng, SeedableRng}; + use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; - use super::Fp2MulDivChip; + use crate::fp2_chip::Fp2MulDivChip; const NUM_LIMBS: usize = 32; const LIMB_BITS: usize = 8; + const OFFSET: usize = Fp2Opcode::CLASS_OFFSET; + const MAX_INS_CAPACITY: usize = 128; type F = BabyBear; - #[test] - fn test_fp2_muldiv() { - let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let modulus = BN254_MODULUS.clone(); - let config = ExprBuilderConfig { - modulus: modulus.clone(), - num_limbs: NUM_LIMBS, - limb_bits: LIMB_BITS, - }; - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let adapter = Rv32VecHeapAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - bitwise_chip.clone(), - ); - let mut chip = Fp2MulDivChip::new( - adapter, - config, - Fp2Opcode::CLASS_OFFSET, - tester.range_checker(), - tester.offline_memory_mutex_arc(), - ); - assert_eq!( - chip.0.core.expr().builder.num_variables, - 2, - "Fp2MulDiv should only introduce new z Fp2 variable (2 Fp var)" - ); - - let mut rng = StdRng::seed_from_u64(42); + fn set_and_execute_rand( + tester: &mut VmChipTestBuilder, + chip: &mut Fp2MulDivChip, + modulus: &BigUint, + ) { + let mut rng = create_seeded_rng(); let x = Fq2::random(&mut rng); let y = Fq2::random(&mut rng); let inputs = [x.c0, x.c1, y.c0, y.c1].map(bn254_fq_to_biguint); let expected_mul = bn254_fq2_to_biguint_vec(x * y); let r_mul = chip - .0 - .core .expr() .execute_with_output(inputs.to_vec(), vec![true, false]); assert_eq!(r_mul.len(), 2); @@ -199,8 +201,6 @@ mod tests { let expected_div = bn254_fq2_to_biguint_vec(x * y.invert().unwrap()); let r_div = chip - .0 - .core .expr() .execute_with_output(inputs.to_vec(), vec![false, true]); assert_eq!(r_div.len(), 2); @@ -221,30 +221,63 @@ mod tests { .map(BabyBear::from_canonical_u32) }) .collect_vec(); - let modulus = - biguint_to_limbs::(modulus, LIMB_BITS).map(BabyBear::from_canonical_u32); + let modulus = biguint_to_limbs::(modulus.clone(), LIMB_BITS) + .map(BabyBear::from_canonical_u32); let zero = [BabyBear::ZERO; NUM_LIMBS]; let setup_instruction = rv32_write_heap_default( - &mut tester, + tester, vec![modulus, zero], vec![zero; 2], - chip.0.core.air.offset + Fp2Opcode::SETUP_MULDIV as usize, + OFFSET + Fp2Opcode::SETUP_MULDIV as usize, ); let instruction1 = rv32_write_heap_default( - &mut tester, + tester, x_limbs.clone(), y_limbs.clone(), - chip.0.core.air.offset + Fp2Opcode::MUL as usize, + OFFSET + Fp2Opcode::MUL as usize, + ); + let instruction2 = + rv32_write_heap_default(tester, x_limbs, y_limbs, OFFSET + Fp2Opcode::DIV as usize); + tester.execute(chip, &setup_instruction); + tester.execute(chip, &instruction1); + tester.execute(chip, &instruction2); + } + + #[test] + fn test_fp2_muldiv() { + let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); + let modulus = BN254_MODULUS.clone(); + let config = ExprBuilderConfig { + modulus: modulus.clone(), + num_limbs: NUM_LIMBS, + limb_bits: LIMB_BITS, + }; + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); + + let mut chip = Fp2MulDivChip::new( + tester.execution_bridge(), + tester.memory_bridge(), + tester.memory_helper(), + tester.address_bits(), + config, + OFFSET, + bitwise_chip.clone(), + tester.range_checker(), ); - let instruction2 = rv32_write_heap_default( - &mut tester, - x_limbs, - y_limbs, - chip.0.core.air.offset + Fp2Opcode::DIV as usize, + chip.set_trace_height(MAX_INS_CAPACITY); + + assert_eq!( + chip.expr().builder.num_variables, + 2, + "Fp2MulDiv should only introduce new z Fp2 variable (2 Fp var)" ); - tester.execute(&mut chip, &setup_instruction); - tester.execute(&mut chip, &instruction1); - tester.execute(&mut chip, &instruction2); + + let num_ops = 10; + for _ in 0..num_ops { + set_and_execute_rand(&mut tester, &mut chip, &modulus); + } + let tester = tester.build().load(chip).load(bitwise_chip).finalize(); tester.simple_test().expect("Verification failed"); } diff --git a/extensions/algebra/circuit/src/fp2_extension.rs b/extensions/algebra/circuit/src/fp2_extension.rs index 37968081bd..6c2f39ca8b 100644 --- a/extensions/algebra/circuit/src/fp2_extension.rs +++ b/extensions/algebra/circuit/src/fp2_extension.rs @@ -2,17 +2,18 @@ use derive_more::derive::From; use num_bigint::BigUint; use openvm_algebra_transpiler::Fp2Opcode; use openvm_circuit::{ - arch::{SystemPort, VmExtension, VmInventory, VmInventoryBuilder, VmInventoryError}, + arch::{ + ExecutionBridge, SystemPort, VmExtension, VmInventory, VmInventoryBuilder, VmInventoryError, + }, system::phantom::PhantomChip, }; -use openvm_circuit_derive::{AnyEnum, InstructionExecutor}; +use openvm_circuit_derive::{AnyEnum, InsExecutorE1, InsExecutorE2, InstructionExecutor}; use openvm_circuit_primitives::bitwise_op_lookup::{ BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, }; use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; use openvm_instructions::{LocalOpcode, VmOpcode}; use openvm_mod_circuit_builder::ExprBuilderConfig; -use openvm_rv32_adapters::Rv32VecHeapAdapterChip; use openvm_stark_backend::p3_field::PrimeField32; use serde::{Deserialize, Serialize}; use serde_with::{serde_as, DisplayFromStr}; @@ -23,6 +24,8 @@ use crate::{ ModularExtension, }; +// TODO: this should be decided after e2 execution + #[serde_as] #[derive(Clone, Debug, derive_new::new, Serialize, Deserialize)] pub struct Fp2Extension { @@ -59,7 +62,9 @@ impl Fp2Extension { } } -#[derive(ChipUsageGetter, Chip, InstructionExecutor, AnyEnum, From)] +#[derive( + ChipUsageGetter, Chip, InstructionExecutor, InsExecutorE1, InsExecutorE2, AnyEnum, From, +)] pub enum Fp2ExtensionExecutor { // 32 limbs prime Fp2AddSubRv32_32(Fp2AddSubChip), @@ -90,6 +95,11 @@ impl VmExtension for Fp2Extension { program_bus, memory_bridge, } = builder.system_port(); + + let execution_bridge = ExecutionBridge::new(execution_bus, program_bus); + let range_checker = builder.system_base().range_checker_chip.clone(); + let pointer_max_bits = builder.system_config().memory_config.pointer_max_bits; + let bitwise_lu_chip = if let Some(&chip) = builder .find_chip::>() .first() @@ -101,9 +111,6 @@ impl VmExtension for Fp2Extension { inventory.add_periphery_chip(chip.clone()); chip }; - let offline_memory = builder.system_base().offline_memory(); - let range_checker = builder.system_base().range_checker_chip.clone(); - let address_bits = builder.system_config().memory_config.pointer_max_bits; let addsub_opcodes = (Fp2Opcode::ADD as usize)..=(Fp2Opcode::SETUP_ADDSUB as usize); let muldiv_opcodes = (Fp2Opcode::MUL as usize)..=(Fp2Opcode::SETUP_MULDIV as usize); @@ -123,28 +130,17 @@ impl VmExtension for Fp2Extension { num_limbs: 48, limb_bits: 8, }; - let adapter_chip_32 = Rv32VecHeapAdapterChip::new( - execution_bus, - program_bus, - memory_bridge, - address_bits, - bitwise_lu_chip.clone(), - ); - let adapter_chip_48 = Rv32VecHeapAdapterChip::new( - execution_bus, - program_bus, - memory_bridge, - address_bits, - bitwise_lu_chip.clone(), - ); if bytes <= 32 { let addsub_chip = Fp2AddSubChip::new( - adapter_chip_32.clone(), + execution_bridge, + memory_bridge, + builder.system_base().memory_controller.helper(), + pointer_max_bits, config32.clone(), start_offset, + bitwise_lu_chip.clone(), range_checker.clone(), - offline_memory.clone(), ); inventory.add_executor( Fp2ExtensionExecutor::Fp2AddSubRv32_32(addsub_chip), @@ -153,11 +149,14 @@ impl VmExtension for Fp2Extension { .map(|x| VmOpcode::from_usize(x + start_offset)), )?; let muldiv_chip = Fp2MulDivChip::new( - adapter_chip_32.clone(), + execution_bridge, + memory_bridge, + builder.system_base().memory_controller.helper(), + pointer_max_bits, config32.clone(), start_offset, + bitwise_lu_chip.clone(), range_checker.clone(), - offline_memory.clone(), ); inventory.add_executor( Fp2ExtensionExecutor::Fp2MulDivRv32_32(muldiv_chip), @@ -167,11 +166,14 @@ impl VmExtension for Fp2Extension { )?; } else if bytes <= 48 { let addsub_chip = Fp2AddSubChip::new( - adapter_chip_48.clone(), + execution_bridge, + memory_bridge, + builder.system_base().memory_controller.helper(), + pointer_max_bits, config48.clone(), start_offset, + bitwise_lu_chip.clone(), range_checker.clone(), - offline_memory.clone(), ); inventory.add_executor( Fp2ExtensionExecutor::Fp2AddSubRv32_48(addsub_chip), @@ -180,11 +182,14 @@ impl VmExtension for Fp2Extension { .map(|x| VmOpcode::from_usize(x + start_offset)), )?; let muldiv_chip = Fp2MulDivChip::new( - adapter_chip_48.clone(), + execution_bridge, + memory_bridge, + builder.system_base().memory_controller.helper(), + pointer_max_bits, config48.clone(), start_offset, + bitwise_lu_chip.clone(), range_checker.clone(), - offline_memory.clone(), ); inventory.add_executor( Fp2ExtensionExecutor::Fp2MulDivRv32_48(muldiv_chip), diff --git a/extensions/algebra/circuit/src/lib.rs b/extensions/algebra/circuit/src/lib.rs index ffddacc61a..9f7ba737f9 100644 --- a/extensions/algebra/circuit/src/lib.rs +++ b/extensions/algebra/circuit/src/lib.rs @@ -1,3 +1,28 @@ +use std::{ + array::from_fn, + borrow::{Borrow, BorrowMut}, +}; + +use openvm_circuit::{ + arch::{ + execution::ExecuteFunc, + execution_mode::{E1ExecutionCtx, E2ExecutionCtx}, + instructions::riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS}, + DynArray, E2PreCompute, + ExecutionError::InvalidInstruction, + Result, StepExecutorE1, StepExecutorE2, VmSegmentState, + }, + system::memory::POINTER_MAX_BITS, +}; +use openvm_circuit_derive::{TraceFiller, TraceStep}; +use openvm_circuit_primitives::{var_range::SharedVariableRangeCheckerChip, AlignedBytesBorrow}; +use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP}; +use openvm_mod_circuit_builder::{ + run_field_expression_precomputed, FieldExpr, FieldExpressionStep, +}; +use openvm_rv32_adapters::Rv32VecHeapAdapterStep; +use openvm_stark_backend::p3_field::PrimeField32; + pub mod fp2_chip; pub mod modular_chip; @@ -9,3 +34,254 @@ mod fp2_extension; pub use fp2_extension::*; mod config; pub use config::*; + +#[derive(TraceStep, TraceFiller)] +pub struct FieldExprVecHeapStep< + const NUM_READS: usize, + const BLOCKS: usize, + const BLOCK_SIZE: usize, +>( + pub FieldExpressionStep< + Rv32VecHeapAdapterStep, + >, +); + +impl + FieldExprVecHeapStep +{ + #[allow(clippy::too_many_arguments)] + pub fn new( + adapter: Rv32VecHeapAdapterStep, + expr: FieldExpr, + offset: usize, + local_opcode_idx: Vec, + opcode_flag_idx: Vec, + range_checker: SharedVariableRangeCheckerChip, + name: &str, + should_finalize: bool, + ) -> Self { + Self(FieldExpressionStep::new( + adapter, + expr, + offset, + local_opcode_idx, + opcode_flag_idx, + range_checker, + name, + should_finalize, + )) + } +} + +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct FieldExpressionPreCompute<'a, const NUM_READS: usize> { + a: u8, + // NUM_READS <= 2 as in Rv32VecHeapAdapter + rs_addrs: [u8; NUM_READS], + expr: &'a FieldExpr, + flag_idx: u8, +} + +impl<'a, const NUM_READS: usize, const BLOCKS: usize, const BLOCK_SIZE: usize> + FieldExprVecHeapStep +{ + fn pre_compute_impl( + &'a self, + pc: u32, + inst: &Instruction, + data: &mut FieldExpressionPreCompute<'a, NUM_READS>, + ) -> Result { + let Instruction { + opcode, + a, + b, + c, + d, + e, + .. + } = inst; + + // Validate instruction format + let a = a.as_canonical_u32(); + let b = b.as_canonical_u32(); + let c = c.as_canonical_u32(); + let d = d.as_canonical_u32(); + let e = e.as_canonical_u32(); + if d != RV32_REGISTER_AS || e != RV32_MEMORY_AS { + return Err(InvalidInstruction(pc)); + } + + let local_opcode = opcode.local_opcode_idx(self.0.offset); + + // Pre-compute flag_idx + let needs_setup = self.0.expr.needs_setup(); + let mut flag_idx = self.0.expr.num_flags() as u8; + if needs_setup { + // Find which opcode this is in our local_opcode_idx list + if let Some(opcode_position) = self + .0 + .local_opcode_idx + .iter() + .position(|&idx| idx == local_opcode) + { + // If this is NOT the last opcode (setup), get the corresponding flag_idx + if opcode_position < self.0.opcode_flag_idx.len() { + flag_idx = self.0.opcode_flag_idx[opcode_position] as u8; + } + } + } + + let rs_addrs = from_fn(|i| if i == 0 { b } else { c } as u8); + *data = FieldExpressionPreCompute { + a: a as u8, + rs_addrs, + expr: &self.0.expr, + flag_idx, + }; + + Ok(needs_setup) + } +} + +impl + StepExecutorE1 for FieldExprVecHeapStep +{ + #[inline(always)] + fn pre_compute_size(&self) -> usize { + std::mem::size_of::>() + } + + fn pre_compute_e1( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result> + where + Ctx: E1ExecutionCtx, + { + let pre_compute: &mut FieldExpressionPreCompute = data.borrow_mut(); + + let needs_setup = self.pre_compute_impl(pc, inst, pre_compute)?; + let fn_ptr = if needs_setup { + execute_e1_impl::<_, _, NUM_READS, BLOCKS, BLOCK_SIZE, true> + } else { + execute_e1_impl::<_, _, NUM_READS, BLOCKS, BLOCK_SIZE, false> + }; + + Ok(fn_ptr) + } +} + +impl + StepExecutorE2 for FieldExprVecHeapStep +{ + #[inline(always)] + fn e2_pre_compute_size(&self) -> usize { + std::mem::size_of::>>() + } + + fn pre_compute_e2( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result> + where + Ctx: E2ExecutionCtx, + { + let pre_compute: &mut E2PreCompute> = + data.borrow_mut(); + pre_compute.chip_idx = chip_idx as u32; + + let needs_setup = self.pre_compute_impl(pc, inst, &mut pre_compute.data)?; + let fn_ptr = if needs_setup { + execute_e2_impl::<_, _, NUM_READS, BLOCKS, BLOCK_SIZE, true> + } else { + execute_e2_impl::<_, _, NUM_READS, BLOCKS, BLOCK_SIZE, false> + }; + + Ok(fn_ptr) + } +} + +unsafe fn execute_e1_impl< + F: PrimeField32, + CTX: E1ExecutionCtx, + const NUM_READS: usize, + const BLOCKS: usize, + const BLOCK_SIZE: usize, + const NEEDS_SETUP: bool, +>( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &FieldExpressionPreCompute = pre_compute.borrow(); + + execute_e12_impl::<_, _, NUM_READS, BLOCKS, BLOCK_SIZE, NEEDS_SETUP>(pre_compute, vm_state); +} + +unsafe fn execute_e2_impl< + F: PrimeField32, + CTX: E2ExecutionCtx, + const NUM_READS: usize, + const BLOCKS: usize, + const BLOCK_SIZE: usize, + const NEEDS_SETUP: bool, +>( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &E2PreCompute> = pre_compute.borrow(); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, 1); + execute_e12_impl::<_, _, NUM_READS, BLOCKS, BLOCK_SIZE, NEEDS_SETUP>( + &pre_compute.data, + vm_state, + ); +} + +unsafe fn execute_e12_impl< + F: PrimeField32, + CTX: E1ExecutionCtx, + const NUM_READS: usize, + const BLOCKS: usize, + const BLOCK_SIZE: usize, + const NEEDS_SETUP: bool, +>( + pre_compute: &FieldExpressionPreCompute, + vm_state: &mut VmSegmentState, +) { + // Read register values + let rs_vals = pre_compute + .rs_addrs + .map(|addr| u32::from_le_bytes(vm_state.vm_read(RV32_REGISTER_AS, addr as u32))); + + // Read memory values + let read_data: [[[u8; BLOCK_SIZE]; BLOCKS]; NUM_READS] = rs_vals.map(|address| { + debug_assert!(address as usize + BLOCK_SIZE * BLOCKS - 1 < (1 << POINTER_MAX_BITS)); + from_fn(|i| vm_state.vm_read(RV32_MEMORY_AS, address + (i * BLOCK_SIZE) as u32)) + }); + let read_data: DynArray = read_data.into(); + + let writes = run_field_expression_precomputed::( + pre_compute.expr, + pre_compute.flag_idx as usize, + &read_data.0, + ); + + let rd_val = u32::from_le_bytes(vm_state.vm_read(RV32_REGISTER_AS, pre_compute.a as u32)); + debug_assert!(rd_val as usize + BLOCK_SIZE * BLOCKS - 1 < (1 << POINTER_MAX_BITS)); + + // Write output data to memory + let data: [[u8; BLOCK_SIZE]; BLOCKS] = writes.into(); + for (i, block) in data.into_iter().enumerate() { + vm_state.vm_write(RV32_MEMORY_AS, rd_val + (i * BLOCK_SIZE) as u32, &block); + } + + vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP); + vm_state.instret += 1; +} diff --git a/extensions/algebra/circuit/src/modular_chip/addsub.rs b/extensions/algebra/circuit/src/modular_chip/addsub.rs index 34bede150f..67a31d798f 100644 --- a/extensions/algebra/circuit/src/modular_chip/addsub.rs +++ b/extensions/algebra/circuit/src/modular_chip/addsub.rs @@ -1,22 +1,25 @@ -use std::{ - cell::RefCell, - rc::Rc, - sync::{Arc, Mutex}, -}; +use std::{cell::RefCell, rc::Rc}; use openvm_algebra_transpiler::Rv32ModularArithmeticOpcode; -use openvm_circuit::{arch::VmChipWrapper, system::memory::OfflineMemory}; -use openvm_circuit_derive::InstructionExecutor; -use openvm_circuit_primitives::var_range::{ - SharedVariableRangeCheckerChip, VariableRangeCheckerBus, +use openvm_circuit::{ + arch::ExecutionBridge, + system::memory::{offline_checker::MemoryBridge, SharedMemoryHelper}, +}; +use openvm_circuit_derive::{InsExecutorE1, InsExecutorE2, InstructionExecutor}; +use openvm_circuit_primitives::{ + bitwise_op_lookup::SharedBitwiseOperationLookupChip, + var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus}, + Chip, ChipUsageGetter, }; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; +use openvm_instructions::riscv::RV32_CELL_BITS; use openvm_mod_circuit_builder::{ - ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreChip, FieldVariable, + ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreAir, FieldVariable, }; -use openvm_rv32_adapters::Rv32VecHeapAdapterChip; +use openvm_rv32_adapters::{Rv32VecHeapAdapterAir, Rv32VecHeapAdapterStep}; use openvm_stark_backend::p3_field::PrimeField32; +use super::{ModularAir, ModularChip, ModularStep}; + pub fn addsub_expr( config: ExprBuilderConfig, range_bus: VariableRangeCheckerBus, @@ -43,39 +46,58 @@ pub fn addsub_expr( ) } -#[derive(Chip, ChipUsageGetter, InstructionExecutor)] +#[derive(Chip, ChipUsageGetter, InstructionExecutor, InsExecutorE1, InsExecutorE2)] pub struct ModularAddSubChip( - pub VmChipWrapper< - F, - Rv32VecHeapAdapterChip, - FieldExpressionCoreChip, - >, + pub ModularChip, ); impl ModularAddSubChip { + #[allow(clippy::too_many_arguments)] pub fn new( - adapter: Rv32VecHeapAdapterChip, + execution_bridge: ExecutionBridge, + memory_bridge: MemoryBridge, + mem_helper: SharedMemoryHelper, + pointer_max_bits: usize, config: ExprBuilderConfig, offset: usize, + bitwise_lookup_chip: SharedBitwiseOperationLookupChip, range_checker: SharedVariableRangeCheckerChip, - offline_memory: Arc>>, ) -> Self { let (expr, is_add_flag, is_sub_flag) = addsub_expr(config, range_checker.bus()); - let core = FieldExpressionCoreChip::new( + + let local_opcode_idx = vec![ + Rv32ModularArithmeticOpcode::ADD as usize, + Rv32ModularArithmeticOpcode::SUB as usize, + Rv32ModularArithmeticOpcode::SETUP_ADDSUB as usize, + ]; + let opcode_flag_idx = vec![is_add_flag, is_sub_flag]; + let air = ModularAir::new( + Rv32VecHeapAdapterAir::new( + execution_bridge, + memory_bridge, + bitwise_lookup_chip.bus(), + pointer_max_bits, + ), + FieldExpressionCoreAir::new( + expr.clone(), + offset, + local_opcode_idx.clone(), + opcode_flag_idx.clone(), + ), + ); + + let step = ModularStep::new( + Rv32VecHeapAdapterStep::new(pointer_max_bits, bitwise_lookup_chip), expr, offset, - vec![ - Rv32ModularArithmeticOpcode::ADD as usize, - Rv32ModularArithmeticOpcode::SUB as usize, - Rv32ModularArithmeticOpcode::SETUP_ADDSUB as usize, - ], - vec![is_add_flag, is_sub_flag], + local_opcode_idx, + opcode_flag_idx, range_checker, "ModularAddSub", false, ); - Self(VmChipWrapper::new(adapter, core, offline_memory)) + Self(ModularChip::new(air, step, mem_helper)) } } diff --git a/extensions/algebra/circuit/src/modular_chip/is_eq.rs b/extensions/algebra/circuit/src/modular_chip/is_eq.rs index fe91585466..3ee5ce7f88 100644 --- a/extensions/algebra/circuit/src/modular_chip/is_eq.rs +++ b/extensions/algebra/circuit/src/modular_chip/is_eq.rs @@ -5,32 +5,45 @@ use std::{ use num_bigint::BigUint; use openvm_algebra_transpiler::Rv32ModularArithmeticOpcode; -use openvm_circuit::arch::{ - AdapterAirContext, AdapterRuntimeContext, MinimalInstruction, Result, VmAdapterInterface, - VmCoreAir, VmCoreChip, +use openvm_circuit::{ + arch::{ + execution_mode::{E1ExecutionCtx, E2ExecutionCtx}, + get_record_from_slice, AdapterAirContext, AdapterTraceFiller, AdapterTraceStep, + E2PreCompute, EmptyAdapterCoreLayout, ExecuteFunc, + ExecutionError::InvalidInstruction, + MatrixRecordArena, MinimalInstruction, NewVmChipWrapper, RecordArena, Result, + StepExecutorE1, StepExecutorE2, TraceFiller, TraceStep, VmAdapterInterface, VmAirWrapper, + VmCoreAir, VmSegmentState, VmStateMut, + }, + system::memory::{online::TracingMemory, MemoryAuxColsFactory, POINTER_MAX_BITS}, }; +use openvm_circuit_derive::{TraceFiller, TraceStep}; use openvm_circuit_primitives::{ bigint::utils::big_uint_to_limbs, bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, is_equal_array::{IsEqArrayIo, IsEqArraySubAir}, - SubAir, TraceSubRowGenerator, + AlignedBytesBorrow, SubAir, TraceSubRowGenerator, }; use openvm_circuit_primitives_derive::AlignedBorrow; -use openvm_instructions::{instruction::Instruction, LocalOpcode}; +use openvm_instructions::{ + instruction::Instruction, + program::DEFAULT_PC_STEP, + riscv::{RV32_CELL_BITS, RV32_MEMORY_AS, RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS}, + LocalOpcode, +}; +use openvm_rv32_adapters::{Rv32IsEqualModAdapterAir, Rv32IsEqualModeAdapterStep}; use openvm_stark_backend::{ interaction::InteractionBuilder, p3_air::{AirBuilder, BaseAir}, p3_field::{Field, FieldAlgebra, PrimeField32}, rap::BaseAirWithPublicValues, }; -use serde::{Deserialize, Serialize}; -use serde_big_array::BigArray; // Given two numbers b and c, we want to prove that a) b == c or b != c, depending on // result of cmp_result and b) b, c < N for some modulus N that is passed into the AIR // at runtime (i.e. when chip is instantiated). #[repr(C)] -#[derive(AlignedBorrow)] +#[derive(AlignedBorrow, Debug)] pub struct ModularIsEqualCoreCols { pub is_valid: T, pub is_setup: T, @@ -278,155 +291,403 @@ where } #[repr(C)] -#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] -pub struct ModularIsEqualCoreRecord { - #[serde(with = "BigArray")] - pub b: [T; READ_LIMBS], - #[serde(with = "BigArray")] - pub c: [T; READ_LIMBS], - pub cmp_result: T, - #[serde(with = "BigArray")] - pub eq_marker: [T; READ_LIMBS], - pub b_diff_idx: usize, - pub c_diff_idx: usize, +#[derive(AlignedBytesBorrow, Debug)] +pub struct ModularIsEqualRecord { pub is_setup: bool, + pub b: [u8; READ_LIMBS], + pub c: [u8; READ_LIMBS], } -pub struct ModularIsEqualCoreChip< +#[derive(derive_new::new)] +pub struct ModularIsEqualStep< + A, const READ_LIMBS: usize, const WRITE_LIMBS: usize, const LIMB_BITS: usize, > { - pub air: ModularIsEqualCoreAir, + adapter: A, + pub modulus_limbs: [u8; READ_LIMBS], + pub offset: usize, pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, } -impl - ModularIsEqualCoreChip -{ - pub fn new( - modulus: BigUint, - bitwise_lookup_chip: SharedBitwiseOperationLookupChip, - offset: usize, - ) -> Self { - Self { - air: ModularIsEqualCoreAir::new(modulus, bitwise_lookup_chip.bus(), offset), - bitwise_lookup_chip, - } - } -} - -impl< - F: PrimeField32, - I: VmAdapterInterface, - const READ_LIMBS: usize, - const WRITE_LIMBS: usize, - const LIMB_BITS: usize, - > VmCoreChip for ModularIsEqualCoreChip +impl + TraceStep for ModularIsEqualStep where - I::Reads: Into<[[F; READ_LIMBS]; 2]>, - I::Writes: From<[[F; WRITE_LIMBS]; 1]>, + F: PrimeField32, + A: 'static + + for<'a> AdapterTraceStep< + F, + CTX, + ReadData: Into<[[u8; READ_LIMBS]; 2]>, + WriteData: From<[u8; WRITE_LIMBS]>, + >, { - type Record = ModularIsEqualCoreRecord; - type Air = ModularIsEqualCoreAir; + type RecordLayout = EmptyAdapterCoreLayout; + type RecordMut<'a> = (A::RecordMut<'a>, &'a mut ModularIsEqualRecord); - #[allow(clippy::type_complexity)] - fn execute_instruction( - &self, + fn execute<'buf, RA>( + &mut self, + state: VmStateMut, CTX>, instruction: &Instruction, - _from_pc: u32, - reads: I::Reads, - ) -> Result<(AdapterRuntimeContext, Self::Record)> { - let data: [[F; READ_LIMBS]; 2] = reads.into(); - let b = data[0].map(|x| x.as_canonical_u32()); - let c = data[1].map(|y| y.as_canonical_u32()); - let (b_cmp, b_diff_idx) = run_unsigned_less_than::(&b, &self.air.modulus_limbs); - let (c_cmp, c_diff_idx) = run_unsigned_less_than::(&c, &self.air.modulus_limbs); - let is_setup = instruction.opcode.local_opcode_idx(self.air.offset) + arena: &'buf mut RA, + ) -> Result<()> + where + RA: RecordArena<'buf, Self::RecordLayout, Self::RecordMut<'buf>>, + { + let Instruction { opcode, .. } = instruction; + + let local_opcode = + Rv32ModularArithmeticOpcode::from_usize(opcode.local_opcode_idx(self.offset)); + matches!( + local_opcode, + Rv32ModularArithmeticOpcode::IS_EQ | Rv32ModularArithmeticOpcode::SETUP_ISEQ + ); + + let (mut adapter_record, core_record) = arena.alloc(EmptyAdapterCoreLayout::new()); + + A::start(*state.pc, state.memory, &mut adapter_record); + [core_record.b, core_record.c] = self + .adapter + .read(state.memory, instruction, &mut adapter_record) + .into(); + + core_record.is_setup = instruction.opcode.local_opcode_idx(self.offset) == Rv32ModularArithmeticOpcode::SETUP_ISEQ as usize; - if !is_setup { - assert!(b_cmp, "{:?} >= {:?}", b, self.air.modulus_limbs); - } - assert!(c_cmp, "{:?} >= {:?}", c, self.air.modulus_limbs); - if !is_setup { - self.bitwise_lookup_chip.request_range( - self.air.modulus_limbs[b_diff_idx] - b[b_diff_idx] - 1, - self.air.modulus_limbs[c_diff_idx] - c[c_diff_idx] - 1, - ); - } + let mut write_data = [0u8; WRITE_LIMBS]; + write_data[0] = (core_record.b == core_record.c) as u8; - let mut eq_marker = [F::ZERO; READ_LIMBS]; - let mut cmp_result = F::ZERO; - self.air - .subair - .generate_subrow((&data[0], &data[1]), (&mut eq_marker, &mut cmp_result)); - - let mut writes = [F::ZERO; WRITE_LIMBS]; - writes[0] = cmp_result; - - let output = AdapterRuntimeContext::without_pc([writes]); - let record = ModularIsEqualCoreRecord { - is_setup, - b: data[0], - c: data[1], - cmp_result, - eq_marker, - b_diff_idx, - c_diff_idx, - }; + self.adapter.write( + state.memory, + instruction, + write_data.into(), + &mut adapter_record, + ); - Ok((output, record)) + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + + Ok(()) } fn get_opcode_name(&self, opcode: usize) -> String { format!( "{:?}", - Rv32ModularArithmeticOpcode::from_usize(opcode - self.air.offset) + Rv32ModularArithmeticOpcode::from_usize(opcode - self.offset) ) } +} - fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) { - let row_slice: &mut ModularIsEqualCoreCols<_, READ_LIMBS> = row_slice.borrow_mut(); - row_slice.is_valid = F::ONE; - row_slice.is_setup = F::from_bool(record.is_setup); - row_slice.b = record.b; - row_slice.c = record.c; - row_slice.cmp_result = record.cmp_result; - - row_slice.eq_marker = record.eq_marker; +impl + TraceFiller for ModularIsEqualStep +where + F: PrimeField32, + A: 'static + AdapterTraceFiller, +{ + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + let (adapter_row, mut core_row) = row_slice.split_at_mut(A::WIDTH); + self.adapter.fill_trace_row(mem_helper, adapter_row); + let record: &ModularIsEqualRecord = + unsafe { get_record_from_slice(&mut core_row, ()) }; + let cols: &mut ModularIsEqualCoreCols = core_row.borrow_mut(); + let (b_cmp, b_diff_idx) = + run_unsigned_less_than::(&record.b, &self.modulus_limbs); + let (c_cmp, c_diff_idx) = + run_unsigned_less_than::(&record.c, &self.modulus_limbs); if !record.is_setup { - row_slice.b_lt_diff = F::from_canonical_u32(self.air.modulus_limbs[record.b_diff_idx]) - - record.b[record.b_diff_idx]; + assert!(b_cmp, "{:?} >= {:?}", record.b, self.modulus_limbs); } - row_slice.c_lt_diff = F::from_canonical_u32(self.air.modulus_limbs[record.c_diff_idx]) - - record.c[record.c_diff_idx]; - row_slice.c_lt_mark = if record.b_diff_idx == record.c_diff_idx { + assert!(c_cmp, "{:?} >= {:?}", record.c, self.modulus_limbs); + + // Writing in reverse order + cols.c_lt_mark = if b_diff_idx == c_diff_idx { F::ONE } else { - F::from_canonical_u8(2) + F::TWO }; - row_slice.lt_marker = from_fn(|i| { - if i == record.b_diff_idx { + + cols.c_lt_diff = + F::from_canonical_u8(self.modulus_limbs[c_diff_idx] - record.c[c_diff_idx]); + if !record.is_setup { + cols.b_lt_diff = + F::from_canonical_u8(self.modulus_limbs[b_diff_idx] - record.b[b_diff_idx]); + self.bitwise_lookup_chip.request_range( + (self.modulus_limbs[b_diff_idx] - record.b[b_diff_idx] - 1) as u32, + (self.modulus_limbs[c_diff_idx] - record.c[c_diff_idx] - 1) as u32, + ); + } else { + cols.b_lt_diff = F::ZERO; + } + + cols.lt_marker = from_fn(|i| { + if i == b_diff_idx { F::ONE - } else if i == record.c_diff_idx { - row_slice.c_lt_mark + } else if i == c_diff_idx { + cols.c_lt_mark } else { F::ZERO } }); + + cols.c = record.c.map(F::from_canonical_u8); + cols.b = record.b.map(F::from_canonical_u8); + let sub_air = IsEqArraySubAir::; + sub_air.generate_subrow( + (&cols.b, &cols.c), + (&mut cols.eq_marker, &mut cols.cmp_result), + ); + + cols.is_setup = F::from_bool(record.is_setup); + cols.is_valid = F::ONE; } +} - fn air(&self) -> &Self::Air { - &self.air +#[derive(TraceStep, TraceFiller)] +pub struct VmModularIsEqualStep< + const NUM_LANES: usize, + const LANE_SIZE: usize, + const TOTAL_LIMBS: usize, +>( + ModularIsEqualStep< + Rv32IsEqualModeAdapterStep<2, NUM_LANES, LANE_SIZE, TOTAL_LIMBS>, + TOTAL_LIMBS, + RV32_REGISTER_NUM_LIMBS, + RV32_CELL_BITS, + >, +); + +impl + VmModularIsEqualStep +{ + pub fn new( + adapter: Rv32IsEqualModeAdapterStep<2, NUM_LANES, LANE_SIZE, TOTAL_READ_SIZE>, + modulus_limbs: [u8; TOTAL_READ_SIZE], + offset: usize, + bitwise_lookup_chip: SharedBitwiseOperationLookupChip, + ) -> Self { + Self(ModularIsEqualStep::new( + adapter, + modulus_limbs, + offset, + bitwise_lookup_chip, + )) + } +} + +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct ModularIsEqualPreCompute { + a: u8, + rs_addrs: [u8; 2], + modulus_limbs: [u8; READ_LIMBS], +} + +impl + VmModularIsEqualStep +{ + fn pre_compute_impl( + &self, + pc: u32, + inst: &Instruction, + data: &mut ModularIsEqualPreCompute, + ) -> Result { + let Instruction { + opcode, + a, + b, + c, + d, + e, + .. + } = inst; + + let local_opcode = + Rv32ModularArithmeticOpcode::from_usize(opcode.local_opcode_idx(self.0.offset)); + + // Validate instruction format + let a = a.as_canonical_u32(); + let b = b.as_canonical_u32(); + let c = c.as_canonical_u32(); + let d = d.as_canonical_u32(); + let e = e.as_canonical_u32(); + if d != RV32_REGISTER_AS || e != RV32_MEMORY_AS { + return Err(InvalidInstruction(pc)); + } + + if !matches!( + local_opcode, + Rv32ModularArithmeticOpcode::IS_EQ | Rv32ModularArithmeticOpcode::SETUP_ISEQ + ) { + return Err(InvalidInstruction(pc)); + } + + let rs_addrs = from_fn(|i| if i == 0 { b } else { c } as u8); + *data = ModularIsEqualPreCompute { + a: a as u8, + rs_addrs, + modulus_limbs: self.0.modulus_limbs, + }; + + let is_setup = local_opcode == Rv32ModularArithmeticOpcode::SETUP_ISEQ; + + Ok(is_setup) + } +} + +impl + StepExecutorE1 for VmModularIsEqualStep +where + F: PrimeField32, +{ + #[inline(always)] + fn pre_compute_size(&self) -> usize { + std::mem::size_of::>() + } + + fn pre_compute_e1( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result> { + let pre_compute: &mut ModularIsEqualPreCompute = data.borrow_mut(); + + let is_setup = self.pre_compute_impl(pc, inst, pre_compute)?; + let fn_ptr = if is_setup { + execute_e1_impl::<_, _, NUM_LANES, LANE_SIZE, TOTAL_READ_SIZE, true> + } else { + execute_e1_impl::<_, _, NUM_LANES, LANE_SIZE, TOTAL_READ_SIZE, false> + }; + + Ok(fn_ptr) + } +} + +impl + StepExecutorE2 for VmModularIsEqualStep +where + F: PrimeField32, +{ + #[inline(always)] + fn e2_pre_compute_size(&self) -> usize { + std::mem::size_of::>>() + } + + fn pre_compute_e2( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result> { + let pre_compute: &mut E2PreCompute> = + data.borrow_mut(); + pre_compute.chip_idx = chip_idx as u32; + + let is_setup = self.pre_compute_impl(pc, inst, &mut pre_compute.data)?; + let fn_ptr = if is_setup { + execute_e2_impl::<_, _, NUM_LANES, LANE_SIZE, TOTAL_READ_SIZE, true> + } else { + execute_e2_impl::<_, _, NUM_LANES, LANE_SIZE, TOTAL_READ_SIZE, false> + }; + + Ok(fn_ptr) } } +unsafe fn execute_e1_impl< + F: PrimeField32, + CTX: E1ExecutionCtx, + const NUM_LANES: usize, + const LANE_SIZE: usize, + const TOTAL_READ_SIZE: usize, + const IS_SETUP: bool, +>( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &ModularIsEqualPreCompute = pre_compute.borrow(); + + execute_e12_impl::<_, _, NUM_LANES, LANE_SIZE, TOTAL_READ_SIZE, IS_SETUP>( + pre_compute, + vm_state, + ); +} + +unsafe fn execute_e2_impl< + F: PrimeField32, + CTX: E2ExecutionCtx, + const NUM_LANES: usize, + const LANE_SIZE: usize, + const TOTAL_READ_SIZE: usize, + const IS_SETUP: bool, +>( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &E2PreCompute> = + pre_compute.borrow(); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, 1); + execute_e12_impl::<_, _, NUM_LANES, LANE_SIZE, TOTAL_READ_SIZE, IS_SETUP>( + &pre_compute.data, + vm_state, + ); +} + +unsafe fn execute_e12_impl< + F: PrimeField32, + CTX: E1ExecutionCtx, + const NUM_LANES: usize, + const LANE_SIZE: usize, + const TOTAL_READ_SIZE: usize, + const IS_SETUP: bool, +>( + pre_compute: &ModularIsEqualPreCompute, + vm_state: &mut VmSegmentState, +) { + // Read register values + let rs_vals = pre_compute + .rs_addrs + .map(|addr| u32::from_le_bytes(vm_state.vm_read(RV32_REGISTER_AS, addr as u32))); + + // Read memory values + let [b, c]: [[u8; TOTAL_READ_SIZE]; 2] = rs_vals.map(|address| { + debug_assert!(address as usize + TOTAL_READ_SIZE - 1 < (1 << POINTER_MAX_BITS)); + from_fn::<_, NUM_LANES, _>(|i| { + vm_state.vm_read::<_, LANE_SIZE>(RV32_MEMORY_AS, address + (i * LANE_SIZE) as u32) + }) + .concat() + .try_into() + .unwrap() + }); + + if !IS_SETUP { + let (b_cmp, _) = run_unsigned_less_than::(&b, &pre_compute.modulus_limbs); + debug_assert!(b_cmp, "{:?} >= {:?}", b, pre_compute.modulus_limbs); + } + + let (c_cmp, _) = run_unsigned_less_than::(&c, &pre_compute.modulus_limbs); + debug_assert!(c_cmp, "{:?} >= {:?}", c, pre_compute.modulus_limbs); + + // Compute result + let mut write_data = [0u8; RV32_REGISTER_NUM_LIMBS]; + write_data[0] = (b == c) as u8; + + // Write result to register + vm_state.vm_write(RV32_REGISTER_AS, pre_compute.a as u32, &write_data); + + vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP); + vm_state.instret += 1; +} + // Returns (cmp_result, diff_idx) +#[inline(always)] pub(super) fn run_unsigned_less_than( - x: &[u32; NUM_LIMBS], - y: &[u32; NUM_LIMBS], + x: &[u8; NUM_LIMBS], + y: &[u8; NUM_LIMBS], ) -> (bool, usize) { for i in (0..NUM_LIMBS).rev() { if x[i] != y[i] { @@ -435,3 +696,25 @@ pub(super) fn run_unsigned_less_than( } (false, NUM_LIMBS) } + +// Must have TOTAL_LIMBS = NUM_LANES * LANE_SIZE +pub type ModularIsEqualAir< + const NUM_LANES: usize, + const LANE_SIZE: usize, + const TOTAL_LIMBS: usize, +> = VmAirWrapper< + Rv32IsEqualModAdapterAir<2, NUM_LANES, LANE_SIZE, TOTAL_LIMBS>, + ModularIsEqualCoreAir, +>; + +pub type ModularIsEqualChip< + F, + const NUM_LANES: usize, + const LANE_SIZE: usize, + const TOTAL_LIMBS: usize, +> = NewVmChipWrapper< + F, + ModularIsEqualAir, + VmModularIsEqualStep, + MatrixRecordArena, +>; diff --git a/extensions/algebra/circuit/src/modular_chip/mod.rs b/extensions/algebra/circuit/src/modular_chip/mod.rs index 2dd9838206..4a46eba277 100644 --- a/extensions/algebra/circuit/src/modular_chip/mod.rs +++ b/extensions/algebra/circuit/src/modular_chip/mod.rs @@ -1,24 +1,39 @@ -mod addsub; -pub use addsub::*; +use openvm_circuit::arch::{MatrixRecordArena, NewVmChipWrapper, VmAirWrapper}; +use openvm_mod_circuit_builder::FieldExpressionCoreAir; +use openvm_rv32_adapters::Rv32VecHeapAdapterAir; + +use crate::FieldExprVecHeapStep; + mod is_eq; pub use is_eq::*; +mod addsub; +pub use addsub::*; mod muldiv; pub use muldiv::*; -use openvm_circuit::arch::VmChipWrapper; -use openvm_instructions::riscv::{RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}; -use openvm_rv32_adapters::Rv32IsEqualModAdapterChip; #[cfg(test)] mod tests; -// Must have TOTAL_LIMBS = NUM_LANES * LANE_SIZE -pub type ModularIsEqualChip< - F, - const NUM_LANES: usize, - const LANE_SIZE: usize, - const TOTAL_LIMBS: usize, -> = VmChipWrapper< +pub(crate) type ModularAir = VmAirWrapper< + Rv32VecHeapAdapterAir<2, BLOCKS, BLOCKS, BLOCK_SIZE, BLOCK_SIZE>, + FieldExpressionCoreAir, +>; + +pub(crate) type ModularStep = + FieldExprVecHeapStep<2, BLOCKS, BLOCK_SIZE>; + +pub(crate) type ModularChip = NewVmChipWrapper< F, - Rv32IsEqualModAdapterChip, - ModularIsEqualCoreChip, + ModularAir, + ModularStep, + MatrixRecordArena, >; + +#[cfg(test)] +pub(crate) type ModularDenseChip = + NewVmChipWrapper< + F, + ModularAir, + ModularStep, + openvm_circuit::arch::DenseRecordArena, + >; diff --git a/extensions/algebra/circuit/src/modular_chip/muldiv.rs b/extensions/algebra/circuit/src/modular_chip/muldiv.rs index 30f063e2b1..f67be1de2c 100644 --- a/extensions/algebra/circuit/src/modular_chip/muldiv.rs +++ b/extensions/algebra/circuit/src/modular_chip/muldiv.rs @@ -1,22 +1,25 @@ -use std::{ - cell::RefCell, - rc::Rc, - sync::{Arc, Mutex}, -}; +use std::{cell::RefCell, rc::Rc}; use openvm_algebra_transpiler::Rv32ModularArithmeticOpcode; -use openvm_circuit::{arch::VmChipWrapper, system::memory::OfflineMemory}; -use openvm_circuit_derive::InstructionExecutor; -use openvm_circuit_primitives::var_range::{ - SharedVariableRangeCheckerChip, VariableRangeCheckerBus, +use openvm_circuit::{ + arch::ExecutionBridge, + system::memory::{offline_checker::MemoryBridge, SharedMemoryHelper}, +}; +use openvm_circuit_derive::{InsExecutorE1, InsExecutorE2, InstructionExecutor}; +use openvm_circuit_primitives::{ + bitwise_op_lookup::SharedBitwiseOperationLookupChip, + var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus}, + Chip, ChipUsageGetter, }; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; +use openvm_instructions::riscv::RV32_CELL_BITS; use openvm_mod_circuit_builder::{ - ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreChip, FieldVariable, SymbolicExpr, + ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreAir, FieldVariable, SymbolicExpr, }; -use openvm_rv32_adapters::Rv32VecHeapAdapterChip; +use openvm_rv32_adapters::{Rv32VecHeapAdapterAir, Rv32VecHeapAdapterStep}; use openvm_stark_backend::p3_field::PrimeField32; +use super::{ModularAir, ModularChip, ModularStep}; + pub fn muldiv_expr( config: ExprBuilderConfig, range_bus: VariableRangeCheckerBus, @@ -58,39 +61,58 @@ pub fn muldiv_expr( ) } -#[derive(Chip, ChipUsageGetter, InstructionExecutor)] +#[derive(Chip, ChipUsageGetter, InstructionExecutor, InsExecutorE1, InsExecutorE2)] pub struct ModularMulDivChip( - pub VmChipWrapper< - F, - Rv32VecHeapAdapterChip, - FieldExpressionCoreChip, - >, + pub ModularChip, ); impl ModularMulDivChip { + #[allow(clippy::too_many_arguments)] pub fn new( - adapter: Rv32VecHeapAdapterChip, + execution_bridge: ExecutionBridge, + memory_bridge: MemoryBridge, + mem_helper: SharedMemoryHelper, + pointer_max_bits: usize, config: ExprBuilderConfig, offset: usize, + bitwise_lookup_chip: SharedBitwiseOperationLookupChip, range_checker: SharedVariableRangeCheckerChip, - offline_memory: Arc>>, ) -> Self { let (expr, is_mul_flag, is_div_flag) = muldiv_expr(config, range_checker.bus()); - let core = FieldExpressionCoreChip::new( + + let local_opcode_idx = vec![ + Rv32ModularArithmeticOpcode::MUL as usize, + Rv32ModularArithmeticOpcode::DIV as usize, + Rv32ModularArithmeticOpcode::SETUP_MULDIV as usize, + ]; + let opcode_flag_idx = vec![is_mul_flag, is_div_flag]; + let air = ModularAir::new( + Rv32VecHeapAdapterAir::new( + execution_bridge, + memory_bridge, + bitwise_lookup_chip.bus(), + pointer_max_bits, + ), + FieldExpressionCoreAir::new( + expr.clone(), + offset, + local_opcode_idx.clone(), + opcode_flag_idx.clone(), + ), + ); + + let step = ModularStep::new( + Rv32VecHeapAdapterStep::new(pointer_max_bits, bitwise_lookup_chip), expr, offset, - vec![ - Rv32ModularArithmeticOpcode::MUL as usize, - Rv32ModularArithmeticOpcode::DIV as usize, - Rv32ModularArithmeticOpcode::SETUP_MULDIV as usize, - ], - vec![is_mul_flag, is_div_flag], + local_opcode_idx, + opcode_flag_idx, range_checker, "ModularMulDiv", false, ); - Self(VmChipWrapper::new(adapter, core, offline_memory)) + Self(ModularChip::new(air, step, mem_helper)) } } diff --git a/extensions/algebra/circuit/src/modular_chip/tests.rs b/extensions/algebra/circuit/src/modular_chip/tests.rs index 1ad3310f76..6c511c6f9f 100644 --- a/extensions/algebra/circuit/src/modular_chip/tests.rs +++ b/extensions/algebra/circuit/src/modular_chip/tests.rs @@ -6,7 +6,6 @@ use openvm_algebra_transpiler::Rv32ModularArithmeticOpcode; use openvm_circuit::arch::{ instructions::LocalOpcode, testing::{VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}, - AdapterRuntimeContext, Result, VmAdapterInterface, VmChipWrapper, VmCoreChip, }; use openvm_circuit_primitives::{ bigint::utils::{big_uint_to_limbs, secp256k1_coord_prime, secp256k1_scalar_prime}, @@ -18,105 +17,66 @@ use openvm_mod_circuit_builder::{ ExprBuilderConfig, }; use openvm_pairing_guest::bls12_381::BLS12_381_MODULUS; -use openvm_rv32_adapters::{ - rv32_write_heap_default, write_ptr_reg, Rv32IsEqualModAdapterChip, Rv32VecHeapAdapterChip, -}; +use openvm_rv32_adapters::{rv32_write_heap_default, write_ptr_reg}; use openvm_rv32im_circuit::adapters::RV32_REGISTER_NUM_LIMBS; -use openvm_stark_backend::p3_field::{FieldAlgebra, PrimeField32}; +use openvm_stark_backend::p3_field::FieldAlgebra; use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; -use rand::Rng; +use rand::{rngs::StdRng, Rng}; use super::{ - ModularAddSubChip, ModularIsEqualChip, ModularIsEqualCoreAir, ModularIsEqualCoreChip, - ModularIsEqualCoreCols, ModularIsEqualCoreRecord, ModularMulDivChip, + ModularAddSubChip, ModularIsEqualChip, ModularIsEqualCoreAir, ModularIsEqualCoreCols, + ModularMulDivChip, }; const NUM_LIMBS: usize = 32; const LIMB_BITS: usize = 8; -const BLOCK_SIZE: usize = 32; +const _BLOCK_SIZE: usize = 32; +const MAX_INS_CAPACITY: usize = 128; type F = BabyBear; -const ADD_LOCAL: usize = Rv32ModularArithmeticOpcode::ADD as usize; -const MUL_LOCAL: usize = Rv32ModularArithmeticOpcode::MUL as usize; +#[cfg(test)] +mod addsubtests { + use openvm_circuit::arch::InstructionExecutor; + use openvm_mod_circuit_builder::FieldExpressionCoreRecordMut; + use openvm_rv32_adapters::Rv32VecHeapAdapterRecord; + use test_case::test_case; -#[test] -fn test_coord_addsub() { - let opcode_offset = 0; - let modulus = secp256k1_coord_prime(); - test_addsub(opcode_offset, modulus); -} + use super::*; + use crate::modular_chip::ModularDenseChip; -#[test] -fn test_scalar_addsub() { - let opcode_offset = 4; - let modulus = secp256k1_scalar_prime(); - test_addsub(opcode_offset, modulus); -} + const ADD_LOCAL: usize = Rv32ModularArithmeticOpcode::ADD as usize; + + fn set_and_execute_addsub>( + tester: &mut VmChipTestBuilder, + chip: &mut E, + modulus: &BigUint, + is_setup: bool, + offset: usize, + ) { + let mut rng = create_seeded_rng(); + + let (a, b, op) = if is_setup { + (modulus.clone(), BigUint::zero(), ADD_LOCAL + 2) + } else { + let a_digits: Vec<_> = (0..NUM_LIMBS) + .map(|_| rng.gen_range(0..(1 << LIMB_BITS))) + .collect(); + let mut a = BigUint::new(a_digits.clone()); + let b_digits: Vec<_> = (0..NUM_LIMBS) + .map(|_| rng.gen_range(0..(1 << LIMB_BITS))) + .collect(); + let mut b = BigUint::new(b_digits.clone()); + + let op = rng.gen_range(0..2) + ADD_LOCAL; // 0 for add, 1 for sub + a %= modulus; + b %= modulus; + (a, b, op) + }; -fn test_addsub(opcode_offset: usize, modulus: BigUint) { - let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let config = ExprBuilderConfig { - modulus: modulus.clone(), - num_limbs: NUM_LIMBS, - limb_bits: LIMB_BITS, - }; - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - - // doing 1xNUM_LIMBS reads and writes - let adapter = Rv32VecHeapAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - bitwise_chip.clone(), - ); - let mut chip = ModularAddSubChip::new( - adapter, - config, - Rv32ModularArithmeticOpcode::CLASS_OFFSET + opcode_offset, - tester.range_checker(), - tester.offline_memory_mutex_arc(), - ); - let mut rng = create_seeded_rng(); - let num_tests = 50; - let mut all_ops = vec![ADD_LOCAL + 2]; // setup - let mut all_a = vec![modulus.clone()]; - let mut all_b = vec![BigUint::zero()]; - - // First loop: generate all random test data. - for _ in 0..num_tests { - let a_digits: Vec<_> = (0..NUM_LIMBS) - .map(|_| rng.gen_range(0..(1 << LIMB_BITS))) - .collect(); - let mut a = BigUint::new(a_digits.clone()); - let b_digits: Vec<_> = (0..NUM_LIMBS) - .map(|_| rng.gen_range(0..(1 << LIMB_BITS))) - .collect(); - let mut b = BigUint::new(b_digits.clone()); - - let op = rng.gen_range(0..2) + ADD_LOCAL; // 0 for add, 1 for sub - a %= &modulus; - b %= &modulus; - - all_ops.push(op); - all_a.push(a); - all_b.push(b); - } - // Second loop: actually run the tests. - for i in 0..=num_tests { - let op = all_ops[i]; - let a = all_a[i].clone(); - let b = all_b[i].clone(); - if i > 0 { - // if not setup - assert!(a < modulus); - assert!(b < modulus); - } let expected_answer = match op - ADD_LOCAL { - 0 => (&a + &b) % &modulus, - 1 => (&a + &modulus - &b) % &modulus, - 2 => a.clone() % &modulus, + 0 => (&a + &b) % modulus, + 1 => (&a + modulus - &b) % modulus, + 2 => a.clone() % modulus, _ => panic!(), }; @@ -133,11 +93,11 @@ fn test_addsub(opcode_offset: usize, modulus: BigUint) { let data_as = 2; let address1 = 0u32; let address2 = 128u32; - let address3 = (1 << 28) + 1234; // a large memory address to test heap adapter + let address3 = (1 << 28) + 1228; // a large memory address to test heap adapter - write_ptr_reg(&mut tester, ptr_as, addr_ptr1, address1); - write_ptr_reg(&mut tester, ptr_as, addr_ptr2, address2); - write_ptr_reg(&mut tester, ptr_as, addr_ptr3, address3); + write_ptr_reg(tester, ptr_as, addr_ptr1, address1); + write_ptr_reg(tester, ptr_as, addr_ptr2, address2); + write_ptr_reg(tester, ptr_as, addr_ptr3, address3); let a_limbs: [BabyBear; NUM_LIMBS] = biguint_to_limbs(a.clone(), LIMB_BITS).map(BabyBear::from_canonical_u32); @@ -147,105 +107,159 @@ fn test_addsub(opcode_offset: usize, modulus: BigUint) { tester.write(data_as, address2 as usize, b_limbs); let instruction = Instruction::from_isize( - VmOpcode::from_usize(chip.0.core.air.offset + op), + VmOpcode::from_usize(offset + op), addr_ptr3 as isize, addr_ptr1 as isize, addr_ptr2 as isize, ptr_as as isize, data_as as isize, ); - tester.execute(&mut chip, &instruction); + tester.execute(chip, &instruction); let expected_limbs = biguint_to_limbs::(expected_answer, LIMB_BITS); - for (i, expected) in expected_limbs.into_iter().enumerate() { - let address = address3 as usize + i; - let read_val = tester.read_cell(data_as, address); - assert_eq!(BabyBear::from_canonical_u32(expected), read_val); - } + let read_vals = tester.read::(data_as, address3 as usize); + assert_eq!(read_vals, expected_limbs.map(F::from_canonical_u32)); } - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); - tester.simple_test().expect("Verification failed"); -} + #[test_case(0, secp256k1_coord_prime(), 50)] + #[test_case(4, secp256k1_scalar_prime(), 50)] + fn test_addsub(opcode_offset: usize, modulus: BigUint, num_ops: usize) { + let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); + let config = ExprBuilderConfig { + modulus: modulus.clone(), + num_limbs: NUM_LIMBS, + limb_bits: LIMB_BITS, + }; + let offset = Rv32ModularArithmeticOpcode::CLASS_OFFSET + opcode_offset; + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); -#[test] -fn test_coord_muldiv() { - let opcode_offset = 0; - let modulus = secp256k1_coord_prime(); - test_muldiv(opcode_offset, modulus); -} + // doing 1xNUM_LIMBS reads and writes + let mut chip = ModularAddSubChip::::new( + tester.execution_bridge(), + tester.memory_bridge(), + tester.memory_helper(), + tester.address_bits(), + config, + offset, + bitwise_chip.clone(), + tester.range_checker(), + ); + chip.0.set_trace_buffer_height(MAX_INS_CAPACITY); -#[test] -fn test_scalar_muldiv() { - let opcode_offset = 4; - let modulus = secp256k1_scalar_prime(); - test_muldiv(opcode_offset, modulus); -} + for i in 0..num_ops { + set_and_execute_addsub(&mut tester, &mut chip, &modulus, i == 0, offset); + } -fn test_muldiv(opcode_offset: usize, modulus: BigUint) { - let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let config = ExprBuilderConfig { - modulus: modulus.clone(), - num_limbs: NUM_LIMBS, - limb_bits: LIMB_BITS, - }; - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - // doing 1xNUM_LIMBS reads and writes - let adapter = Rv32VecHeapAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - bitwise_chip.clone(), - ); - let mut chip = ModularMulDivChip::new( - adapter, - config, - Rv32ModularArithmeticOpcode::CLASS_OFFSET + opcode_offset, - tester.range_checker(), - tester.offline_memory_mutex_arc(), - ); - let mut rng = create_seeded_rng(); - let num_tests = 50; - let mut all_ops = vec![MUL_LOCAL + 2]; - let mut all_a = vec![modulus.clone()]; - let mut all_b = vec![BigUint::zero()]; - - // First loop: generate all random test data. - for _ in 0..num_tests { - let a_digits: Vec<_> = (0..NUM_LIMBS) - .map(|_| rng.gen_range(0..(1 << LIMB_BITS))) - .collect(); - let mut a = BigUint::new(a_digits.clone()); - let b_digits: Vec<_> = (0..NUM_LIMBS) - .map(|_| rng.gen_range(0..(1 << LIMB_BITS))) - .collect(); - let mut b = BigUint::new(b_digits.clone()); - - // let op = rng.gen_range(2..4); // 2 for mul, 3 for div - let op = MUL_LOCAL; - a %= &modulus; - b %= &modulus; - - all_ops.push(op); - all_a.push(a); - all_b.push(b); + let tester = tester.build().load(chip).load(bitwise_chip).finalize(); + tester.simple_test().expect("Verification failed"); } - // Second loop: actually run the tests. - for i in 0..=num_tests { - let op = all_ops[i]; - let a = all_a[i].clone(); - let b = all_b[i].clone(); - if i > 0 { - // if not setup - assert!(a < modulus); - assert!(b < modulus); + + #[test_case(0, secp256k1_coord_prime(), 50)] + #[test_case(4, secp256k1_scalar_prime(), 50)] + fn dense_record_arena_test(opcode_offset: usize, modulus: BigUint, num_ops: usize) { + let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); + let config = ExprBuilderConfig { + modulus: modulus.clone(), + num_limbs: NUM_LIMBS, + limb_bits: LIMB_BITS, + }; + let offset = Rv32ModularArithmeticOpcode::CLASS_OFFSET + opcode_offset; + + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); + let mut sparse_chip = ModularAddSubChip::::new( + tester.execution_bridge(), + tester.memory_bridge(), + tester.memory_helper(), + tester.address_bits(), + config.clone(), + offset, + bitwise_chip.clone(), + tester.range_checker(), + ); + sparse_chip.0.set_trace_buffer_height(MAX_INS_CAPACITY); + + { + // Using a trick to create a dense chip using the air and step of the sparse chip + // doing 1xNUM_LIMBS reads and writes + let tmp_chip = ModularAddSubChip::::new( + tester.execution_bridge(), + tester.memory_bridge(), + tester.memory_helper(), + tester.address_bits(), + config, + offset, + bitwise_chip.clone(), + tester.range_checker(), + ); + + let mut dense_chip = + ModularDenseChip::new(tmp_chip.0.air, tmp_chip.0.step, tester.memory_helper()); + dense_chip.set_trace_buffer_height(MAX_INS_CAPACITY); + + for i in 0..num_ops { + set_and_execute_addsub(&mut tester, &mut dense_chip, &modulus, i == 0, offset); + } + + type Record<'a> = ( + &'a mut Rv32VecHeapAdapterRecord<2, 1, 1, NUM_LIMBS, NUM_LIMBS>, + FieldExpressionCoreRecordMut<'a>, + ); + let mut record_interpreter = dense_chip.arena.get_record_seeker::(); + record_interpreter.transfer_to_matrix_arena( + &mut sparse_chip.0.arena, + dense_chip.step.0.get_record_layout::(), + ); } + + let tester = tester + .build() + .load(sparse_chip) + .load(bitwise_chip) + .finalize(); + tester.simple_test().expect("Verification failed"); + } +} + +#[cfg(test)] +mod muldivtests { + use test_case::test_case; + + use super::*; + + const MUL_LOCAL: usize = Rv32ModularArithmeticOpcode::MUL as usize; + + fn set_and_execute_muldiv( + tester: &mut VmChipTestBuilder, + chip: &mut ModularMulDivChip, + modulus: &BigUint, + is_setup: bool, + ) { + let mut rng = create_seeded_rng(); + + let (a, b, op) = if is_setup { + (modulus.clone(), BigUint::zero(), MUL_LOCAL + 2) + } else { + let a_digits: Vec<_> = (0..NUM_LIMBS) + .map(|_| rng.gen_range(0..(1 << LIMB_BITS))) + .collect(); + let mut a = BigUint::new(a_digits.clone()); + let b_digits: Vec<_> = (0..NUM_LIMBS) + .map(|_| rng.gen_range(0..(1 << LIMB_BITS))) + .collect(); + let mut b = BigUint::new(b_digits.clone()); + + let op = rng.gen_range(0..2) + MUL_LOCAL; // 0 for add, 1 for sub + a %= modulus; + b %= modulus; + (a, b, op) + }; + let expected_answer = match op - MUL_LOCAL { - 0 => (&a * &b) % &modulus, - 1 => (&a * b.modinv(&modulus).unwrap()) % &modulus, - 2 => a.clone() % &modulus, + 0 => (&a * &b) % modulus, + 1 => (&a * b.modinv(modulus).unwrap()) % modulus, + 2 => a.clone() % modulus, _ => panic!(), }; @@ -264,307 +278,370 @@ fn test_muldiv(opcode_offset: usize, modulus: BigUint) { let address2 = 128; let address3 = 256; - write_ptr_reg(&mut tester, ptr_as, addr_ptr1, address1); - write_ptr_reg(&mut tester, ptr_as, addr_ptr2, address2); - write_ptr_reg(&mut tester, ptr_as, addr_ptr3, address3); + write_ptr_reg(tester, ptr_as, addr_ptr1, address1); + write_ptr_reg(tester, ptr_as, addr_ptr2, address2); + write_ptr_reg(tester, ptr_as, addr_ptr3, address3); - let a_limbs: [BabyBear; NUM_LIMBS] = - biguint_to_limbs(a.clone(), LIMB_BITS).map(BabyBear::from_canonical_u32); + let a_limbs: [F; NUM_LIMBS] = + biguint_to_limbs(a.clone(), LIMB_BITS).map(F::from_canonical_u32); tester.write(data_as, address1 as usize, a_limbs); - let b_limbs: [BabyBear; NUM_LIMBS] = - biguint_to_limbs(b.clone(), LIMB_BITS).map(BabyBear::from_canonical_u32); + let b_limbs: [F; NUM_LIMBS] = + biguint_to_limbs(b.clone(), LIMB_BITS).map(F::from_canonical_u32); tester.write(data_as, address2 as usize, b_limbs); let instruction = Instruction::from_isize( - VmOpcode::from_usize(chip.0.core.air.offset + op), + VmOpcode::from_usize(chip.0.step.0.offset + op), addr_ptr3 as isize, addr_ptr1 as isize, addr_ptr2 as isize, ptr_as as isize, data_as as isize, ); - tester.execute(&mut chip, &instruction); + tester.execute(chip, &instruction); let expected_limbs = biguint_to_limbs::(expected_answer, LIMB_BITS); - for (i, expected) in expected_limbs.into_iter().enumerate() { - let address = address3 as usize + i; - let read_val = tester.read_cell(data_as, address); - assert_eq!(BabyBear::from_canonical_u32(expected), read_val); - } + let read_vals = tester.read::(data_as, address3 as usize); + assert_eq!(read_vals, expected_limbs.map(F::from_canonical_u32)); } - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); - - tester.simple_test().expect("Verification failed"); -} -fn test_is_equal( - opcode_offset: usize, - modulus: BigUint, - num_tests: usize, -) { - let mut rng = create_seeded_rng(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - - let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let mut chip = ModularIsEqualChip::::new( - Rv32IsEqualModAdapterChip::new( - tester.execution_bus(), - tester.program_bus(), + #[test_case(0, secp256k1_coord_prime(), 50)] + #[test_case(4, secp256k1_scalar_prime(), 50)] + fn test_muldiv(opcode_offset: usize, modulus: BigUint, num_ops: usize) { + let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); + let config = ExprBuilderConfig { + modulus: modulus.clone(), + num_limbs: NUM_LIMBS, + limb_bits: LIMB_BITS, + }; + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); + // doing 1xNUM_LIMBS reads and writes + let mut chip = ModularMulDivChip::new( + tester.execution_bridge(), tester.memory_bridge(), + tester.memory_helper(), tester.address_bits(), + config, + Rv32ModularArithmeticOpcode::CLASS_OFFSET + opcode_offset, bitwise_chip.clone(), - ), - ModularIsEqualCoreChip::new(modulus.clone(), bitwise_chip.clone(), opcode_offset), - tester.offline_memory_mutex_arc(), - ); + tester.range_checker(), + ); + chip.0.set_trace_buffer_height(MAX_INS_CAPACITY); - { - let vec = big_uint_to_limbs(&modulus, LIMB_BITS); - let modulus_limbs: [F; TOTAL_LIMBS] = std::array::from_fn(|i| { - if i < vec.len() { - F::from_canonical_usize(vec[i]) - } else { - F::ZERO - } - }); + for i in 0..num_ops { + set_and_execute_muldiv(&mut tester, &mut chip, &modulus, i == 0); + } + let tester = tester.build().load(chip).load(bitwise_chip).finalize(); - let setup_instruction = rv32_write_heap_default::( - &mut tester, - vec![modulus_limbs], - vec![[F::ZERO; TOTAL_LIMBS]], - opcode_offset + Rv32ModularArithmeticOpcode::SETUP_ISEQ as usize, + tester.simple_test().expect("Verification failed"); + } +} + +#[cfg(test)] +mod is_equal_tests { + use openvm_rv32_adapters::{Rv32IsEqualModAdapterAir, Rv32IsEqualModeAdapterStep}; + use openvm_stark_backend::{ + p3_air::BaseAir, + p3_matrix::{ + dense::{DenseMatrix, RowMajorMatrix}, + Matrix, + }, + utils::disable_debug_builder, + verifier::VerificationError, + }; + + use super::*; + use crate::modular_chip::{ModularIsEqualAir, VmModularIsEqualStep}; + + fn create_test_chips< + const NUM_LANES: usize, + const LANE_SIZE: usize, + const TOTAL_LIMBS: usize, + >( + tester: &mut VmChipTestBuilder, + modulus: &BigUint, + modulus_limbs: [u8; TOTAL_LIMBS], + offset: usize, + ) -> ( + ModularIsEqualChip, + SharedBitwiseOperationLookupChip, + ) { + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); + + let mut chip = ModularIsEqualChip::::new( + ModularIsEqualAir::new( + Rv32IsEqualModAdapterAir::new( + tester.execution_bridge(), + tester.memory_bridge(), + bitwise_bus, + tester.address_bits(), + ), + ModularIsEqualCoreAir::new(modulus.clone(), bitwise_bus, offset), + ), + VmModularIsEqualStep::new( + Rv32IsEqualModeAdapterStep::new(tester.address_bits(), bitwise_chip.clone()), + modulus_limbs, + offset, + bitwise_chip.clone(), + ), + tester.memory_helper(), ); - tester.execute(&mut chip, &setup_instruction); + chip.set_trace_buffer_height(MAX_INS_CAPACITY); + + (chip, bitwise_chip) } - for _ in 0..num_tests { - let b = generate_field_element::(&modulus, &mut rng); - let c = if rng.gen_bool(0.5) { - b + + #[allow(clippy::too_many_arguments)] + fn set_and_execute_is_equal< + const NUM_LANES: usize, + const LANE_SIZE: usize, + const TOTAL_LIMBS: usize, + >( + tester: &mut VmChipTestBuilder, + chip: &mut ModularIsEqualChip, + rng: &mut StdRng, + modulus: &BigUint, + offset: usize, + modulus_limbs: [F; TOTAL_LIMBS], + is_setup: bool, + b: Option<[F; TOTAL_LIMBS]>, + c: Option<[F; TOTAL_LIMBS]>, + ) { + let instruction = if is_setup { + rv32_write_heap_default::( + tester, + vec![modulus_limbs], + vec![[F::ZERO; TOTAL_LIMBS]], + offset + Rv32ModularArithmeticOpcode::SETUP_ISEQ as usize, + ) } else { - generate_field_element::(&modulus, &mut rng) + let b = b.unwrap_or( + generate_field_element::(modulus, rng) + .map(F::from_canonical_u32), + ); + let c = c.unwrap_or(if rng.gen_bool(0.5) { + b + } else { + generate_field_element::(modulus, rng) + .map(F::from_canonical_u32) + }); + + rv32_write_heap_default::( + tester, + vec![b], + vec![c], + offset + Rv32ModularArithmeticOpcode::IS_EQ as usize, + ) }; + tester.execute(chip, &instruction); + } - let instruction = rv32_write_heap_default::( - &mut tester, - vec![b.map(F::from_canonical_u32)], - vec![c.map(F::from_canonical_u32)], - opcode_offset + Rv32ModularArithmeticOpcode::IS_EQ as usize, - ); - tester.execute(&mut chip, &instruction); + ////////////////////////////////////////////////////////////////////////////////////// + // POSITIVE TESTS + // + // Randomly generate computations and execute, ensuring that the generated trace + // passes all constraints. + ////////////////////////////////////////////////////////////////////////////////////// + + #[test] + fn test_modular_is_equal_1x32() { + test_is_equal::<1, 32, 32>(17, secp256k1_coord_prime(), 100); } - // Special case where b == c are close to the prime - let b_vec = big_uint_to_limbs(&modulus, LIMB_BITS); - let mut b = from_fn(|i| if i < b_vec.len() { b_vec[i] as u32 } else { 0 }); - b[0] -= 1; - let instruction = rv32_write_heap_default::( - &mut tester, - vec![b.map(F::from_canonical_u32)], - vec![b.map(F::from_canonical_u32)], - opcode_offset + Rv32ModularArithmeticOpcode::IS_EQ as usize, - ); - tester.execute(&mut chip, &instruction); - - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); - tester.simple_test().expect("Verification failed"); -} + #[test] + fn test_modular_is_equal_3x16() { + test_is_equal::<3, 16, 48>(17, BLS12_381_MODULUS.clone(), 100); + } -#[test] -fn test_modular_is_equal_1x32() { - test_is_equal::<1, 32, 32>(17, secp256k1_coord_prime(), 100); -} + fn test_is_equal( + opcode_offset: usize, + modulus: BigUint, + num_tests: usize, + ) { + let mut rng = create_seeded_rng(); + let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); -#[test] -fn test_modular_is_equal_3x16() { - test_is_equal::<3, 16, 48>(17, BLS12_381_MODULUS.clone(), 100); -} + let vec = big_uint_to_limbs(&modulus, LIMB_BITS); + let modulus_limbs: [u8; TOTAL_LIMBS] = + from_fn(|i| if i < vec.len() { vec[i] as u8 } else { 0 }); -// Wrapper chip for testing a bad setup row -type BadModularIsEqualChip< - F, - const NUM_LANES: usize, - const LANE_SIZE: usize, - const TOTAL_LIMBS: usize, -> = VmChipWrapper< - F, - Rv32IsEqualModAdapterChip, - BadModularIsEqualCoreChip, ->; - -// Wrapper chip for testing a bad setup row -struct BadModularIsEqualCoreChip< - const READ_LIMBS: usize, - const WRITE_LIMBS: usize, - const LIMB_BITS: usize, -> { - chip: ModularIsEqualCoreChip, -} + let (mut chip, bitwise_chip) = create_test_chips::( + &mut tester, + &modulus, + modulus_limbs, + opcode_offset, + ); -impl - BadModularIsEqualCoreChip -{ - pub fn new( - modulus: BigUint, - bitwise_lookup_chip: SharedBitwiseOperationLookupChip, - offset: usize, - ) -> Self { - Self { - chip: ModularIsEqualCoreChip::new(modulus, bitwise_lookup_chip, offset), + let modulus_limbs = modulus_limbs.map(F::from_canonical_u8); + + for i in 0..num_tests { + set_and_execute_is_equal( + &mut tester, + &mut chip, + &mut rng, + &modulus, + opcode_offset, + modulus_limbs, + i == 0, // the first test is a setup test + None, + None, + ); } + + // Special case where b == c are close to the prime + let mut b = modulus_limbs; + b[0] -= F::ONE; + set_and_execute_is_equal( + &mut tester, + &mut chip, + &mut rng, + &modulus, + opcode_offset, + modulus_limbs, + false, + Some(b), + Some(b), + ); + + let tester = tester.build().load(chip).load(bitwise_chip).finalize(); + tester.simple_test().expect("Verification failed"); } -} -impl< - F: PrimeField32, - I: VmAdapterInterface, + ////////////////////////////////////////////////////////////////////////////////////// + // NEGATIVE TESTS + // + // Given a fake trace of a single operation, setup a chip and run the test. We replace + // part of the trace and check that the chip throws the expected error. + ////////////////////////////////////////////////////////////////////////////////////// + + /// Negative tests test for 3 "type" of errors determined by the value of b[0]: + fn run_negative_is_equal_test< + const NUM_LANES: usize, + const LANE_SIZE: usize, const READ_LIMBS: usize, - const WRITE_LIMBS: usize, - const LIMB_BITS: usize, - > VmCoreChip for BadModularIsEqualCoreChip -where - I::Reads: Into<[[F; READ_LIMBS]; 2]>, - I::Writes: From<[[F; WRITE_LIMBS]; 1]>, -{ - type Record = ModularIsEqualCoreRecord; - type Air = ModularIsEqualCoreAir; - - #[allow(clippy::type_complexity)] - fn execute_instruction( - &self, - instruction: &Instruction, - from_pc: u32, - reads: I::Reads, - ) -> Result<(AdapterRuntimeContext, Self::Record)> { - // Override the b_diff_idx to be out of bounds. - // This will cause lt_marker to be all zeros except a 2. - // There was a bug in this case which allowed b to be less than N. - self.chip.execute_instruction(instruction, from_pc, reads) - } + >( + modulus: BigUint, + opcode_offset: usize, + test_case: usize, + expected_error: VerificationError, + ) { + let mut rng = create_seeded_rng(); + let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - fn get_opcode_name(&self, opcode: usize) -> String { - as VmCoreChip>::get_opcode_name(&self.chip, opcode) - } + let vec = big_uint_to_limbs(&modulus, LIMB_BITS); + let modulus_limbs: [u8; READ_LIMBS] = + from_fn(|i| if i < vec.len() { vec[i] as u8 } else { 0 }); - fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) { - as VmCoreChip>::generate_trace_row(&self.chip, row_slice, record.clone()); - let row_slice: &mut ModularIsEqualCoreCols<_, READ_LIMBS> = row_slice.borrow_mut(); - // decide which bug to test based on b[0] - if record.b[0] == F::ONE { - // test the constraint that c_lt_mark = 2 when is_setup = 1 - row_slice.c_lt_mark = F::ONE; - row_slice.lt_marker = [F::ZERO; READ_LIMBS]; - row_slice.lt_marker[READ_LIMBS - 1] = F::ONE; - row_slice.c_lt_diff = - F::from_canonical_u32(self.chip.air.modulus_limbs[READ_LIMBS - 1]) - - record.c[READ_LIMBS - 1]; - row_slice.b_lt_diff = - F::from_canonical_u32(self.chip.air.modulus_limbs[READ_LIMBS - 1]) - - record.b[READ_LIMBS - 1]; - } else if record.b[0] == F::from_canonical_u32(2) { - // test the constraint that b[i] = N[i] for all i when prefix_sum is not 1 or - // lt_marker_sum - is_setup - row_slice.c_lt_mark = F::from_canonical_u8(2); - row_slice.lt_marker = [F::ZERO; READ_LIMBS]; - row_slice.lt_marker[READ_LIMBS - 1] = F::from_canonical_u8(2); - row_slice.c_lt_diff = - F::from_canonical_u32(self.chip.air.modulus_limbs[READ_LIMBS - 1]) - - record.c[READ_LIMBS - 1]; - } else if record.b[0] == F::from_canonical_u32(3) { - // test the constraint that sum_i lt_marker[i] = 2 when is_setup = 1 - row_slice.c_lt_mark = F::from_canonical_u8(2); - row_slice.lt_marker = [F::ZERO; READ_LIMBS]; - row_slice.lt_marker[READ_LIMBS - 1] = F::from_canonical_u8(2); - row_slice.lt_marker[0] = F::ONE; - row_slice.b_lt_diff = - F::from_canonical_u32(self.chip.air.modulus_limbs[0]) - record.b[0]; - row_slice.c_lt_diff = - F::from_canonical_u32(self.chip.air.modulus_limbs[READ_LIMBS - 1]) - - record.c[READ_LIMBS - 1]; - } - } + let (mut chip, bitwise_chip) = create_test_chips::( + &mut tester, + &modulus, + modulus_limbs, + opcode_offset, + ); - fn air(&self) -> &Self::Air { - as VmCoreChip>::air( - &self.chip, - ) - } -} + let modulus_limbs = modulus_limbs.map(F::from_canonical_u8); -// Test that passes the wrong modulus in the setup instruction. -// This proof should fail to verify. -fn test_is_equal_setup_bad< - const NUM_LANES: usize, - const LANE_SIZE: usize, - const TOTAL_LIMBS: usize, ->( - opcode_offset: usize, - modulus: BigUint, - b_val: u32, /* used to select which bug to test. currently only 1, 2, and 3 are supported - * (because there are three bugs to test) */ -) { - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - - let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let mut chip = BadModularIsEqualChip::::new( - Rv32IsEqualModAdapterChip::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - bitwise_chip.clone(), - ), - BadModularIsEqualCoreChip::new(modulus.clone(), bitwise_chip.clone(), opcode_offset), - tester.offline_memory_mutex_arc(), - ); - - let mut b_limbs = [F::ZERO; TOTAL_LIMBS]; - b_limbs[0] = F::from_canonical_u32(b_val); - let setup_instruction = rv32_write_heap_default::( - &mut tester, - vec![b_limbs], - vec![[F::ZERO; TOTAL_LIMBS]], - opcode_offset + Rv32ModularArithmeticOpcode::SETUP_ISEQ as usize, - ); - tester.execute(&mut chip, &setup_instruction); - - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); - tester.simple_test().expect("Verification failed"); -} + set_and_execute_is_equal( + &mut tester, + &mut chip, + &mut rng, + &modulus, + opcode_offset, + modulus_limbs, + true, + None, + None, + ); -#[should_panic] -#[test] -fn test_modular_is_equal_setup_bad_1_1x32() { - test_is_equal_setup_bad::<1, 32, 32>(17, secp256k1_coord_prime(), 1); -} + let adapter_width = BaseAir::::width(&chip.air.adapter); + let modify_trace = |trace: &mut DenseMatrix| { + let mut trace_row = trace.row_slice(0).to_vec(); + let cols: &mut ModularIsEqualCoreCols<_, READ_LIMBS> = + trace_row.split_at_mut(adapter_width).1.borrow_mut(); + if test_case == 1 { + // test the constraint that c_lt_mark = 2 when is_setup = 1 + cols.b[0] = F::from_canonical_u32(1); + cols.c_lt_mark = F::ONE; + cols.lt_marker = [F::ZERO; READ_LIMBS]; + cols.lt_marker[READ_LIMBS - 1] = F::ONE; + cols.c_lt_diff = modulus_limbs[READ_LIMBS - 1] - cols.c[READ_LIMBS - 1]; + cols.b_lt_diff = modulus_limbs[READ_LIMBS - 1] - cols.b[READ_LIMBS - 1]; + } else if test_case == 2 { + // test the constraint that b[i] = N[i] for all i when prefix_sum is not 1 or + // lt_marker_sum - is_setup + cols.b[0] = F::from_canonical_u32(2); + cols.c_lt_mark = F::from_canonical_u8(2); + cols.lt_marker = [F::ZERO; READ_LIMBS]; + cols.lt_marker[READ_LIMBS - 1] = F::from_canonical_u8(2); + cols.c_lt_diff = modulus_limbs[READ_LIMBS - 1] - cols.c[READ_LIMBS - 1]; + } else if test_case == 3 { + // test the constraint that sum_i lt_marker[i] = 2 when is_setup = 1 + cols.b[0] = F::from_canonical_u32(3); + cols.c_lt_mark = F::from_canonical_u8(2); + cols.lt_marker = [F::ZERO; READ_LIMBS]; + cols.lt_marker[READ_LIMBS - 1] = F::from_canonical_u8(2); + cols.lt_marker[0] = F::ONE; + cols.b_lt_diff = modulus_limbs[0] - cols.b[0]; + cols.c_lt_diff = modulus_limbs[READ_LIMBS - 1] - cols.c[READ_LIMBS - 1]; + } + *trace = RowMajorMatrix::new(trace_row, trace.width()); + }; -#[should_panic] -#[test] -fn test_modular_is_equal_setup_bad_2_1x32_2() { - test_is_equal_setup_bad::<1, 32, 32>(17, secp256k1_coord_prime(), 2); -} + disable_debug_builder(); + let tester = tester + .build() + .load_and_prank_trace(chip, modify_trace) + .load(bitwise_chip) + .finalize(); + tester.simple_test_with_expected_error(expected_error); + } -#[should_panic] -#[test] -fn test_modular_is_equal_setup_bad_3_1x32() { - test_is_equal_setup_bad::<1, 32, 32>(17, secp256k1_coord_prime(), 3); -} + #[test] + fn negative_test_modular_is_equal_1x32() { + run_negative_is_equal_test::<1, 32, 32>( + secp256k1_coord_prime(), + 17, + 1, + VerificationError::OodEvaluationMismatch, + ); -#[should_panic] -#[test] -fn test_modular_is_equal_setup_bad_1_3x16() { - test_is_equal_setup_bad::<3, 16, 48>(17, BLS12_381_MODULUS.clone(), 1); -} + run_negative_is_equal_test::<1, 32, 32>( + secp256k1_coord_prime(), + 17, + 2, + VerificationError::OodEvaluationMismatch, + ); -#[should_panic] -#[test] -fn test_modular_is_equal_setup_bad_2_3x16() { - test_is_equal_setup_bad::<3, 16, 48>(17, BLS12_381_MODULUS.clone(), 2); -} + run_negative_is_equal_test::<1, 32, 32>( + secp256k1_coord_prime(), + 17, + 3, + VerificationError::OodEvaluationMismatch, + ); + } + + #[test] + fn negative_test_modular_is_equal_3x16() { + run_negative_is_equal_test::<3, 16, 48>( + BLS12_381_MODULUS.clone(), + 17, + 1, + VerificationError::OodEvaluationMismatch, + ); -#[should_panic] -#[test] -fn test_modular_is_equal_setup_bad_3_3x16() { - test_is_equal_setup_bad::<3, 16, 48>(17, BLS12_381_MODULUS.clone(), 3); + run_negative_is_equal_test::<3, 16, 48>( + BLS12_381_MODULUS.clone(), + 17, + 2, + VerificationError::OodEvaluationMismatch, + ); + + run_negative_is_equal_test::<3, 16, 48>( + BLS12_381_MODULUS.clone(), + 17, + 3, + VerificationError::OodEvaluationMismatch, + ); + } } diff --git a/extensions/algebra/circuit/src/modular_extension.rs b/extensions/algebra/circuit/src/modular_extension.rs index 99632d6ce3..b90628bab6 100644 --- a/extensions/algebra/circuit/src/modular_extension.rs +++ b/extensions/algebra/circuit/src/modular_extension.rs @@ -1,20 +1,25 @@ +use std::array; + use derive_more::derive::From; use num_bigint::{BigUint, RandBigInt}; use num_traits::{FromPrimitive, One}; use openvm_algebra_transpiler::{ModularPhantom, Rv32ModularArithmeticOpcode}; use openvm_circuit::{ self, - arch::{SystemPort, VmExtension, VmInventory, VmInventoryBuilder, VmInventoryError}, + arch::{ + ExecutionBridge, SystemPort, VmExtension, VmInventory, VmInventoryBuilder, VmInventoryError, + }, system::phantom::PhantomChip, }; -use openvm_circuit_derive::{AnyEnum, InstructionExecutor}; -use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, +use openvm_circuit_derive::{AnyEnum, InsExecutorE1, InsExecutorE2, InstructionExecutor}; +use openvm_circuit_primitives::{ + bigint::utils::big_uint_to_limbs, + bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, }; use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; use openvm_instructions::{LocalOpcode, PhantomDiscriminant, VmOpcode}; use openvm_mod_circuit_builder::ExprBuilderConfig; -use openvm_rv32_adapters::{Rv32IsEqualModAdapterChip, Rv32VecHeapAdapterChip}; +use openvm_rv32_adapters::{Rv32IsEqualModAdapterAir, Rv32IsEqualModeAdapterStep}; use openvm_stark_backend::p3_field::PrimeField32; use rand::Rng; use serde::{Deserialize, Serialize}; @@ -22,9 +27,12 @@ use serde_with::{serde_as, DisplayFromStr}; use strum::EnumCount; use crate::modular_chip::{ - ModularAddSubChip, ModularIsEqualChip, ModularIsEqualCoreChip, ModularMulDivChip, + ModularAddSubChip, ModularIsEqualAir, ModularIsEqualChip, ModularIsEqualCoreAir, + ModularMulDivChip, VmModularIsEqualStep, }; +// TODO: this should be decided after e2 execution + #[serde_as] #[derive(Clone, Debug, derive_new::new, Serialize, Deserialize)] pub struct ModularExtension { @@ -46,7 +54,9 @@ impl ModularExtension { } } -#[derive(ChipUsageGetter, Chip, InstructionExecutor, AnyEnum, From)] +#[derive( + ChipUsageGetter, Chip, InstructionExecutor, AnyEnum, From, InsExecutorE1, InsExecutorE2, +)] pub enum ModularExtensionExecutor { // 32 limbs prime ModularAddSubRv32_32(ModularAddSubChip), @@ -79,7 +89,11 @@ impl VmExtension for ModularExtension { program_bus, memory_bridge, } = builder.system_port(); + + let execution_bridge = ExecutionBridge::new(execution_bus, program_bus); let range_checker = builder.system_base().range_checker_chip.clone(); + let pointer_max_bits = builder.system_config().memory_config.pointer_max_bits; + let bitwise_lu_chip = if let Some(&chip) = builder .find_chip::>() .first() @@ -91,8 +105,6 @@ impl VmExtension for ModularExtension { inventory.add_periphery_chip(chip.clone()); chip }; - let offline_memory = builder.system_base().offline_memory(); - let address_bits = builder.system_config().memory_config.pointer_max_bits; let addsub_opcodes = (Rv32ModularArithmeticOpcode::ADD as usize) ..=(Rv32ModularArithmeticOpcode::SETUP_ADDSUB as usize); @@ -117,28 +129,19 @@ impl VmExtension for ModularExtension { num_limbs: 48, limb_bits: 8, }; - let adapter_chip_32 = Rv32VecHeapAdapterChip::new( - execution_bus, - program_bus, - memory_bridge, - address_bits, - bitwise_lu_chip.clone(), - ); - let adapter_chip_48 = Rv32VecHeapAdapterChip::new( - execution_bus, - program_bus, - memory_bridge, - address_bits, - bitwise_lu_chip.clone(), - ); + + let modulus_limbs = big_uint_to_limbs(modulus, 8); if bytes <= 32 { let addsub_chip = ModularAddSubChip::new( - adapter_chip_32.clone(), + execution_bridge, + memory_bridge, + builder.system_base().memory_controller.helper(), + pointer_max_bits, config32.clone(), start_offset, + bitwise_lu_chip.clone(), range_checker.clone(), - offline_memory.clone(), ); inventory.add_executor( ModularExtensionExecutor::ModularAddSubRv32_32(addsub_chip), @@ -147,11 +150,14 @@ impl VmExtension for ModularExtension { .map(|x| VmOpcode::from_usize(x + start_offset)), )?; let muldiv_chip = ModularMulDivChip::new( - adapter_chip_32.clone(), + execution_bridge, + memory_bridge, + builder.system_base().memory_controller.helper(), + pointer_max_bits, config32.clone(), start_offset, + bitwise_lu_chip.clone(), range_checker.clone(), - offline_memory.clone(), ); inventory.add_executor( ModularExtensionExecutor::ModularMulDivRv32_32(muldiv_chip), @@ -159,20 +165,35 @@ impl VmExtension for ModularExtension { .clone() .map(|x| VmOpcode::from_usize(x + start_offset)), )?; + + let modulus_limbs = array::from_fn(|i| { + if i < modulus_limbs.len() { + modulus_limbs[i] as u8 + } else { + 0 + } + }); let isequal_chip = ModularIsEqualChip::new( - Rv32IsEqualModAdapterChip::new( - execution_bus, - program_bus, - memory_bridge, - address_bits, - bitwise_lu_chip.clone(), + ModularIsEqualAir::new( + Rv32IsEqualModAdapterAir::new( + execution_bridge, + memory_bridge, + bitwise_lu_chip.bus(), + pointer_max_bits, + ), + ModularIsEqualCoreAir::new( + modulus.clone(), + bitwise_lu_chip.bus(), + start_offset, + ), ), - ModularIsEqualCoreChip::new( - modulus.clone(), - bitwise_lu_chip.clone(), + VmModularIsEqualStep::new( + Rv32IsEqualModeAdapterStep::new(pointer_max_bits, bitwise_lu_chip.clone()), + modulus_limbs, start_offset, + bitwise_lu_chip.clone(), ), - offline_memory.clone(), + builder.system_base().memory_controller.helper(), ); inventory.add_executor( ModularExtensionExecutor::ModularIsEqualRv32_32(isequal_chip), @@ -182,11 +203,14 @@ impl VmExtension for ModularExtension { )?; } else if bytes <= 48 { let addsub_chip = ModularAddSubChip::new( - adapter_chip_48.clone(), + execution_bridge, + memory_bridge, + builder.system_base().memory_controller.helper(), + pointer_max_bits, config48.clone(), start_offset, + bitwise_lu_chip.clone(), range_checker.clone(), - offline_memory.clone(), ); inventory.add_executor( ModularExtensionExecutor::ModularAddSubRv32_48(addsub_chip), @@ -195,11 +219,14 @@ impl VmExtension for ModularExtension { .map(|x| VmOpcode::from_usize(x + start_offset)), )?; let muldiv_chip = ModularMulDivChip::new( - adapter_chip_48.clone(), + execution_bridge, + memory_bridge, + builder.system_base().memory_controller.helper(), + pointer_max_bits, config48.clone(), start_offset, + bitwise_lu_chip.clone(), range_checker.clone(), - offline_memory.clone(), ); inventory.add_executor( ModularExtensionExecutor::ModularMulDivRv32_48(muldiv_chip), @@ -207,20 +234,34 @@ impl VmExtension for ModularExtension { .clone() .map(|x| VmOpcode::from_usize(x + start_offset)), )?; + let modulus_limbs = array::from_fn(|i| { + if i < modulus_limbs.len() { + modulus_limbs[i] as u8 + } else { + 0 + } + }); let isequal_chip = ModularIsEqualChip::new( - Rv32IsEqualModAdapterChip::new( - execution_bus, - program_bus, - memory_bridge, - address_bits, - bitwise_lu_chip.clone(), + ModularIsEqualAir::new( + Rv32IsEqualModAdapterAir::new( + execution_bridge, + memory_bridge, + bitwise_lu_chip.bus(), + pointer_max_bits, + ), + ModularIsEqualCoreAir::new( + modulus.clone(), + bitwise_lu_chip.bus(), + start_offset, + ), ), - ModularIsEqualCoreChip::new( - modulus.clone(), - bitwise_lu_chip.clone(), + VmModularIsEqualStep::new( + Rv32IsEqualModeAdapterStep::new(pointer_max_bits, bitwise_lu_chip.clone()), + modulus_limbs, start_offset, + bitwise_lu_chip.clone(), ), - offline_memory.clone(), + builder.system_base().memory_controller.helper(), ); inventory.add_executor( ModularExtensionExecutor::ModularIsEqualRv32_48(isequal_chip), @@ -258,10 +299,10 @@ pub(crate) mod phantom { use num_bigint::BigUint; use openvm_circuit::{ arch::{PhantomSubExecutor, Streams}, - system::memory::MemoryController, + system::memory::online::GuestMemory, }; use openvm_instructions::{riscv::RV32_MEMORY_AS, PhantomDiscriminant}; - use openvm_rv32im_circuit::adapters::unsafe_read_rv32_register; + use openvm_rv32im_circuit::adapters::read_rv32_register; use openvm_stark_backend::p3_field::PrimeField32; use rand::{rngs::StdRng, SeedableRng}; @@ -282,12 +323,13 @@ pub(crate) mod phantom { // Note that non_qr is fixed for each modulus. impl PhantomSubExecutor for SqrtHintSubEx { fn phantom_execute( - &mut self, - memory: &MemoryController, + &self, + memory: &GuestMemory, streams: &mut Streams, + _: &mut StdRng, _: PhantomDiscriminant, - a: F, - _: F, + a: u32, + _: u32, c_upper: u16, ) -> eyre::Result<()> { let mod_idx = c_upper as usize; @@ -306,15 +348,12 @@ pub(crate) mod phantom { bail!("Modulus too large") }; - let rs1 = unsafe_read_rv32_register(memory, a); - let mut x_limbs: Vec = Vec::with_capacity(num_limbs); - for i in 0..num_limbs { - let limb = memory.unsafe_read_cell( - F::from_canonical_u32(RV32_MEMORY_AS), - F::from_canonical_u32(rs1 + i as u32), - ); - x_limbs.push(limb.as_canonical_u32() as u8); - } + let rs1 = read_rv32_register(memory, a); + // SAFETY: + // - MEMORY_AS consists of `u8`s + // - MEMORY_AS is in bounds + let x_limbs: Vec = + unsafe { memory.memory.get_slice((RV32_MEMORY_AS, rs1), num_limbs) }.to_vec(); let x = BigUint::from_bytes_le(&x_limbs); let (success, sqrt) = match mod_sqrt(&x, modulus, &self.non_qrs[mod_idx]) { @@ -372,12 +411,13 @@ pub(crate) mod phantom { impl PhantomSubExecutor for NonQrHintSubEx { fn phantom_execute( - &mut self, - _: &MemoryController, + &self, + _: &GuestMemory, streams: &mut Streams, + _: &mut StdRng, _: PhantomDiscriminant, - _: F, - _: F, + _: u32, + _: u32, c_upper: u16, ) -> eyre::Result<()> { let mod_idx = c_upper as usize; diff --git a/extensions/algebra/moduli-macros/src/lib.rs b/extensions/algebra/moduli-macros/src/lib.rs index fc30341195..0dc7128588 100644 --- a/extensions/algebra/moduli-macros/src/lib.rs +++ b/extensions/algebra/moduli-macros/src/lib.rs @@ -965,7 +965,6 @@ pub fn moduli_init(input: TokenStream) -> TokenStream { let ModuliDefine { items } = parse_macro_input!(input as ModuliDefine); let mut externs = Vec::new(); - let mut openvm_section = Vec::new(); // List of all modular limbs in one (that is, with a compile-time known size) array. let mut two_modular_limbs_flattened_list = Vec::::new(); @@ -976,8 +975,6 @@ pub fn moduli_init(input: TokenStream) -> TokenStream { for (mod_idx, item) in items.into_iter().enumerate() { let modulus = item.value(); - println!("[init] modulus #{} = {}", mod_idx, modulus); - let modulus_bytes = string_to_bytes(&modulus); let mut limbs = modulus_bytes.len(); let mut block_size = 32; @@ -1012,31 +1009,11 @@ pub fn moduli_init(input: TokenStream) -> TokenStream { .collect::>() .join(""); - let serialized_modulus = - core::iter::once(1) // 1 for "modulus" - .chain(core::iter::once(mod_idx as u8)) // mod_idx is u8 for now (can make it u32), because we don't know the order of - // variables in the elf - .chain((modulus_bytes.len() as u32).to_le_bytes().iter().copied()) - .chain(modulus_bytes.iter().copied()) - .collect::>(); - let serialized_name = syn::Ident::new( - &format!("OPENVM_SERIALIZED_MODULUS_{}", mod_idx), - span.into(), - ); - let serialized_len = serialized_modulus.len(); let setup_extern_func = syn::Ident::new( &format!("moduli_setup_extern_func_{}", modulus_hex), span.into(), ); - openvm_section.push(quote::quote_spanned! { span.into() => - #[cfg(target_os = "zkvm")] - #[link_section = ".openvm"] - #[no_mangle] - #[used] - static #serialized_name: [u8; #serialized_len] = [#(#serialized_modulus),*]; - }); - for op_type in ["add", "sub", "mul", "div"] { let func_name = syn::Ident::new( &format!("{}_extern_func_{}", op_type, modulus_hex), @@ -1126,19 +1103,12 @@ pub fn moduli_init(input: TokenStream) -> TokenStream { extern "C" fn #setup_extern_func() { #[cfg(target_os = "zkvm")] { - let mut ptr = 0; - assert_eq!(super::#serialized_name[ptr], 1); - ptr += 1; - assert_eq!(super::#serialized_name[ptr], #mod_idx as u8); - ptr += 1; - assert_eq!(super::#serialized_name[ptr..ptr+4].iter().rev().fold(0, |acc, &x| acc * 256 + x as usize), #limbs); - ptr += 4; - let remaining = &super::#serialized_name[ptr..]; - // To avoid importing #struct_name, we create a placeholder struct with the same size and alignment. #[repr(C, align(#block_size))] struct AlignedPlaceholder([u8; #limbs]); + const MODULUS_BYTES: AlignedPlaceholder = AlignedPlaceholder([#(#modulus_bytes),*]); + // We are going to use the numeric representation of the `rs2` register to distinguish the chip to setup. // The transpiler will transform this instruction, based on whether `rs2` is `x0`, `x1` or `x2`, into a `SETUP_ADDSUB`, `SETUP_MULDIV` or `SETUP_ISEQ` instruction. let mut uninit: core::mem::MaybeUninit = core::mem::MaybeUninit::uninit(); @@ -1149,7 +1119,7 @@ pub fn moduli_init(input: TokenStream) -> TokenStream { + #mod_idx * (::openvm_algebra_guest::ModArithBaseFunct7::MODULAR_ARITHMETIC_MAX_KINDS as usize), rd = In uninit.as_mut_ptr(), - rs1 = In remaining.as_ptr(), + rs1 = In MODULUS_BYTES.0.as_ptr(), rs2 = Const "x0" // will be parsed as 0 and therefore transpiled to SETUP_ADDMOD ); openvm::platform::custom_insn_r!( @@ -1159,7 +1129,7 @@ pub fn moduli_init(input: TokenStream) -> TokenStream { + #mod_idx * (::openvm_algebra_guest::ModArithBaseFunct7::MODULAR_ARITHMETIC_MAX_KINDS as usize), rd = In uninit.as_mut_ptr(), - rs1 = In remaining.as_ptr(), + rs1 = In MODULUS_BYTES.0.as_ptr(), rs2 = Const "x1" // will be parsed as 1 and therefore transpiled to SETUP_MULDIV ); unsafe { @@ -1172,7 +1142,7 @@ pub fn moduli_init(input: TokenStream) -> TokenStream { + #mod_idx * (::openvm_algebra_guest::ModArithBaseFunct7::MODULAR_ARITHMETIC_MAX_KINDS as usize), rd = InOut tmp, - rs1 = In remaining.as_ptr(), + rs1 = In MODULUS_BYTES.0.as_ptr(), rs2 = Const "x2" // will be parsed as 2 and therefore transpiled to SETUP_ISEQ ); // rd = inout(reg) is necessary because this instruction will write to `rd` register @@ -1185,7 +1155,6 @@ pub fn moduli_init(input: TokenStream) -> TokenStream { let total_limbs_cnt = two_modular_limbs_flattened_list.len(); let cnt_limbs_list_len = limb_list_borders.len(); TokenStream::from(quote::quote_spanned! { span.into() => - #(#openvm_section)* #[allow(non_snake_case)] #[cfg(target_os = "zkvm")] mod openvm_intrinsics_ffi { diff --git a/extensions/algebra/tests/src/lib.rs b/extensions/algebra/tests/src/lib.rs index 181f592544..9c6ab2cb63 100644 --- a/extensions/algebra/tests/src/lib.rs +++ b/extensions/algebra/tests/src/lib.rs @@ -8,7 +8,7 @@ mod tests { Fp2Extension, ModularExtension, Rv32ModularConfig, Rv32ModularWithFp2Config, }; use openvm_algebra_transpiler::{Fp2TranspilerExtension, ModularTranspilerExtension}; - use openvm_circuit::{arch::SystemConfig, utils::air_test}; + use openvm_circuit::utils::{air_test, test_system_config_with_continuations}; use openvm_ecc_circuit::SECP256K1_CONFIG; use openvm_instructions::exe::VmExe; use openvm_rv32im_transpiler::{ @@ -20,11 +20,27 @@ mod tests { type F = BabyBear; + #[cfg(test)] + fn test_rv32modular_config(moduli: Vec) -> Rv32ModularConfig { + let mut config = Rv32ModularConfig::new(moduli); + config.system = test_system_config_with_continuations(); + config + } + + #[cfg(test)] + fn test_rv32modularwithfp2_config( + moduli_with_names: Vec<(String, BigUint)>, + ) -> Rv32ModularWithFp2Config { + let mut config = Rv32ModularWithFp2Config::new(moduli_with_names); + config.system = test_system_config_with_continuations(); + config + } + #[test] fn test_moduli_setup() -> Result<()> { let moduli = ["4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129015664037894272559787", "1000000000000000003", "2305843009213693951"] .map(|s| BigUint::from_str(s).unwrap()); - let config = Rv32ModularConfig::new(moduli.to_vec()); + let config = test_rv32modular_config(moduli.to_vec()); let elf = build_example_program_at_path(get_programs_dir!(), "moduli_setup", &config)?; let openvm_exe = VmExe::from_elf( elf, @@ -41,7 +57,7 @@ mod tests { #[test] fn test_modular() -> Result<()> { - let config = Rv32ModularConfig::new(vec![SECP256K1_CONFIG.modulus.clone()]); + let config = test_rv32modular_config(vec![SECP256K1_CONFIG.modulus.clone()]); let elf = build_example_program_at_path(get_programs_dir!(), "little", &config)?; let openvm_exe = VmExe::from_elf( elf, @@ -57,7 +73,7 @@ mod tests { #[test] fn test_complex_two_moduli() -> Result<()> { - let config = Rv32ModularWithFp2Config::new(vec![ + let config = test_rv32modularwithfp2_config(vec![ ( "Complex1".to_string(), BigUint::from_str("998244353").unwrap(), @@ -85,7 +101,7 @@ mod tests { #[test] fn test_complex_redundant_modulus() -> Result<()> { let config = Rv32ModularWithFp2Config { - system: SystemConfig::default().with_continuations(), + system: test_system_config_with_continuations(), base: Default::default(), mul: Default::default(), io: Default::default(), @@ -120,7 +136,7 @@ mod tests { #[test] fn test_complex() -> Result<()> { - let config = Rv32ModularWithFp2Config::new(vec![( + let config = test_rv32modularwithfp2_config(vec![( "Complex".to_string(), SECP256K1_CONFIG.modulus.clone(), )]); @@ -141,7 +157,7 @@ mod tests { #[test] #[should_panic] fn test_invalid_setup() { - let config = Rv32ModularConfig::new(vec![ + let config = test_rv32modular_config(vec![ BigUint::from_str("998244353").unwrap(), BigUint::from_str("1000000007").unwrap(), ]); @@ -168,7 +184,7 @@ mod tests { #[test] fn test_sqrt() -> Result<()> { - let config = Rv32ModularConfig::new(vec![SECP256K1_CONFIG.modulus.clone()]); + let config = test_rv32modular_config(vec![SECP256K1_CONFIG.modulus.clone()]); let elf = build_example_program_at_path(get_programs_dir!(), "sqrt", &config)?; let openvm_exe = VmExe::from_elf( elf, diff --git a/extensions/bigint/circuit/Cargo.toml b/extensions/bigint/circuit/Cargo.toml index 09d68a9d1b..aa9114c34a 100644 --- a/extensions/bigint/circuit/Cargo.toml +++ b/extensions/bigint/circuit/Cargo.toml @@ -29,6 +29,8 @@ serde.workspace = true openvm-stark-sdk = { workspace = true } openvm-circuit = { workspace = true, features = ["test-utils"] } openvm-rv32-adapters = { workspace = true, features = ["test-utils"] } +test-case.workspace = true +alloy-primitives = { version = "1.2.1" } [features] default = ["parallel", "jemalloc"] diff --git a/extensions/bigint/circuit/src/base_alu.rs b/extensions/bigint/circuit/src/base_alu.rs new file mode 100644 index 0000000000..491787d020 --- /dev/null +++ b/extensions/bigint/circuit/src/base_alu.rs @@ -0,0 +1,262 @@ +use std::{ + borrow::{Borrow, BorrowMut}, + mem::transmute, +}; + +use openvm_bigint_transpiler::Rv32BaseAlu256Opcode; +use openvm_circuit::arch::{ + execution_mode::{E1ExecutionCtx, E2ExecutionCtx}, + E2PreCompute, ExecuteFunc, + ExecutionError::InvalidInstruction, + MatrixRecordArena, NewVmChipWrapper, StepExecutorE1, StepExecutorE2, VmAirWrapper, + VmSegmentState, +}; +use openvm_circuit_derive::{TraceFiller, TraceStep}; +use openvm_circuit_primitives::bitwise_op_lookup::SharedBitwiseOperationLookupChip; +use openvm_circuit_primitives_derive::AlignedBytesBorrow; +use openvm_instructions::{ + instruction::Instruction, + program::DEFAULT_PC_STEP, + riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS}, + LocalOpcode, +}; +use openvm_rv32_adapters::{Rv32HeapAdapterAir, Rv32HeapAdapterStep}; +use openvm_rv32im_circuit::{BaseAluCoreAir, BaseAluStep}; +use openvm_rv32im_transpiler::BaseAluOpcode; +use openvm_stark_backend::p3_field::PrimeField32; + +use crate::{INT256_NUM_LIMBS, RV32_CELL_BITS}; + +pub type Rv32BaseAlu256Air = VmAirWrapper< + Rv32HeapAdapterAir<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS>, + BaseAluCoreAir, +>; + +#[derive(TraceStep, TraceFiller)] +pub struct Rv32BaseAlu256Step(BaseStep); +pub type Rv32BaseAlu256Chip = + NewVmChipWrapper>; + +type BaseStep = BaseAluStep; +type AdapterStep = Rv32HeapAdapterStep<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS>; + +impl Rv32BaseAlu256Step { + pub fn new( + adapter: AdapterStep, + bitwise_lookup_chip: SharedBitwiseOperationLookupChip, + offset: usize, + ) -> Self { + Self(BaseAluStep::new(adapter, bitwise_lookup_chip, offset)) + } +} + +#[derive(AlignedBytesBorrow)] +struct BaseAluPreCompute { + a: u8, + b: u8, + c: u8, +} + +impl StepExecutorE1 for Rv32BaseAlu256Step { + fn pre_compute_size(&self) -> usize { + size_of::() + } + + fn pre_compute_e1( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> openvm_circuit::arch::Result> + where + Ctx: E1ExecutionCtx, + { + let data: &mut BaseAluPreCompute = data.borrow_mut(); + let local_opcode = self.pre_compute_impl(pc, inst, data)?; + let fn_ptr = match local_opcode { + BaseAluOpcode::ADD => execute_e1_impl::<_, _, AddOp>, + BaseAluOpcode::SUB => execute_e1_impl::<_, _, SubOp>, + BaseAluOpcode::XOR => execute_e1_impl::<_, _, XorOp>, + BaseAluOpcode::OR => execute_e1_impl::<_, _, OrOp>, + BaseAluOpcode::AND => execute_e1_impl::<_, _, AndOp>, + }; + Ok(fn_ptr) + } +} + +impl StepExecutorE2 for Rv32BaseAlu256Step { + fn e2_pre_compute_size(&self) -> usize { + size_of::>() + } + + fn pre_compute_e2( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> openvm_circuit::arch::Result> + where + Ctx: E2ExecutionCtx, + { + let data: &mut E2PreCompute = data.borrow_mut(); + data.chip_idx = chip_idx as u32; + let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?; + let fn_ptr = match local_opcode { + BaseAluOpcode::ADD => execute_e2_impl::<_, _, AddOp>, + BaseAluOpcode::SUB => execute_e2_impl::<_, _, SubOp>, + BaseAluOpcode::XOR => execute_e2_impl::<_, _, XorOp>, + BaseAluOpcode::OR => execute_e2_impl::<_, _, OrOp>, + BaseAluOpcode::AND => execute_e2_impl::<_, _, AndOp>, + }; + Ok(fn_ptr) + } +} + +#[inline(always)] +unsafe fn execute_e12_impl( + pre_compute: &BaseAluPreCompute, + vm_state: &mut VmSegmentState, +) { + let rs1_ptr = vm_state.vm_read::(RV32_REGISTER_AS, pre_compute.b as u32); + let rs2_ptr = vm_state.vm_read::(RV32_REGISTER_AS, pre_compute.c as u32); + let rd_ptr = vm_state.vm_read::(RV32_REGISTER_AS, pre_compute.a as u32); + let rs1 = vm_state.vm_read::(RV32_MEMORY_AS, u32::from_le_bytes(rs1_ptr)); + let rs2 = vm_state.vm_read::(RV32_MEMORY_AS, u32::from_le_bytes(rs2_ptr)); + let rd = ::compute(rs1, rs2); + vm_state.vm_write(RV32_MEMORY_AS, u32::from_le_bytes(rd_ptr), &rd); + vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP); + vm_state.instret += 1; +} + +unsafe fn execute_e1_impl( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &BaseAluPreCompute = pre_compute.borrow(); + execute_e12_impl::(pre_compute, vm_state); +} + +unsafe fn execute_e2_impl( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &E2PreCompute = pre_compute.borrow(); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, 1); + execute_e12_impl::(&pre_compute.data, vm_state); +} + +impl Rv32BaseAlu256Step { + fn pre_compute_impl( + &self, + pc: u32, + inst: &Instruction, + data: &mut BaseAluPreCompute, + ) -> openvm_circuit::arch::Result { + let Instruction { + opcode, + a, + b, + c, + d, + e, + .. + } = inst; + let e_u32 = e.as_canonical_u32(); + if d.as_canonical_u32() != RV32_REGISTER_AS || e_u32 != RV32_MEMORY_AS { + return Err(InvalidInstruction(pc)); + } + *data = BaseAluPreCompute { + a: a.as_canonical_u32() as u8, + b: b.as_canonical_u32() as u8, + c: c.as_canonical_u32() as u8, + }; + let local_opcode = + BaseAluOpcode::from_usize(opcode.local_opcode_idx(Rv32BaseAlu256Opcode::CLASS_OFFSET)); + Ok(local_opcode) + } +} + +trait AluOp { + fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> [u8; INT256_NUM_LIMBS]; +} +struct AddOp; +struct SubOp; +struct XorOp; +struct OrOp; +struct AndOp; +impl AluOp for AddOp { + #[inline(always)] + fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> [u8; INT256_NUM_LIMBS] { + let rs1_u64: [u64; 4] = unsafe { transmute(rs1) }; + let rs2_u64: [u64; 4] = unsafe { transmute(rs2) }; + let mut rd_u64 = [0u64; 4]; + let (res, mut carry) = rs1_u64[0].overflowing_add(rs2_u64[0]); + rd_u64[0] = res; + for i in 1..4 { + let (res1, c1) = rs1_u64[i].overflowing_add(rs2_u64[i]); + let (res2, c2) = res1.overflowing_add(carry as u64); + carry = c1 || c2; + rd_u64[i] = res2; + } + unsafe { transmute(rd_u64) } + } +} +impl AluOp for SubOp { + #[inline(always)] + fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> [u8; INT256_NUM_LIMBS] { + let rs1_u64: [u64; 4] = unsafe { transmute(rs1) }; + let rs2_u64: [u64; 4] = unsafe { transmute(rs2) }; + let mut rd_u64 = [0u64; 4]; + let (res, mut borrow) = rs1_u64[0].overflowing_sub(rs2_u64[0]); + rd_u64[0] = res; + for i in 1..4 { + let (res1, c1) = rs1_u64[i].overflowing_sub(rs2_u64[i]); + let (res2, c2) = res1.overflowing_sub(borrow as u64); + borrow = c1 || c2; + rd_u64[i] = res2; + } + unsafe { transmute(rd_u64) } + } +} +impl AluOp for XorOp { + #[inline(always)] + fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> [u8; INT256_NUM_LIMBS] { + let rs1_u64: [u64; 4] = unsafe { transmute(rs1) }; + let rs2_u64: [u64; 4] = unsafe { transmute(rs2) }; + let mut rd_u64 = [0u64; 4]; + // Compiler will expand this loop. + for i in 0..4 { + rd_u64[i] = rs1_u64[i] ^ rs2_u64[i]; + } + unsafe { transmute(rd_u64) } + } +} +impl AluOp for OrOp { + #[inline(always)] + fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> [u8; INT256_NUM_LIMBS] { + let rs1_u64: [u64; 4] = unsafe { transmute(rs1) }; + let rs2_u64: [u64; 4] = unsafe { transmute(rs2) }; + let mut rd_u64 = [0u64; 4]; + // Compiler will expand this loop. + for i in 0..4 { + rd_u64[i] = rs1_u64[i] | rs2_u64[i]; + } + unsafe { transmute(rd_u64) } + } +} +impl AluOp for AndOp { + #[inline(always)] + fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> [u8; INT256_NUM_LIMBS] { + let rs1_u64: [u64; 4] = unsafe { transmute(rs1) }; + let rs2_u64: [u64; 4] = unsafe { transmute(rs2) }; + let mut rd_u64 = [0u64; 4]; + // Compiler will expand this loop. + for i in 0..4 { + rd_u64[i] = rs1_u64[i] & rs2_u64[i]; + } + unsafe { transmute(rd_u64) } + } +} diff --git a/extensions/bigint/circuit/src/branch_eq.rs b/extensions/bigint/circuit/src/branch_eq.rs new file mode 100644 index 0000000000..3607b4a681 --- /dev/null +++ b/extensions/bigint/circuit/src/branch_eq.rs @@ -0,0 +1,188 @@ +use std::borrow::{Borrow, BorrowMut}; + +use openvm_bigint_transpiler::Rv32BranchEqual256Opcode; +use openvm_circuit::arch::{ + execution_mode::{E1ExecutionCtx, E2ExecutionCtx}, + E2PreCompute, ExecuteFunc, + ExecutionError::InvalidInstruction, + MatrixRecordArena, NewVmChipWrapper, StepExecutorE1, StepExecutorE2, VmAirWrapper, + VmSegmentState, +}; +use openvm_circuit_derive::{TraceFiller, TraceStep}; +use openvm_circuit_primitives_derive::AlignedBytesBorrow; +use openvm_instructions::{ + instruction::Instruction, + program::DEFAULT_PC_STEP, + riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS}, + LocalOpcode, +}; +use openvm_rv32_adapters::{Rv32HeapBranchAdapterAir, Rv32HeapBranchAdapterStep}; +use openvm_rv32im_circuit::{BranchEqualCoreAir, BranchEqualStep}; +use openvm_rv32im_transpiler::BranchEqualOpcode; +use openvm_stark_backend::p3_field::PrimeField32; + +use crate::INT256_NUM_LIMBS; + +/// BranchEqual256 +pub type Rv32BranchEqual256Air = VmAirWrapper< + Rv32HeapBranchAdapterAir<2, INT256_NUM_LIMBS>, + BranchEqualCoreAir, +>; +#[derive(TraceStep, TraceFiller)] +pub struct Rv32BranchEqual256Step(BaseStep); +pub type Rv32BranchEqual256Chip = + NewVmChipWrapper>; + +type BaseStep = BranchEqualStep, INT256_NUM_LIMBS>; +type AdapterStep = Rv32HeapBranchAdapterStep<2, INT256_NUM_LIMBS>; + +impl Rv32BranchEqual256Step { + pub fn new(adapter_step: AdapterStep, offset: usize, pc_step: u32) -> Self { + Self(BaseStep::new(adapter_step, offset, pc_step)) + } +} + +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct BranchEqPreCompute { + imm: isize, + a: u8, + b: u8, +} + +impl StepExecutorE1 for Rv32BranchEqual256Step { + fn pre_compute_size(&self) -> usize { + size_of::() + } + + fn pre_compute_e1( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> openvm_circuit::arch::Result> + where + Ctx: E1ExecutionCtx, + { + let data: &mut BranchEqPreCompute = data.borrow_mut(); + let local_opcode = self.pre_compute_impl(pc, inst, data)?; + let fn_ptr = match local_opcode { + BranchEqualOpcode::BEQ => execute_e1_impl::<_, _, false>, + BranchEqualOpcode::BNE => execute_e1_impl::<_, _, true>, + }; + Ok(fn_ptr) + } +} + +impl StepExecutorE2 for Rv32BranchEqual256Step { + fn e2_pre_compute_size(&self) -> usize { + size_of::>() + } + + fn pre_compute_e2( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> openvm_circuit::arch::Result> + where + Ctx: E2ExecutionCtx, + { + let data: &mut E2PreCompute = data.borrow_mut(); + data.chip_idx = chip_idx as u32; + let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?; + let fn_ptr = match local_opcode { + BranchEqualOpcode::BEQ => execute_e2_impl::<_, _, false>, + BranchEqualOpcode::BNE => execute_e2_impl::<_, _, true>, + }; + Ok(fn_ptr) + } +} + +#[inline(always)] +unsafe fn execute_e12_impl( + pre_compute: &BranchEqPreCompute, + vm_state: &mut VmSegmentState, +) { + let rs1_ptr = vm_state.vm_read::(RV32_REGISTER_AS, pre_compute.a as u32); + let rs2_ptr = vm_state.vm_read::(RV32_REGISTER_AS, pre_compute.b as u32); + let rs1 = vm_state.vm_read::(RV32_MEMORY_AS, u32::from_le_bytes(rs1_ptr)); + let rs2 = vm_state.vm_read::(RV32_MEMORY_AS, u32::from_le_bytes(rs2_ptr)); + let cmp_result = u256_eq(rs1, rs2); + if cmp_result ^ IS_NE { + vm_state.pc = (vm_state.pc as isize + pre_compute.imm) as u32; + } else { + vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP); + } + + vm_state.instret += 1; +} + +unsafe fn execute_e1_impl( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &BranchEqPreCompute = pre_compute.borrow(); + execute_e12_impl::(pre_compute, vm_state); +} + +unsafe fn execute_e2_impl( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &E2PreCompute = pre_compute.borrow(); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, 1); + execute_e12_impl::(&pre_compute.data, vm_state); +} + +impl Rv32BranchEqual256Step { + fn pre_compute_impl( + &self, + pc: u32, + inst: &Instruction, + data: &mut BranchEqPreCompute, + ) -> openvm_circuit::arch::Result { + let Instruction { + opcode, + a, + b, + c, + d, + e, + .. + } = inst; + let c = c.as_canonical_u32(); + let imm = if F::ORDER_U32 - c < c { + -((F::ORDER_U32 - c) as isize) + } else { + c as isize + }; + let e_u32 = e.as_canonical_u32(); + if d.as_canonical_u32() != RV32_REGISTER_AS || e_u32 != RV32_MEMORY_AS { + return Err(InvalidInstruction(pc)); + } + *data = BranchEqPreCompute { + imm, + a: a.as_canonical_u32() as u8, + b: b.as_canonical_u32() as u8, + }; + let local_opcode = BranchEqualOpcode::from_usize( + opcode.local_opcode_idx(Rv32BranchEqual256Opcode::CLASS_OFFSET), + ); + Ok(local_opcode) + } +} + +fn u256_eq(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> bool { + let rs1_u64: [u64; 4] = unsafe { std::mem::transmute(rs1) }; + let rs2_u64: [u64; 4] = unsafe { std::mem::transmute(rs2) }; + for i in 0..4 { + if rs1_u64[i] != rs2_u64[i] { + return false; + } + } + true +} diff --git a/extensions/bigint/circuit/src/branch_lt.rs b/extensions/bigint/circuit/src/branch_lt.rs new file mode 100644 index 0000000000..36ae6aebb8 --- /dev/null +++ b/extensions/bigint/circuit/src/branch_lt.rs @@ -0,0 +1,225 @@ +use std::borrow::{Borrow, BorrowMut}; + +use openvm_bigint_transpiler::Rv32BranchLessThan256Opcode; +use openvm_circuit::arch::{ + execution_mode::{E1ExecutionCtx, E2ExecutionCtx}, + E2PreCompute, ExecuteFunc, + ExecutionError::InvalidInstruction, + MatrixRecordArena, NewVmChipWrapper, StepExecutorE1, StepExecutorE2, VmAirWrapper, + VmSegmentState, +}; +use openvm_circuit_derive::{TraceFiller, TraceStep}; +use openvm_circuit_primitives::bitwise_op_lookup::SharedBitwiseOperationLookupChip; +use openvm_circuit_primitives_derive::AlignedBytesBorrow; +use openvm_instructions::{ + instruction::Instruction, + program::DEFAULT_PC_STEP, + riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS}, + LocalOpcode, +}; +use openvm_rv32_adapters::{Rv32HeapBranchAdapterAir, Rv32HeapBranchAdapterStep}; +use openvm_rv32im_circuit::{BranchLessThanCoreAir, BranchLessThanStep}; +use openvm_rv32im_transpiler::BranchLessThanOpcode; +use openvm_stark_backend::p3_field::PrimeField32; + +use crate::{ + common::{i256_lt, u256_lt}, + INT256_NUM_LIMBS, RV32_CELL_BITS, +}; + +/// BranchLessThan256 +pub type Rv32BranchLessThan256Air = VmAirWrapper< + Rv32HeapBranchAdapterAir<2, INT256_NUM_LIMBS>, + BranchLessThanCoreAir, +>; +#[derive(TraceStep, TraceFiller)] +pub struct Rv32BranchLessThan256Step(BaseStep); +pub type Rv32BranchLessThan256Chip = + NewVmChipWrapper>; + +type BaseStep = BranchLessThanStep< + Rv32HeapBranchAdapterStep<2, INT256_NUM_LIMBS>, + INT256_NUM_LIMBS, + RV32_CELL_BITS, +>; +type AdapterStep = Rv32HeapBranchAdapterStep<2, INT256_NUM_LIMBS>; + +impl Rv32BranchLessThan256Step { + pub fn new( + adapter: AdapterStep, + bitwise_lookup_chip: SharedBitwiseOperationLookupChip, + offset: usize, + ) -> Self { + Self(BaseStep::new(adapter, bitwise_lookup_chip, offset)) + } +} + +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct BranchLtPreCompute { + imm: isize, + a: u8, + b: u8, +} + +impl StepExecutorE1 for Rv32BranchLessThan256Step { + fn pre_compute_size(&self) -> usize { + size_of::() + } + + fn pre_compute_e1( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> openvm_circuit::arch::Result> + where + Ctx: E1ExecutionCtx, + { + let data: &mut BranchLtPreCompute = data.borrow_mut(); + let local_opcode = self.pre_compute_impl(pc, inst, data)?; + let fn_ptr = match local_opcode { + BranchLessThanOpcode::BLT => execute_e1_impl::<_, _, BltOp>, + BranchLessThanOpcode::BLTU => execute_e1_impl::<_, _, BltuOp>, + BranchLessThanOpcode::BGE => execute_e1_impl::<_, _, BgeOp>, + BranchLessThanOpcode::BGEU => execute_e1_impl::<_, _, BgeuOp>, + }; + Ok(fn_ptr) + } +} + +impl StepExecutorE2 for Rv32BranchLessThan256Step { + fn e2_pre_compute_size(&self) -> usize { + size_of::>() + } + + fn pre_compute_e2( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> openvm_circuit::arch::Result> + where + Ctx: E2ExecutionCtx, + { + let data: &mut E2PreCompute = data.borrow_mut(); + data.chip_idx = chip_idx as u32; + let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?; + let fn_ptr = match local_opcode { + BranchLessThanOpcode::BLT => execute_e2_impl::<_, _, BltOp>, + BranchLessThanOpcode::BLTU => execute_e2_impl::<_, _, BltuOp>, + BranchLessThanOpcode::BGE => execute_e2_impl::<_, _, BgeOp>, + BranchLessThanOpcode::BGEU => execute_e2_impl::<_, _, BgeuOp>, + }; + Ok(fn_ptr) + } +} + +#[inline(always)] +unsafe fn execute_e12_impl( + pre_compute: &BranchLtPreCompute, + vm_state: &mut VmSegmentState, +) { + let rs1_ptr = vm_state.vm_read::(RV32_REGISTER_AS, pre_compute.a as u32); + let rs2_ptr = vm_state.vm_read::(RV32_REGISTER_AS, pre_compute.b as u32); + let rs1 = vm_state.vm_read::(RV32_MEMORY_AS, u32::from_le_bytes(rs1_ptr)); + let rs2 = vm_state.vm_read::(RV32_MEMORY_AS, u32::from_le_bytes(rs2_ptr)); + let cmp_result = OP::compute(rs1, rs2); + if cmp_result { + vm_state.pc = (vm_state.pc as isize + pre_compute.imm) as u32; + } else { + vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP); + } + vm_state.instret += 1; +} + +unsafe fn execute_e1_impl( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &BranchLtPreCompute = pre_compute.borrow(); + execute_e12_impl::(pre_compute, vm_state); +} + +unsafe fn execute_e2_impl( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &E2PreCompute = pre_compute.borrow(); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, 1); + execute_e12_impl::(&pre_compute.data, vm_state); +} + +impl Rv32BranchLessThan256Step { + fn pre_compute_impl( + &self, + pc: u32, + inst: &Instruction, + data: &mut BranchLtPreCompute, + ) -> openvm_circuit::arch::Result { + let Instruction { + opcode, + a, + b, + c, + d, + e, + .. + } = inst; + let c = c.as_canonical_u32(); + let imm = if F::ORDER_U32 - c < c { + -((F::ORDER_U32 - c) as isize) + } else { + c as isize + }; + let e_u32 = e.as_canonical_u32(); + if d.as_canonical_u32() != RV32_REGISTER_AS || e_u32 != RV32_MEMORY_AS { + return Err(InvalidInstruction(pc)); + } + *data = BranchLtPreCompute { + imm, + a: a.as_canonical_u32() as u8, + b: b.as_canonical_u32() as u8, + }; + let local_opcode = BranchLessThanOpcode::from_usize( + opcode.local_opcode_idx(Rv32BranchLessThan256Opcode::CLASS_OFFSET), + ); + Ok(local_opcode) + } +} + +trait BranchLessThanOp { + fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> bool; +} +struct BltOp; +struct BltuOp; +struct BgeOp; +struct BgeuOp; + +impl BranchLessThanOp for BltOp { + #[inline(always)] + fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> bool { + i256_lt(rs1, rs2) + } +} +impl BranchLessThanOp for BltuOp { + #[inline(always)] + fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> bool { + u256_lt(rs1, rs2) + } +} +impl BranchLessThanOp for BgeOp { + #[inline(always)] + fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> bool { + !i256_lt(rs1, rs2) + } +} +impl BranchLessThanOp for BgeuOp { + #[inline(always)] + fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> bool { + !u256_lt(rs1, rs2) + } +} diff --git a/extensions/bigint/circuit/src/common.rs b/extensions/bigint/circuit/src/common.rs new file mode 100644 index 0000000000..14c49ce68c --- /dev/null +++ b/extensions/bigint/circuit/src/common.rs @@ -0,0 +1,66 @@ +use crate::{INT256_NUM_LIMBS, RV32_CELL_BITS}; + +#[inline(always)] +pub(crate) fn u256_lt(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> bool { + let rs1_u64: [u64; 4] = unsafe { std::mem::transmute(rs1) }; + let rs2_u64: [u64; 4] = unsafe { std::mem::transmute(rs2) }; + for i in (0..4).rev() { + if rs1_u64[i] != rs2_u64[i] { + return rs1_u64[i] < rs2_u64[i]; + } + } + false +} + +#[inline(always)] +pub(crate) fn i256_lt(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> bool { + // true for negative. false for positive + let rs1_sign = rs1[INT256_NUM_LIMBS - 1] >> (RV32_CELL_BITS - 1) == 1; + let rs2_sign = rs2[INT256_NUM_LIMBS - 1] >> (RV32_CELL_BITS - 1) == 1; + let rs1_u64: [u64; 4] = unsafe { std::mem::transmute(rs1) }; + let rs2_u64: [u64; 4] = unsafe { std::mem::transmute(rs2) }; + for i in (0..4).rev() { + if rs1_u64[i] != rs2_u64[i] { + return (rs1_u64[i] < rs2_u64[i]) ^ rs1_sign ^ rs2_sign; + } + } + false +} + +#[cfg(test)] +mod tests { + use alloy_primitives::{I256, U256}; + use rand::{prelude::StdRng, Rng, SeedableRng}; + + use crate::{ + common::{i256_lt, u256_lt}, + INT256_NUM_LIMBS, + }; + + #[test] + fn test_u256_lt() { + let mut rng = StdRng::from_seed([42; 32]); + for _ in 0..10000 { + let limbs_a: [u64; 4] = rng.gen(); + let limbs_b: [u64; 4] = rng.gen(); + let a = U256::from_limbs(limbs_a); + let b = U256::from_limbs(limbs_b); + let a_u8: [u8; INT256_NUM_LIMBS] = unsafe { std::mem::transmute(limbs_a) }; + let b_u8: [u8; INT256_NUM_LIMBS] = unsafe { std::mem::transmute(limbs_b) }; + assert_eq!(u256_lt(a_u8, b_u8), a < b); + } + } + #[test] + fn test_i256_lt() { + let mut rng = StdRng::from_seed([42; 32]); + for _ in 0..10000 { + let limbs_a: [u64; 4] = rng.gen(); + let limbs_b: [u64; 4] = rng.gen(); + let a = I256::from_limbs(limbs_a); + let b = I256::from_limbs(limbs_b); + let a_u8: [u8; INT256_NUM_LIMBS] = unsafe { std::mem::transmute(limbs_a) }; + let b_u8: [u8; INT256_NUM_LIMBS] = unsafe { std::mem::transmute(limbs_b) }; + assert_eq!(i256_lt(a_u8, b_u8), a < b); + } + } +} diff --git a/extensions/bigint/circuit/src/extension.rs b/extensions/bigint/circuit/src/extension.rs index 390b79cc63..65b55b19d5 100644 --- a/extensions/bigint/circuit/src/extension.rs +++ b/extensions/bigint/circuit/src/extension.rs @@ -5,26 +5,34 @@ use openvm_bigint_transpiler::{ }; use openvm_circuit::{ arch::{ - InitFileGenerator, SystemConfig, SystemPort, VmExtension, VmInventory, VmInventoryBuilder, - VmInventoryError, + ExecutionBridge, InitFileGenerator, SystemConfig, SystemPort, VmAirWrapper, VmExtension, + VmInventory, VmInventoryBuilder, VmInventoryError, }, system::phantom::PhantomChip, }; -use openvm_circuit_derive::{AnyEnum, InstructionExecutor, VmConfig}; +use openvm_circuit_derive::{AnyEnum, InsExecutorE1, InsExecutorE2, InstructionExecutor, VmConfig}; use openvm_circuit_primitives::{ bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, range_tuple::{RangeTupleCheckerBus, SharedRangeTupleCheckerChip}, }; use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; use openvm_instructions::{program::DEFAULT_PC_STEP, LocalOpcode}; +use openvm_rv32_adapters::{ + Rv32HeapAdapterAir, Rv32HeapAdapterStep, Rv32HeapBranchAdapterAir, Rv32HeapBranchAdapterStep, +}; use openvm_rv32im_circuit::{ - Rv32I, Rv32IExecutor, Rv32IPeriphery, Rv32Io, Rv32IoExecutor, Rv32IoPeriphery, Rv32M, - Rv32MExecutor, Rv32MPeriphery, + BaseAluCoreAir, BranchEqualCoreAir, BranchLessThanCoreAir, LessThanCoreAir, + MultiplicationCoreAir, Rv32I, Rv32IExecutor, Rv32IPeriphery, Rv32Io, Rv32IoExecutor, + Rv32IoPeriphery, Rv32M, Rv32MExecutor, Rv32MPeriphery, ShiftCoreAir, }; use openvm_stark_backend::p3_field::PrimeField32; use serde::{Deserialize, Serialize}; -use crate::*; +use crate::{ + shift::{Rv32Shift256Chip, Rv32Shift256Step}, + *, +}; +// TODO: this should be decided after e2 execution #[derive(Clone, Debug, VmConfig, derive_new::new, Serialize, Deserialize)] pub struct Int256Rv32Config { @@ -73,7 +81,9 @@ fn default_range_tuple_checker_sizes() -> [u32; 2] { [1 << 8, 32 * (1 << 8)] } -#[derive(ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum)] +#[derive( + ChipUsageGetter, Chip, InstructionExecutor, InsExecutorE1, InsExecutorE2, From, AnyEnum, +)] pub enum Int256Executor { BaseAlu256(Rv32BaseAlu256Chip), LessThan256(Rv32LessThan256Chip), @@ -105,6 +115,8 @@ impl VmExtension for Int256 { program_bus, memory_bridge, } = builder.system_port(); + let execution_bridge = ExecutionBridge::new(execution_bus, program_bus); + let range_checker_chip = builder.system_base().range_checker_chip.clone(); let bitwise_lu_chip = if let Some(&chip) = builder .find_chip::>() @@ -117,8 +129,8 @@ impl VmExtension for Int256 { inventory.add_periphery_chip(chip.clone()); chip }; - let offline_memory = builder.system_base().offline_memory(); - let address_bits = builder.system_config().memory_config.pointer_max_bits; + + let pointer_max_bits = builder.system_config().memory_config.pointer_max_bits; let range_tuple_chip = if let Some(chip) = builder .find_chip::>() @@ -137,66 +149,93 @@ impl VmExtension for Int256 { }; let base_alu_chip = Rv32BaseAlu256Chip::new( - Rv32HeapAdapterChip::new( - execution_bus, - program_bus, - memory_bridge, - address_bits, + VmAirWrapper::new( + Rv32HeapAdapterAir::new( + execution_bridge, + memory_bridge, + bitwise_lu_chip.bus(), + pointer_max_bits, + ), + BaseAluCoreAir::new(bitwise_lu_chip.bus(), Rv32BaseAlu256Opcode::CLASS_OFFSET), + ), + Rv32BaseAlu256Step::new( + Rv32HeapAdapterStep::new(pointer_max_bits, bitwise_lu_chip.clone()), bitwise_lu_chip.clone(), + Rv32BaseAlu256Opcode::CLASS_OFFSET, ), - BaseAluCoreChip::new(bitwise_lu_chip.clone(), Rv32BaseAlu256Opcode::CLASS_OFFSET), - offline_memory.clone(), + builder.system_base().memory_controller.helper(), ); + inventory.add_executor( base_alu_chip, Rv32BaseAlu256Opcode::iter().map(|x| x.global_opcode()), )?; let less_than_chip = Rv32LessThan256Chip::new( - Rv32HeapAdapterChip::new( - execution_bus, - program_bus, - memory_bridge, - address_bits, + VmAirWrapper::new( + Rv32HeapAdapterAir::new( + execution_bridge, + memory_bridge, + bitwise_lu_chip.bus(), + pointer_max_bits, + ), + LessThanCoreAir::new(bitwise_lu_chip.bus(), Rv32LessThan256Opcode::CLASS_OFFSET), + ), + Rv32LessThan256Step::new( + Rv32HeapAdapterStep::new(pointer_max_bits, bitwise_lu_chip.clone()), bitwise_lu_chip.clone(), + Rv32LessThan256Opcode::CLASS_OFFSET, ), - LessThanCoreChip::new(bitwise_lu_chip.clone(), Rv32LessThan256Opcode::CLASS_OFFSET), - offline_memory.clone(), + builder.system_base().memory_controller.helper(), ); + inventory.add_executor( less_than_chip, Rv32LessThan256Opcode::iter().map(|x| x.global_opcode()), )?; let branch_equal_chip = Rv32BranchEqual256Chip::new( - Rv32HeapBranchAdapterChip::new( - execution_bus, - program_bus, - memory_bridge, - address_bits, - bitwise_lu_chip.clone(), + VmAirWrapper::new( + Rv32HeapBranchAdapterAir::new( + execution_bridge, + memory_bridge, + bitwise_lu_chip.bus(), + pointer_max_bits, + ), + BranchEqualCoreAir::new(Rv32BranchEqual256Opcode::CLASS_OFFSET, DEFAULT_PC_STEP), ), - BranchEqualCoreChip::new(Rv32BranchEqual256Opcode::CLASS_OFFSET, DEFAULT_PC_STEP), - offline_memory.clone(), + Rv32BranchEqual256Step::new( + Rv32HeapBranchAdapterStep::new(pointer_max_bits, bitwise_lu_chip.clone()), + Rv32BranchEqual256Opcode::CLASS_OFFSET, + DEFAULT_PC_STEP, + ), + builder.system_base().memory_controller.helper(), ); + inventory.add_executor( branch_equal_chip, Rv32BranchEqual256Opcode::iter().map(|x| x.global_opcode()), )?; let branch_less_than_chip = Rv32BranchLessThan256Chip::new( - Rv32HeapBranchAdapterChip::new( - execution_bus, - program_bus, - memory_bridge, - address_bits, - bitwise_lu_chip.clone(), + VmAirWrapper::new( + Rv32HeapBranchAdapterAir::new( + execution_bridge, + memory_bridge, + bitwise_lu_chip.bus(), + pointer_max_bits, + ), + BranchLessThanCoreAir::new( + bitwise_lu_chip.bus(), + Rv32BranchLessThan256Opcode::CLASS_OFFSET, + ), ), - BranchLessThanCoreChip::new( + Rv32BranchLessThan256Step::new( + Rv32HeapBranchAdapterStep::new(pointer_max_bits, bitwise_lu_chip.clone()), bitwise_lu_chip.clone(), Rv32BranchLessThan256Opcode::CLASS_OFFSET, ), - offline_memory.clone(), + builder.system_base().memory_controller.helper(), ); inventory.add_executor( branch_less_than_chip, @@ -204,36 +243,51 @@ impl VmExtension for Int256 { )?; let multiplication_chip = Rv32Multiplication256Chip::new( - Rv32HeapAdapterChip::new( - execution_bus, - program_bus, - memory_bridge, - address_bits, - bitwise_lu_chip.clone(), + VmAirWrapper::new( + Rv32HeapAdapterAir::new( + execution_bridge, + memory_bridge, + bitwise_lu_chip.bus(), + pointer_max_bits, + ), + MultiplicationCoreAir::new(*range_tuple_chip.bus(), Rv32Mul256Opcode::CLASS_OFFSET), ), - MultiplicationCoreChip::new(range_tuple_chip, Rv32Mul256Opcode::CLASS_OFFSET), - offline_memory.clone(), + Rv32Multiplication256Step::new( + Rv32HeapAdapterStep::new(pointer_max_bits, bitwise_lu_chip.clone()), + range_tuple_chip.clone(), + Rv32Mul256Opcode::CLASS_OFFSET, + ), + builder.system_base().memory_controller.helper(), ); + inventory.add_executor( multiplication_chip, Rv32Mul256Opcode::iter().map(|x| x.global_opcode()), )?; let shift_chip = Rv32Shift256Chip::new( - Rv32HeapAdapterChip::new( - execution_bus, - program_bus, - memory_bridge, - address_bits, - bitwise_lu_chip.clone(), + VmAirWrapper::new( + Rv32HeapAdapterAir::new( + execution_bridge, + memory_bridge, + bitwise_lu_chip.bus(), + pointer_max_bits, + ), + ShiftCoreAir::new( + bitwise_lu_chip.bus(), + range_checker_chip.bus(), + Rv32Shift256Opcode::CLASS_OFFSET, + ), ), - ShiftCoreChip::new( + Rv32Shift256Step::new( + Rv32HeapAdapterStep::new(pointer_max_bits, bitwise_lu_chip.clone()), bitwise_lu_chip.clone(), - range_checker_chip, + range_checker_chip.clone(), Rv32Shift256Opcode::CLASS_OFFSET, ), - offline_memory.clone(), + builder.system_base().memory_controller.helper(), ); + inventory.add_executor( shift_chip, Rv32Shift256Opcode::iter().map(|x| x.global_opcode()), diff --git a/extensions/bigint/circuit/src/less_than.rs b/extensions/bigint/circuit/src/less_than.rs new file mode 100644 index 0000000000..df46676c27 --- /dev/null +++ b/extensions/bigint/circuit/src/less_than.rs @@ -0,0 +1,180 @@ +use std::borrow::{Borrow, BorrowMut}; + +use openvm_bigint_transpiler::Rv32LessThan256Opcode; +use openvm_circuit::arch::{ + execution_mode::{E1ExecutionCtx, E2ExecutionCtx}, + E2PreCompute, ExecuteFunc, + ExecutionError::InvalidInstruction, + MatrixRecordArena, NewVmChipWrapper, StepExecutorE1, StepExecutorE2, VmAirWrapper, + VmSegmentState, +}; +use openvm_circuit_derive::{TraceFiller, TraceStep}; +use openvm_circuit_primitives::bitwise_op_lookup::SharedBitwiseOperationLookupChip; +use openvm_circuit_primitives_derive::AlignedBytesBorrow; +use openvm_instructions::{ + instruction::Instruction, + program::DEFAULT_PC_STEP, + riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS}, + LocalOpcode, +}; +use openvm_rv32_adapters::{Rv32HeapAdapterAir, Rv32HeapAdapterStep}; +use openvm_rv32im_circuit::{LessThanCoreAir, LessThanStep}; +use openvm_rv32im_transpiler::LessThanOpcode; +use openvm_stark_backend::p3_field::PrimeField32; + +use crate::{common, INT256_NUM_LIMBS, RV32_CELL_BITS}; + +/// LessThan256 +pub type Rv32LessThan256Air = VmAirWrapper< + Rv32HeapAdapterAir<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS>, + LessThanCoreAir, +>; +#[derive(TraceStep, TraceFiller)] +pub struct Rv32LessThan256Step(BaseStep); +pub type Rv32LessThan256Chip = + NewVmChipWrapper>; + +type BaseStep = LessThanStep; +type AdapterStep = Rv32HeapAdapterStep<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS>; + +impl Rv32LessThan256Step { + pub fn new( + adapter: AdapterStep, + bitwise_lookup_chip: SharedBitwiseOperationLookupChip, + offset: usize, + ) -> Self { + Self(BaseStep::new(adapter, bitwise_lookup_chip, offset)) + } +} + +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct LessThanPreCompute { + a: u8, + b: u8, + c: u8, +} + +impl StepExecutorE1 for Rv32LessThan256Step { + fn pre_compute_size(&self) -> usize { + size_of::() + } + + fn pre_compute_e1( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> openvm_circuit::arch::Result> + where + Ctx: E1ExecutionCtx, + { + let data: &mut LessThanPreCompute = data.borrow_mut(); + let local_opcode = self.pre_compute_impl(pc, inst, data)?; + let fn_ptr = match local_opcode { + LessThanOpcode::SLT => execute_e1_impl::<_, _, false>, + LessThanOpcode::SLTU => execute_e1_impl::<_, _, true>, + }; + Ok(fn_ptr) + } +} + +impl StepExecutorE2 for Rv32LessThan256Step { + fn e2_pre_compute_size(&self) -> usize { + size_of::>() + } + + fn pre_compute_e2( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> openvm_circuit::arch::Result> + where + Ctx: E2ExecutionCtx, + { + let data: &mut E2PreCompute = data.borrow_mut(); + data.chip_idx = chip_idx as u32; + let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?; + let fn_ptr = match local_opcode { + LessThanOpcode::SLT => execute_e2_impl::<_, _, false>, + LessThanOpcode::SLTU => execute_e2_impl::<_, _, true>, + }; + Ok(fn_ptr) + } +} + +#[inline(always)] +unsafe fn execute_e12_impl( + pre_compute: &LessThanPreCompute, + vm_state: &mut VmSegmentState, +) { + let rs1_ptr = vm_state.vm_read::(RV32_REGISTER_AS, pre_compute.b as u32); + let rs2_ptr = vm_state.vm_read::(RV32_REGISTER_AS, pre_compute.c as u32); + let rd_ptr = vm_state.vm_read::(RV32_REGISTER_AS, pre_compute.a as u32); + let rs1 = vm_state.vm_read::(RV32_MEMORY_AS, u32::from_le_bytes(rs1_ptr)); + let rs2 = vm_state.vm_read::(RV32_MEMORY_AS, u32::from_le_bytes(rs2_ptr)); + let cmp_result = if IS_U256 { + common::u256_lt(rs1, rs2) + } else { + common::i256_lt(rs1, rs2) + }; + let mut rd = [0u8; INT256_NUM_LIMBS]; + rd[0] = cmp_result as u8; + vm_state.vm_write(RV32_MEMORY_AS, u32::from_le_bytes(rd_ptr), &rd); + + vm_state.pc += DEFAULT_PC_STEP; + vm_state.instret += 1; +} + +unsafe fn execute_e1_impl( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &LessThanPreCompute = pre_compute.borrow(); + execute_e12_impl::(pre_compute, vm_state); +} + +unsafe fn execute_e2_impl( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &E2PreCompute = pre_compute.borrow(); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, 1); + execute_e12_impl::(&pre_compute.data, vm_state); +} + +impl Rv32LessThan256Step { + fn pre_compute_impl( + &self, + pc: u32, + inst: &Instruction, + data: &mut LessThanPreCompute, + ) -> openvm_circuit::arch::Result { + let Instruction { + opcode, + a, + b, + c, + d, + e, + .. + } = inst; + let e_u32 = e.as_canonical_u32(); + if d.as_canonical_u32() != RV32_REGISTER_AS || e_u32 != RV32_MEMORY_AS { + return Err(InvalidInstruction(pc)); + } + *data = LessThanPreCompute { + a: a.as_canonical_u32() as u8, + b: b.as_canonical_u32() as u8, + c: c.as_canonical_u32() as u8, + }; + let local_opcode = LessThanOpcode::from_usize( + opcode.local_opcode_idx(Rv32LessThan256Opcode::CLASS_OFFSET), + ); + Ok(local_opcode) + } +} diff --git a/extensions/bigint/circuit/src/lib.rs b/extensions/bigint/circuit/src/lib.rs index 295ef73db2..2bbe6a2052 100644 --- a/extensions/bigint/circuit/src/lib.rs +++ b/extensions/bigint/circuit/src/lib.rs @@ -1,49 +1,22 @@ -use openvm_circuit::{self, arch::VmChipWrapper}; -use openvm_rv32_adapters::{Rv32HeapAdapterChip, Rv32HeapBranchAdapterChip}; -use openvm_rv32im_circuit::{ - adapters::{INT256_NUM_LIMBS, RV32_CELL_BITS}, - BaseAluCoreChip, BranchEqualCoreChip, BranchLessThanCoreChip, LessThanCoreChip, - MultiplicationCoreChip, ShiftCoreChip, -}; - mod extension; pub use extension::*; +mod base_alu; +mod branch_eq; +mod branch_lt; +pub(crate) mod common; +mod less_than; +mod mult; +mod shift; #[cfg(test)] mod tests; -pub type Rv32BaseAlu256Chip = VmChipWrapper< - F, - Rv32HeapAdapterChip, - BaseAluCoreChip, ->; - -pub type Rv32LessThan256Chip = VmChipWrapper< - F, - Rv32HeapAdapterChip, - LessThanCoreChip, ->; - -pub type Rv32Multiplication256Chip = VmChipWrapper< - F, - Rv32HeapAdapterChip, - MultiplicationCoreChip, ->; - -pub type Rv32Shift256Chip = VmChipWrapper< - F, - Rv32HeapAdapterChip, - ShiftCoreChip, ->; - -pub type Rv32BranchEqual256Chip = VmChipWrapper< - F, - Rv32HeapBranchAdapterChip, - BranchEqualCoreChip, ->; +pub use base_alu::*; +pub use branch_eq::*; +pub use branch_lt::*; +pub use less_than::*; +pub use mult::*; +pub use shift::*; -pub type Rv32BranchLessThan256Chip = VmChipWrapper< - F, - Rv32HeapBranchAdapterChip, - BranchLessThanCoreChip, ->; +pub(crate) const INT256_NUM_LIMBS: usize = 32; +pub(crate) const RV32_CELL_BITS: usize = 8; diff --git a/extensions/bigint/circuit/src/mult.rs b/extensions/bigint/circuit/src/mult.rs new file mode 100644 index 0000000000..29b03151c4 --- /dev/null +++ b/extensions/bigint/circuit/src/mult.rs @@ -0,0 +1,207 @@ +use std::borrow::{Borrow, BorrowMut}; + +use openvm_bigint_transpiler::Rv32Mul256Opcode; +use openvm_circuit::arch::{ + execution_mode::{E1ExecutionCtx, E2ExecutionCtx}, + E2PreCompute, ExecuteFunc, + ExecutionError::InvalidInstruction, + MatrixRecordArena, NewVmChipWrapper, StepExecutorE1, StepExecutorE2, VmAirWrapper, + VmSegmentState, +}; +use openvm_circuit_derive::{TraceFiller, TraceStep}; +use openvm_circuit_primitives::range_tuple::SharedRangeTupleCheckerChip; +use openvm_circuit_primitives_derive::AlignedBytesBorrow; +use openvm_instructions::{ + instruction::Instruction, + program::DEFAULT_PC_STEP, + riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS}, + LocalOpcode, +}; +use openvm_rv32_adapters::{Rv32HeapAdapterAir, Rv32HeapAdapterStep}; +use openvm_rv32im_circuit::{MultiplicationCoreAir, MultiplicationStep}; +use openvm_rv32im_transpiler::MulOpcode; +use openvm_stark_backend::p3_field::PrimeField32; + +use crate::{INT256_NUM_LIMBS, RV32_CELL_BITS}; + +/// Multiplication256 +pub type Rv32Multiplication256Air = VmAirWrapper< + Rv32HeapAdapterAir<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS>, + MultiplicationCoreAir, +>; +#[derive(TraceStep, TraceFiller)] +pub struct Rv32Multiplication256Step(BaseStep); +pub type Rv32Multiplication256Chip = + NewVmChipWrapper>; + +type BaseStep = MultiplicationStep; +type AdapterStep = Rv32HeapAdapterStep<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS>; + +impl Rv32Multiplication256Step { + pub fn new( + adapter: AdapterStep, + range_tuple_chip: SharedRangeTupleCheckerChip<2>, + offset: usize, + ) -> Self { + Self(BaseStep::new(adapter, range_tuple_chip, offset)) + } +} + +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct MultPreCompute { + a: u8, + b: u8, + c: u8, +} + +impl StepExecutorE1 for Rv32Multiplication256Step { + fn pre_compute_size(&self) -> usize { + size_of::() + } + + fn pre_compute_e1( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> openvm_circuit::arch::Result> + where + Ctx: E1ExecutionCtx, + { + let data: &mut MultPreCompute = data.borrow_mut(); + self.pre_compute_impl(pc, inst, data)?; + Ok(execute_e1_impl) + } +} + +impl StepExecutorE2 for Rv32Multiplication256Step { + fn e2_pre_compute_size(&self) -> usize { + size_of::>() + } + + fn pre_compute_e2( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> openvm_circuit::arch::Result> + where + Ctx: E2ExecutionCtx, + { + let data: &mut E2PreCompute = data.borrow_mut(); + data.chip_idx = chip_idx as u32; + self.pre_compute_impl(pc, inst, &mut data.data)?; + Ok(execute_e2_impl) + } +} + +#[inline(always)] +unsafe fn execute_e12_impl( + pre_compute: &MultPreCompute, + vm_state: &mut VmSegmentState, +) { + let rs1_ptr = vm_state.vm_read::(RV32_REGISTER_AS, pre_compute.b as u32); + let rs2_ptr = vm_state.vm_read::(RV32_REGISTER_AS, pre_compute.c as u32); + let rd_ptr = vm_state.vm_read::(RV32_REGISTER_AS, pre_compute.a as u32); + let rs1 = vm_state.vm_read::(RV32_MEMORY_AS, u32::from_le_bytes(rs1_ptr)); + let rs2 = vm_state.vm_read::(RV32_MEMORY_AS, u32::from_le_bytes(rs2_ptr)); + let rd = u256_mul(rs1, rs2); + vm_state.vm_write(RV32_MEMORY_AS, u32::from_le_bytes(rd_ptr), &rd); + + vm_state.pc += DEFAULT_PC_STEP; + vm_state.instret += 1; +} + +unsafe fn execute_e1_impl( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &MultPreCompute = pre_compute.borrow(); + execute_e12_impl(pre_compute, vm_state); +} + +unsafe fn execute_e2_impl( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &E2PreCompute = pre_compute.borrow(); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, 1); + execute_e12_impl(&pre_compute.data, vm_state); +} + +impl Rv32Multiplication256Step { + fn pre_compute_impl( + &self, + pc: u32, + inst: &Instruction, + data: &mut MultPreCompute, + ) -> openvm_circuit::arch::Result<()> { + let Instruction { + opcode, + a, + b, + c, + d, + e, + .. + } = inst; + let e_u32 = e.as_canonical_u32(); + if d.as_canonical_u32() != RV32_REGISTER_AS || e_u32 != RV32_MEMORY_AS { + return Err(InvalidInstruction(pc)); + } + let local_opcode = + MulOpcode::from_usize(opcode.local_opcode_idx(Rv32Mul256Opcode::CLASS_OFFSET)); + assert_eq!(local_opcode, MulOpcode::MUL); + *data = MultPreCompute { + a: a.as_canonical_u32() as u8, + b: b.as_canonical_u32() as u8, + c: c.as_canonical_u32() as u8, + }; + Ok(()) + } +} + +#[inline(always)] +pub(crate) fn u256_mul( + rs1: [u8; INT256_NUM_LIMBS], + rs2: [u8; INT256_NUM_LIMBS], +) -> [u8; INT256_NUM_LIMBS] { + let rs1_u64: [u32; 8] = unsafe { std::mem::transmute(rs1) }; + let rs2_u64: [u32; 8] = unsafe { std::mem::transmute(rs2) }; + let mut rd = [0u32; 8]; + for i in 0..8 { + let mut carry = 0u64; + for j in 0..(8 - i) { + let res = rs1_u64[i] as u64 * rs2_u64[j] as u64 + rd[i + j] as u64 + carry; + rd[i + j] = res as u32; + carry = res >> 32; + } + } + unsafe { std::mem::transmute(rd) } +} + +#[cfg(test)] +mod tests { + use alloy_primitives::U256; + use rand::{prelude::StdRng, Rng, SeedableRng}; + + use crate::{mult::u256_mul, INT256_NUM_LIMBS}; + + #[test] + fn test_u256_mul() { + let mut rng = StdRng::from_seed([42; 32]); + for _ in 0..10000 { + let limbs_a: [u64; 4] = rng.gen(); + let limbs_b: [u64; 4] = rng.gen(); + let a = U256::from_limbs(limbs_a); + let b = U256::from_limbs(limbs_b); + let a_u8: [u8; INT256_NUM_LIMBS] = unsafe { std::mem::transmute(limbs_a) }; + let b_u8: [u8; INT256_NUM_LIMBS] = unsafe { std::mem::transmute(limbs_b) }; + assert_eq!(U256::from_le_bytes(u256_mul(a_u8, b_u8)), a.wrapping_mul(b)); + } + } +} diff --git a/extensions/bigint/circuit/src/shift.rs b/extensions/bigint/circuit/src/shift.rs new file mode 100644 index 0000000000..31f05cf255 --- /dev/null +++ b/extensions/bigint/circuit/src/shift.rs @@ -0,0 +1,289 @@ +use std::borrow::{Borrow, BorrowMut}; + +use openvm_bigint_transpiler::Rv32Shift256Opcode; +use openvm_circuit::arch::{ + execution_mode::{E1ExecutionCtx, E2ExecutionCtx}, + E2PreCompute, ExecuteFunc, + ExecutionError::InvalidInstruction, + MatrixRecordArena, NewVmChipWrapper, StepExecutorE1, StepExecutorE2, VmAirWrapper, + VmSegmentState, +}; +use openvm_circuit_derive::{TraceFiller, TraceStep}; +use openvm_circuit_primitives::{ + bitwise_op_lookup::SharedBitwiseOperationLookupChip, var_range::SharedVariableRangeCheckerChip, +}; +use openvm_circuit_primitives_derive::AlignedBytesBorrow; +use openvm_instructions::{ + instruction::Instruction, + program::DEFAULT_PC_STEP, + riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS}, + LocalOpcode, +}; +use openvm_rv32_adapters::{Rv32HeapAdapterAir, Rv32HeapAdapterStep}; +use openvm_rv32im_circuit::{ShiftCoreAir, ShiftStep}; +use openvm_rv32im_transpiler::ShiftOpcode; +use openvm_stark_backend::p3_field::PrimeField32; + +use crate::{INT256_NUM_LIMBS, RV32_CELL_BITS}; + +/// Shift256 +pub type Rv32Shift256Air = VmAirWrapper< + Rv32HeapAdapterAir<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS>, + ShiftCoreAir, +>; +#[derive(TraceStep, TraceFiller)] +pub struct Rv32Shift256Step(BaseStep); +pub type Rv32Shift256Chip = + NewVmChipWrapper>; + +type BaseStep = ShiftStep; +type AdapterStep = Rv32HeapAdapterStep<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS>; + +impl Rv32Shift256Step { + pub fn new( + adapter: AdapterStep, + bitwise_lookup_chip: SharedBitwiseOperationLookupChip, + range_checker_chip: SharedVariableRangeCheckerChip, + offset: usize, + ) -> Self { + Self(BaseStep::new( + adapter, + bitwise_lookup_chip, + range_checker_chip, + offset, + )) + } +} + +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct ShiftPreCompute { + a: u8, + b: u8, + c: u8, +} + +impl StepExecutorE1 for Rv32Shift256Step { + fn pre_compute_size(&self) -> usize { + size_of::() + } + + fn pre_compute_e1( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> openvm_circuit::arch::Result> + where + Ctx: E1ExecutionCtx, + { + let data: &mut ShiftPreCompute = data.borrow_mut(); + let local_opcode = self.pre_compute_impl(pc, inst, data)?; + let fn_ptr = match local_opcode { + ShiftOpcode::SLL => execute_e1_impl::<_, _, SllOp>, + ShiftOpcode::SRA => execute_e1_impl::<_, _, SraOp>, + ShiftOpcode::SRL => execute_e1_impl::<_, _, SrlOp>, + }; + Ok(fn_ptr) + } +} + +impl StepExecutorE2 for Rv32Shift256Step { + fn e2_pre_compute_size(&self) -> usize { + size_of::>() + } + + fn pre_compute_e2( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> openvm_circuit::arch::Result> + where + Ctx: E2ExecutionCtx, + { + let data: &mut E2PreCompute = data.borrow_mut(); + data.chip_idx = chip_idx as u32; + let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?; + let fn_ptr = match local_opcode { + ShiftOpcode::SLL => execute_e2_impl::<_, _, SllOp>, + ShiftOpcode::SRA => execute_e2_impl::<_, _, SraOp>, + ShiftOpcode::SRL => execute_e2_impl::<_, _, SrlOp>, + }; + Ok(fn_ptr) + } +} + +#[inline(always)] +unsafe fn execute_e12_impl( + pre_compute: &ShiftPreCompute, + vm_state: &mut VmSegmentState, +) { + let rs1_ptr = vm_state.vm_read::(RV32_REGISTER_AS, pre_compute.b as u32); + let rs2_ptr = vm_state.vm_read::(RV32_REGISTER_AS, pre_compute.c as u32); + let rd_ptr = vm_state.vm_read::(RV32_REGISTER_AS, pre_compute.a as u32); + let rs1 = vm_state.vm_read::(RV32_MEMORY_AS, u32::from_le_bytes(rs1_ptr)); + let rs2 = vm_state.vm_read::(RV32_MEMORY_AS, u32::from_le_bytes(rs2_ptr)); + let rd = OP::compute(rs1, rs2); + vm_state.vm_write(RV32_MEMORY_AS, u32::from_le_bytes(rd_ptr), &rd); + vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP); + vm_state.instret += 1; +} + +unsafe fn execute_e1_impl( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &ShiftPreCompute = pre_compute.borrow(); + execute_e12_impl::(pre_compute, vm_state); +} +unsafe fn execute_e2_impl( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &E2PreCompute = pre_compute.borrow(); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, 1); + execute_e12_impl::(&pre_compute.data, vm_state); +} + +impl Rv32Shift256Step { + fn pre_compute_impl( + &self, + pc: u32, + inst: &Instruction, + data: &mut ShiftPreCompute, + ) -> openvm_circuit::arch::Result { + let Instruction { + opcode, + a, + b, + c, + d, + e, + .. + } = inst; + let e_u32 = e.as_canonical_u32(); + if d.as_canonical_u32() != RV32_REGISTER_AS || e_u32 != RV32_MEMORY_AS { + return Err(InvalidInstruction(pc)); + } + *data = ShiftPreCompute { + a: a.as_canonical_u32() as u8, + b: b.as_canonical_u32() as u8, + c: c.as_canonical_u32() as u8, + }; + let local_opcode = + ShiftOpcode::from_usize(opcode.local_opcode_idx(Rv32Shift256Opcode::CLASS_OFFSET)); + Ok(local_opcode) + } +} + +trait ShiftOp { + fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> [u8; INT256_NUM_LIMBS]; +} +struct SllOp; +struct SrlOp; +struct SraOp; +impl ShiftOp for SllOp { + #[inline(always)] + fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> [u8; INT256_NUM_LIMBS] { + let rs1_u64: [u64; 4] = unsafe { std::mem::transmute(rs1) }; + let rs2_u64: [u64; 4] = unsafe { std::mem::transmute(rs2) }; + let mut rd = [0u64; 4]; + // Only use the first 8 bits. + let shift = (rs2_u64[0] & 0xff) as u32; + let index_offset = (shift / u64::BITS) as usize; + let bit_offset = shift % u64::BITS; + let mut carry = 0u64; + for i in index_offset..4 { + let curr = rs1_u64[i - index_offset]; + rd[i] = (curr << bit_offset) + carry; + if bit_offset > 0 { + carry = curr >> (u64::BITS - bit_offset); + } + } + unsafe { std::mem::transmute(rd) } + } +} +impl ShiftOp for SrlOp { + #[inline(always)] + fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> [u8; INT256_NUM_LIMBS] { + // Logical right shift - fill with 0 + shift_right(rs1, rs2, 0) + } +} +impl ShiftOp for SraOp { + #[inline(always)] + fn compute(rs1: [u8; INT256_NUM_LIMBS], rs2: [u8; INT256_NUM_LIMBS]) -> [u8; INT256_NUM_LIMBS] { + // Arithmetic right shift - fill with sign bit + if rs1[INT256_NUM_LIMBS - 1] & 0x80 > 0 { + shift_right(rs1, rs2, u64::MAX) + } else { + shift_right(rs1, rs2, 0) + } + } +} + +#[inline(always)] +fn shift_right( + rs1: [u8; INT256_NUM_LIMBS], + rs2: [u8; INT256_NUM_LIMBS], + init_value: u64, +) -> [u8; INT256_NUM_LIMBS] { + let rs1_u64: [u64; 4] = unsafe { std::mem::transmute(rs1) }; + let rs2_u64: [u64; 4] = unsafe { std::mem::transmute(rs2) }; + let mut rd = [init_value; 4]; + let shift = (rs2_u64[0] & 0xff) as u32; + let index_offset = (shift / u64::BITS) as usize; + let bit_offset = shift % u64::BITS; + let mut carry = if bit_offset > 0 { + init_value << (u64::BITS - bit_offset) + } else { + 0 + }; + for i in (index_offset..4).rev() { + let curr = rs1_u64[i]; + rd[i - index_offset] = (curr >> bit_offset) + carry; + if bit_offset > 0 { + carry = curr << (u64::BITS - bit_offset); + } + } + unsafe { std::mem::transmute(rd) } +} + +#[cfg(test)] +mod tests { + use alloy_primitives::U256; + use rand::{prelude::StdRng, Rng, SeedableRng}; + + use crate::{ + shift::{ShiftOp, SllOp, SraOp, SrlOp}, + INT256_NUM_LIMBS, + }; + + #[test] + fn test_shift_op() { + let mut rng = StdRng::from_seed([42; 32]); + for _ in 0..10000 { + let limbs_a: [u8; INT256_NUM_LIMBS] = rng.gen(); + let mut limbs_b: [u8; INT256_NUM_LIMBS] = [0; INT256_NUM_LIMBS]; + let shift: u8 = rng.gen(); + limbs_b[0] = shift; + let a = U256::from_le_bytes(limbs_a); + { + let res = SllOp::compute(limbs_a, limbs_b); + assert_eq!(U256::from_le_bytes(res), a << shift); + } + { + let res = SraOp::compute(limbs_a, limbs_b); + assert_eq!(U256::from_le_bytes(res), a.arithmetic_shr(shift as usize)); + } + { + let res = SrlOp::compute(limbs_a, limbs_b); + assert_eq!(U256::from_le_bytes(res), a >> shift); + } + } + } +} diff --git a/extensions/bigint/circuit/src/tests.rs b/extensions/bigint/circuit/src/tests.rs index 0e26352410..a12b98caf9 100644 --- a/extensions/bigint/circuit/src/tests.rs +++ b/extensions/bigint/circuit/src/tests.rs @@ -5,7 +5,7 @@ use openvm_bigint_transpiler::{ use openvm_circuit::{ arch::{ testing::{VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS, RANGE_TUPLE_CHECKER_BUS}, - InstructionExecutor, + InsExecutorE1, InstructionExecutor, VmAirWrapper, }, utils::generate_long_number, }; @@ -13,171 +13,173 @@ use openvm_circuit_primitives::{ bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, range_tuple::{RangeTupleCheckerBus, SharedRangeTupleCheckerChip}, }; -use openvm_instructions::{program::PC_BITS, riscv::RV32_CELL_BITS, LocalOpcode}; +use openvm_instructions::{ + program::{DEFAULT_PC_STEP, PC_BITS}, + riscv::RV32_CELL_BITS, + LocalOpcode, +}; use openvm_rv32_adapters::{ - rv32_heap_branch_default, rv32_write_heap_default, Rv32HeapAdapterChip, - Rv32HeapBranchAdapterChip, + rv32_heap_branch_default, rv32_write_heap_default, Rv32HeapAdapterAir, Rv32HeapAdapterStep, + Rv32HeapBranchAdapterAir, Rv32HeapBranchAdapterStep, }; use openvm_rv32im_circuit::{ adapters::{INT256_NUM_LIMBS, RV_B_TYPE_IMM_BITS}, - BaseAluCoreChip, BranchEqualCoreChip, BranchLessThanCoreChip, LessThanCoreChip, - MultiplicationCoreChip, ShiftCoreChip, + BaseAluCoreAir, BranchEqualCoreAir, BranchLessThanCoreAir, LessThanCoreAir, + MultiplicationCoreAir, ShiftCoreAir, }; use openvm_rv32im_transpiler::{ - BaseAluOpcode, BranchEqualOpcode, BranchLessThanOpcode, LessThanOpcode, ShiftOpcode, + BaseAluOpcode, BranchEqualOpcode, BranchLessThanOpcode, LessThanOpcode, MulOpcode, ShiftOpcode, }; use openvm_stark_backend::p3_field::{FieldAlgebra, PrimeField32}; use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; -use rand::Rng; +use rand::{rngs::StdRng, Rng}; +use test_case::test_case; -use super::{ - Rv32BaseAlu256Chip, Rv32BranchEqual256Chip, Rv32BranchLessThan256Chip, Rv32LessThan256Chip, - Rv32Multiplication256Chip, Rv32Shift256Chip, +use crate::{ + Rv32BaseAlu256Chip, Rv32BaseAlu256Step, Rv32BranchEqual256Chip, Rv32BranchEqual256Step, + Rv32BranchLessThan256Chip, Rv32BranchLessThan256Step, Rv32LessThan256Chip, Rv32LessThan256Step, + Rv32Multiplication256Chip, Rv32Multiplication256Step, Rv32Shift256Chip, Rv32Shift256Step, }; type F = BabyBear; +const MAX_INS_CAPACITY: usize = 128; +const ABS_MAX_BRANCH: i32 = 1 << (RV_B_TYPE_IMM_BITS - 1); #[allow(clippy::type_complexity)] -fn run_int_256_rand_execute>( - opcode: usize, - num_ops: usize, - executor: &mut E, +fn set_and_execute_rand>( tester: &mut VmChipTestBuilder, + chip: &mut E, + rng: &mut StdRng, + opcode: usize, branch_fn: Option bool>, ) { - const ABS_MAX_BRANCH: i32 = 1 << (RV_B_TYPE_IMM_BITS - 1); - - let mut rng = create_seeded_rng(); let branch = branch_fn.is_some(); - for _ in 0..num_ops { - let b = generate_long_number::(&mut rng); - let c = generate_long_number::(&mut rng); - if branch { - let imm = rng.gen_range((-ABS_MAX_BRANCH)..ABS_MAX_BRANCH); - let instruction = rv32_heap_branch_default( - tester, - vec![b.map(F::from_canonical_u32)], - vec![c.map(F::from_canonical_u32)], - imm as isize, - opcode, - ); - - tester.execute_with_pc( - executor, - &instruction, - rng.gen_range((ABS_MAX_BRANCH as u32)..(1 << (PC_BITS - 1))), - ); - - let cmp_result = branch_fn.unwrap()(opcode, &b, &c); - let from_pc = tester.execution.last_from_pc().as_canonical_u32() as i32; - let to_pc = tester.execution.last_to_pc().as_canonical_u32() as i32; - assert_eq!(to_pc, from_pc + if cmp_result { imm } else { 4 }); - } else { - let instruction = rv32_write_heap_default( - tester, - vec![b.map(F::from_canonical_u32)], - vec![c.map(F::from_canonical_u32)], - opcode, - ); - tester.execute(executor, &instruction); - } + let b = generate_long_number::(rng); + let c = generate_long_number::(rng); + if branch { + let imm = rng.gen_range((-ABS_MAX_BRANCH)..ABS_MAX_BRANCH); + let instruction = rv32_heap_branch_default( + tester, + vec![b.map(F::from_canonical_u32)], + vec![c.map(F::from_canonical_u32)], + imm as isize, + opcode, + ); + + tester.execute_with_pc( + chip, + &instruction, + rng.gen_range((ABS_MAX_BRANCH as u32)..(1 << (PC_BITS - 1))), + ); + + let cmp_result = branch_fn.unwrap()(opcode, &b, &c); + let from_pc = tester.execution.last_from_pc().as_canonical_u32() as i32; + let to_pc = tester.execution.last_to_pc().as_canonical_u32() as i32; + assert_eq!(to_pc, from_pc + if cmp_result { imm } else { 4 }); + } else { + let instruction = rv32_write_heap_default( + tester, + vec![b.map(F::from_canonical_u32)], + vec![c.map(F::from_canonical_u32)], + opcode, + ); + tester.execute(chip, &instruction); } } +#[test_case(BaseAluOpcode::ADD, 24)] +#[test_case(BaseAluOpcode::SUB, 24)] +#[test_case(BaseAluOpcode::XOR, 24)] +#[test_case(BaseAluOpcode::OR, 24)] +#[test_case(BaseAluOpcode::AND, 24)] fn run_alu_256_rand_test(opcode: BaseAluOpcode, num_ops: usize) { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default(); + let offset = Rv32BaseAlu256Opcode::CLASS_OFFSET; + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let mut tester = VmChipTestBuilder::default(); - let mut chip = Rv32BaseAlu256Chip::::new( - Rv32HeapAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), + VmAirWrapper::new( + Rv32HeapAdapterAir::new( + tester.execution_bridge(), + tester.memory_bridge(), + bitwise_bus, + tester.address_bits(), + ), + BaseAluCoreAir::new(bitwise_bus, offset), + ), + Rv32BaseAlu256Step::new( + Rv32HeapAdapterStep::new(tester.address_bits(), bitwise_chip.clone()), bitwise_chip.clone(), + offset, ), - BaseAluCoreChip::new(bitwise_chip.clone(), Rv32BaseAlu256Opcode::CLASS_OFFSET), - tester.offline_memory_mutex_arc(), + tester.memory_helper(), ); + chip.set_trace_height(MAX_INS_CAPACITY); - run_int_256_rand_execute( - opcode.local_usize() + Rv32BaseAlu256Opcode::CLASS_OFFSET, - num_ops, - &mut chip, - &mut tester, - None, - ); + for _ in 0..num_ops { + set_and_execute_rand( + &mut tester, + &mut chip, + &mut rng, + opcode.local_usize() + offset, + None, + ); + } let tester = tester.build().load(chip).load(bitwise_chip).finalize(); tester.simple_test().expect("Verification failed"); } -#[test] -fn alu_256_add_rand_test() { - run_alu_256_rand_test(BaseAluOpcode::ADD, 24); -} - -#[test] -fn alu_256_sub_rand_test() { - run_alu_256_rand_test(BaseAluOpcode::SUB, 24); -} - -#[test] -fn alu_256_xor_rand_test() { - run_alu_256_rand_test(BaseAluOpcode::XOR, 24); -} - -#[test] -fn alu_256_or_rand_test() { - run_alu_256_rand_test(BaseAluOpcode::OR, 24); -} - -#[test] -fn alu_256_and_rand_test() { - run_alu_256_rand_test(BaseAluOpcode::AND, 24); -} - +#[test_case(LessThanOpcode::SLT, 24)] +#[test_case(LessThanOpcode::SLTU, 24)] fn run_lt_256_rand_test(opcode: LessThanOpcode, num_ops: usize) { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default(); + let offset = Rv32LessThan256Opcode::CLASS_OFFSET; + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - - let mut tester = VmChipTestBuilder::default(); let mut chip = Rv32LessThan256Chip::::new( - Rv32HeapAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), + VmAirWrapper::new( + Rv32HeapAdapterAir::new( + tester.execution_bridge(), + tester.memory_bridge(), + bitwise_bus, + tester.address_bits(), + ), + LessThanCoreAir::new(bitwise_bus, offset), + ), + Rv32LessThan256Step::new( + Rv32HeapAdapterStep::new(tester.address_bits(), bitwise_chip.clone()), bitwise_chip.clone(), + offset, ), - LessThanCoreChip::new(bitwise_chip.clone(), Rv32LessThan256Opcode::CLASS_OFFSET), - tester.offline_memory_mutex_arc(), + tester.memory_helper(), ); + chip.set_trace_height(MAX_INS_CAPACITY); - run_int_256_rand_execute( - opcode.local_usize() + Rv32LessThan256Opcode::CLASS_OFFSET, - num_ops, - &mut chip, - &mut tester, - None, - ); + for _ in 0..num_ops { + set_and_execute_rand( + &mut tester, + &mut chip, + &mut rng, + opcode.local_usize() + offset, + None, + ); + } let tester = tester.build().load(chip).load(bitwise_chip).finalize(); tester.simple_test().expect("Verification failed"); } -#[test] -fn lt_256_slt_rand_test() { - run_lt_256_rand_test(LessThanOpcode::SLT, 24); -} - -#[test] -fn lt_256_sltu_rand_test() { - run_lt_256_rand_test(LessThanOpcode::SLTU, 24); -} +#[test_case(MulOpcode::MUL, 24)] +fn run_mul_256_rand_test(opcode: MulOpcode, num_ops: usize) { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default(); + let offset = Rv32Mul256Opcode::CLASS_OFFSET; -fn run_mul_256_rand_test(num_ops: usize) { let range_tuple_bus = RangeTupleCheckerBus::new( RANGE_TUPLE_CHECKER_BUS, [ @@ -185,106 +187,121 @@ fn run_mul_256_rand_test(num_ops: usize) { (INT256_NUM_LIMBS * (1 << RV32_CELL_BITS)) as u32, ], ); - let range_tuple_checker = SharedRangeTupleCheckerChip::new(range_tuple_bus); + let range_tuple_chip = SharedRangeTupleCheckerChip::new(range_tuple_bus); let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let mut tester = VmChipTestBuilder::default(); let mut chip = Rv32Multiplication256Chip::::new( - Rv32HeapAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - bitwise_chip.clone(), + VmAirWrapper::new( + Rv32HeapAdapterAir::new( + tester.execution_bridge(), + tester.memory_bridge(), + bitwise_bus, + tester.address_bits(), + ), + MultiplicationCoreAir::new(range_tuple_bus, offset), + ), + Rv32Multiplication256Step::new( + Rv32HeapAdapterStep::new(tester.address_bits(), bitwise_chip.clone()), + range_tuple_chip.clone(), + offset, ), - MultiplicationCoreChip::new(range_tuple_checker.clone(), Rv32Mul256Opcode::CLASS_OFFSET), - tester.offline_memory_mutex_arc(), + tester.memory_helper(), ); + chip.set_trace_height(MAX_INS_CAPACITY); - run_int_256_rand_execute( - Rv32Mul256Opcode::CLASS_OFFSET, - num_ops, - &mut chip, - &mut tester, - None, - ); + for _ in 0..num_ops { + set_and_execute_rand( + &mut tester, + &mut chip, + &mut rng, + opcode.local_usize() + offset, + None, + ); + } let tester = tester .build() .load(chip) - .load(range_tuple_checker) + .load(range_tuple_chip) .load(bitwise_chip) .finalize(); tester.simple_test().expect("Verification failed"); } -#[test] -fn mul_256_rand_test() { - run_mul_256_rand_test(24); -} - +#[test_case(ShiftOpcode::SLL, 24)] +#[test_case(ShiftOpcode::SRL, 24)] +#[test_case(ShiftOpcode::SRA, 24)] fn run_shift_256_rand_test(opcode: ShiftOpcode, num_ops: usize) { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default(); + let offset = Rv32Shift256Opcode::CLASS_OFFSET; + + let range_checker_chip = tester.range_checker(); let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let mut tester = VmChipTestBuilder::default(); let mut chip = Rv32Shift256Chip::::new( - Rv32HeapAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - bitwise_chip.clone(), + VmAirWrapper::new( + Rv32HeapAdapterAir::new( + tester.execution_bridge(), + tester.memory_bridge(), + bitwise_bus, + tester.address_bits(), + ), + ShiftCoreAir::new(bitwise_bus, range_checker_chip.bus(), offset), ), - ShiftCoreChip::new( + Rv32Shift256Step::new( + Rv32HeapAdapterStep::new(tester.address_bits(), bitwise_chip.clone()), bitwise_chip.clone(), - tester.memory_controller().borrow().range_checker.clone(), - Rv32Shift256Opcode::CLASS_OFFSET, + range_checker_chip.clone(), + offset, ), - tester.offline_memory_mutex_arc(), + tester.memory_helper(), ); + chip.set_trace_height(MAX_INS_CAPACITY); - run_int_256_rand_execute( - opcode.local_usize() + Rv32Shift256Opcode::CLASS_OFFSET, - num_ops, - &mut chip, - &mut tester, - None, - ); + for _ in 0..num_ops { + set_and_execute_rand( + &mut tester, + &mut chip, + &mut rng, + opcode.local_usize() + offset, + None, + ); + } + + drop(range_checker_chip); let tester = tester.build().load(chip).load(bitwise_chip).finalize(); tester.simple_test().expect("Verification failed"); } -#[test] -fn shift_256_sll_rand_test() { - run_shift_256_rand_test(ShiftOpcode::SLL, 24); -} - -#[test] -fn shift_256_srl_rand_test() { - run_shift_256_rand_test(ShiftOpcode::SRL, 24); -} - -#[test] -fn shift_256_sra_rand_test() { - run_shift_256_rand_test(ShiftOpcode::SRA, 24); -} - +#[test_case(BranchEqualOpcode::BEQ, 24)] +#[test_case(BranchEqualOpcode::BNE, 24)] fn run_beq_256_rand_test(opcode: BranchEqualOpcode, num_ops: usize) { + let mut rng = create_seeded_rng(); let mut tester = VmChipTestBuilder::default(); + let offset = Rv32BranchEqual256Opcode::CLASS_OFFSET; + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); let mut chip = Rv32BranchEqual256Chip::::new( - Rv32HeapBranchAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - bitwise_chip.clone(), + VmAirWrapper::new( + Rv32HeapBranchAdapterAir::new( + tester.execution_bridge(), + tester.memory_bridge(), + bitwise_bus, + tester.address_bits(), + ), + BranchEqualCoreAir::new(offset, DEFAULT_PC_STEP), + ), + Rv32BranchEqual256Step::new( + Rv32HeapBranchAdapterStep::new(tester.address_bits(), bitwise_chip.clone()), + offset, + DEFAULT_PC_STEP, ), - BranchEqualCoreChip::new(Rv32BranchEqual256Opcode::CLASS_OFFSET, 4), - tester.offline_memory_mutex_arc(), + tester.memory_helper(), ); + chip.set_trace_height(MAX_INS_CAPACITY); let branch_fn = |opcode: usize, x: &[u32; INT256_NUM_LIMBS], y: &[u32; INT256_NUM_LIMBS]| { x.iter() @@ -294,93 +311,80 @@ fn run_beq_256_rand_test(opcode: BranchEqualOpcode, num_ops: usize) { == BranchEqualOpcode::BNE.local_usize() + Rv32BranchEqual256Opcode::CLASS_OFFSET) }; - run_int_256_rand_execute( - opcode.local_usize() + Rv32BranchEqual256Opcode::CLASS_OFFSET, - num_ops, - &mut chip, - &mut tester, - Some(branch_fn), - ); + for _ in 0..num_ops { + set_and_execute_rand( + &mut tester, + &mut chip, + &mut rng, + opcode.local_usize() + offset, + Some(branch_fn), + ); + } let tester = tester.build().load(chip).load(bitwise_chip).finalize(); tester.simple_test().expect("Verification failed"); } -#[test] -fn beq_256_beq_rand_test() { - run_beq_256_rand_test(BranchEqualOpcode::BEQ, 24); -} - -#[test] -fn beq_256_bne_rand_test() { - run_beq_256_rand_test(BranchEqualOpcode::BNE, 24); -} - +#[test_case(BranchLessThanOpcode::BLT, 24)] +#[test_case(BranchLessThanOpcode::BLTU, 24)] +#[test_case(BranchLessThanOpcode::BGE, 24)] +#[test_case(BranchLessThanOpcode::BGEU, 24)] fn run_blt_256_rand_test(opcode: BranchLessThanOpcode, num_ops: usize) { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default(); + let offset = Rv32BranchLessThan256Opcode::CLASS_OFFSET; + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let mut tester = VmChipTestBuilder::default(); let mut chip = Rv32BranchLessThan256Chip::::new( - Rv32HeapBranchAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - bitwise_chip.clone(), + VmAirWrapper::new( + Rv32HeapBranchAdapterAir::new( + tester.execution_bridge(), + tester.memory_bridge(), + bitwise_bus, + tester.address_bits(), + ), + BranchLessThanCoreAir::new(bitwise_bus, offset), ), - BranchLessThanCoreChip::new( + Rv32BranchLessThan256Step::new( + Rv32HeapBranchAdapterStep::new(tester.address_bits(), bitwise_chip.clone()), bitwise_chip.clone(), - Rv32BranchLessThan256Opcode::CLASS_OFFSET, + offset, ), - tester.offline_memory_mutex_arc(), + tester.memory_helper(), ); + chip.set_trace_height(MAX_INS_CAPACITY); - let branch_fn = |opcode: usize, x: &[u32; INT256_NUM_LIMBS], y: &[u32; INT256_NUM_LIMBS]| { - let opcode = - BranchLessThanOpcode::from_usize(opcode - Rv32BranchLessThan256Opcode::CLASS_OFFSET); - let (is_ge, is_signed) = match opcode { - BranchLessThanOpcode::BLT => (false, true), - BranchLessThanOpcode::BLTU => (false, false), - BranchLessThanOpcode::BGE => (true, true), - BranchLessThanOpcode::BGEU => (true, false), - }; - let x_sign = x[INT256_NUM_LIMBS - 1] >> (RV32_CELL_BITS - 1) != 0 && is_signed; - let y_sign = y[INT256_NUM_LIMBS - 1] >> (RV32_CELL_BITS - 1) != 0 && is_signed; - for (x, y) in x.iter().rev().zip(y.iter().rev()) { - if x != y { - return (x < y) ^ x_sign ^ y_sign ^ is_ge; + let branch_fn = + |opcode: usize, x: &[u32; INT256_NUM_LIMBS], y: &[u32; INT256_NUM_LIMBS]| -> bool { + let opcode = BranchLessThanOpcode::from_usize( + opcode - Rv32BranchLessThan256Opcode::CLASS_OFFSET, + ); + let (is_ge, is_signed) = match opcode { + BranchLessThanOpcode::BLT => (false, true), + BranchLessThanOpcode::BLTU => (false, false), + BranchLessThanOpcode::BGE => (true, true), + BranchLessThanOpcode::BGEU => (true, false), + }; + let x_sign = x[INT256_NUM_LIMBS - 1] >> (RV32_CELL_BITS - 1) != 0 && is_signed; + let y_sign = y[INT256_NUM_LIMBS - 1] >> (RV32_CELL_BITS - 1) != 0 && is_signed; + for (x, y) in x.iter().rev().zip(y.iter().rev()) { + if x != y { + return (x < y) ^ x_sign ^ y_sign ^ is_ge; + } } - } - is_ge - }; + is_ge + }; - run_int_256_rand_execute( - opcode.local_usize() + Rv32BranchLessThan256Opcode::CLASS_OFFSET, - num_ops, - &mut chip, - &mut tester, - Some(branch_fn), - ); + for _ in 0..num_ops { + set_and_execute_rand( + &mut tester, + &mut chip, + &mut rng, + opcode.local_usize() + offset, + Some(branch_fn), + ); + } let tester = tester.build().load(chip).load(bitwise_chip).finalize(); tester.simple_test().expect("Verification failed"); } - -#[test] -fn blt_256_blt_rand_test() { - run_blt_256_rand_test(BranchLessThanOpcode::BLT, 24); -} - -#[test] -fn blt_256_bltu_rand_test() { - run_blt_256_rand_test(BranchLessThanOpcode::BLTU, 24); -} - -#[test] -fn blt_256_bge_rand_test() { - run_blt_256_rand_test(BranchLessThanOpcode::BGE, 24); -} - -#[test] -fn blt_256_bgeu_rand_test() { - run_blt_256_rand_test(BranchLessThanOpcode::BGEU, 24); -} diff --git a/extensions/ecc/circuit/Cargo.toml b/extensions/ecc/circuit/Cargo.toml index dca4fb91e9..81798e207a 100644 --- a/extensions/ecc/circuit/Cargo.toml +++ b/extensions/ecc/circuit/Cargo.toml @@ -26,6 +26,7 @@ strum = { workspace = true } derive_more = { workspace = true } derive-new = { workspace = true } once_cell = { workspace = true, features = ["std"] } +rand = { workspace = true } serde = { workspace = true } serde_with = { workspace = true } lazy_static = { workspace = true } @@ -37,3 +38,6 @@ openvm-mod-circuit-builder = { workspace = true, features = ["test-utils"] } openvm-circuit = { workspace = true, features = ["test-utils"] } openvm-rv32-adapters = { workspace = true, features = ["test-utils"] } lazy_static = { workspace = true } + +[package.metadata.cargo-shear] +ignored = ["rand"] diff --git a/extensions/ecc/circuit/src/weierstrass_chip/add_ne.rs b/extensions/ecc/circuit/src/weierstrass_chip/add_ne.rs index 24bcc52ef3..f45766c6a4 100644 --- a/extensions/ecc/circuit/src/weierstrass_chip/add_ne.rs +++ b/extensions/ecc/circuit/src/weierstrass_chip/add_ne.rs @@ -1,7 +1,24 @@ use std::{cell::RefCell, rc::Rc}; -use openvm_circuit_primitives::var_range::VariableRangeCheckerBus; -use openvm_mod_circuit_builder::{ExprBuilder, ExprBuilderConfig, FieldExpr}; +use openvm_circuit::{ + arch::ExecutionBridge, + system::memory::{offline_checker::MemoryBridge, SharedMemoryHelper}, +}; +use openvm_circuit_derive::{InsExecutorE1, InsExecutorE2, InstructionExecutor}; +use openvm_circuit_primitives::{ + bitwise_op_lookup::SharedBitwiseOperationLookupChip, + var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus}, + Chip, ChipUsageGetter, +}; +use openvm_ecc_transpiler::Rv32WeierstrassOpcode; +use openvm_instructions::riscv::RV32_CELL_BITS; +use openvm_mod_circuit_builder::{ + ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreAir, +}; +use openvm_rv32_adapters::{Rv32VecHeapAdapterAir, Rv32VecHeapAdapterStep}; +use openvm_stark_backend::p3_field::PrimeField32; + +use super::{WeierstrassAir, WeierstrassChip, WeierstrassStep}; // Assumes that (x1, y1), (x2, y2) both lie on the curve and are not the identity point. // Further assumes that x1, x2 are not equal in the coordinate field. @@ -26,3 +43,58 @@ pub fn ec_add_ne_expr( let builder = builder.borrow().clone(); FieldExpr::new(builder, range_bus, true) } + +/// BLOCK_SIZE: how many cells do we read at a time, must be a power of 2. +/// BLOCKS: how many blocks do we need to represent one input or output +/// For example, for bls12_381, BLOCK_SIZE = 16, each element has 3 blocks and with two elements per +/// input AffinePoint, BLOCKS = 6. For secp256k1, BLOCK_SIZE = 32, BLOCKS = 2. + +#[derive(Chip, ChipUsageGetter, InstructionExecutor, InsExecutorE1, InsExecutorE2)] +pub struct EcAddNeChip( + pub WeierstrassChip, +); + +impl + EcAddNeChip +{ + #[allow(clippy::too_many_arguments)] + pub fn new( + execution_bridge: ExecutionBridge, + memory_bridge: MemoryBridge, + mem_helper: SharedMemoryHelper, + pointer_max_bits: usize, + config: ExprBuilderConfig, + offset: usize, + bitwise_lookup_chip: SharedBitwiseOperationLookupChip, + range_checker: SharedVariableRangeCheckerChip, + ) -> Self { + let expr = ec_add_ne_expr(config, range_checker.bus()); + + let local_opcode_idx = vec![ + Rv32WeierstrassOpcode::EC_ADD_NE as usize, + Rv32WeierstrassOpcode::SETUP_EC_ADD_NE as usize, + ]; + + let air = WeierstrassAir::new( + Rv32VecHeapAdapterAir::new( + execution_bridge, + memory_bridge, + bitwise_lookup_chip.bus(), + pointer_max_bits, + ), + FieldExpressionCoreAir::new(expr.clone(), offset, local_opcode_idx.clone(), vec![]), + ); + + let step = WeierstrassStep::new( + Rv32VecHeapAdapterStep::new(pointer_max_bits, bitwise_lookup_chip), + expr, + offset, + local_opcode_idx, + vec![], + range_checker, + "EcAddNe", + false, + ); + Self(WeierstrassChip::new(air, step, mem_helper)) + } +} diff --git a/extensions/ecc/circuit/src/weierstrass_chip/double.rs b/extensions/ecc/circuit/src/weierstrass_chip/double.rs index 0ae55f2df7..b804ba8931 100644 --- a/extensions/ecc/circuit/src/weierstrass_chip/double.rs +++ b/extensions/ecc/circuit/src/weierstrass_chip/double.rs @@ -2,8 +2,25 @@ use std::{cell::RefCell, rc::Rc}; use num_bigint::BigUint; use num_traits::One; -use openvm_circuit_primitives::var_range::VariableRangeCheckerBus; -use openvm_mod_circuit_builder::{ExprBuilder, ExprBuilderConfig, FieldExpr, FieldVariable}; +use openvm_circuit::{ + arch::ExecutionBridge, + system::memory::{offline_checker::MemoryBridge, SharedMemoryHelper}, +}; +use openvm_circuit_derive::{InsExecutorE1, InsExecutorE2, InstructionExecutor}; +use openvm_circuit_primitives::{ + bitwise_op_lookup::SharedBitwiseOperationLookupChip, + var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus}, + Chip, ChipUsageGetter, +}; +use openvm_ecc_transpiler::Rv32WeierstrassOpcode; +use openvm_instructions::riscv::RV32_CELL_BITS; +use openvm_mod_circuit_builder::{ + ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreAir, FieldVariable, +}; +use openvm_rv32_adapters::{Rv32VecHeapAdapterAir, Rv32VecHeapAdapterStep}; +use openvm_stark_backend::p3_field::PrimeField32; + +use super::{WeierstrassAir, WeierstrassChip, WeierstrassStep}; pub fn ec_double_ne_expr( config: ExprBuilderConfig, // The coordinate field. @@ -34,3 +51,59 @@ pub fn ec_double_ne_expr( let builder = builder.borrow().clone(); FieldExpr::new_with_setup_values(builder, range_bus, true, vec![a_biguint]) } + +/// BLOCK_SIZE: how many cells do we read at a time, must be a power of 2. +/// BLOCKS: how many blocks do we need to represent one input or output +/// For example, for bls12_381, BLOCK_SIZE = 16, each element has 3 blocks and with two elements per +/// input AffinePoint, BLOCKS = 6. For secp256k1, BLOCK_SIZE = 32, BLOCKS = 2. + +#[derive(Chip, ChipUsageGetter, InstructionExecutor, InsExecutorE1, InsExecutorE2)] +pub struct EcDoubleChip( + pub WeierstrassChip, +); + +impl + EcDoubleChip +{ + #[allow(clippy::too_many_arguments)] + pub fn new( + execution_bridge: ExecutionBridge, + memory_bridge: MemoryBridge, + mem_helper: SharedMemoryHelper, + pointer_max_bits: usize, + config: ExprBuilderConfig, + offset: usize, + bitwise_lookup_chip: SharedBitwiseOperationLookupChip, + range_checker: SharedVariableRangeCheckerChip, + a_biguint: BigUint, + ) -> Self { + let expr = ec_double_ne_expr(config, range_checker.bus(), a_biguint); + + let local_opcode_idx = vec![ + Rv32WeierstrassOpcode::EC_DOUBLE as usize, + Rv32WeierstrassOpcode::SETUP_EC_DOUBLE as usize, + ]; + + let air = WeierstrassAir::new( + Rv32VecHeapAdapterAir::new( + execution_bridge, + memory_bridge, + bitwise_lookup_chip.bus(), + pointer_max_bits, + ), + FieldExpressionCoreAir::new(expr.clone(), offset, local_opcode_idx.clone(), vec![]), + ); + + let step = WeierstrassStep::new( + Rv32VecHeapAdapterStep::new(pointer_max_bits, bitwise_lookup_chip), + expr, + offset, + local_opcode_idx, + vec![], + range_checker, + "EcDouble", + true, + ); + Self(WeierstrassChip::new(air, step, mem_helper)) + } +} diff --git a/extensions/ecc/circuit/src/weierstrass_chip/mod.rs b/extensions/ecc/circuit/src/weierstrass_chip/mod.rs index 0bcee1facf..5ed1516934 100644 --- a/extensions/ecc/circuit/src/weierstrass_chip/mod.rs +++ b/extensions/ecc/circuit/src/weierstrass_chip/mod.rs @@ -1,99 +1,40 @@ mod add_ne; mod double; -use std::sync::Arc; - pub use add_ne::*; pub use double::*; +use openvm_algebra_circuit::FieldExprVecHeapStep; #[cfg(test)] mod tests; -use std::sync::Mutex; - -use num_bigint::BigUint; -use openvm_circuit::{arch::VmChipWrapper, system::memory::OfflineMemory}; -use openvm_circuit_derive::InstructionExecutor; -use openvm_circuit_primitives::var_range::SharedVariableRangeCheckerChip; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; -use openvm_ecc_transpiler::Rv32WeierstrassOpcode; -use openvm_mod_circuit_builder::{ExprBuilderConfig, FieldExpressionCoreChip}; -use openvm_rv32_adapters::Rv32VecHeapAdapterChip; -use openvm_stark_backend::p3_field::PrimeField32; - -/// BLOCK_SIZE: how many cells do we read at a time, must be a power of 2. -/// BLOCKS: how many blocks do we need to represent one input or output -/// For example, for bls12_381, BLOCK_SIZE = 16, each element has 3 blocks and with two elements per -/// input AffinePoint, BLOCKS = 6. For secp256k1, BLOCK_SIZE = 32, BLOCKS = 2. -#[derive(Chip, ChipUsageGetter, InstructionExecutor)] -pub struct EcAddNeChip( - pub VmChipWrapper< - F, - Rv32VecHeapAdapterChip, - FieldExpressionCoreChip, - >, -); - -impl - EcAddNeChip -{ - pub fn new( - adapter: Rv32VecHeapAdapterChip, - config: ExprBuilderConfig, - offset: usize, - range_checker: SharedVariableRangeCheckerChip, - offline_memory: Arc>>, - ) -> Self { - let expr = ec_add_ne_expr(config, range_checker.bus()); - let core = FieldExpressionCoreChip::new( - expr, - offset, - vec![ - Rv32WeierstrassOpcode::EC_ADD_NE as usize, - Rv32WeierstrassOpcode::SETUP_EC_ADD_NE as usize, - ], - vec![], - range_checker, - "EcAddNe", - false, - ); - Self(VmChipWrapper::new(adapter, core, offline_memory)) - } -} - -#[derive(Chip, ChipUsageGetter, InstructionExecutor)] -pub struct EcDoubleChip( - pub VmChipWrapper< - F, - Rv32VecHeapAdapterChip, - FieldExpressionCoreChip, - >, -); - -impl - EcDoubleChip -{ - pub fn new( - adapter: Rv32VecHeapAdapterChip, - range_checker: SharedVariableRangeCheckerChip, - config: ExprBuilderConfig, - offset: usize, - a: BigUint, - offline_memory: Arc>>, - ) -> Self { - let expr = ec_double_ne_expr(config, range_checker.bus(), a); - let core = FieldExpressionCoreChip::new( - expr, - offset, - vec![ - Rv32WeierstrassOpcode::EC_DOUBLE as usize, - Rv32WeierstrassOpcode::SETUP_EC_DOUBLE as usize, - ], - vec![], - range_checker, - "EcDouble", - true, - ); - Self(VmChipWrapper::new(adapter, core, offline_memory)) - } -} +use openvm_circuit::arch::{MatrixRecordArena, NewVmChipWrapper, VmAirWrapper}; +use openvm_mod_circuit_builder::FieldExpressionCoreAir; +use openvm_rv32_adapters::Rv32VecHeapAdapterAir; + +pub(crate) type WeierstrassAir< + const NUM_READS: usize, + const BLOCKS: usize, + const BLOCK_SIZE: usize, +> = VmAirWrapper< + Rv32VecHeapAdapterAir, + FieldExpressionCoreAir, +>; + +pub(crate) type WeierstrassStep< + const NUM_READS: usize, + const BLOCKS: usize, + const BLOCK_SIZE: usize, +> = FieldExprVecHeapStep; + +pub(crate) type WeierstrassChip< + F, + const NUM_READS: usize, + const BLOCKS: usize, + const BLOCK_SIZE: usize, +> = NewVmChipWrapper< + F, + WeierstrassAir, + WeierstrassStep, + MatrixRecordArena, +>; diff --git a/extensions/ecc/circuit/src/weierstrass_chip/tests.rs b/extensions/ecc/circuit/src/weierstrass_chip/tests.rs index 213918ec2e..99051550c8 100644 --- a/extensions/ecc/circuit/src/weierstrass_chip/tests.rs +++ b/extensions/ecc/circuit/src/weierstrass_chip/tests.rs @@ -10,7 +10,7 @@ use openvm_circuit_primitives::{ use openvm_ecc_transpiler::Rv32WeierstrassOpcode; use openvm_instructions::{riscv::RV32_CELL_BITS, LocalOpcode}; use openvm_mod_circuit_builder::{test_utils::biguint_to_limbs, ExprBuilderConfig, FieldExpr}; -use openvm_rv32_adapters::{rv32_write_heap_default, Rv32VecHeapAdapterChip}; +use openvm_rv32_adapters::rv32_write_heap_default; use openvm_stark_backend::p3_field::FieldAlgebra; use openvm_stark_sdk::p3_baby_bear::BabyBear; @@ -19,6 +19,7 @@ use super::{EcAddNeChip, EcDoubleChip}; const NUM_LIMBS: usize = 32; const LIMB_BITS: usize = 8; const BLOCK_SIZE: usize = 32; +const MAX_INS_CAPACITY: usize = 128; type F = BabyBear; lazy_static::lazy_static! { @@ -87,21 +88,20 @@ fn test_add_ne() { }; let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let adapter = Rv32VecHeapAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), + + let mut chip = EcAddNeChip::::new( + tester.execution_bridge(), tester.memory_bridge(), + tester.memory_helper(), tester.address_bits(), - bitwise_chip.clone(), - ); - let mut chip = EcAddNeChip::new( - adapter, config, Rv32WeierstrassOpcode::CLASS_OFFSET, + bitwise_chip.clone(), tester.range_checker(), - tester.offline_memory_mutex_arc(), ); - assert_eq!(chip.0.core.expr().builder.num_variables, 3); // lambda, x3, y3 + chip.0.set_trace_buffer_height(MAX_INS_CAPACITY); + + assert_eq!(chip.0.step.0.expr.builder.num_variables, 3); // lambda, x3, y3 let (p1_x, p1_y) = SampleEcPoints[0].clone(); let (p2_x, p2_y) = SampleEcPoints[1].clone(); @@ -117,21 +117,22 @@ fn test_add_ne() { let r = chip .0 - .core - .expr() + .step + .0 + .expr .execute(vec![p1_x, p1_y, p2_x, p2_y], vec![true]); assert_eq!(r.len(), 3); // lambda, x3, y3 assert_eq!(r[1], SampleEcPoints[2].0); assert_eq!(r[2], SampleEcPoints[2].1); - let prime_limbs: [BabyBear; NUM_LIMBS] = prime_limbs(chip.0.core.expr()).try_into().unwrap(); + let prime_limbs: [BabyBear; NUM_LIMBS] = prime_limbs(&chip.0.step.0.expr).try_into().unwrap(); let mut one_limbs = [BabyBear::ONE; NUM_LIMBS]; one_limbs[0] = BabyBear::ONE; let setup_instruction = rv32_write_heap_default( &mut tester, vec![prime_limbs, one_limbs], // inputs[0] = prime, others doesn't matter vec![one_limbs, one_limbs], - chip.0.core.air.offset + Rv32WeierstrassOpcode::SETUP_EC_ADD_NE as usize, + chip.0.step.0.offset + Rv32WeierstrassOpcode::SETUP_EC_ADD_NE as usize, ); tester.execute(&mut chip, &setup_instruction); @@ -139,7 +140,7 @@ fn test_add_ne() { &mut tester, vec![p1_x_limbs, p1_y_limbs], vec![p2_x_limbs, p2_y_limbs], - chip.0.core.air.offset + Rv32WeierstrassOpcode::EC_ADD_NE as usize, + chip.0.step.0.offset + Rv32WeierstrassOpcode::EC_ADD_NE as usize, ); tester.execute(&mut chip, &instruction); @@ -159,13 +160,18 @@ fn test_double() { }; let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let adapter = Rv32VecHeapAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), + let mut chip = EcDoubleChip::::new( + tester.execution_bridge(), tester.memory_bridge(), + tester.memory_helper(), tester.address_bits(), + config, + Rv32WeierstrassOpcode::CLASS_OFFSET, bitwise_chip.clone(), + tester.range_checker(), + BigUint::zero(), ); + chip.0.set_trace_buffer_height(MAX_INS_CAPACITY); let (p1_x, p1_y) = SampleEcPoints[1].clone(); let p1_x_limbs = @@ -173,29 +179,21 @@ fn test_double() { let p1_y_limbs = biguint_to_limbs::(p1_y.clone(), LIMB_BITS).map(BabyBear::from_canonical_u32); - let mut chip = EcDoubleChip::new( - adapter, - tester.memory_controller().borrow().range_checker.clone(), - config, - Rv32WeierstrassOpcode::CLASS_OFFSET, - BigUint::zero(), - tester.offline_memory_mutex_arc(), - ); - assert_eq!(chip.0.core.air.expr.builder.num_variables, 3); // lambda, x3, y3 + assert_eq!(chip.0.step.0.expr.builder.num_variables, 3); // lambda, x3, y3 - let r = chip.0.core.air.expr.execute(vec![p1_x, p1_y], vec![true]); + let r = chip.0.step.0.expr.execute(vec![p1_x, p1_y], vec![true]); assert_eq!(r.len(), 3); // lambda, x3, y3 assert_eq!(r[1], SampleEcPoints[3].0); assert_eq!(r[2], SampleEcPoints[3].1); - let prime_limbs: [BabyBear; NUM_LIMBS] = prime_limbs(&chip.0.core.air.expr).try_into().unwrap(); + let prime_limbs: [BabyBear; NUM_LIMBS] = prime_limbs(&chip.0.step.0.expr).try_into().unwrap(); let a_limbs = [BabyBear::ZERO; NUM_LIMBS]; let setup_instruction = rv32_write_heap_default( &mut tester, vec![prime_limbs, a_limbs], /* inputs[0] = prime, inputs[1] = a coeff of weierstrass * equation */ vec![], - chip.0.core.air.offset + Rv32WeierstrassOpcode::SETUP_EC_DOUBLE as usize, + chip.0.step.0.offset + Rv32WeierstrassOpcode::SETUP_EC_DOUBLE as usize, ); tester.execute(&mut chip, &setup_instruction); @@ -203,7 +201,7 @@ fn test_double() { &mut tester, vec![p1_x_limbs, p1_y_limbs], vec![], - chip.0.core.air.offset + Rv32WeierstrassOpcode::EC_DOUBLE as usize, + chip.0.step.0.offset + Rv32WeierstrassOpcode::EC_DOUBLE as usize, ); tester.execute(&mut chip, &instruction); @@ -227,13 +225,19 @@ fn test_p256_double() { .unwrap(); let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let adapter = Rv32VecHeapAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), + + let mut chip = EcDoubleChip::::new( + tester.execution_bridge(), tester.memory_bridge(), + tester.memory_helper(), tester.address_bits(), + config, + Rv32WeierstrassOpcode::CLASS_OFFSET, bitwise_chip.clone(), + tester.range_checker(), + a.clone(), ); + chip.0.set_trace_buffer_height(MAX_INS_CAPACITY); // Testing data from: http://point-at-infinity.org/ecc/nisttv let p1_x = BigUint::from_str_radix( @@ -251,17 +255,9 @@ fn test_p256_double() { let p1_y_limbs = biguint_to_limbs::(p1_y.clone(), LIMB_BITS).map(BabyBear::from_canonical_u32); - let mut chip = EcDoubleChip::new( - adapter, - tester.memory_controller().borrow().range_checker.clone(), - config, - Rv32WeierstrassOpcode::CLASS_OFFSET, - a.clone(), - tester.offline_memory_mutex_arc(), - ); - assert_eq!(chip.0.core.air.expr.builder.num_variables, 3); // lambda, x3, y3 + assert_eq!(chip.0.step.0.expr.builder.num_variables, 3); // lambda, x3, y3 - let r = chip.0.core.air.expr.execute(vec![p1_x, p1_y], vec![true]); + let r = chip.0.step.0.expr.execute(vec![p1_x, p1_y], vec![true]); assert_eq!(r.len(), 3); // lambda, x3, y3 let expected_double_x = BigUint::from_str_radix( "7CF27B188D034F7E8A52380304B51AC3C08969E277F21B35A60B48FC47669978", @@ -276,7 +272,7 @@ fn test_p256_double() { assert_eq!(r[1], expected_double_x); assert_eq!(r[2], expected_double_y); - let prime_limbs: [BabyBear; NUM_LIMBS] = prime_limbs(&chip.0.core.air.expr).try_into().unwrap(); + let prime_limbs: [BabyBear; NUM_LIMBS] = prime_limbs(&chip.0.step.0.expr).try_into().unwrap(); let a_limbs = biguint_to_limbs::(a.clone(), LIMB_BITS).map(BabyBear::from_canonical_u32); let setup_instruction = rv32_write_heap_default( @@ -284,7 +280,7 @@ fn test_p256_double() { vec![prime_limbs, a_limbs], /* inputs[0] = prime, inputs[1] = a coeff of weierstrass * equation */ vec![], - chip.0.core.air.offset + Rv32WeierstrassOpcode::SETUP_EC_DOUBLE as usize, + chip.0.step.0.offset + Rv32WeierstrassOpcode::SETUP_EC_DOUBLE as usize, ); tester.execute(&mut chip, &setup_instruction); @@ -292,9 +288,12 @@ fn test_p256_double() { &mut tester, vec![p1_x_limbs, p1_y_limbs], vec![], - chip.0.core.air.offset + Rv32WeierstrassOpcode::EC_DOUBLE as usize, + chip.0.step.0.offset + Rv32WeierstrassOpcode::EC_DOUBLE as usize, ); + tester.execute(&mut chip, &instruction); + // Adding another row to make sure there are dummy rows, and that the dummy row constraints are + // satisfied tester.execute(&mut chip, &instruction); let tester = tester.build().load(chip).load(bitwise_chip).finalize(); diff --git a/extensions/ecc/circuit/src/weierstrass_extension.rs b/extensions/ecc/circuit/src/weierstrass_extension.rs index f0ec35e688..a0a41fbe18 100644 --- a/extensions/ecc/circuit/src/weierstrass_extension.rs +++ b/extensions/ecc/circuit/src/weierstrass_extension.rs @@ -5,10 +5,12 @@ use num_bigint::BigUint; use num_traits::{FromPrimitive, Zero}; use once_cell::sync::Lazy; use openvm_circuit::{ - arch::{SystemPort, VmExtension, VmInventory, VmInventoryBuilder, VmInventoryError}, + arch::{ + ExecutionBridge, SystemPort, VmExtension, VmInventory, VmInventoryBuilder, VmInventoryError, + }, system::phantom::PhantomChip, }; -use openvm_circuit_derive::{AnyEnum, InstructionExecutor}; +use openvm_circuit_derive::{AnyEnum, InsExecutorE1, InsExecutorE2, InstructionExecutor}; use openvm_circuit_primitives::bitwise_op_lookup::{ BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, }; @@ -16,7 +18,6 @@ use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; use openvm_ecc_transpiler::Rv32WeierstrassOpcode; use openvm_instructions::{LocalOpcode, VmOpcode}; use openvm_mod_circuit_builder::ExprBuilderConfig; -use openvm_rv32_adapters::Rv32VecHeapAdapterChip; use openvm_stark_backend::p3_field::PrimeField32; use serde::{Deserialize, Serialize}; use serde_with::{serde_as, DisplayFromStr}; @@ -24,6 +25,8 @@ use strum::EnumCount; use super::{EcAddNeChip, EcDoubleChip}; +// TODO: this should be decided after e2 execution + #[serde_as] #[derive(Clone, Debug, derive_new::new, Serialize, Deserialize)] pub struct CurveConfig { @@ -77,7 +80,7 @@ impl WeierstrassExtension { } } -#[derive(Chip, ChipUsageGetter, InstructionExecutor, AnyEnum)] +#[derive(Chip, ChipUsageGetter, InstructionExecutor, AnyEnum, InsExecutorE1, InsExecutorE2)] pub enum WeierstrassExtensionExecutor { // 32 limbs prime EcAddNeRv32_32(EcAddNeChip), @@ -107,6 +110,11 @@ impl VmExtension for WeierstrassExtension { program_bus, memory_bridge, } = builder.system_port(); + + let execution_bridge = ExecutionBridge::new(execution_bus, program_bus); + let range_checker = builder.system_base().range_checker_chip.clone(); + let pointer_max_bits = builder.system_config().memory_config.pointer_max_bits; + let bitwise_lu_chip = if let Some(&chip) = builder .find_chip::>() .first() @@ -118,9 +126,7 @@ impl VmExtension for WeierstrassExtension { inventory.add_periphery_chip(chip.clone()); chip }; - let offline_memory = builder.system_base().offline_memory(); - let range_checker = builder.system_base().range_checker_chip.clone(); - let pointer_bits = builder.system_config().memory_config.pointer_max_bits; + let ec_add_ne_opcodes = (Rv32WeierstrassOpcode::EC_ADD_NE as usize) ..=(Rv32WeierstrassOpcode::SETUP_EC_ADD_NE as usize); let ec_double_opcodes = (Rv32WeierstrassOpcode::EC_DOUBLE as usize) @@ -142,18 +148,16 @@ impl VmExtension for WeierstrassExtension { }; if bytes <= 32 { let add_ne_chip = EcAddNeChip::new( - Rv32VecHeapAdapterChip::::new( - execution_bus, - program_bus, - memory_bridge, - pointer_bits, - bitwise_lu_chip.clone(), - ), + execution_bridge, + memory_bridge, + builder.system_base().memory_controller.helper(), + pointer_max_bits, config32.clone(), start_offset, + bitwise_lu_chip.clone(), range_checker.clone(), - offline_memory.clone(), ); + inventory.add_executor( WeierstrassExtensionExecutor::EcAddNeRv32_32(add_ne_chip), ec_add_ne_opcodes @@ -161,18 +165,15 @@ impl VmExtension for WeierstrassExtension { .map(|x| VmOpcode::from_usize(x + start_offset)), )?; let double_chip = EcDoubleChip::new( - Rv32VecHeapAdapterChip::::new( - execution_bus, - program_bus, - memory_bridge, - pointer_bits, - bitwise_lu_chip.clone(), - ), - range_checker.clone(), + execution_bridge, + memory_bridge, + builder.system_base().memory_controller.helper(), + pointer_max_bits, config32.clone(), start_offset, + bitwise_lu_chip.clone(), + range_checker.clone(), curve.a.clone(), - offline_memory.clone(), ); inventory.add_executor( WeierstrassExtensionExecutor::EcDoubleRv32_32(double_chip), @@ -182,18 +183,16 @@ impl VmExtension for WeierstrassExtension { )?; } else if bytes <= 48 { let add_ne_chip = EcAddNeChip::new( - Rv32VecHeapAdapterChip::::new( - execution_bus, - program_bus, - memory_bridge, - pointer_bits, - bitwise_lu_chip.clone(), - ), + execution_bridge, + memory_bridge, + builder.system_base().memory_controller.helper(), + pointer_max_bits, config48.clone(), start_offset, + bitwise_lu_chip.clone(), range_checker.clone(), - offline_memory.clone(), ); + inventory.add_executor( WeierstrassExtensionExecutor::EcAddNeRv32_48(add_ne_chip), ec_add_ne_opcodes @@ -201,18 +200,15 @@ impl VmExtension for WeierstrassExtension { .map(|x| VmOpcode::from_usize(x + start_offset)), )?; let double_chip = EcDoubleChip::new( - Rv32VecHeapAdapterChip::::new( - execution_bus, - program_bus, - memory_bridge, - pointer_bits, - bitwise_lu_chip.clone(), - ), - range_checker.clone(), + execution_bridge, + memory_bridge, + builder.system_base().memory_controller.helper(), + pointer_max_bits, config48.clone(), start_offset, + bitwise_lu_chip.clone(), + range_checker.clone(), curve.a.clone(), - offline_memory.clone(), ); inventory.add_executor( WeierstrassExtensionExecutor::EcDoubleRv32_48(double_chip), diff --git a/extensions/ecc/tests/src/lib.rs b/extensions/ecc/tests/src/lib.rs index 1bd01eb936..b9ae366d82 100644 --- a/extensions/ecc/tests/src/lib.rs +++ b/extensions/ecc/tests/src/lib.rs @@ -11,7 +11,7 @@ mod tests { use openvm_algebra_transpiler::ModularTranspilerExtension; use openvm_circuit::{ arch::instructions::exe::VmExe, - utils::{air_test, air_test_with_min_segments}, + utils::{air_test, air_test_with_min_segments, test_system_config_with_continuations}, }; use openvm_ecc_circuit::{CurveConfig, Rv32WeierstrassConfig, P256_CONFIG, SECP256K1_CONFIG}; use openvm_ecc_transpiler::EccTranspilerExtension; @@ -35,9 +35,16 @@ mod tests { type F = BabyBear; + #[cfg(test)] + fn test_rv32weierstrass_config(curves: Vec) -> Rv32WeierstrassConfig { + let mut config = Rv32WeierstrassConfig::new(curves); + config.system = test_system_config_with_continuations(); + config + } + #[test] fn test_ec() -> Result<()> { - let config = Rv32WeierstrassConfig::new(vec![SECP256K1_CONFIG.clone()]); + let config = test_rv32weierstrass_config(vec![SECP256K1_CONFIG.clone()]); let elf = build_example_program_at_path_with_features( get_programs_dir!(), "ec", @@ -59,7 +66,7 @@ mod tests { #[test] fn test_ec_nonzero_a() -> Result<()> { - let config = Rv32WeierstrassConfig::new(vec![P256_CONFIG.clone()]); + let config = test_rv32weierstrass_config(vec![P256_CONFIG.clone()]); let elf = build_example_program_at_path_with_features( get_programs_dir!(), "ec_nonzero_a", @@ -82,7 +89,7 @@ mod tests { #[test] fn test_ec_two_curves() -> Result<()> { let config = - Rv32WeierstrassConfig::new(vec![SECP256K1_CONFIG.clone(), P256_CONFIG.clone()]); + test_rv32weierstrass_config(vec![SECP256K1_CONFIG.clone(), P256_CONFIG.clone()]); let elf = build_example_program_at_path_with_features( get_programs_dir!(), "ec_two_curves", @@ -106,8 +113,7 @@ mod tests { fn test_decompress() -> Result<()> { use halo2curves_axiom::{group::Curve, secp256k1::Secp256k1Affine}; - let config = - Rv32WeierstrassConfig::new(vec![SECP256K1_CONFIG.clone(), + let config = test_rv32weierstrass_config(vec![SECP256K1_CONFIG.clone(), CurveConfig { struct_name: "CurvePoint5mod8".to_string(), modulus: BigUint::from_str("115792089237316195423570985008687907853269984665640564039457584007913129639501") @@ -261,7 +267,7 @@ mod tests { ) .unwrap(); let config = - Rv32WeierstrassConfig::new(vec![SECP256K1_CONFIG.clone(), P256_CONFIG.clone()]); + test_rv32weierstrass_config(vec![SECP256K1_CONFIG.clone(), P256_CONFIG.clone()]); air_test(config, openvm_exe); } } diff --git a/extensions/keccak256/circuit/Cargo.toml b/extensions/keccak256/circuit/Cargo.toml index 941303ab39..2299a0599a 100644 --- a/extensions/keccak256/circuit/Cargo.toml +++ b/extensions/keccak256/circuit/Cargo.toml @@ -23,12 +23,10 @@ p3-keccak-air = { workspace = true } strum.workspace = true tiny-keccak.workspace = true itertools.workspace = true -tracing.workspace = true derive-new.workspace = true derive_more = { workspace = true, features = ["from"] } rand.workspace = true serde.workspace = true -serde-big-array.workspace = true [dev-dependencies] openvm-stark-sdk = { workspace = true } diff --git a/extensions/keccak256/circuit/src/extension.rs b/extensions/keccak256/circuit/src/extension.rs index 5993f69eda..eddf3d8a0c 100644 --- a/extensions/keccak256/circuit/src/extension.rs +++ b/extensions/keccak256/circuit/src/extension.rs @@ -1,3 +1,5 @@ +use std::result::Result; + use derive_more::derive::From; use openvm_circuit::{ arch::{ @@ -6,7 +8,7 @@ use openvm_circuit::{ }, system::phantom::PhantomChip, }; -use openvm_circuit_derive::{AnyEnum, InstructionExecutor, VmConfig}; +use openvm_circuit_derive::{AnyEnum, InsExecutorE1, InsExecutorE2, InstructionExecutor, VmConfig}; use openvm_circuit_primitives::bitwise_op_lookup::BitwiseOperationLookupBus; use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; use openvm_instructions::*; @@ -20,6 +22,8 @@ use strum::IntoEnumIterator; use crate::*; +// TODO: this should be decided after e2 execution + #[derive(Clone, Debug, VmConfig, derive_new::new, Serialize, Deserialize)] pub struct Keccak256Rv32Config { #[system] @@ -52,7 +56,9 @@ impl InitFileGenerator for Keccak256Rv32Config {} #[derive(Clone, Copy, Debug, Default, Serialize, Deserialize)] pub struct Keccak256; -#[derive(ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum)] +#[derive( + ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum, InsExecutorE1, InsExecutorE2, +)] pub enum Keccak256Executor { Keccak256(KeccakVmChip), } @@ -72,11 +78,8 @@ impl VmExtension for Keccak256 { builder: &mut VmInventoryBuilder, ) -> Result, VmInventoryError> { let mut inventory = VmInventory::new(); - let SystemPort { - execution_bus, - program_bus, - memory_bridge, - } = builder.system_port(); + let pointer_max_bits = builder.system_config().memory_config.pointer_max_bits; + let bitwise_lu_chip = if let Some(&chip) = builder .find_chip::>() .first() @@ -88,17 +91,26 @@ impl VmExtension for Keccak256 { inventory.add_periphery_chip(chip.clone()); chip }; - let offline_memory = builder.system_base().offline_memory(); - let address_bits = builder.system_config().memory_config.pointer_max_bits; - let keccak_chip = KeccakVmChip::new( + let SystemPort { execution_bus, program_bus, memory_bridge, - address_bits, - bitwise_lu_chip, - Rv32KeccakOpcode::CLASS_OFFSET, - offline_memory, + } = builder.system_port(); + let keccak_chip = KeccakVmChip::new( + KeccakVmAir::new( + ExecutionBridge::new(execution_bus, program_bus), + memory_bridge, + bitwise_lu_chip.bus(), + pointer_max_bits, + Rv32KeccakOpcode::CLASS_OFFSET, + ), + KeccakVmStep::new( + bitwise_lu_chip.clone(), + Rv32KeccakOpcode::CLASS_OFFSET, + pointer_max_bits, + ), + builder.system_base().memory_controller.helper(), ); inventory.add_executor( keccak_chip, diff --git a/extensions/keccak256/circuit/src/lib.rs b/extensions/keccak256/circuit/src/lib.rs index c9fd1c9f5a..a422535310 100644 --- a/extensions/keccak256/circuit/src/lib.rs +++ b/extensions/keccak256/circuit/src/lib.rs @@ -1,17 +1,9 @@ //! Stateful keccak256 hasher. Handles full keccak sponge (padding, absorb, keccak-f) on //! variable length inputs read from VM memory. -use std::{ - array::from_fn, - cmp::min, - sync::{Arc, Mutex}, -}; use openvm_circuit_primitives::bitwise_op_lookup::SharedBitwiseOperationLookupChip; use openvm_stark_backend::p3_field::PrimeField32; -use serde::{Deserialize, Serialize}; -use serde_big_array::BigArray; -use tiny_keccak::{Hasher, Keccak}; -use utils::num_keccak_f; +use p3_keccak_air::NUM_ROUNDS; pub mod air; pub mod columns; @@ -24,19 +16,25 @@ pub use extension::*; #[cfg(test)] mod tests; +use std::borrow::{Borrow, BorrowMut}; + pub use air::KeccakVmAir; -use openvm_circuit::{ - arch::{ExecutionBridge, ExecutionBus, ExecutionError, ExecutionState, InstructionExecutor}, - system::{ - memory::{offline_checker::MemoryBridge, MemoryController, OfflineMemory, RecordId}, - program::ProgramBus, - }, +use openvm_circuit::arch::{ + execution_mode::{E1ExecutionCtx, E2ExecutionCtx}, + E2PreCompute, ExecuteFunc, ExecutionBridge, + ExecutionError::InvalidInstruction, + MatrixRecordArena, NewVmChipWrapper, Result, StepExecutorE1, StepExecutorE2, VmSegmentState, }; +use openvm_circuit_primitives_derive::AlignedBytesBorrow; use openvm_instructions::{ - instruction::Instruction, program::DEFAULT_PC_STEP, riscv::RV32_REGISTER_NUM_LIMBS, LocalOpcode, + instruction::Instruction, + program::DEFAULT_PC_STEP, + riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS}, + LocalOpcode, }; use openvm_keccak256_transpiler::Rv32KeccakOpcode; -use openvm_rv32im_circuit::adapters::read_rv32_register; + +use crate::utils::{keccak256, num_keccak_f}; // ==== Constants for register/memory adapter ==== /// Register reads to get dst, src, len @@ -69,218 +67,163 @@ pub const KECCAK_DIGEST_BYTES: usize = 32; /// Number of 64-bit digest limbs. pub const KECCAK_DIGEST_U64S: usize = KECCAK_DIGEST_BYTES / 8; -pub struct KeccakVmChip { - pub air: KeccakVmAir, - /// IO and memory data necessary for each opcode call - pub records: Vec>, - pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<8>, - - offset: usize, - - offline_memory: Arc>>, -} - -#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] -pub struct KeccakRecord { - pub pc: F, - pub dst_read: RecordId, - pub src_read: RecordId, - pub len_read: RecordId, - pub input_blocks: Vec, - pub digest_writes: [RecordId; KECCAK_DIGEST_WRITES], -} +pub type KeccakVmChip = NewVmChipWrapper>; -#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] -pub struct KeccakInputBlock { - /// Memory reads for non-padding bytes in this block. Length is at most [KECCAK_RATE_BYTES / - /// KECCAK_WORD_SIZE]. - pub reads: Vec, - /// Index in `reads` of the memory read for < KECCAK_WORD_SIZE bytes, if any. - pub partial_read_idx: Option, - /// Bytes with padding. Can be derived from `bytes_read` but we store for convenience. - #[serde(with = "BigArray")] - pub padded_bytes: [u8; KECCAK_RATE_BYTES], - pub remaining_len: usize, - pub src: usize, - pub is_new_start: bool, +//#[derive(derive_new::new)] +pub struct KeccakVmStep { + pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<8>, + pub offset: usize, + pub pointer_max_bits: usize, } -impl KeccakVmChip { +impl KeccakVmStep { pub fn new( - execution_bus: ExecutionBus, - program_bus: ProgramBus, - memory_bridge: MemoryBridge, - address_bits: usize, bitwise_lookup_chip: SharedBitwiseOperationLookupChip<8>, offset: usize, - offline_memory: Arc>>, + pointer_max_bits: usize, ) -> Self { Self { - air: KeccakVmAir::new( - ExecutionBridge::new(execution_bus, program_bus), - memory_bridge, - bitwise_lookup_chip.bus(), - address_bits, - offset, - ), bitwise_lookup_chip, - records: Vec::new(), offset, - offline_memory, + pointer_max_bits, } } } -impl InstructionExecutor for KeccakVmChip { - fn execute( - &mut self, - memory: &mut MemoryController, - instruction: &Instruction, - from_state: ExecutionState, - ) -> Result, ExecutionError> { - let &Instruction { - opcode, - a, - b, - c, - d, - e, - .. - } = instruction; - let local_opcode = Rv32KeccakOpcode::from_usize(opcode.local_opcode_idx(self.offset)); - debug_assert_eq!(local_opcode, Rv32KeccakOpcode::KECCAK256); - - let mut timestamp_delta = 3; - let (dst_read, dst) = read_rv32_register(memory, d, a); - let (src_read, src) = read_rv32_register(memory, d, b); - let (len_read, len) = read_rv32_register(memory, d, c); - #[cfg(debug_assertions)] - { - assert!(dst < (1 << self.air.ptr_max_bits)); - assert!(src < (1 << self.air.ptr_max_bits)); - assert!(len < (1 << self.air.ptr_max_bits)); - } - - let mut remaining_len = len as usize; - let num_blocks = num_keccak_f(remaining_len); - let mut input_blocks = Vec::with_capacity(num_blocks); - let mut hasher = Keccak::v256(); - let mut src = src as usize; +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct KeccakPreCompute { + a: u8, + b: u8, + c: u8, +} - for block_idx in 0..num_blocks { - if block_idx != 0 { - memory.increment_timestamp_by(KECCAK_REGISTER_READS as u32); - timestamp_delta += KECCAK_REGISTER_READS as u32; - } - let mut reads = Vec::with_capacity(KECCAK_RATE_BYTES); +impl StepExecutorE1 for KeccakVmStep { + fn pre_compute_size(&self) -> usize { + size_of::() + } - let mut partial_read_idx = None; - let mut bytes = [0u8; KECCAK_RATE_BYTES]; - for i in (0..KECCAK_RATE_BYTES).step_by(KECCAK_WORD_SIZE) { - if i < remaining_len { - let read = - memory.read::(e, F::from_canonical_usize(src + i)); + fn pre_compute_e1( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result> + where + Ctx: E1ExecutionCtx, + { + let data: &mut KeccakPreCompute = data.borrow_mut(); + self.pre_compute_impl(pc, inst, data)?; + Ok(execute_e1_impl::<_, _>) + } +} - let chunk = read.1.map(|x| { - x.as_canonical_u32() - .try_into() - .expect("Memory cell not a byte") - }); - let copy_len = min(KECCAK_WORD_SIZE, remaining_len - i); - if copy_len != KECCAK_WORD_SIZE { - partial_read_idx = Some(reads.len()); - } - bytes[i..i + copy_len].copy_from_slice(&chunk[..copy_len]); - reads.push(read.0); - } else { - memory.increment_timestamp(); - } - timestamp_delta += 1; - } +impl StepExecutorE2 for KeccakVmStep { + fn e2_pre_compute_size(&self) -> usize { + size_of::>() + } - let mut block = KeccakInputBlock { - reads, - partial_read_idx, - padded_bytes: bytes, - remaining_len, - src, - is_new_start: block_idx == 0, - }; - if block_idx != num_blocks - 1 { - src += KECCAK_RATE_BYTES; - remaining_len -= KECCAK_RATE_BYTES; - hasher.update(&block.padded_bytes); - } else { - // handle padding here since it is convenient - debug_assert!(remaining_len < KECCAK_RATE_BYTES); - hasher.update(&block.padded_bytes[..remaining_len]); + fn pre_compute_e2( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result> + where + Ctx: E2ExecutionCtx, + { + let data: &mut E2PreCompute = data.borrow_mut(); + data.chip_idx = chip_idx as u32; + self.pre_compute_impl(pc, inst, &mut data.data)?; + Ok(execute_e2_impl::<_, _>) + } +} - if remaining_len == KECCAK_RATE_BYTES - 1 { - block.padded_bytes[remaining_len] = 0b1000_0001; - } else { - block.padded_bytes[remaining_len] = 0x01; - block.padded_bytes[KECCAK_RATE_BYTES - 1] = 0x80; - } - } - input_blocks.push(block); - } - let mut output = [0u8; 32]; - hasher.finalize(&mut output); - let dst = dst as usize; - let digest_writes: [_; KECCAK_DIGEST_WRITES] = from_fn(|i| { - timestamp_delta += 1; - memory - .write::( - e, - F::from_canonical_usize(dst + i * KECCAK_WORD_SIZE), - from_fn(|j| F::from_canonical_u8(output[i * KECCAK_WORD_SIZE + j])), +#[inline(always)] +unsafe fn execute_e12_impl( + pre_compute: &KeccakPreCompute, + vm_state: &mut VmSegmentState, +) -> u32 { + let dst = vm_state.vm_read(RV32_REGISTER_AS, pre_compute.a as u32); + let src = vm_state.vm_read(RV32_REGISTER_AS, pre_compute.b as u32); + let len = vm_state.vm_read(RV32_REGISTER_AS, pre_compute.c as u32); + let dst_u32 = u32::from_le_bytes(dst); + let src_u32 = u32::from_le_bytes(src); + let len_u32 = u32::from_le_bytes(len); + + let (output, height) = if IS_E1 { + // SAFETY: RV32_MEMORY_AS is memory address space of type u8 + let message = vm_state.vm_read_slice(RV32_MEMORY_AS, src_u32, len_u32 as usize); + let output = keccak256(message); + (output, 0) + } else { + let num_reads = (len_u32 as usize).div_ceil(KECCAK_WORD_SIZE); + let message: Vec<_> = (0..num_reads) + .flat_map(|i| { + vm_state.vm_read::( + RV32_MEMORY_AS, + src_u32 + (i * KECCAK_WORD_SIZE) as u32, ) - .0 - }); - tracing::trace!("[runtime] keccak256 output: {:?}", output); - - let record = KeccakRecord { - pc: F::from_canonical_u32(from_state.pc), - dst_read, - src_read, - len_read, - input_blocks, - digest_writes, - }; - - // Add the events to chip state for later trace generation usage - self.records.push(record); - - // NOTE: Check this is consistent with KeccakVmAir::timestamp_change (we don't use it to - // avoid unnecessary conversions here) - let total_timestamp_delta = - len + (KECCAK_REGISTER_READS + KECCAK_ABSORB_READS + KECCAK_DIGEST_WRITES) as u32; - memory.increment_timestamp_by(total_timestamp_delta - timestamp_delta); + }) + .collect(); + let output = keccak256(&message[..len_u32 as usize]); + let height = (num_keccak_f(len_u32 as usize) * NUM_ROUNDS) as u32; + (output, height) + }; + vm_state.vm_write(RV32_MEMORY_AS, dst_u32, &output); + + vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP); + vm_state.instret += 1; + + height +} - Ok(ExecutionState { - pc: from_state.pc + DEFAULT_PC_STEP, - timestamp: from_state.timestamp + total_timestamp_delta, - }) - } +unsafe fn execute_e1_impl( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &KeccakPreCompute = pre_compute.borrow(); + execute_e12_impl::(pre_compute, vm_state); +} - fn get_opcode_name(&self, _: usize) -> String { - "KECCAK256".to_string() - } +unsafe fn execute_e2_impl( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &E2PreCompute = pre_compute.borrow(); + let height = execute_e12_impl::(&pre_compute.data, vm_state); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, height); } -impl Default for KeccakInputBlock { - fn default() -> Self { - // Padding for empty byte array so padding constraints still hold - let mut padded_bytes = [0u8; KECCAK_RATE_BYTES]; - padded_bytes[0] = 0x01; - *padded_bytes.last_mut().unwrap() = 0x80; - Self { - padded_bytes, - partial_read_idx: None, - remaining_len: 0, - is_new_start: true, - reads: Vec::new(), - src: 0, +impl KeccakVmStep { + fn pre_compute_impl( + &self, + pc: u32, + inst: &Instruction, + data: &mut KeccakPreCompute, + ) -> Result<()> { + let Instruction { + opcode, + a, + b, + c, + d, + e, + .. + } = inst; + let e_u32 = e.as_canonical_u32(); + if d.as_canonical_u32() != RV32_REGISTER_AS || e_u32 != RV32_MEMORY_AS { + return Err(InvalidInstruction(pc)); } + *data = KeccakPreCompute { + a: a.as_canonical_u32() as u8, + b: b.as_canonical_u32() as u8, + c: c.as_canonical_u32() as u8, + }; + assert_eq!(&Rv32KeccakOpcode::KECCAK256.global_opcode(), opcode); + Ok(()) } } diff --git a/extensions/keccak256/circuit/src/tests.rs b/extensions/keccak256/circuit/src/tests.rs index 65a34491b8..206bffefa7 100644 --- a/extensions/keccak256/circuit/src/tests.rs +++ b/extensions/keccak256/circuit/src/tests.rs @@ -1,12 +1,22 @@ -use std::borrow::BorrowMut; +use std::{array, borrow::BorrowMut}; use hex::FromHex; -use openvm_circuit::arch::testing::{VmChipTestBuilder, VmChipTester, BITWISE_OP_LOOKUP_BUS}; +use openvm_circuit::{ + arch::{ + testing::{memory::gen_pointer, VmChipTestBuilder, VmChipTester, BITWISE_OP_LOOKUP_BUS}, + DenseRecordArena, InstructionExecutor, NewVmChipWrapper, + }, + utils::get_random_message, +}; use openvm_circuit_primitives::bitwise_op_lookup::{ BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, }; -use openvm_instructions::{instruction::Instruction, LocalOpcode}; -use openvm_keccak256_transpiler::Rv32KeccakOpcode; +use openvm_instructions::{ + instruction::Instruction, + riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS}, + LocalOpcode, +}; +use openvm_keccak256_transpiler::Rv32KeccakOpcode::{self, *}; use openvm_stark_backend::{ p3_field::FieldAlgebra, utils::disable_debug_builder, verifier::VerificationError, }; @@ -15,38 +25,107 @@ use openvm_stark_sdk::{ utils::create_seeded_rng, }; use p3_keccak_air::NUM_ROUNDS; -use rand::Rng; +use rand::{rngs::StdRng, Rng}; use tiny_keccak::Hasher; use super::{columns::KeccakVmCols, utils::num_keccak_f, KeccakVmChip}; +use crate::{trace::KeccakVmRecordLayout, utils::keccak256, KeccakVmAir, KeccakVmStep}; type F = BabyBear; +const MAX_INS_CAPACITY: usize = 8192; + +fn create_test_chips( + tester: &mut VmChipTestBuilder, +) -> (KeccakVmChip, SharedBitwiseOperationLookupChip<8>) { + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); + let bitwise_chip = SharedBitwiseOperationLookupChip::<8>::new(bitwise_bus); + let mut chip = KeccakVmChip::new( + KeccakVmAir::new( + tester.execution_bridge(), + tester.memory_bridge(), + bitwise_bus, + tester.address_bits(), + Rv32KeccakOpcode::CLASS_OFFSET, + ), + KeccakVmStep::new( + bitwise_chip.clone(), + Rv32KeccakOpcode::CLASS_OFFSET, + tester.address_bits(), + ), + tester.memory_helper(), + ); + chip.set_trace_buffer_height(MAX_INS_CAPACITY); + + (chip, bitwise_chip) +} + +fn set_and_execute>( + tester: &mut VmChipTestBuilder, + chip: &mut E, + rng: &mut StdRng, + opcode: Rv32KeccakOpcode, + message: Option<&[u8]>, + len: Option, +) { + let len = len.unwrap_or(rng.gen_range(1..3000)); + let tmp = get_random_message(rng, len); + let message: &[u8] = message.unwrap_or(&tmp); + + let rd = gen_pointer(rng, 4); + let rs1 = gen_pointer(rng, 4); + let rs2 = gen_pointer(rng, 4); + + let max_mem_ptr: u32 = 1 << tester.address_bits(); + let dst_ptr = rng.gen_range(0..max_mem_ptr); + let dst_ptr = dst_ptr ^ (dst_ptr & 3); + tester.write(1, rd, dst_ptr.to_le_bytes().map(F::from_canonical_u8)); + let src_ptr = rng.gen_range(0..(max_mem_ptr - len as u32)); + let src_ptr = src_ptr ^ (src_ptr & 3); + tester.write(1, rs1, src_ptr.to_le_bytes().map(F::from_canonical_u8)); + tester.write(1, rs2, len.to_le_bytes().map(F::from_canonical_u8)); + + message.chunks(4).enumerate().for_each(|(i, chunk)| { + let chunk: [&u8; 4] = array::from_fn(|i| chunk.get(i).unwrap_or(&0)); + tester.write( + 2, + src_ptr as usize + i * 4, + chunk.map(|&x| F::from_canonical_u8(x)), + ); + }); + + tester.execute( + chip, + &Instruction::from_usize(opcode.global_opcode(), [rd, rs1, rs2, 1, 2]), + ); + + let output = keccak256(message); + assert_eq!( + output.map(F::from_canonical_u8), + tester.read::<32>(2, dst_ptr as usize) + ); +} + // io is vector of (input, expected_output, prank_output) where prank_output is Some if the trace // will be replaced #[allow(clippy::type_complexity)] fn build_keccak256_test( io: Vec<(Vec, Option<[u8; 32]>, Option<[u8; 32]>)>, ) -> VmChipTester { - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::<8>::new(bitwise_bus); - + let mut rng = create_seeded_rng(); let mut tester = VmChipTestBuilder::default(); - let mut chip = KeccakVmChip::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - bitwise_chip.clone(), - Rv32KeccakOpcode::CLASS_OFFSET, - tester.offline_memory_mutex_arc(), - ); + let (mut chip, bitwise_chip) = create_test_chips(&mut tester); - let mut dst = 0; - let src = 0; + let max_mem_ptr = 1 << (tester.address_bits() - 3); + let mut dst = rng.gen_range(0..max_mem_ptr) << 2; + let src = rng.gen_range(0..max_mem_ptr) << 2; for (input, expected_output, _) in &io { - let [a, b, c] = [0, 4, 8]; // space apart for register limbs - let [d, e] = [1, 2]; + let [a, b, c] = [ + gen_pointer(&mut rng, 4), + gen_pointer(&mut rng, 4), + gen_pointer(&mut rng, 4), + ]; // space apart for register limbs + let [d, e] = [RV32_REGISTER_AS as usize, RV32_MEMORY_AS as usize]; tester.write(d, a, (dst as u32).to_le_bytes().map(F::from_canonical_u8)); tester.write(d, b, (src as u32).to_le_bytes().map(F::from_canonical_u8)); @@ -55,9 +134,15 @@ fn build_keccak256_test( c, (input.len() as u32).to_le_bytes().map(F::from_canonical_u8), ); - for (i, byte) in input.iter().enumerate() { - tester.write_cell(e, src + i, F::from_canonical_u8(*byte)); - } + + input.chunks(4).enumerate().for_each(|(i, chunk)| { + let chunk: [&u8; 4] = array::from_fn(|i| chunk.get(i).unwrap_or(&0)); + tester.write( + 2, + src as usize + i * 4, + chunk.map(|&x| F::from_canonical_u8(x)), + ); + }); tester.execute( &mut chip, @@ -71,13 +156,15 @@ fn build_keccak256_test( ), ); if let Some(output) = expected_output { - for (i, byte) in output.iter().enumerate() { - assert_eq!(tester.read_cell(e, dst + i), F::from_canonical_u8(*byte)); - } + assert_eq!( + output.map(F::from_canonical_u8), + tester.read::<32>(e, dst as usize) + ); } // shift dst to not deal with timestamps for pranking dst += 32; } + let mut tester = tester.build().load(chip).load(bitwise_chip).finalize(); let keccak_trace = tester.air_proof_inputs[2] @@ -113,21 +200,34 @@ fn build_keccak256_test( tester } +/////////////////////////////////////////////////////////////////////////////////////// +/// POSITIVE TESTS +/// +/// Randomly generate computations and execute, ensuring that the generated trace +/// passes all constraints. +/////////////////////////////////////////////////////////////////////////////////////// #[test] -fn test_keccak256_negative() { +fn rand_keccak256_test() { let mut rng = create_seeded_rng(); - let mut hasher = tiny_keccak::Keccak::v256(); - let input: Vec<_> = vec![0; 137]; - hasher.update(&input); - let mut out = [0u8; 32]; - hasher.finalize(&mut out); - out[0] = rng.gen(); - let tester = build_keccak256_test(vec![(input, None, Some(out))]); - disable_debug_builder(); - assert_eq!( - tester.simple_test().err(), - Some(VerificationError::OodEvaluationMismatch) + let mut tester = VmChipTestBuilder::default(); + let (mut chip, bitwise_chip) = create_test_chips(&mut tester); + + let num_ops: usize = 10; + for _ in 0..num_ops { + set_and_execute(&mut tester, &mut chip, &mut rng, KECCAK256, None, None); + } + + set_and_execute( + &mut tester, + &mut chip, + &mut rng, + KECCAK256, + None, + Some(10000), ); + + let tester = tester.build().load(chip).load(bitwise_chip).finalize(); + tester.simple_test().expect("Verification failed"); } // Keccak Known Answer Test (KAT) vectors from https://keccak.team/obsolete/KeccakKAT-3.zip. @@ -152,7 +252,98 @@ fn test_keccak256_positive_kat_vectors() { let output = Vec::from_hex(output).unwrap(); io.push((input, Some(output.try_into().unwrap()), None)); } - let tester = build_keccak256_test(io); + + tester.simple_test().expect("Verification failed"); +} + +////////////////////////////////////////////////////////////////////////////////////// +// NEGATIVE TESTS +// +// Given a fake trace of a single operation, setup a chip and run the test. We replace +// part of the trace and check that the chip throws the expected error. +////////////////////////////////////////////////////////////////////////////////////// +#[test] +fn test_keccak256_negative() { + let mut rng = create_seeded_rng(); + let mut hasher = tiny_keccak::Keccak::v256(); + let input: Vec<_> = vec![0; 137]; + hasher.update(&input); + let mut out = [0u8; 32]; + hasher.finalize(&mut out); + out[0] = rng.gen(); + let tester = build_keccak256_test(vec![(input, None, Some(out))]); + disable_debug_builder(); + assert_eq!( + tester.simple_test().err(), + Some(VerificationError::OodEvaluationMismatch) + ); +} + +/////////////////////////////////////////////////////////////////////////////////////// +/// DENSE TESTS +/// +/// Ensure that the chip works as expected with dense records. +/// We first execute some instructions with a [DenseRecordArena] and transfer the records +/// to a [MatrixRecordArena]. After transferring we generate the trace and make sure that +/// all the constraints pass. +/////////////////////////////////////////////////////////////////////////////////////// +type KeccakVmChipDense = NewVmChipWrapper; + +fn create_test_chip_dense(tester: &mut VmChipTestBuilder) -> KeccakVmChipDense { + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); + let bitwise_chip = SharedBitwiseOperationLookupChip::<8>::new(bitwise_bus); + + let mut chip = KeccakVmChipDense::new( + KeccakVmAir::new( + tester.execution_bridge(), + tester.memory_bridge(), + bitwise_bus, + tester.address_bits(), + Rv32KeccakOpcode::CLASS_OFFSET, + ), + KeccakVmStep::new( + bitwise_chip.clone(), + Rv32KeccakOpcode::CLASS_OFFSET, + tester.address_bits(), + ), + tester.memory_helper(), + ); + chip.set_trace_buffer_height(MAX_INS_CAPACITY); + chip +} + +#[test] +fn dense_record_arena_test() { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default(); + let (mut sparse_chip, bitwise_chip) = create_test_chips(&mut tester); + + { + let mut dense_chip = create_test_chip_dense(&mut tester); + + let num_ops: usize = 10; + for _ in 0..num_ops { + set_and_execute( + &mut tester, + &mut dense_chip, + &mut rng, + KECCAK256, + None, + None, + ); + } + + let mut record_interpreter = dense_chip + .arena + .get_record_seeker::<_, KeccakVmRecordLayout>(); + record_interpreter.transfer_to_matrix_arena(&mut sparse_chip.arena); + } + + let tester = tester + .build() + .load(sparse_chip) + .load(bitwise_chip) + .finalize(); tester.simple_test().expect("Verification failed"); } diff --git a/extensions/keccak256/circuit/src/trace.rs b/extensions/keccak256/circuit/src/trace.rs index c314c38eac..036ad734ed 100644 --- a/extensions/keccak256/circuit/src/trace.rs +++ b/extensions/keccak256/circuit/src/trace.rs @@ -1,16 +1,33 @@ -use std::{array::from_fn, borrow::BorrowMut, sync::Arc}; +use std::{ + array::{self, from_fn}, + borrow::{Borrow, BorrowMut}, + cmp::min, +}; -use openvm_circuit::system::memory::RecordId; -use openvm_instructions::riscv::{RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}; +use openvm_circuit::{ + arch::{ + get_record_from_slice, CustomBorrow, MultiRowLayout, MultiRowMetadata, RecordArena, Result, + SizedRecord, TraceFiller, TraceStep, VmStateMut, + }, + system::memory::{ + offline_checker::{MemoryReadAuxRecord, MemoryWriteBytesAuxRecord}, + online::TracingMemory, + MemoryAuxColsFactory, + }, +}; +use openvm_circuit_primitives::AlignedBytesBorrow; +use openvm_instructions::{ + instruction::Instruction, + program::DEFAULT_PC_STEP, + riscv::{RV32_CELL_BITS, RV32_MEMORY_AS, RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS}, + LocalOpcode, +}; +use openvm_keccak256_transpiler::Rv32KeccakOpcode; +use openvm_rv32im_circuit::adapters::{read_rv32_register, tracing_read, tracing_write}; use openvm_stark_backend::{ - config::{StarkGenericConfig, Val}, - p3_air::BaseAir, - p3_field::{FieldAlgebra, PrimeField32}, + p3_field::PrimeField32, p3_matrix::{dense::RowMajorMatrix, Matrix}, p3_maybe_rayon::prelude::*, - prover::types::AirProofInput, - rap::get_air_name, - AirRef, Chip, ChipUsageGetter, }; use p3_keccak_air::{ generate_trace_rows, NUM_KECCAK_COLS as NUM_KECCAK_PERM_COLS, NUM_ROUNDS, U64_LIMBS, @@ -18,258 +35,538 @@ use p3_keccak_air::{ use tiny_keccak::keccakf; use super::{ - columns::{KeccakInstructionCols, KeccakVmCols}, - KeccakVmChip, KECCAK_ABSORB_READS, KECCAK_DIGEST_WRITES, KECCAK_RATE_BYTES, KECCAK_RATE_U16S, + columns::KeccakVmCols, KECCAK_ABSORB_READS, KECCAK_DIGEST_WRITES, KECCAK_RATE_BYTES, KECCAK_REGISTER_READS, NUM_ABSORB_ROUNDS, }; +use crate::{ + columns::NUM_KECCAK_VM_COLS, + utils::{keccak256, keccak_f, num_keccak_f}, + KeccakVmStep, KECCAK_DIGEST_BYTES, KECCAK_RATE_U16S, KECCAK_WORD_SIZE, +}; + +#[derive(Clone, Copy)] +pub struct KeccakVmMetadata { + pub len: usize, +} -impl Chip for KeccakVmChip> -where - Val: PrimeField32, -{ - fn air(&self) -> AirRef { - Arc::new(self.air) +impl MultiRowMetadata for KeccakVmMetadata { + #[inline(always)] + fn get_num_rows(&self) -> usize { + num_keccak_f(self.len) * NUM_ROUNDS } +} + +pub(crate) type KeccakVmRecordLayout = MultiRowLayout; - fn generate_air_proof_input(self) -> AirProofInput { - let trace_width = self.trace_width(); - let records = self.records; - let total_num_blocks: usize = records.iter().map(|r| r.input_blocks.len()).sum(); - let mut states = Vec::with_capacity(total_num_blocks); - let mut instruction_blocks = Vec::with_capacity(total_num_blocks); - let memory = self.offline_memory.lock().unwrap(); - - #[derive(Clone)] - struct StateDiff { - /// hi-byte of pre-state - pre_hi: [u8; KECCAK_RATE_U16S], - /// hi-byte of post-state - post_hi: [u8; KECCAK_RATE_U16S], - /// if first block - register_reads: Option<[RecordId; KECCAK_REGISTER_READS]>, - /// if last block - digest_writes: Option<[RecordId; KECCAK_DIGEST_WRITES]>, +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug, Clone)] +pub struct KeccakVmRecordHeader { + pub from_pc: u32, + pub timestamp: u32, + pub rd_ptr: u32, + pub rs1_ptr: u32, + pub rs2_ptr: u32, + pub dst: u32, + pub src: u32, + pub len: u32, + + pub register_reads_aux: [MemoryReadAuxRecord; KECCAK_REGISTER_READS], + pub write_aux: [MemoryWriteBytesAuxRecord; KECCAK_DIGEST_WRITES], +} + +pub struct KeccakVmRecordMut<'a> { + pub inner: &'a mut KeccakVmRecordHeader, + // Having a continuous slice of the input is useful for fast hashing in `execute` + pub input: &'a mut [u8], + pub read_aux: &'a mut [MemoryReadAuxRecord], +} + +/// Custom borrowing that splits the buffer into a fixed `KeccakVmRecord` header +/// followed by a slice of `u8`'s of length `num_reads * KECCAK_WORD_SIZE` where `num_reads` is +/// provided at runtime, followed by a slice of `MemoryReadAuxRecord`'s of length `num_reads`. +/// Uses `align_to_mut()` to make sure the slice is properly aligned to `MemoryReadAuxRecord`. +/// Has debug assertions that check the size and alignment of the slices. +impl<'a> CustomBorrow<'a, KeccakVmRecordMut<'a>, KeccakVmRecordLayout> for [u8] { + fn custom_borrow(&'a mut self, layout: KeccakVmRecordLayout) -> KeccakVmRecordMut<'a> { + let (record_buf, rest) = + unsafe { self.split_at_mut_unchecked(size_of::()) }; + + let num_reads = layout.metadata.len.div_ceil(KECCAK_WORD_SIZE); + // Note: each read is `KECCAK_WORD_SIZE` bytes + let (input, rest) = unsafe { rest.split_at_mut_unchecked(num_reads * KECCAK_WORD_SIZE) }; + let (_, read_aux_buf, _) = unsafe { rest.align_to_mut::() }; + KeccakVmRecordMut { + inner: record_buf.borrow_mut(), + input, + read_aux: &mut read_aux_buf[..num_reads], } + } - impl Default for StateDiff { - fn default() -> Self { - Self { - pre_hi: [0; KECCAK_RATE_U16S], - post_hi: [0; KECCAK_RATE_U16S], - register_reads: None, - digest_writes: None, - } + unsafe fn extract_layout(&self) -> KeccakVmRecordLayout { + let header: &KeccakVmRecordHeader = self.borrow(); + KeccakVmRecordLayout { + metadata: KeccakVmMetadata { + len: header.len as usize, + }, + } + } +} + +impl SizedRecord for KeccakVmRecordMut<'_> { + fn size(layout: &KeccakVmRecordLayout) -> usize { + let num_reads = layout.metadata.len.div_ceil(KECCAK_WORD_SIZE); + let mut total_len = size_of::(); + total_len += num_reads * KECCAK_WORD_SIZE; + // Align the pointer to the alignment of `MemoryReadAuxRecord` + total_len = total_len.next_multiple_of(align_of::()); + total_len += num_reads * size_of::(); + total_len + } + + fn alignment(_layout: &KeccakVmRecordLayout) -> usize { + align_of::() + } +} + +impl TraceStep for KeccakVmStep { + type RecordLayout = KeccakVmRecordLayout; + type RecordMut<'a> = KeccakVmRecordMut<'a>; + + fn get_opcode_name(&self, _: usize) -> String { + format!("{:?}", Rv32KeccakOpcode::KECCAK256) + } + + fn execute<'buf, RA>( + &mut self, + state: VmStateMut, CTX>, + instruction: &Instruction, + arena: &'buf mut RA, + ) -> Result<()> + where + RA: RecordArena<'buf, Self::RecordLayout, Self::RecordMut<'buf>>, + { + let &Instruction { + opcode, + a, + b, + c, + d, + e, + .. + } = instruction; + debug_assert_eq!(opcode, Rv32KeccakOpcode::KECCAK256.global_opcode()); + debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); + debug_assert_eq!(e.as_canonical_u32(), RV32_MEMORY_AS); + + // Reading the length first without tracing to allocate a record of correct size + let len = read_rv32_register(state.memory.data(), c.as_canonical_u32()) as usize; + + let num_reads = len.div_ceil(KECCAK_WORD_SIZE); + let num_blocks = num_keccak_f(len); + let record = arena.alloc(KeccakVmRecordLayout::new(KeccakVmMetadata { len })); + + record.inner.from_pc = *state.pc; + record.inner.timestamp = state.memory.timestamp(); + record.inner.rd_ptr = a.as_canonical_u32(); + record.inner.rs1_ptr = b.as_canonical_u32(); + record.inner.rs2_ptr = c.as_canonical_u32(); + + record.inner.dst = u32::from_le_bytes(tracing_read( + state.memory, + RV32_REGISTER_AS, + record.inner.rd_ptr, + &mut record.inner.register_reads_aux[0].prev_timestamp, + )); + record.inner.src = u32::from_le_bytes(tracing_read( + state.memory, + RV32_REGISTER_AS, + record.inner.rs1_ptr, + &mut record.inner.register_reads_aux[1].prev_timestamp, + )); + record.inner.len = u32::from_le_bytes(tracing_read( + state.memory, + RV32_REGISTER_AS, + record.inner.rs2_ptr, + &mut record.inner.register_reads_aux[2].prev_timestamp, + )); + + debug_assert!(record.inner.src as usize + len <= (1 << self.pointer_max_bits)); + debug_assert!( + record.inner.dst as usize + KECCAK_DIGEST_BYTES <= (1 << self.pointer_max_bits) + ); + // We don't support messages longer than 2^[pointer_max_bits] bytes + debug_assert!(record.inner.len < (1 << self.pointer_max_bits)); + + for idx in 0..num_reads { + if idx % KECCAK_ABSORB_READS == 0 && idx != 0 { + // Need to increment the timestamp according at the start of each block due to the + // AIR constraints + state + .memory + .increment_timestamp_by(KECCAK_REGISTER_READS as u32); } + let read = tracing_read::<_, KECCAK_WORD_SIZE>( + state.memory, + RV32_MEMORY_AS, + record.inner.src + (idx * KECCAK_WORD_SIZE) as u32, + &mut record.read_aux[idx].prev_timestamp, + ); + record.input[idx * KECCAK_WORD_SIZE..(idx + 1) * KECCAK_WORD_SIZE] + .copy_from_slice(&read); } - // prepare the states - let mut state: [u64; 25]; - for record in records { - let dst_read = memory.record_by_id(record.dst_read); - let src_read = memory.record_by_id(record.src_read); - let len_read = memory.record_by_id(record.len_read); - - state = [0u64; 25]; - let src_limbs: [_; RV32_REGISTER_NUM_LIMBS - 1] = src_read.data_slice() - [1..RV32_REGISTER_NUM_LIMBS] - .try_into() - .unwrap(); - let len_limbs: [_; RV32_REGISTER_NUM_LIMBS - 1] = len_read.data_slice() - [1..RV32_REGISTER_NUM_LIMBS] - .try_into() - .unwrap(); - let mut instruction = KeccakInstructionCols { - pc: record.pc, - is_enabled: Val::::ONE, - is_enabled_first_round: Val::::ZERO, - start_timestamp: Val::::from_canonical_u32(dst_read.timestamp), - dst_ptr: dst_read.pointer, - src_ptr: src_read.pointer, - len_ptr: len_read.pointer, - dst: dst_read.data_slice().try_into().unwrap(), - src_limbs, - src: Val::::from_canonical_usize(record.input_blocks[0].src), - len_limbs, - remaining_len: Val::::from_canonical_usize( - record.input_blocks[0].remaining_len, - ), - }; - let num_blocks = record.input_blocks.len(); - for (idx, block) in record.input_blocks.into_iter().enumerate() { - // absorb - for (bytes, s) in block.padded_bytes.chunks_exact(8).zip(state.iter_mut()) { - // u64 <-> bytes conversion is little-endian - for (i, &byte) in bytes.iter().enumerate() { - let s_byte = (*s >> (i * 8)) as u8; - // Update bitwise lookup (i.e. xor) chip state: order matters! - if idx != 0 { - self.bitwise_lookup_chip - .request_xor(byte as u32, s_byte as u32); - } - *s ^= (byte as u64) << (i * 8); - } - } - let pre_hi: [u8; KECCAK_RATE_U16S] = - from_fn(|i| (state[i / U64_LIMBS] >> ((i % U64_LIMBS) * 16 + 8)) as u8); - states.push(state); - keccakf(&mut state); - let post_hi: [u8; KECCAK_RATE_U16S] = - from_fn(|i| (state[i / U64_LIMBS] >> ((i % U64_LIMBS) * 16 + 8)) as u8); - // Range check the final state - if idx == num_blocks - 1 { - for s in state.into_iter().take(NUM_ABSORB_ROUNDS) { - for s_byte in s.to_le_bytes() { - self.bitwise_lookup_chip.request_xor(0, s_byte as u32); - } - } - } - let register_reads = - (idx == 0).then_some([record.dst_read, record.src_read, record.len_read]); - let digest_writes = (idx == num_blocks - 1).then_some(record.digest_writes); - let diff = StateDiff { - pre_hi, - post_hi, - register_reads, - digest_writes, - }; - instruction_blocks.push((instruction, diff, block)); - instruction.remaining_len -= Val::::from_canonical_usize(KECCAK_RATE_BYTES); - instruction.src += Val::::from_canonical_usize(KECCAK_RATE_BYTES); - instruction.start_timestamp += - Val::::from_canonical_usize(KECCAK_REGISTER_READS + KECCAK_ABSORB_READS); + // Due to the AIR constraints, need to set the timestamp to the following: + state.memory.timestamp = record.inner.timestamp + + (num_blocks * (KECCAK_ABSORB_READS + KECCAK_REGISTER_READS)) as u32; + + let digest = keccak256(&record.input[..len]); + for (i, word) in digest.chunks_exact(KECCAK_WORD_SIZE).enumerate() { + tracing_write::<_, KECCAK_WORD_SIZE>( + state.memory, + RV32_MEMORY_AS, + record.inner.dst + (i * KECCAK_WORD_SIZE) as u32, + word.try_into().unwrap(), + &mut record.inner.write_aux[i].prev_timestamp, + &mut record.inner.write_aux[i].prev_data, + ); + } + + // Due to the AIR constraints, the final memory timestamp should be the following: + state.memory.timestamp = record.inner.timestamp + + (len + KECCAK_REGISTER_READS + KECCAK_ABSORB_READS + KECCAK_DIGEST_WRITES) as u32; + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + Ok(()) + } +} + +impl TraceFiller for KeccakVmStep { + fn fill_trace( + &self, + mem_helper: &MemoryAuxColsFactory, + trace_matrix: &mut RowMajorMatrix, + rows_used: usize, + ) { + if rows_used == 0 { + return; + } + + let mut chunks = Vec::with_capacity(trace_matrix.height() / NUM_ROUNDS); + let mut sizes = Vec::with_capacity(trace_matrix.height() / NUM_ROUNDS); + let mut trace = &mut trace_matrix.values[..]; + let mut num_blocks_so_far = 0; + + // First pass over the trace to get the number of blocks for each instruction + // and divide the matrix into chunks of needed sizes + loop { + if num_blocks_so_far * NUM_ROUNDS >= rows_used { + // Push all the dummy rows as a single chunk and break + chunks.push(trace); + sizes.push((0, 0)); + break; + } else { + let record: &KeccakVmRecordHeader = + unsafe { get_record_from_slice(&mut trace, ()) }; + let num_blocks = num_keccak_f(record.len as usize); + let (chunk, rest) = + trace.split_at_mut(NUM_KECCAK_VM_COLS * NUM_ROUNDS * num_blocks); + chunks.push(chunk); + sizes.push((num_blocks, record.len as usize)); + num_blocks_so_far += num_blocks; + trace = rest; } } - // We need to transpose state matrices due to a plonky3 issue: https://github.com/Plonky3/Plonky3/issues/672 - // Note: the fix for this issue will be a commit after the major Field crate refactor PR https://github.com/Plonky3/Plonky3/pull/640 - // which will require a significant refactor to switch to. - let p3_states = states - .into_iter() - .map(|state| { - // transpose of 5x5 matrix - from_fn(|i| { - let x = i / 5; - let y = i % 5; - state[x + 5 * y] - }) - }) - .collect(); - let p3_keccak_trace: RowMajorMatrix> = generate_trace_rows(p3_states, 0); - let num_rows = p3_keccak_trace.height(); - // Every `NUM_ROUNDS` rows corresponds to one input block - let num_blocks = num_rows.div_ceil(NUM_ROUNDS); - // Resize with dummy `is_enabled = 0` - instruction_blocks.resize(num_blocks, Default::default()); - - let aux_cols_factory = memory.aux_cols_factory(); - - // Use unsafe alignment so we can parallelly write to the matrix - let mut trace = - RowMajorMatrix::new(Val::::zero_vec(num_rows * trace_width), trace_width); - let limb_shift_bits = RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.air.ptr_max_bits; - - trace - .values - .par_chunks_mut(trace_width * NUM_ROUNDS) - .zip( - p3_keccak_trace - .values - .par_chunks(NUM_KECCAK_PERM_COLS * NUM_ROUNDS), - ) - .zip(instruction_blocks.into_par_iter()) - .for_each(|((rows, p3_keccak_mat), (instruction, diff, block))| { - let height = rows.len() / trace_width; - for (row, p3_keccak_row) in rows - .chunks_exact_mut(trace_width) - .zip(p3_keccak_mat.chunks_exact(NUM_KECCAK_PERM_COLS)) - { - // Safety: `KeccakPermCols` **must** be the first field in `KeccakVmCols` - row[..NUM_KECCAK_PERM_COLS].copy_from_slice(p3_keccak_row); - let row_mut: &mut KeccakVmCols> = row.borrow_mut(); - row_mut.instruction = instruction; - - row_mut.sponge.block_bytes = - block.padded_bytes.map(Val::::from_canonical_u8); - if let Some(partial_read_idx) = block.partial_read_idx { - let partial_read = memory.record_by_id(block.reads[partial_read_idx]); - row_mut - .mem_oc - .partial_block - .copy_from_slice(&partial_read.data_slice()[1..]); - } - for (i, is_padding) in row_mut.sponge.is_padding_byte.iter_mut().enumerate() { - *is_padding = Val::::from_bool(i >= block.remaining_len); - } - } - let first_row: &mut KeccakVmCols> = rows[..trace_width].borrow_mut(); - first_row.sponge.is_new_start = Val::::from_bool(block.is_new_start); - first_row.sponge.state_hi = diff.pre_hi.map(Val::::from_canonical_u8); - first_row.instruction.is_enabled_first_round = first_row.instruction.is_enabled; - // Make memory access aux columns. Any aux column not explicitly defined defaults to - // all 0s - if let Some(register_reads) = diff.register_reads { - let need_range_check = [ - ®ister_reads[0], // dst - ®ister_reads[1], // src - ®ister_reads[2], // len - ®ister_reads[2], - ] - .map(|r| { - memory - .record_by_id(*r) - .data_slice() - .last() - .unwrap() - .as_canonical_u32() - }); - for bytes in need_range_check.chunks(2) { - self.bitwise_lookup_chip.request_range( - bytes[0] << limb_shift_bits, - bytes[1] << limb_shift_bits, - ); - } - for (i, id) in register_reads.into_iter().enumerate() { - aux_cols_factory.generate_read_aux( - memory.record_by_id(id), - &mut first_row.mem_oc.register_aux[i], - ); - } - } - for (i, id) in block.reads.into_iter().enumerate() { - aux_cols_factory.generate_read_aux( - memory.record_by_id(id), - &mut first_row.mem_oc.absorb_reads[i], - ); + // First, parallelize over instruction chunks, every instruction can have multiple blocks + // Then, compute some additional values for each block and parallelize over blocks within an + // instruction Finally, compute some additional values for each row and parallelize + // over rows within a block + chunks + .par_iter_mut() + .zip(sizes.par_iter()) + .for_each(|(slice, (num_blocks, len))| { + if *num_blocks == 0 { + // Fill in the dummy rows in parallel + // Note: a 'block' of dummy rows is generated by `generate_trace_rows` from the + // zero state dummy rows are repeated every + // `NUM_ROUNDS` rows + let p3_trace: RowMajorMatrix = generate_trace_rows(vec![[0u64; 25]; 1], 0); + + slice + .par_chunks_exact_mut(NUM_KECCAK_VM_COLS) + .enumerate() + .for_each(|(row_idx, row)| { + let idx = row_idx % NUM_ROUNDS; + row[..NUM_KECCAK_PERM_COLS].copy_from_slice( + &p3_trace.values + [idx * NUM_KECCAK_PERM_COLS..(idx + 1) * NUM_KECCAK_PERM_COLS], + ); + + // Need to get rid of the accidental garbage data that might overflow + // the F's prime field. Unfortunately, there + // is no good way around this + unsafe { + std::ptr::write_bytes( + row.as_mut_ptr().add(NUM_KECCAK_PERM_COLS) as *mut u8, + 0, + (NUM_KECCAK_VM_COLS - NUM_KECCAK_PERM_COLS) * size_of::(), + ); + } + let cols: &mut KeccakVmCols = row.borrow_mut(); + // The first row of a `dummy` block should have `is_new_start = F::ONE` + cols.sponge.is_new_start = F::from_bool(idx == 0); + cols.sponge.block_bytes[0] = F::ONE; + cols.sponge.block_bytes[KECCAK_RATE_BYTES - 1] = + F::from_canonical_u32(0x80); + cols.sponge.is_padding_byte = [F::ONE; KECCAK_RATE_BYTES]; + }); + return; } - let last_row: &mut KeccakVmCols> = - rows[(height - 1) * trace_width..].borrow_mut(); - last_row.sponge.state_hi = diff.post_hi.map(Val::::from_canonical_u8); - last_row.inner.export = instruction.is_enabled - * Val::::from_bool(block.remaining_len < KECCAK_RATE_BYTES); - if let Some(digest_writes) = diff.digest_writes { - for (i, record_id) in digest_writes.into_iter().enumerate() { - let record = memory.record_by_id(record_id); - aux_cols_factory - .generate_write_aux(record, &mut last_row.mem_oc.digest_writes[i]); - } + let num_reads = len.div_ceil(KECCAK_WORD_SIZE); + let read_len = num_reads * KECCAK_WORD_SIZE; + + let record: KeccakVmRecordMut = unsafe { + get_record_from_slice( + slice, + KeccakVmRecordLayout::new(KeccakVmMetadata { len: *len }), + ) + }; + + // Copy the read aux records and inner record to another place + // to safely fill in the trace matrix without overwriting the record + let mut read_aux_records = Vec::with_capacity(num_reads); + read_aux_records.extend_from_slice(record.read_aux); + let vm_record = record.inner.clone(); + let partial_block = if read_len != *len { + record.input[read_len - KECCAK_WORD_SIZE + 1..] + .try_into() + .unwrap() + } else { + [0u8; KECCAK_WORD_SIZE - 1] } - }); + .map(F::from_canonical_u8); + let mut input = Vec::with_capacity(*num_blocks * KECCAK_RATE_BYTES); + input.extend_from_slice(&record.input[..*len]); + // Pad the input according to the Keccak spec + input.push(0x01); + input.resize(input.capacity(), 0); + *input.last_mut().unwrap() += 0x80; - AirProofInput::simple_no_pis(trace) - } -} + let mut states = Vec::with_capacity(*num_blocks); + let mut state = [0u64; 25]; -impl ChipUsageGetter for KeccakVmChip { - fn air_name(&self) -> String { - get_air_name(&self.air) - } - fn current_trace_height(&self) -> usize { - let num_blocks: usize = self.records.iter().map(|r| r.input_blocks.len()).sum(); - num_blocks * NUM_ROUNDS - } + input + .chunks_exact(KECCAK_RATE_BYTES) + .enumerate() + .for_each(|(idx, chunk)| { + // absorb + for (bytes, s) in chunk.chunks_exact(8).zip(state.iter_mut()) { + // u64 <-> bytes conversion is little-endian + for (i, &byte) in bytes.iter().enumerate() { + let s_byte = (*s >> (i * 8)) as u8; + // Update bitwise lookup (i.e. xor) chip state: order matters! + if idx != 0 { + self.bitwise_lookup_chip + .request_xor(byte as u32, s_byte as u32); + } + *s ^= (byte as u64) << (i * 8); + } + } + states.push(state); + keccakf(&mut state); + }); + + slice + .par_chunks_exact_mut(NUM_ROUNDS * NUM_KECCAK_VM_COLS) + .enumerate() + .for_each(|(block_idx, block_slice)| { + // We need to transpose state matrices due to a plonky3 issue: https://github.com/Plonky3/Plonky3/issues/672 + // Note: the fix for this issue will be a commit after the major Field crate refactor PR https://github.com/Plonky3/Plonky3/pull/640 + // which will require a significant refactor to switch to. + let state = from_fn(|i| { + let x = i / 5; + let y = i % 5; + states[block_idx][x + 5 * y] + }); + + // Note: we can call `generate_trace_rows` for each block separately because + // its trace only depends on the current `state` + // `generate_trace_rows` will generate additional dummy rows to make the + // height into power of 2, but we can safely discard them + let p3_trace: RowMajorMatrix = generate_trace_rows(vec![state], 0); + let input_offset = block_idx * KECCAK_RATE_BYTES; + let start_timestamp = vm_record.timestamp + + (block_idx * (KECCAK_REGISTER_READS + KECCAK_ABSORB_READS)) as u32; + let rem_len = *len - input_offset; + + block_slice + .par_chunks_exact_mut(NUM_KECCAK_VM_COLS) + .enumerate() + .zip(p3_trace.values.par_chunks(NUM_KECCAK_PERM_COLS)) + .for_each(|((row_idx, row), p3_row)| { + // Fill the inner columns + // Safety: `KeccakPermCols` **must** be the first field in + // `KeccakVmCols` + row[..NUM_KECCAK_PERM_COLS].copy_from_slice(p3_row); - fn trace_width(&self) -> usize { - BaseAir::::width(&self.air) + let cols: &mut KeccakVmCols = row.borrow_mut(); + // Fill the sponge columns + cols.sponge.is_new_start = + F::from_bool(block_idx == 0 && row_idx == 0); + if rem_len < KECCAK_RATE_BYTES { + cols.sponge.is_padding_byte[..rem_len].fill(F::ZERO); + cols.sponge.is_padding_byte[rem_len..].fill(F::ONE); + } else { + cols.sponge.is_padding_byte = [F::ZERO; KECCAK_RATE_BYTES]; + } + cols.sponge.block_bytes = array::from_fn(|i| { + F::from_canonical_u8(input[input_offset + i]) + }); + if row_idx == 0 { + cols.sponge.state_hi = from_fn(|i| { + F::from_canonical_u8( + (states[block_idx][i / U64_LIMBS] + >> ((i % U64_LIMBS) * 16 + 8)) + as u8, + ) + }); + } else if row_idx == NUM_ROUNDS - 1 { + let state = keccak_f(states[block_idx]); + cols.sponge.state_hi = from_fn(|i| { + F::from_canonical_u8( + (state[i / U64_LIMBS] >> ((i % U64_LIMBS) * 16 + 8)) + as u8, + ) + }); + if block_idx == num_blocks - 1 { + cols.inner.export = F::ONE; + for s in state.into_iter().take(NUM_ABSORB_ROUNDS) { + for s_byte in s.to_le_bytes() { + self.bitwise_lookup_chip + .request_xor(0, s_byte as u32); + } + } + } + } else { + cols.sponge.state_hi = [F::ZERO; KECCAK_RATE_U16S]; + } + + // Fill the instruction columns + cols.instruction.pc = F::from_canonical_u32(vm_record.from_pc); + cols.instruction.is_enabled = F::ONE; + cols.instruction.is_enabled_first_round = + F::from_bool(row_idx == 0); + cols.instruction.start_timestamp = + F::from_canonical_u32(start_timestamp); + cols.instruction.dst_ptr = F::from_canonical_u32(vm_record.rd_ptr); + cols.instruction.src_ptr = F::from_canonical_u32(vm_record.rs1_ptr); + cols.instruction.len_ptr = F::from_canonical_u32(vm_record.rs2_ptr); + cols.instruction.dst = + vm_record.dst.to_le_bytes().map(F::from_canonical_u8); + + let src = vm_record.src + (block_idx * KECCAK_RATE_BYTES) as u32; + cols.instruction.src = F::from_canonical_u32(src); + cols.instruction.src_limbs.copy_from_slice( + &src.to_le_bytes().map(F::from_canonical_u8)[1..], + ); + cols.instruction.len_limbs.copy_from_slice( + &(rem_len as u32).to_le_bytes().map(F::from_canonical_u8)[1..], + ); + cols.instruction.remaining_len = + F::from_canonical_u32(rem_len as u32); + + // Fill the register reads + if row_idx == 0 && block_idx == 0 { + for ((i, cols), vm_record) in cols + .mem_oc + .register_aux + .iter_mut() + .enumerate() + .zip(vm_record.register_reads_aux.iter()) + { + mem_helper.fill( + vm_record.prev_timestamp, + start_timestamp + i as u32, + cols.as_mut(), + ); + } + + let msl_rshift = RV32_CELL_BITS * (RV32_REGISTER_NUM_LIMBS - 1); + let msl_lshift = RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS + - self.pointer_max_bits; + // Update the bitwise lookup chip + self.bitwise_lookup_chip.request_range( + (vm_record.dst >> msl_rshift) << msl_lshift, + (vm_record.src >> msl_rshift) << msl_lshift, + ); + self.bitwise_lookup_chip.request_range( + (vm_record.len >> msl_rshift) << msl_lshift, + (vm_record.len >> msl_rshift) << msl_lshift, + ); + } else { + cols.mem_oc.register_aux.par_iter_mut().for_each(|aux| { + mem_helper.fill_zero(aux.as_mut()); + }); + } + + // Fill the absorb reads + if row_idx == 0 { + let reads_offs = block_idx * KECCAK_ABSORB_READS; + let num_reads = min( + rem_len.div_ceil(KECCAK_WORD_SIZE), + KECCAK_ABSORB_READS, + ); + let start_timestamp = + start_timestamp + KECCAK_REGISTER_READS as u32; + for i in 0..num_reads { + mem_helper.fill( + read_aux_records[i + reads_offs].prev_timestamp, + start_timestamp + i as u32, + cols.mem_oc.absorb_reads[i].as_mut(), + ); + } + for i in num_reads..KECCAK_ABSORB_READS { + mem_helper.fill_zero(cols.mem_oc.absorb_reads[i].as_mut()); + } + } else { + cols.mem_oc.absorb_reads.par_iter_mut().for_each(|aux| { + mem_helper.fill_zero(aux.as_mut()); + }); + } + + if block_idx == num_blocks - 1 && row_idx == NUM_ROUNDS - 1 { + let timestamp = start_timestamp + + (KECCAK_ABSORB_READS + KECCAK_REGISTER_READS) as u32; + cols.mem_oc + .digest_writes + .par_iter_mut() + .enumerate() + .zip(vm_record.write_aux.par_iter()) + .for_each(|((i, cols), vm_record)| { + cols.set_prev_data( + vm_record.prev_data.map(F::from_canonical_u8), + ); + mem_helper.fill( + vm_record.prev_timestamp, + timestamp + i as u32, + cols.as_mut(), + ); + }); + } else { + cols.mem_oc.digest_writes.par_iter_mut().for_each(|aux| { + aux.set_prev_data([F::ZERO; KECCAK_WORD_SIZE]); + mem_helper.fill_zero(aux.as_mut()); + }); + } + + // Set the partial block only for the last block + if block_idx == num_blocks - 1 { + cols.mem_oc.partial_block = partial_block; + } else { + cols.mem_oc.partial_block = [F::ZERO; KECCAK_WORD_SIZE - 1]; + } + }); + }); + }); } } diff --git a/extensions/keccak256/guest/src/lib.rs b/extensions/keccak256/guest/src/lib.rs index 7e2bb3da54..acfeea785b 100644 --- a/extensions/keccak256/guest/src/lib.rs +++ b/extensions/keccak256/guest/src/lib.rs @@ -1,5 +1,10 @@ #![no_std] +#[cfg(target_os = "zkvm")] +extern crate alloc; +#[cfg(target_os = "zkvm")] +use openvm_platform::alloc::AlignedBuf; + /// This is custom-0 defined in RISC-V spec document pub const OPCODE: u8 = 0x0b; pub const KECCAK256_FUNCT3: u8 = 0b100; @@ -21,6 +26,43 @@ pub const KECCAK256_FUNCT7: u8 = 0; #[inline(always)] #[no_mangle] pub extern "C" fn native_keccak256(bytes: *const u8, len: usize, output: *mut u8) { + // SAFETY: assuming safety assumptions of the inputs, we handle all cases where `bytes` or + // `output` are not aligned to 4 bytes. + const MIN_ALIGN: usize = 4; + unsafe { + if bytes as usize % MIN_ALIGN != 0 { + let aligned_buff = AlignedBuf::new(bytes, len, MIN_ALIGN); + if output as usize % MIN_ALIGN != 0 { + let aligned_out = AlignedBuf::uninit(32, MIN_ALIGN); + __native_keccak256(aligned_buff.ptr, len, aligned_out.ptr); + core::ptr::copy_nonoverlapping(aligned_out.ptr as *const u8, output, 32); + } else { + __native_keccak256(aligned_buff.ptr, len, output); + } + } else { + if output as usize % MIN_ALIGN != 0 { + let aligned_out = AlignedBuf::uninit(32, MIN_ALIGN); + __native_keccak256(bytes, len, aligned_out.ptr); + core::ptr::copy_nonoverlapping(aligned_out.ptr as *const u8, output, 32); + } else { + __native_keccak256(bytes, len, output); + } + }; + } +} + +/// keccak256 intrinsic binding +/// +/// # Safety +/// +/// The VM accepts the preimage by pointer and length, and writes the +/// 32-byte hash. +/// - `bytes` must point to an input buffer at least `len` long. +/// - `output` must point to a buffer that is at least 32-bytes long. +/// - `bytes` and `output` must be 4-byte aligned. +#[cfg(target_os = "zkvm")] +#[inline(always)] +fn __native_keccak256(bytes: *const u8, len: usize, output: *mut u8) { openvm_platform::custom_insn_r!( opcode = OPCODE, funct3 = KECCAK256_FUNCT3, diff --git a/extensions/native/circuit/Cargo.toml b/extensions/native/circuit/Cargo.toml index 5d5913b4be..67c3981ba7 100644 --- a/extensions/native/circuit/Cargo.toml +++ b/extensions/native/circuit/Cargo.toml @@ -17,23 +17,23 @@ openvm-circuit = { workspace = true } openvm-circuit-derive = { workspace = true } openvm-instructions = { workspace = true } openvm-rv32im-circuit = { workspace = true } +openvm-rv32im-transpiler = { workspace = true } openvm-native-compiler = { workspace = true } strum.workspace = true itertools.workspace = true -tracing.workspace = true derive-new.workspace = true derive_more = { workspace = true, features = ["from"] } rand.workspace = true eyre.workspace = true serde.workspace = true -serde-big-array.workspace = true static_assertions.workspace = true [dev-dependencies] openvm-stark-sdk = { workspace = true } openvm-circuit = { workspace = true, features = ["test-utils"] } +test-case.workspace = true [features] default = ["parallel"] diff --git a/extensions/native/circuit/examples/fibonacci.rs b/extensions/native/circuit/examples/fibonacci.rs index aca5e2d6c5..8dfb29a835 100644 --- a/extensions/native/circuit/examples/fibonacci.rs +++ b/extensions/native/circuit/examples/fibonacci.rs @@ -47,6 +47,6 @@ fn main() { builder.halt(); let program = builder.compile_isa(); - println!("{}", program); + println!("{program}"); execute_program(program, vec![]); } diff --git a/extensions/native/circuit/src/adapters/alu_native_adapter.rs b/extensions/native/circuit/src/adapters/alu_native_adapter.rs index e85797536f..0eb87bd2e3 100644 --- a/extensions/native/circuit/src/adapters/alu_native_adapter.rs +++ b/extensions/native/circuit/src/adapters/alu_native_adapter.rs @@ -1,23 +1,26 @@ use std::{ borrow::{Borrow, BorrowMut}, - marker::PhantomData, + mem::size_of, }; use openvm_circuit::{ arch::{ - AdapterAirContext, AdapterRuntimeContext, BasicAdapterInterface, ExecutionBridge, - ExecutionBus, ExecutionState, MinimalInstruction, Result, VmAdapterAir, VmAdapterChip, - VmAdapterInterface, + get_record_from_slice, AdapterAirContext, AdapterTraceFiller, AdapterTraceStep, + BasicAdapterInterface, ExecutionBridge, ExecutionState, MinimalInstruction, VmAdapterAir, }, system::{ memory::{ - offline_checker::{MemoryBridge, MemoryReadOrImmediateAuxCols, MemoryWriteAuxCols}, - MemoryAddress, MemoryController, OfflineMemory, + offline_checker::{ + MemoryBridge, MemoryReadAuxRecord, MemoryReadOrImmediateAuxCols, + MemoryWriteAuxCols, MemoryWriteAuxRecord, + }, + online::TracingMemory, + MemoryAddress, MemoryAuxColsFactory, }, - native_adapter::{NativeReadRecord, NativeWriteRecord}, - program::ProgramBus, + native_adapter::util::{tracing_read_or_imm_native, tracing_write_native}, }, }; +use openvm_circuit_primitives::AlignedBytesBorrow; use openvm_circuit_primitives_derive::AlignedBorrow; use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP}; use openvm_native_compiler::conversion::AS; @@ -27,28 +30,6 @@ use openvm_stark_backend::{ p3_field::{Field, FieldAlgebra, PrimeField32}, }; -#[derive(Debug)] -pub struct AluNativeAdapterChip { - pub air: AluNativeAdapterAir, - _marker: PhantomData, -} - -impl AluNativeAdapterChip { - pub fn new( - execution_bus: ExecutionBus, - program_bus: ProgramBus, - memory_bridge: MemoryBridge, - ) -> Self { - Self { - air: AluNativeAdapterAir { - execution_bridge: ExecutionBridge::new(execution_bus, program_bus), - memory_bridge, - }, - _marker: PhantomData, - } - } -} - #[repr(C)] #[derive(AlignedBorrow)] pub struct AluNativeAdapterCols { @@ -93,6 +74,8 @@ impl VmAdapterAir for AluNativeAdapterAir { let native_as = AB::Expr::from_canonical_u32(AS::Native as u32); + // TODO: we assume address space is either 0 or 4, should we add a + // constraint for that? self.memory_bridge .read_or_immediate( MemoryAddress::new(cols.e_as, cols.b_pointer), @@ -144,88 +127,136 @@ impl VmAdapterAir for AluNativeAdapterAir { } } -impl VmAdapterChip for AluNativeAdapterChip { - type ReadRecord = NativeReadRecord; - type WriteRecord = NativeWriteRecord; - type Air = AluNativeAdapterAir; - type Interface = BasicAdapterInterface, 2, 1, 1, 1>; +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug)] +pub struct AluNativeAdapterRecord { + pub from_pc: u32, + pub from_timestamp: u32, - fn preprocess( - &mut self, - memory: &mut MemoryController, - instruction: &Instruction, - ) -> Result<( - >::Reads, - Self::ReadRecord, - )> { - let Instruction { b, c, e, f, .. } = *instruction; - - let reads = vec![memory.read::<1>(e, b), memory.read::<1>(f, c)]; - let i_reads: [_; 2] = std::array::from_fn(|i| reads[i].1); - - Ok(( - i_reads, - Self::ReadRecord { - reads: reads.try_into().unwrap(), - }, - )) + pub a_ptr: F, + pub b: F, + pub c: F, + + // Will set prev_timestamp to `u32::MAX` if the read is an immediate + pub reads_aux: [MemoryReadAuxRecord; 2], + pub write_aux: MemoryWriteAuxRecord, +} + +#[derive(derive_new::new)] +pub struct AluNativeAdapterStep; + +impl AdapterTraceStep for AluNativeAdapterStep { + const WIDTH: usize = size_of::>(); + type ReadData = [F; 2]; + type WriteData = [F; 1]; + type RecordMut<'a> = &'a mut AluNativeAdapterRecord; + + #[inline(always)] + fn start(pc: u32, memory: &TracingMemory, record: &mut Self::RecordMut<'_>) { + record.from_pc = pc; + record.from_timestamp = memory.timestamp; } - fn postprocess( - &mut self, - memory: &mut MemoryController, - _instruction: &Instruction, - from_state: ExecutionState, - output: AdapterRuntimeContext, - _read_record: &Self::ReadRecord, - ) -> Result<(ExecutionState, Self::WriteRecord)> { - let Instruction { a, .. } = *_instruction; - let writes = vec![memory.write( - F::from_canonical_u32(AS::Native as u32), - a, - output.writes[0], - )]; - - Ok(( - ExecutionState { - pc: output.to_pc.unwrap_or(from_state.pc + DEFAULT_PC_STEP), - timestamp: memory.timestamp(), - }, - Self::WriteRecord { - from_state, - writes: writes.try_into().unwrap(), - }, - )) + #[inline(always)] + fn read( + &self, + memory: &mut TracingMemory, + instruction: &Instruction, + record: &mut Self::RecordMut<'_>, + ) -> Self::ReadData { + let &Instruction { b, c, e, f, .. } = instruction; + + record.b = b; + let rs1 = tracing_read_or_imm_native( + memory, + e.as_canonical_u32(), + b, + &mut record.reads_aux[0].prev_timestamp, + ); + record.c = c; + let rs2 = tracing_read_or_imm_native( + memory, + f.as_canonical_u32(), + c, + &mut record.reads_aux[1].prev_timestamp, + ); + [rs1, rs2] } - fn generate_trace_row( + #[inline(always)] + fn write( &self, - row_slice: &mut [F], - read_record: Self::ReadRecord, - write_record: Self::WriteRecord, - memory: &OfflineMemory, + memory: &mut TracingMemory, + instruction: &Instruction, + data: Self::WriteData, + record: &mut Self::RecordMut<'_>, ) { - let row_slice: &mut AluNativeAdapterCols<_> = row_slice.borrow_mut(); - let aux_cols_factory = memory.aux_cols_factory(); + let &Instruction { a, .. } = instruction; + + record.a_ptr = a; + tracing_write_native( + memory, + a.as_canonical_u32(), + data, + &mut record.write_aux.prev_timestamp, + &mut record.write_aux.prev_data, + ); + } +} - row_slice.from_state = write_record.from_state.map(F::from_canonical_u32); +impl AdapterTraceFiller for AluNativeAdapterStep { + #[inline(always)] + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, mut adapter_row: &mut [F]) { + let record: &AluNativeAdapterRecord = + unsafe { get_record_from_slice(&mut adapter_row, ()) }; + let adapter_row: &mut AluNativeAdapterCols = adapter_row.borrow_mut(); - row_slice.a_pointer = memory.record_by_id(write_record.writes[0].0).pointer; - row_slice.b_pointer = memory.record_by_id(read_record.reads[0].0).pointer; - row_slice.c_pointer = memory.record_by_id(read_record.reads[1].0).pointer; - row_slice.e_as = memory.record_by_id(read_record.reads[0].0).address_space; - row_slice.f_as = memory.record_by_id(read_record.reads[1].0).address_space; + // Writing in reverse order to avoid overwriting the `record` + adapter_row + .write_aux + .set_prev_data(record.write_aux.prev_data); + mem_helper.fill( + record.write_aux.prev_timestamp, + record.from_timestamp + 2, + adapter_row.write_aux.as_mut(), + ); - for (i, x) in read_record.reads.iter().enumerate() { - let read = memory.record_by_id(x.0); - aux_cols_factory.generate_read_or_immediate_aux(read, &mut row_slice.reads_aux[i]); + let native_as = F::from_canonical_u32(AS::Native as u32); + for ((i, read_record), read_cols) in record + .reads_aux + .iter() + .enumerate() + .zip(adapter_row.reads_aux.iter_mut()) + .rev() + { + let as_col = if i == 0 { + &mut adapter_row.e_as + } else { + &mut adapter_row.f_as + }; + // previous timestamp is u32::MAX if the read is an immediate + if read_record.prev_timestamp == u32::MAX { + read_cols.is_zero_aux = F::ZERO; + read_cols.is_immediate = F::ONE; + mem_helper.fill(0, record.from_timestamp + i as u32, read_cols.as_mut()); + *as_col = F::ZERO; + } else { + read_cols.is_zero_aux = native_as.inverse(); + read_cols.is_immediate = F::ZERO; + mem_helper.fill( + read_record.prev_timestamp, + record.from_timestamp + i as u32, + read_cols.as_mut(), + ); + *as_col = native_as; + } } - let write = memory.record_by_id(write_record.writes[0].0); - aux_cols_factory.generate_write_aux(write, &mut row_slice.write_aux); - } + adapter_row.c_pointer = record.c; + adapter_row.b_pointer = record.b; + adapter_row.a_pointer = record.a_ptr; - fn air(&self) -> &Self::Air { - &self.air + adapter_row.from_state.timestamp = F::from_canonical_u32(record.from_timestamp); + adapter_row.from_state.pc = F::from_canonical_u32(record.from_pc); } } diff --git a/extensions/native/circuit/src/adapters/branch_native_adapter.rs b/extensions/native/circuit/src/adapters/branch_native_adapter.rs index 7d3e97a6bf..4e7016694e 100644 --- a/extensions/native/circuit/src/adapters/branch_native_adapter.rs +++ b/extensions/native/circuit/src/adapters/branch_native_adapter.rs @@ -1,23 +1,23 @@ use std::{ borrow::{Borrow, BorrowMut}, - marker::PhantomData, + mem::size_of, }; use openvm_circuit::{ arch::{ - AdapterAirContext, AdapterRuntimeContext, BasicAdapterInterface, ExecutionBridge, - ExecutionBus, ExecutionState, ImmInstruction, Result, VmAdapterAir, VmAdapterChip, - VmAdapterInterface, + get_record_from_slice, AdapterAirContext, AdapterTraceFiller, AdapterTraceStep, + BasicAdapterInterface, ExecutionBridge, ExecutionState, ImmInstruction, VmAdapterAir, }, system::{ memory::{ - offline_checker::{MemoryBridge, MemoryReadOrImmediateAuxCols}, - MemoryAddress, MemoryController, OfflineMemory, + offline_checker::{MemoryBridge, MemoryReadAuxRecord, MemoryReadOrImmediateAuxCols}, + online::TracingMemory, + MemoryAddress, MemoryAuxColsFactory, }, - native_adapter::NativeReadRecord, - program::ProgramBus, + native_adapter::util::tracing_read_or_imm_native, }, }; +use openvm_circuit_primitives::AlignedBytesBorrow; use openvm_circuit_primitives_derive::AlignedBorrow; use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP}; use openvm_native_compiler::conversion::AS; @@ -27,37 +27,15 @@ use openvm_stark_backend::{ p3_field::{Field, FieldAlgebra, PrimeField32}, }; -#[derive(Debug)] -pub struct BranchNativeAdapterChip { - pub air: BranchNativeAdapterAir, - _marker: PhantomData, -} - -impl BranchNativeAdapterChip { - pub fn new( - execution_bus: ExecutionBus, - program_bus: ProgramBus, - memory_bridge: MemoryBridge, - ) -> Self { - Self { - air: BranchNativeAdapterAir { - execution_bridge: ExecutionBridge::new(execution_bus, program_bus), - memory_bridge, - }, - _marker: PhantomData, - } - } -} - #[repr(C)] -#[derive(AlignedBorrow)] +#[derive(AlignedBorrow, Debug)] pub struct BranchNativeAdapterReadCols { pub address: MemoryAddress, pub read_aux: MemoryReadOrImmediateAuxCols, } #[repr(C)] -#[derive(AlignedBorrow)] +#[derive(AlignedBorrow, Debug)] pub struct BranchNativeAdapterCols { pub from_state: ExecutionState, pub reads_aux: [BranchNativeAdapterReadCols; 2], @@ -145,71 +123,115 @@ impl VmAdapterAir for BranchNativeAdapterAir { } } -impl VmAdapterChip for BranchNativeAdapterChip { - type ReadRecord = NativeReadRecord; - type WriteRecord = ExecutionState; - type Air = BranchNativeAdapterAir; - type Interface = BasicAdapterInterface, 2, 0, 1, 1>; +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug)] +pub struct BranchNativeAdapterRecord { + pub from_pc: u32, + pub from_timestamp: u32, + + pub ptrs: [F; 2], + // Will set prev_timestamp to `u32::MAX` if the read is an immediate + pub reads_aux: [MemoryReadAuxRecord; 2], +} - fn preprocess( - &mut self, - memory: &mut MemoryController, - instruction: &Instruction, - ) -> Result<( - >::Reads, - Self::ReadRecord, - )> { - let Instruction { a, b, d, e, .. } = *instruction; - - let reads = vec![memory.read::<1>(d, a), memory.read::<1>(e, b)]; - let i_reads: [_; 2] = std::array::from_fn(|i| reads[i].1); - - Ok(( - i_reads, - Self::ReadRecord { - reads: reads.try_into().unwrap(), - }, - )) +#[derive(derive_new::new)] +pub struct BranchNativeAdapterStep; + +impl AdapterTraceStep for BranchNativeAdapterStep +where + F: PrimeField32, +{ + const WIDTH: usize = size_of::>(); + type ReadData = [F; 2]; + type WriteData = (); + type RecordMut<'a> = &'a mut BranchNativeAdapterRecord; + + #[inline(always)] + fn start(pc: u32, memory: &TracingMemory, record: &mut Self::RecordMut<'_>) { + record.from_pc = pc; + record.from_timestamp = memory.timestamp; } - fn postprocess( - &mut self, - memory: &mut MemoryController, - _instruction: &Instruction, - from_state: ExecutionState, - output: AdapterRuntimeContext, - _read_record: &Self::ReadRecord, - ) -> Result<(ExecutionState, Self::WriteRecord)> { - Ok(( - ExecutionState { - pc: output.to_pc.unwrap_or(from_state.pc + DEFAULT_PC_STEP), - timestamp: memory.timestamp(), - }, - from_state, - )) + #[inline(always)] + fn read( + &self, + memory: &mut TracingMemory, + instruction: &Instruction, + record: &mut Self::RecordMut<'_>, + ) -> Self::ReadData { + let &Instruction { a, b, d, e, .. } = instruction; + + record.ptrs[0] = a; + let rs1 = tracing_read_or_imm_native( + memory, + d.as_canonical_u32(), + a, + &mut record.reads_aux[0].prev_timestamp, + ); + record.ptrs[1] = b; + let rs2 = tracing_read_or_imm_native( + memory, + e.as_canonical_u32(), + b, + &mut record.reads_aux[1].prev_timestamp, + ); + [rs1, rs2] } - fn generate_trace_row( + #[inline(always)] + fn write( &self, - row_slice: &mut [F], - read_record: Self::ReadRecord, - write_record: Self::WriteRecord, - memory: &OfflineMemory, + _memory: &mut TracingMemory, + _instruction: &Instruction, + _data: Self::WriteData, + _record: &mut Self::RecordMut<'_>, ) { - let row_slice: &mut BranchNativeAdapterCols<_> = row_slice.borrow_mut(); - let aux_cols_factory = memory.aux_cols_factory(); - - row_slice.from_state = write_record.map(F::from_canonical_u32); - for (i, x) in read_record.reads.iter().enumerate() { - let read = memory.record_by_id(x.0); + // This adapter doesn't write anything + } +} - row_slice.reads_aux[i].address = MemoryAddress::new(read.address_space, read.pointer); - aux_cols_factory - .generate_read_or_immediate_aux(read, &mut row_slice.reads_aux[i].read_aux); +impl AdapterTraceFiller for BranchNativeAdapterStep { + #[inline(always)] + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, mut adapter_row: &mut [F]) { + let record: &BranchNativeAdapterRecord = + unsafe { get_record_from_slice(&mut adapter_row, ()) }; + let adapter_row: &mut BranchNativeAdapterCols = adapter_row.borrow_mut(); + + // Writing in reverse order to avoid overwriting the `record` + + let native_as = F::from_canonical_u32(AS::Native as u32); + for ((i, read_record), read_cols) in record + .reads_aux + .iter() + .enumerate() + .zip(adapter_row.reads_aux.iter_mut()) + .rev() + { + // previous timestamp is u32::MAX if the read is an immediate + if read_record.prev_timestamp == u32::MAX { + read_cols.read_aux.is_zero_aux = F::ZERO; + read_cols.read_aux.is_immediate = F::ONE; + mem_helper.fill( + 0, + record.from_timestamp + i as u32, + read_cols.read_aux.as_mut(), + ); + read_cols.address.pointer = record.ptrs[i]; + read_cols.address.address_space = F::ZERO; + } else { + read_cols.read_aux.is_zero_aux = native_as.inverse(); + read_cols.read_aux.is_immediate = F::ZERO; + mem_helper.fill( + read_record.prev_timestamp, + record.from_timestamp + i as u32, + read_cols.read_aux.as_mut(), + ); + read_cols.address.pointer = record.ptrs[i]; + read_cols.address.address_space = native_as; + } } - } - fn air(&self) -> &Self::Air { - &self.air + adapter_row.from_state.timestamp = F::from_canonical_u32(record.from_timestamp); + adapter_row.from_state.pc = F::from_canonical_u32(record.from_pc); } } diff --git a/extensions/native/circuit/src/adapters/convert_adapter.rs b/extensions/native/circuit/src/adapters/convert_adapter.rs index cac6d91bac..fa7f856181 100644 --- a/extensions/native/circuit/src/adapters/convert_adapter.rs +++ b/extensions/native/circuit/src/adapters/convert_adapter.rs @@ -1,71 +1,37 @@ use std::{ borrow::{Borrow, BorrowMut}, - marker::PhantomData, + mem::size_of, }; use openvm_circuit::{ arch::{ - AdapterAirContext, AdapterRuntimeContext, BasicAdapterInterface, ExecutionBridge, - ExecutionBus, ExecutionState, MinimalInstruction, Result, VmAdapterAir, VmAdapterChip, - VmAdapterInterface, + get_record_from_slice, AdapterAirContext, AdapterTraceFiller, AdapterTraceStep, + BasicAdapterInterface, ExecutionBridge, ExecutionState, MinimalInstruction, VmAdapterAir, }, system::{ memory::{ - offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols}, - MemoryAddress, MemoryController, OfflineMemory, RecordId, + offline_checker::{ + MemoryBridge, MemoryReadAuxCols, MemoryReadAuxRecord, MemoryWriteAuxCols, + MemoryWriteBytesAuxRecord, + }, + online::TracingMemory, + MemoryAddress, MemoryAuxColsFactory, }, - program::ProgramBus, + native_adapter::util::tracing_read_native, }, }; +use openvm_circuit_primitives::AlignedBytesBorrow; use openvm_circuit_primitives_derive::AlignedBorrow; -use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP}; +use openvm_instructions::{ + instruction::Instruction, program::DEFAULT_PC_STEP, riscv::RV32_MEMORY_AS, +}; use openvm_native_compiler::conversion::AS; +use openvm_rv32im_circuit::adapters::tracing_write; use openvm_stark_backend::{ interaction::InteractionBuilder, p3_air::BaseAir, p3_field::{Field, FieldAlgebra, PrimeField32}, }; -use serde::{Deserialize, Serialize}; -use serde_big_array::BigArray; - -#[repr(C)] -#[derive(Debug, Serialize, Deserialize)] -pub struct VectorReadRecord { - #[serde(with = "BigArray")] - pub reads: [RecordId; NUM_READS], -} - -#[repr(C)] -#[derive(Debug, Serialize, Deserialize)] -pub struct VectorWriteRecord { - pub from_state: ExecutionState, - pub writes: [RecordId; 1], -} - -#[allow(dead_code)] -#[derive(Debug)] -pub struct ConvertAdapterChip { - pub air: ConvertAdapterAir, - _marker: PhantomData, -} - -impl - ConvertAdapterChip -{ - pub fn new( - execution_bus: ExecutionBus, - program_bus: ProgramBus, - memory_bridge: MemoryBridge, - ) -> Self { - Self { - air: ConvertAdapterAir { - execution_bridge: ExecutionBridge::new(execution_bus, program_bus), - memory_bridge, - }, - _marker: PhantomData, - } - } -} #[repr(C)] #[derive(AlignedBorrow)] @@ -155,74 +121,107 @@ impl Vm } } -impl VmAdapterChip - for ConvertAdapterChip -{ - type ReadRecord = VectorReadRecord<1, READ_SIZE>; - type WriteRecord = VectorWriteRecord; - type Air = ConvertAdapterAir; - type Interface = BasicAdapterInterface, 1, 1, READ_SIZE, WRITE_SIZE>; - - fn preprocess( - &mut self, - memory: &mut MemoryController, - instruction: &Instruction, - ) -> Result<( - >::Reads, - Self::ReadRecord, - )> { - let Instruction { b, e, .. } = *instruction; +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug)] +pub struct ConvertAdapterRecord { + pub from_pc: u32, + pub from_timestamp: u32, + + pub a_ptr: F, + pub b_ptr: F, + + pub read_aux: MemoryReadAuxRecord, + pub write_aux: MemoryWriteBytesAuxRecord, +} - let y_val = memory.read::(e, b); +#[derive(derive_new::new)] +pub struct ConvertAdapterStep; - Ok(([y_val.1], Self::ReadRecord { reads: [y_val.0] })) +impl AdapterTraceStep + for ConvertAdapterStep +{ + const WIDTH: usize = size_of::>(); + type ReadData = [F; READ_SIZE]; + type WriteData = [u8; WRITE_SIZE]; + type RecordMut<'a> = &'a mut ConvertAdapterRecord; + + #[inline(always)] + fn start(pc: u32, memory: &TracingMemory, record: &mut Self::RecordMut<'_>) { + record.from_pc = pc; + record.from_timestamp = memory.timestamp; } - fn postprocess( - &mut self, - memory: &mut MemoryController, + #[inline(always)] + fn read( + &self, + memory: &mut TracingMemory, instruction: &Instruction, - from_state: ExecutionState, - output: AdapterRuntimeContext, - _read_record: &Self::ReadRecord, - ) -> Result<(ExecutionState, Self::WriteRecord)> { - let Instruction { a, d, .. } = *instruction; - let (write_id, _) = memory.write::(d, a, output.writes[0]); - - Ok(( - ExecutionState { - pc: output.to_pc.unwrap_or(from_state.pc + DEFAULT_PC_STEP), - timestamp: memory.timestamp(), - }, - Self::WriteRecord { - from_state, - writes: [write_id], - }, - )) + record: &mut Self::RecordMut<'_>, + ) -> Self::ReadData { + let &Instruction { b, e, .. } = instruction; + debug_assert_eq!(e.as_canonical_u32(), AS::Native as u32); + + record.b_ptr = b; + + tracing_read_native( + memory, + b.as_canonical_u32(), + &mut record.read_aux.prev_timestamp, + ) } - fn generate_trace_row( + #[inline(always)] + fn write( &self, - row_slice: &mut [F], - read_record: Self::ReadRecord, - write_record: Self::WriteRecord, - memory: &OfflineMemory, + memory: &mut TracingMemory, + instruction: &Instruction, + data: Self::WriteData, + record: &mut Self::RecordMut<'_>, ) { - let aux_cols_factory = memory.aux_cols_factory(); - let row_slice: &mut ConvertAdapterCols<_, READ_SIZE, WRITE_SIZE> = row_slice.borrow_mut(); - - let read = memory.record_by_id(read_record.reads[0]); - let write = memory.record_by_id(write_record.writes[0]); - - row_slice.from_state = write_record.from_state.map(F::from_canonical_u32); - row_slice.a_pointer = write.pointer; - row_slice.b_pointer = read.pointer; - - aux_cols_factory.generate_read_aux(read, &mut row_slice.reads_aux[0]); - aux_cols_factory.generate_write_aux(write, &mut row_slice.writes_aux[0]); + let &Instruction { a, d, .. } = instruction; + + debug_assert_eq!(d.as_canonical_u32(), RV32_MEMORY_AS); + + record.a_ptr = a; + tracing_write( + memory, + RV32_MEMORY_AS, + a.as_canonical_u32(), + data, + &mut record.write_aux.prev_timestamp, + &mut record.write_aux.prev_data, + ); } +} - fn air(&self) -> &Self::Air { - &self.air +impl + AdapterTraceFiller for ConvertAdapterStep +{ + #[inline(always)] + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, mut row_slice: &mut [F]) { + let record: &ConvertAdapterRecord = + unsafe { get_record_from_slice(&mut row_slice, ()) }; + let adapter_row: &mut ConvertAdapterCols = row_slice.borrow_mut(); + + // Writing in reverse order to avoid overwriting the `record` + mem_helper.fill( + record.read_aux.prev_timestamp, + record.from_timestamp, + adapter_row.reads_aux[0].as_mut(), + ); + + adapter_row.writes_aux[0] + .set_prev_data(record.write_aux.prev_data.map(F::from_canonical_u8)); + mem_helper.fill( + record.write_aux.prev_timestamp, + record.from_timestamp + 1, + adapter_row.writes_aux[0].as_mut(), + ); + + adapter_row.b_pointer = record.b_ptr; + adapter_row.a_pointer = record.a_ptr; + + adapter_row.from_state.timestamp = F::from_canonical_u32(record.from_timestamp); + adapter_row.from_state.pc = F::from_canonical_u32(record.from_pc); } } diff --git a/extensions/native/circuit/src/adapters/loadstore_native_adapter.rs b/extensions/native/circuit/src/adapters/loadstore_native_adapter.rs index 4bcf96d195..8aaf702519 100644 --- a/extensions/native/circuit/src/adapters/loadstore_native_adapter.rs +++ b/extensions/native/circuit/src/adapters/loadstore_native_adapter.rs @@ -5,19 +5,24 @@ use std::{ use openvm_circuit::{ arch::{ - instructions::LocalOpcode, AdapterAirContext, AdapterRuntimeContext, ExecutionBridge, - ExecutionBus, ExecutionState, Result, VmAdapterAir, VmAdapterChip, VmAdapterInterface, + get_record_from_slice, AdapterAirContext, AdapterTraceFiller, AdapterTraceStep, + ExecutionBridge, ExecutionState, VmAdapterAir, VmAdapterInterface, }, system::{ memory::{ - offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols}, - MemoryAddress, MemoryController, OfflineMemory, RecordId, + offline_checker::{ + MemoryBridge, MemoryReadAuxCols, MemoryReadAuxRecord, MemoryWriteAuxCols, + MemoryWriteAuxRecord, + }, + online::TracingMemory, + MemoryAddress, MemoryAuxColsFactory, }, - program::ProgramBus, + native_adapter::util::{tracing_read_native, tracing_write_native}, }, }; +use openvm_circuit_primitives::AlignedBytesBorrow; use openvm_circuit_primitives_derive::AlignedBorrow; -use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP}; +use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode}; use openvm_native_compiler::{ conversion::AS, NativeLoadStoreOpcode::{self, *}, @@ -27,7 +32,6 @@ use openvm_stark_backend::{ p3_air::BaseAir, p3_field::{Field, FieldAlgebra, PrimeField32}, }; -use serde::{Deserialize, Serialize}; pub struct NativeLoadStoreInstruction { pub is_valid: T, @@ -48,55 +52,6 @@ impl VmAdapterInterface type ProcessedInstruction = NativeLoadStoreInstruction; } -#[derive(Debug)] -pub struct NativeLoadStoreAdapterChip { - pub air: NativeLoadStoreAdapterAir, - offset: usize, - _marker: PhantomData, -} - -impl NativeLoadStoreAdapterChip { - pub fn new( - execution_bus: ExecutionBus, - program_bus: ProgramBus, - memory_bridge: MemoryBridge, - offset: usize, - ) -> Self { - Self { - air: NativeLoadStoreAdapterAir { - memory_bridge, - execution_bridge: ExecutionBridge::new(execution_bus, program_bus), - }, - offset, - _marker: PhantomData, - } - } -} - -#[repr(C)] -#[derive(Clone, Debug, Serialize, Deserialize)] -#[serde(bound = "F: Field")] -pub struct NativeLoadStoreReadRecord { - pub pointer_read: RecordId, - pub data_read: Option, - pub write_as: F, - pub write_ptr: F, - - pub a: F, - pub b: F, - pub c: F, - pub d: F, - pub e: F, -} - -#[repr(C)] -#[derive(Clone, Debug, Serialize, Deserialize)] -#[serde(bound = "F: Field")] -pub struct NativeLoadStoreWriteRecord { - pub from_state: ExecutionState, - pub write_id: RecordId, -} - #[repr(C)] #[derive(Clone, Debug, AlignedBorrow)] pub struct NativeLoadStoreAdapterCols { @@ -214,23 +169,49 @@ impl VmAdapterAir } } -impl VmAdapterChip - for NativeLoadStoreAdapterChip +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug)] +pub struct NativeLoadStoreAdapterRecord { + pub from_pc: u32, + pub from_timestamp: u32, + pub a: F, + pub b: F, + pub c: F, + pub write_ptr: F, + + pub ptr_read: MemoryReadAuxRecord, + // Will set `prev_timestamp` to u32::MAX if `HINT_STOREW` + pub data_read: MemoryReadAuxRecord, + pub data_write: MemoryWriteAuxRecord, +} + +#[derive(derive_new::new)] +pub struct NativeLoadStoreAdapterStep { + offset: usize, +} + +impl AdapterTraceStep + for NativeLoadStoreAdapterStep { - type ReadRecord = NativeLoadStoreReadRecord; - type WriteRecord = NativeLoadStoreWriteRecord; - type Air = NativeLoadStoreAdapterAir; - type Interface = NativeLoadStoreAdapterInterface; - - fn preprocess( - &mut self, - memory: &mut MemoryController, + const WIDTH: usize = std::mem::size_of::>(); + type ReadData = (F, [F; NUM_CELLS]); + type WriteData = [F; NUM_CELLS]; + type RecordMut<'a> = &'a mut NativeLoadStoreAdapterRecord; + + #[inline(always)] + fn start(pc: u32, memory: &TracingMemory, record: &mut Self::RecordMut<'_>) { + record.from_pc = pc; + record.from_timestamp = memory.timestamp(); + } + + #[inline(always)] + fn read( + &self, + memory: &mut TracingMemory, instruction: &Instruction, - ) -> Result<( - >::Reads, - Self::ReadRecord, - )> { - let Instruction { + record: &mut Self::RecordMut<'_>, + ) -> Self::ReadData { + let &Instruction { opcode, a, b, @@ -238,100 +219,114 @@ impl VmAdapterChip d, e, .. - } = *instruction; + } = instruction; + + debug_assert_eq!(d.as_canonical_u32(), AS::Native as u32); + debug_assert_eq!(e.as_canonical_u32(), AS::Native as u32); + let local_opcode = NativeLoadStoreOpcode::from_usize(opcode.local_opcode_idx(self.offset)); - let read_as = d; - let read_ptr = c; - let read_cell = memory.read_cell(read_as, read_ptr); + record.a = a; + record.b = b; + record.c = c; + + // Read the pointer value from memory + let [read_cell] = tracing_read_native::( + memory, + c.as_canonical_u32(), + &mut record.ptr_read.prev_timestamp, + ); + + let data_read_ptr = match local_opcode { + LOADW => read_cell + record.b, + STOREW | HINT_STOREW => record.a, + } + .as_canonical_u32(); + + // It's easier to do this here than in `write` + match local_opcode { + LOADW => record.write_ptr = record.a, + STOREW | HINT_STOREW => record.write_ptr = read_cell + record.b, + } - let (data_read_as, data_write_as) = { - match local_opcode { - LOADW => (e, d), - STOREW | HINT_STOREW => (d, e), + // Read data based on opcode + let data_read: [F; NUM_CELLS] = match local_opcode { + HINT_STOREW => { + record.data_read.prev_timestamp = u32::MAX; + [F::ZERO; NUM_CELLS] } - }; - let (data_read_ptr, data_write_ptr) = { - match local_opcode { - LOADW => (read_cell.1 + b, a), - STOREW | HINT_STOREW => (a, read_cell.1 + b), + LOADW | STOREW => { + tracing_read_native(memory, data_read_ptr, &mut record.data_read.prev_timestamp) } }; - let data_read = match local_opcode { - HINT_STOREW => None, - LOADW | STOREW => Some(memory.read::(data_read_as, data_read_ptr)), - }; - let record = NativeLoadStoreReadRecord { - pointer_read: read_cell.0, - data_read: data_read.map(|x| x.0), - write_as: data_write_as, - write_ptr: data_write_ptr, - a, - b, - c, - d, - e, - }; - - Ok(( - (read_cell.1, data_read.map_or([F::ZERO; NUM_CELLS], |x| x.1)), - record, - )) + (read_cell, data_read) } - fn postprocess( - &mut self, - memory: &mut MemoryController, + #[inline(always)] + fn write( + &self, + memory: &mut TracingMemory, _instruction: &Instruction, - from_state: ExecutionState, - output: AdapterRuntimeContext, - read_record: &Self::ReadRecord, - ) -> Result<(ExecutionState, Self::WriteRecord)> { - let (write_id, _) = - memory.write::(read_record.write_as, read_record.write_ptr, output.writes); - Ok(( - ExecutionState { - pc: output.to_pc.unwrap_or(from_state.pc + DEFAULT_PC_STEP), - timestamp: memory.timestamp(), - }, - Self::WriteRecord { - from_state: from_state.map(F::from_canonical_u32), - write_id, - }, - )) + data: Self::WriteData, + record: &mut Self::RecordMut<'_>, + ) { + // Write data to memory + tracing_write_native( + memory, + record.write_ptr.as_canonical_u32(), + data, + &mut record.data_write.prev_timestamp, + &mut record.data_write.prev_data, + ); } +} - fn generate_trace_row( - &self, - row_slice: &mut [F], - read_record: Self::ReadRecord, - write_record: Self::WriteRecord, - memory: &OfflineMemory, - ) { - let aux_cols_factory = memory.aux_cols_factory(); - let cols: &mut NativeLoadStoreAdapterCols<_, NUM_CELLS> = row_slice.borrow_mut(); - cols.from_state = write_record.from_state; - cols.a = read_record.a; - cols.b = read_record.b; - cols.c = read_record.c; - - let data_read = read_record.data_read.map(|read| memory.record_by_id(read)); - if let Some(data_read) = data_read { - aux_cols_factory.generate_read_aux(data_read, &mut cols.data_read_aux_cols); - } +impl AdapterTraceFiller + for NativeLoadStoreAdapterStep +{ + #[inline(always)] + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, mut adapter_row: &mut [F]) { + let record: &NativeLoadStoreAdapterRecord = + unsafe { get_record_from_slice(&mut adapter_row, ()) }; + let adapter_row: &mut NativeLoadStoreAdapterCols = adapter_row.borrow_mut(); + + // Writing in reverse order to avoid overwriting the `record` + + let is_hint_storew = record.data_read.prev_timestamp == u32::MAX; + + adapter_row + .data_write_aux_cols + .set_prev_data(record.data_write.prev_data); + // Note, if `HINT_STOREW` we didn't do a data read and we didn't update the timestamp + mem_helper.fill( + record.data_write.prev_timestamp, + record.from_timestamp + 2 - is_hint_storew as u32, + adapter_row.data_write_aux_cols.as_mut(), + ); - let write = memory.record_by_id(write_record.write_id); - cols.data_write_pointer = write.pointer; + if !is_hint_storew { + mem_helper.fill( + record.data_read.prev_timestamp, + record.from_timestamp + 1, + adapter_row.data_read_aux_cols.as_mut(), + ); + } else { + mem_helper.fill_zero(adapter_row.data_read_aux_cols.as_mut()); + } - aux_cols_factory.generate_read_aux( - memory.record_by_id(read_record.pointer_read), - &mut cols.pointer_read_aux_cols, + mem_helper.fill( + record.ptr_read.prev_timestamp, + record.from_timestamp, + adapter_row.pointer_read_aux_cols.as_mut(), ); - aux_cols_factory.generate_write_aux(write, &mut cols.data_write_aux_cols); - } - fn air(&self) -> &Self::Air { - &self.air + adapter_row.data_write_pointer = record.write_ptr; + adapter_row.c = record.c; + adapter_row.b = record.b; + adapter_row.a = record.a; + + adapter_row.from_state.pc = F::from_canonical_u32(record.from_pc); + adapter_row.from_state.timestamp = F::from_canonical_u32(record.from_timestamp); } } diff --git a/extensions/native/circuit/src/adapters/mod.rs b/extensions/native/circuit/src/adapters/mod.rs index c5cd3b9422..308a0705a3 100644 --- a/extensions/native/circuit/src/adapters/mod.rs +++ b/extensions/native/circuit/src/adapters/mod.rs @@ -6,3 +6,9 @@ pub mod convert_adapter; pub mod loadstore_native_adapter; // 2 reads, 1 write, read size = write size = N, no imm support, read/write to address space d pub mod native_vectorized_adapter; + +pub use alu_native_adapter::*; +pub use branch_native_adapter::*; +pub use convert_adapter::*; +pub use loadstore_native_adapter::*; +pub use native_vectorized_adapter::*; diff --git a/extensions/native/circuit/src/adapters/native_vectorized_adapter.rs b/extensions/native/circuit/src/adapters/native_vectorized_adapter.rs index c151197297..ae2c37391a 100644 --- a/extensions/native/circuit/src/adapters/native_vectorized_adapter.rs +++ b/extensions/native/circuit/src/adapters/native_vectorized_adapter.rs @@ -1,22 +1,26 @@ use std::{ borrow::{Borrow, BorrowMut}, - marker::PhantomData, + mem::size_of, }; use openvm_circuit::{ arch::{ - AdapterAirContext, AdapterRuntimeContext, BasicAdapterInterface, ExecutionBridge, - ExecutionBus, ExecutionState, MinimalInstruction, Result, VmAdapterAir, VmAdapterChip, - VmAdapterInterface, + get_record_from_slice, AdapterAirContext, AdapterTraceFiller, AdapterTraceStep, + BasicAdapterInterface, ExecutionBridge, ExecutionState, MinimalInstruction, VmAdapterAir, }, system::{ memory::{ - offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols}, - MemoryAddress, MemoryController, OfflineMemory, RecordId, + offline_checker::{ + MemoryBridge, MemoryReadAuxCols, MemoryReadAuxRecord, MemoryWriteAuxCols, + MemoryWriteAuxRecord, + }, + online::TracingMemory, + MemoryAddress, MemoryAuxColsFactory, }, - program::ProgramBus, + native_adapter::util::{tracing_read_native, tracing_write_native}, }, }; +use openvm_circuit_primitives::AlignedBytesBorrow; use openvm_circuit_primitives_derive::AlignedBorrow; use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP}; use openvm_native_compiler::conversion::AS; @@ -25,44 +29,6 @@ use openvm_stark_backend::{ p3_air::BaseAir, p3_field::{Field, FieldAlgebra, PrimeField32}, }; -use serde::{Deserialize, Serialize}; - -#[allow(dead_code)] -#[derive(Debug)] -pub struct NativeVectorizedAdapterChip { - pub air: NativeVectorizedAdapterAir, - _marker: PhantomData, -} - -impl NativeVectorizedAdapterChip { - pub fn new( - execution_bus: ExecutionBus, - program_bus: ProgramBus, - memory_bridge: MemoryBridge, - ) -> Self { - Self { - air: NativeVectorizedAdapterAir { - execution_bridge: ExecutionBridge::new(execution_bus, program_bus), - memory_bridge, - }, - _marker: PhantomData, - } - } -} - -#[repr(C)] -#[derive(Debug, Serialize, Deserialize)] -pub struct NativeVectorizedReadRecord { - pub b: RecordId, - pub c: RecordId, -} - -#[repr(C)] -#[derive(Debug, Serialize, Deserialize)] -pub struct NativeVectorizedWriteRecord { - pub from_state: ExecutionState, - pub a: RecordId, -} #[repr(C)] #[derive(AlignedBorrow)] @@ -156,80 +122,121 @@ impl VmAdapterAir for NativeVectoriz } } -impl VmAdapterChip for NativeVectorizedAdapterChip { - type ReadRecord = NativeVectorizedReadRecord; - type WriteRecord = NativeVectorizedWriteRecord; - type Air = NativeVectorizedAdapterAir; - type Interface = BasicAdapterInterface, 2, 1, N, N>; +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug)] +pub struct NativeVectorizedAdapterRecord { + pub from_pc: u32, + pub from_timestamp: u32, + pub a_ptr: F, + pub b_ptr: F, + pub c_ptr: F, + pub reads_aux: [MemoryReadAuxRecord; 2], + pub write_aux: MemoryWriteAuxRecord, +} - fn preprocess( - &mut self, - memory: &mut MemoryController, - instruction: &Instruction, - ) -> Result<( - >::Reads, - Self::ReadRecord, - )> { - let Instruction { b, c, d, e, .. } = *instruction; - - let y_val = memory.read::(d, b); - let z_val = memory.read::(e, c); - - Ok(( - [y_val.1, z_val.1], - Self::ReadRecord { - b: y_val.0, - c: z_val.0, - }, - )) +#[derive(derive_new::new)] +pub struct NativeVectorizedAdapterStep; + +impl AdapterTraceStep + for NativeVectorizedAdapterStep +{ + const WIDTH: usize = size_of::>(); + type ReadData = [[F; N]; 2]; + type WriteData = [F; N]; + type RecordMut<'a> = &'a mut NativeVectorizedAdapterRecord; + + #[inline(always)] + fn start(pc: u32, memory: &TracingMemory, record: &mut Self::RecordMut<'_>) { + record.from_pc = pc; + record.from_timestamp = memory.timestamp(); } - fn postprocess( - &mut self, - memory: &mut MemoryController, + #[inline(always)] + fn read( + &self, + memory: &mut TracingMemory, instruction: &Instruction, - from_state: ExecutionState, - output: AdapterRuntimeContext, - _read_record: &Self::ReadRecord, - ) -> Result<(ExecutionState, Self::WriteRecord)> { - let Instruction { a, d, .. } = *instruction; - let (a_val, _) = memory.write(d, a, output.writes[0]); - - Ok(( - ExecutionState { - pc: output.to_pc.unwrap_or(from_state.pc + DEFAULT_PC_STEP), - timestamp: memory.timestamp(), - }, - Self::WriteRecord { - from_state, - a: a_val, - }, - )) + record: &mut Self::RecordMut<'_>, + ) -> Self::ReadData { + let &Instruction { b, c, d, e, .. } = instruction; + debug_assert_eq!(d.as_canonical_u32(), AS::Native as u32); + debug_assert_eq!(e.as_canonical_u32(), AS::Native as u32); + + record.b_ptr = b; + let b_val = tracing_read_native( + memory, + b.as_canonical_u32(), + &mut record.reads_aux[0].prev_timestamp, + ); + record.c_ptr = c; + let c_val = tracing_read_native( + memory, + c.as_canonical_u32(), + &mut record.reads_aux[1].prev_timestamp, + ); + + [b_val, c_val] } - fn generate_trace_row( + #[inline(always)] + fn write( &self, - row_slice: &mut [F], - read_record: Self::ReadRecord, - write_record: Self::WriteRecord, - memory: &OfflineMemory, + memory: &mut TracingMemory, + instruction: &Instruction, + data: Self::WriteData, + record: &mut Self::RecordMut<'_>, ) { - let aux_cols_factory = memory.aux_cols_factory(); - let row_slice: &mut NativeVectorizedAdapterCols<_, N> = row_slice.borrow_mut(); - - let b_record = memory.record_by_id(read_record.b); - let c_record = memory.record_by_id(read_record.c); - let a_record = memory.record_by_id(write_record.a); - row_slice.from_state = write_record.from_state.map(F::from_canonical_u32); - row_slice.a_pointer = a_record.pointer; - row_slice.b_pointer = b_record.pointer; - row_slice.c_pointer = c_record.pointer; - aux_cols_factory.generate_read_aux(b_record, &mut row_slice.reads_aux[0]); - aux_cols_factory.generate_read_aux(c_record, &mut row_slice.reads_aux[1]); - aux_cols_factory.generate_write_aux(a_record, &mut row_slice.writes_aux[0]); + let &Instruction { a, d, .. } = instruction; + + debug_assert_eq!(d.as_canonical_u32(), AS::Native as u32); + + record.a_ptr = a; + tracing_write_native( + memory, + a.as_canonical_u32(), + data, + &mut record.write_aux.prev_timestamp, + &mut record.write_aux.prev_data, + ); } +} - fn air(&self) -> &Self::Air { - &self.air +impl AdapterTraceFiller + for NativeVectorizedAdapterStep +{ + #[inline(always)] + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, mut adapter_row: &mut [F]) { + let record: &NativeVectorizedAdapterRecord = + unsafe { get_record_from_slice(&mut adapter_row, ()) }; + let adapter_row: &mut NativeVectorizedAdapterCols = adapter_row.borrow_mut(); + + // Writing in reverse order to avoid overwriting the `record` + adapter_row.writes_aux[0].set_prev_data(record.write_aux.prev_data); + mem_helper.fill( + record.write_aux.prev_timestamp, + record.from_timestamp + 2, + adapter_row.writes_aux[0].as_mut(), + ); + + adapter_row + .reads_aux + .iter_mut() + .enumerate() + .zip(record.reads_aux.iter()) + .rev() + .for_each(|((i, read_cols), read_record)| { + mem_helper.fill( + read_record.prev_timestamp, + record.from_timestamp + i as u32, + read_cols.as_mut(), + ); + }); + + adapter_row.c_pointer = record.c_ptr; + adapter_row.b_pointer = record.b_ptr; + adapter_row.a_pointer = record.a_ptr; + + adapter_row.from_state.timestamp = F::from_canonical_u32(record.from_timestamp); + adapter_row.from_state.pc = F::from_canonical_u32(record.from_pc); } } diff --git a/extensions/native/circuit/src/branch_eq/core.rs b/extensions/native/circuit/src/branch_eq/core.rs new file mode 100644 index 0000000000..5e93eb5407 --- /dev/null +++ b/extensions/native/circuit/src/branch_eq/core.rs @@ -0,0 +1,320 @@ +use std::borrow::{Borrow, BorrowMut}; + +use openvm_circuit::{ + arch::{ + execution_mode::{E1ExecutionCtx, E2ExecutionCtx}, + get_record_from_slice, AdapterTraceFiller, AdapterTraceStep, E2PreCompute, + EmptyAdapterCoreLayout, ExecuteFunc, RecordArena, Result, StepExecutorE1, StepExecutorE2, + TraceFiller, TraceStep, VmSegmentState, VmStateMut, + }, + system::memory::{online::TracingMemory, MemoryAuxColsFactory}, + utils::{transmute_field_to_u32, transmute_u32_to_field}, +}; +use openvm_circuit_primitives::AlignedBytesBorrow; +use openvm_instructions::{ + instruction::Instruction, program::DEFAULT_PC_STEP, riscv::RV32_IMM_AS, LocalOpcode, NATIVE_AS, +}; +use openvm_native_compiler::NativeBranchEqualOpcode; +use openvm_rv32im_circuit::BranchEqualCoreCols; +use openvm_rv32im_transpiler::BranchEqualOpcode; +use openvm_stark_backend::p3_field::PrimeField32; + +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug)] +pub struct NativeBranchEqualCoreRecord { + pub a: F, + pub b: F, + pub imm: F, + pub is_beq: bool, +} + +#[derive(derive_new::new)] + +pub struct NativeBranchEqualStep { + adapter: A, + pub offset: usize, + pub pc_step: u32, +} + +impl TraceStep for NativeBranchEqualStep +where + F: PrimeField32, + A: 'static + AdapterTraceStep, WriteData = ()>, +{ + type RecordLayout = EmptyAdapterCoreLayout; + type RecordMut<'a> = (A::RecordMut<'a>, &'a mut NativeBranchEqualCoreRecord); + + fn get_opcode_name(&self, opcode: usize) -> String { + format!( + "{:?}", + NativeBranchEqualOpcode::from_usize(opcode - self.offset) + ) + } + + fn execute<'buf, RA>( + &mut self, + state: VmStateMut, CTX>, + instruction: &Instruction, + arena: &'buf mut RA, + ) -> Result<()> + where + RA: RecordArena<'buf, Self::RecordLayout, Self::RecordMut<'buf>>, + { + let &Instruction { opcode, c: imm, .. } = instruction; + let (mut adapter_record, core_record) = arena.alloc(EmptyAdapterCoreLayout::new()); + + A::start(*state.pc, state.memory, &mut adapter_record); + + [core_record.a, core_record.b] = self + .adapter + .read(state.memory, instruction, &mut adapter_record) + .into(); + + let cmp_result = core_record.a == core_record.b; + + core_record.imm = imm; + core_record.is_beq = + opcode.local_opcode_idx(self.offset) == BranchEqualOpcode::BEQ as usize; + + if cmp_result == core_record.is_beq { + *state.pc = (F::from_canonical_u32(*state.pc) + imm).as_canonical_u32(); + } else { + *state.pc = state.pc.wrapping_add(self.pc_step); + } + + Ok(()) + } +} + +impl TraceFiller for NativeBranchEqualStep +where + F: PrimeField32, + A: 'static + AdapterTraceFiller, +{ + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + self.adapter.fill_trace_row(mem_helper, adapter_row); + let record: &NativeBranchEqualCoreRecord = + unsafe { get_record_from_slice(&mut core_row, ()) }; + let core_row: &mut BranchEqualCoreCols = core_row.borrow_mut(); + let (cmp_result, diff_inv_val) = run_eq(record.is_beq, record.a, record.b); + + // Writing in reverse order to avoid overwriting the `record` + core_row.diff_inv_marker[0] = diff_inv_val; + + core_row.opcode_bne_flag = F::from_bool(!record.is_beq); + core_row.opcode_beq_flag = F::from_bool(record.is_beq); + + core_row.imm = record.imm; + core_row.cmp_result = F::from_bool(cmp_result); + + core_row.b = [record.b]; + core_row.a = [record.a]; + } +} + +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct NativeBranchEqualPreCompute { + imm: isize, + a_or_imm: u32, + b_or_imm: u32, +} + +impl NativeBranchEqualStep { + #[inline(always)] + fn pre_compute_impl( + &self, + _pc: u32, + inst: &Instruction, + data: &mut NativeBranchEqualPreCompute, + ) -> Result<(bool, bool, bool)> { + let &Instruction { + opcode, + a, + b, + c, + d, + e, + .. + } = inst; + let local_opcode = BranchEqualOpcode::from_usize(opcode.local_opcode_idx(self.offset)); + let c = c.as_canonical_u32(); + let imm = if F::ORDER_U32 - c < c { + -((F::ORDER_U32 - c) as isize) + } else { + c as isize + }; + let d = d.as_canonical_u32(); + let e = e.as_canonical_u32(); + + let a_is_imm = d == RV32_IMM_AS; + let b_is_imm = e == RV32_IMM_AS; + + let a_or_imm = if a_is_imm { + transmute_field_to_u32(&a) + } else { + a.as_canonical_u32() + }; + let b_or_imm = if b_is_imm { + transmute_field_to_u32(&b) + } else { + b.as_canonical_u32() + }; + + *data = NativeBranchEqualPreCompute { + imm, + a_or_imm, + b_or_imm, + }; + + let is_bne = local_opcode == BranchEqualOpcode::BNE; + + Ok((a_is_imm, b_is_imm, is_bne)) + } +} + +impl StepExecutorE1 for NativeBranchEqualStep +where + F: PrimeField32, +{ + #[inline(always)] + fn pre_compute_size(&self) -> usize { + size_of::() + } + + #[inline(always)] + fn pre_compute_e1( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result> { + let pre_compute: &mut NativeBranchEqualPreCompute = data.borrow_mut(); + + let (a_is_imm, b_is_imm, is_bne) = self.pre_compute_impl(pc, inst, pre_compute)?; + + let fn_ptr = match (a_is_imm, b_is_imm, is_bne) { + (true, true, true) => execute_e1_impl::<_, _, true, true, true>, + (true, true, false) => execute_e1_impl::<_, _, true, true, false>, + (true, false, true) => execute_e1_impl::<_, _, true, false, true>, + (true, false, false) => execute_e1_impl::<_, _, true, false, false>, + (false, true, true) => execute_e1_impl::<_, _, false, true, true>, + (false, true, false) => execute_e1_impl::<_, _, false, true, false>, + (false, false, true) => execute_e1_impl::<_, _, false, false, true>, + (false, false, false) => execute_e1_impl::<_, _, false, false, false>, + }; + + Ok(fn_ptr) + } +} + +impl StepExecutorE2 for NativeBranchEqualStep +where + F: PrimeField32, +{ + #[inline(always)] + fn e2_pre_compute_size(&self) -> usize { + size_of::>() + } + + #[inline(always)] + fn pre_compute_e2( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result> { + let pre_compute: &mut E2PreCompute = data.borrow_mut(); + pre_compute.chip_idx = chip_idx as u32; + + let (a_is_imm, b_is_imm, is_bne) = + self.pre_compute_impl(pc, inst, &mut pre_compute.data)?; + + let fn_ptr = match (a_is_imm, b_is_imm, is_bne) { + (true, true, true) => execute_e2_impl::<_, _, true, true, true>, + (true, true, false) => execute_e2_impl::<_, _, true, true, false>, + (true, false, true) => execute_e2_impl::<_, _, true, false, true>, + (true, false, false) => execute_e2_impl::<_, _, true, false, false>, + (false, true, true) => execute_e2_impl::<_, _, false, true, true>, + (false, true, false) => execute_e2_impl::<_, _, false, true, false>, + (false, false, true) => execute_e2_impl::<_, _, false, false, true>, + (false, false, false) => execute_e2_impl::<_, _, false, false, false>, + }; + + Ok(fn_ptr) + } +} + +unsafe fn execute_e1_impl< + F: PrimeField32, + CTX: E1ExecutionCtx, + const A_IS_IMM: bool, + const B_IS_IMM: bool, + const IS_NE: bool, +>( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &NativeBranchEqualPreCompute = pre_compute.borrow(); + execute_e12_impl::<_, _, A_IS_IMM, B_IS_IMM, IS_NE>(pre_compute, vm_state); +} + +unsafe fn execute_e2_impl< + F: PrimeField32, + CTX: E2ExecutionCtx, + const A_IS_IMM: bool, + const B_IS_IMM: bool, + const IS_NE: bool, +>( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &E2PreCompute = pre_compute.borrow(); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, 1); + execute_e12_impl::<_, _, A_IS_IMM, B_IS_IMM, IS_NE>(&pre_compute.data, vm_state); +} + +#[inline(always)] +unsafe fn execute_e12_impl< + F: PrimeField32, + CTX: E1ExecutionCtx, + const A_IS_IMM: bool, + const B_IS_IMM: bool, + const IS_NE: bool, +>( + pre_compute: &NativeBranchEqualPreCompute, + vm_state: &mut VmSegmentState, +) { + let rs1 = if A_IS_IMM { + transmute_u32_to_field(&pre_compute.a_or_imm) + } else { + vm_state.vm_read::(NATIVE_AS, pre_compute.a_or_imm)[0] + }; + let rs2 = if B_IS_IMM { + transmute_u32_to_field(&pre_compute.b_or_imm) + } else { + vm_state.vm_read::(NATIVE_AS, pre_compute.b_or_imm)[0] + }; + if (rs1 == rs2) ^ IS_NE { + vm_state.pc = (vm_state.pc as isize + pre_compute.imm) as u32; + } else { + vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP); + } + vm_state.instret += 1; +} + +// Returns (cmp_result, diff_idx, x[diff_idx] - y[diff_idx]) +#[inline(always)] +pub(super) fn run_eq(is_beq: bool, x: F, y: F) -> (bool, F) +where + F: PrimeField32, +{ + if x != y { + return (!is_beq, (x - y).inverse()); + } + (is_beq, F::ZERO) +} diff --git a/extensions/native/circuit/src/branch_eq/mod.rs b/extensions/native/circuit/src/branch_eq/mod.rs index e1b566bb7f..214dc50005 100644 --- a/extensions/native/circuit/src/branch_eq/mod.rs +++ b/extensions/native/circuit/src/branch_eq/mod.rs @@ -1,8 +1,15 @@ -use openvm_circuit::arch::{VmAirWrapper, VmChipWrapper}; -use openvm_rv32im_circuit::{BranchEqualCoreAir, BranchEqualCoreChip}; +use openvm_circuit::arch::{MatrixRecordArena, NewVmChipWrapper, VmAirWrapper}; +use openvm_rv32im_circuit::BranchEqualCoreAir; -use super::adapters::branch_native_adapter::{BranchNativeAdapterAir, BranchNativeAdapterChip}; +mod core; +pub use core::*; + +use crate::adapters::{BranchNativeAdapterAir, BranchNativeAdapterStep}; + +#[cfg(test)] +mod tests; pub type NativeBranchEqAir = VmAirWrapper>; +pub type NativeBranchEqStep = NativeBranchEqualStep; pub type NativeBranchEqChip = - VmChipWrapper, BranchEqualCoreChip<1>>; + NewVmChipWrapper>; diff --git a/extensions/native/circuit/src/branch_eq/tests.rs b/extensions/native/circuit/src/branch_eq/tests.rs new file mode 100644 index 0000000000..866f2016a4 --- /dev/null +++ b/extensions/native/circuit/src/branch_eq/tests.rs @@ -0,0 +1,324 @@ +use std::borrow::BorrowMut; + +use openvm_circuit::arch::testing::VmChipTestBuilder; +use openvm_instructions::{ + instruction::Instruction, + program::{DEFAULT_PC_STEP, PC_BITS}, + utils::isize_to_field, + LocalOpcode, +}; +use openvm_native_compiler::NativeBranchEqualOpcode; +use openvm_rv32im_circuit::{ + adapters::RV_B_TYPE_IMM_BITS, BranchEqualCoreAir, BranchEqualCoreCols, +}; +use openvm_rv32im_transpiler::BranchEqualOpcode; +use openvm_stark_backend::{ + p3_air::BaseAir, + p3_field::{FieldAlgebra, PrimeField32}, + p3_matrix::{ + dense::{DenseMatrix, RowMajorMatrix}, + Matrix, + }, + utils::disable_debug_builder, + verifier::VerificationError, +}; +use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; +use rand::{rngs::StdRng, Rng}; +use test_case::test_case; + +use crate::{ + adapters::{BranchNativeAdapterAir, BranchNativeAdapterStep}, + branch_eq::{run_eq, NativeBranchEqAir, NativeBranchEqChip, NativeBranchEqStep}, + test_utils::write_native_or_imm, +}; + +type F = BabyBear; +const MAX_INS_CAPACITY: usize = 128; +const ABS_MAX_IMM: i32 = 1 << (RV_B_TYPE_IMM_BITS - 1); + +fn create_test_chip(tester: &mut VmChipTestBuilder) -> NativeBranchEqChip { + let mut chip = NativeBranchEqChip::::new( + NativeBranchEqAir::new( + BranchNativeAdapterAir::new(tester.execution_bridge(), tester.memory_bridge()), + BranchEqualCoreAir::new(NativeBranchEqualOpcode::CLASS_OFFSET, DEFAULT_PC_STEP), + ), + NativeBranchEqStep::new( + BranchNativeAdapterStep::new(), + NativeBranchEqualOpcode::CLASS_OFFSET, + DEFAULT_PC_STEP, + ), + tester.memory_helper(), + ); + chip.set_trace_buffer_height(MAX_INS_CAPACITY); + + chip +} + +#[allow(clippy::too_many_arguments)] +fn set_and_execute( + tester: &mut VmChipTestBuilder, + chip: &mut NativeBranchEqChip, + rng: &mut StdRng, + opcode: NativeBranchEqualOpcode, + a: Option, + b: Option, + imm: Option, +) { + let a_val = a.unwrap_or(rng.gen()); + let b_val = b.unwrap_or(if rng.gen_bool(0.5) { a_val } else { rng.gen() }); + let imm = imm.unwrap_or(rng.gen_range((-ABS_MAX_IMM)..ABS_MAX_IMM)); + let (a, a_as) = write_native_or_imm(tester, rng, a_val, None); + let (b, b_as) = write_native_or_imm(tester, rng, b_val, None); + let initial_pc = rng.gen_range(imm.unsigned_abs()..(1 << (PC_BITS - 1)) - imm.unsigned_abs()); + + tester.execute_with_pc( + chip, + &Instruction::new( + opcode.global_opcode(), + a, + b, + isize_to_field::(imm as isize), + F::from_canonical_usize(a_as), + F::from_canonical_usize(b_as), + F::ZERO, + F::ZERO, + ), + initial_pc, + ); + + let cmp_result = run_eq(opcode.0 == BranchEqualOpcode::BEQ, a_val, b_val).0; + let from_pc = tester.execution.last_from_pc().as_canonical_u32() as i32; + let to_pc = tester.execution.last_to_pc().as_canonical_u32() as i32; + let pc_inc = if cmp_result { + imm + } else { + DEFAULT_PC_STEP as i32 + }; + + assert_eq!(to_pc, from_pc + pc_inc); +} + +////////////////////////////////////////////////////////////////////////////////////// +// POSITIVE TESTS +// +// Randomly generate computations and execute, ensuring that the generated trace +// passes all constraints. +////////////////////////////////////////////////////////////////////////////////////// + +#[test_case(BranchEqualOpcode::BEQ, 100)] +#[test_case(BranchEqualOpcode::BNE, 100)] +fn rand_rv32_branch_eq_test(opcode: BranchEqualOpcode, num_ops: usize) { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default_native(); + let mut chip = create_test_chip(&mut tester); + let opcode = NativeBranchEqualOpcode(opcode); + for _ in 0..num_ops { + set_and_execute(&mut tester, &mut chip, &mut rng, opcode, None, None, None); + } + + let tester = tester.build().load(chip).finalize(); + tester.simple_test().expect("Verification failed"); +} + +////////////////////////////////////////////////////////////////////////////////////// +// NEGATIVE TESTS +// +// Given a fake trace of a single operation, setup a chip and run the test. We replace +// part of the trace and check that the chip throws the expected error. +////////////////////////////////////////////////////////////////////////////////////// + +#[allow(clippy::too_many_arguments)] +fn run_negative_branch_eq_test( + opcode: BranchEqualOpcode, + a: F, + b: F, + prank_cmp_result: Option, + prank_diff_inv_marker: Option, + error: VerificationError, +) { + let imm = 16i32; + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default_native(); + let mut chip = create_test_chip(&mut tester); + + set_and_execute( + &mut tester, + &mut chip, + &mut rng, + NativeBranchEqualOpcode(opcode), + Some(a), + Some(b), + Some(imm), + ); + + let adapter_width = BaseAir::::width(&chip.air.adapter); + let modify_trace = |trace: &mut DenseMatrix| { + let mut values = trace.row_slice(0).to_vec(); + let cols: &mut BranchEqualCoreCols = + values.split_at_mut(adapter_width).1.borrow_mut(); + if let Some(cmp_result) = prank_cmp_result { + cols.cmp_result = F::from_bool(cmp_result); + } + if let Some(diff_inv_marker) = prank_diff_inv_marker { + cols.diff_inv_marker = [diff_inv_marker]; + } + *trace = RowMajorMatrix::new(values, trace.width()); + }; + + disable_debug_builder(); + let tester = tester + .build() + .load_and_prank_trace(chip, modify_trace) + .finalize(); + tester.simple_test_with_expected_error(error); +} + +#[test] +fn rv32_beq_wrong_cmp_negative_test() { + run_negative_branch_eq_test( + BranchEqualOpcode::BEQ, + F::from_canonical_u32(7 << 16), + F::from_canonical_u32(7 << 24), + Some(true), + None, + VerificationError::OodEvaluationMismatch, + ); + + run_negative_branch_eq_test( + BranchEqualOpcode::BEQ, + F::from_canonical_u32(7 << 16), + F::from_canonical_u32(7 << 16), + Some(false), + None, + VerificationError::OodEvaluationMismatch, + ); +} + +#[test] +fn rv32_beq_zero_inv_marker_negative_test() { + run_negative_branch_eq_test( + BranchEqualOpcode::BEQ, + F::from_canonical_u32(7 << 16), + F::from_canonical_u32(7 << 24), + Some(true), + Some(F::ZERO), + VerificationError::OodEvaluationMismatch, + ); +} + +#[test] +fn rv32_beq_invalid_inv_marker_negative_test() { + run_negative_branch_eq_test( + BranchEqualOpcode::BEQ, + F::from_canonical_u32(7 << 16), + F::from_canonical_u32(7 << 24), + Some(false), + Some(F::from_canonical_u32(1 << 16)), + VerificationError::OodEvaluationMismatch, + ); +} + +#[test] +fn rv32_bne_wrong_cmp_negative_test() { + run_negative_branch_eq_test( + BranchEqualOpcode::BNE, + F::from_canonical_u32(7 << 16), + F::from_canonical_u32(7 << 24), + Some(false), + None, + VerificationError::OodEvaluationMismatch, + ); + + run_negative_branch_eq_test( + BranchEqualOpcode::BNE, + F::from_canonical_u32(7 << 16), + F::from_canonical_u32(7 << 16), + Some(true), + None, + VerificationError::OodEvaluationMismatch, + ); +} + +#[test] +fn rv32_bne_zero_inv_marker_negative_test() { + run_negative_branch_eq_test( + BranchEqualOpcode::BNE, + F::from_canonical_u32(7 << 16), + F::from_canonical_u32(7 << 24), + Some(false), + Some(F::ZERO), + VerificationError::OodEvaluationMismatch, + ); +} + +#[test] +fn rv32_bne_invalid_inv_marker_negative_test() { + run_negative_branch_eq_test( + BranchEqualOpcode::BNE, + F::from_canonical_u32(7 << 16), + F::from_canonical_u32(7 << 24), + Some(true), + Some(F::from_canonical_u32(1 << 16)), + VerificationError::OodEvaluationMismatch, + ); +} + +/////////////////////////////////////////////////////////////////////////////////////// +/// SANITY TESTS +/// +/// Ensure that solve functions produce the correct results. +/////////////////////////////////////////////////////////////////////////////////////// + +#[test] +fn execute_roundtrip_sanity_test() { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default_native(); + let mut chip = create_test_chip(&mut tester); + + let x = F::from_canonical_u32(u32::from_le_bytes([19, 4, 179, 60])); + let y = F::from_canonical_u32(u32::from_le_bytes([19, 32, 180, 60])); + set_and_execute( + &mut tester, + &mut chip, + &mut rng, + NativeBranchEqualOpcode(BranchEqualOpcode::BEQ), + Some(x), + Some(y), + Some(8), + ); + + set_and_execute( + &mut tester, + &mut chip, + &mut rng, + NativeBranchEqualOpcode(BranchEqualOpcode::BNE), + Some(x), + Some(y), + Some(8), + ); +} + +#[test] +fn run_eq_sanity_test() { + let x = F::from_canonical_u32(u32::from_le_bytes([19, 4, 17, 60])); + let (cmp_result, diff_val) = run_eq(true, x, x); + assert!(cmp_result); + assert_eq!(diff_val, F::ZERO); + + let (cmp_result, diff_val) = run_eq(false, x, x); + assert!(!cmp_result); + assert_eq!(diff_val, F::ZERO); +} + +#[test] +fn run_ne_sanity_test() { + let x = F::from_canonical_u32(u32::from_le_bytes([19, 4, 17, 60])); + let y = F::from_canonical_u32(u32::from_le_bytes([19, 32, 18, 60])); + let (cmp_result, diff_val) = run_eq(true, x, y); + assert!(!cmp_result); + assert_eq!(diff_val * (x - y), F::ONE); + + let (cmp_result, diff_val) = run_eq(false, x, y); + assert!(cmp_result); + assert_eq!(diff_val * (x - y), F::ONE); +} diff --git a/extensions/native/circuit/src/castf/core.rs b/extensions/native/circuit/src/castf/core.rs index 664767e35e..426d7a2016 100644 --- a/extensions/native/circuit/src/castf/core.rs +++ b/extensions/native/circuit/src/castf/core.rs @@ -1,15 +1,25 @@ use std::borrow::{Borrow, BorrowMut}; -use openvm_circuit::arch::{ - AdapterAirContext, AdapterRuntimeContext, MinimalInstruction, Result, VmAdapterInterface, - VmCoreAir, VmCoreChip, +use openvm_circuit::{ + arch::{ + execution_mode::{E1ExecutionCtx, E2ExecutionCtx}, + get_record_from_slice, AdapterAirContext, AdapterTraceFiller, AdapterTraceStep, + E2PreCompute, EmptyAdapterCoreLayout, ExecuteFunc, + ExecutionError::InvalidInstruction, + MinimalInstruction, RecordArena, Result, StepExecutorE1, StepExecutorE2, TraceFiller, + TraceStep, VmAdapterInterface, VmCoreAir, VmSegmentState, VmStateMut, + }, + system::memory::{online::TracingMemory, MemoryAuxColsFactory}, }; -use openvm_circuit_primitives::var_range::{ - SharedVariableRangeCheckerChip, VariableRangeCheckerBus, +use openvm_circuit_primitives::{ + var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus}, + AlignedBytesBorrow, }; use openvm_circuit_primitives_derive::AlignedBorrow; -use openvm_instructions::{instruction::Instruction, LocalOpcode}; -use openvm_native_compiler::CastfOpcode; +use openvm_instructions::{ + instruction::Instruction, program::DEFAULT_PC_STEP, riscv::RV32_MEMORY_AS, LocalOpcode, +}; +use openvm_native_compiler::{conversion::AS, CastfOpcode}; use openvm_rv32im_circuit::adapters::RV32_REGISTER_NUM_LIMBS; use openvm_stark_backend::{ interaction::InteractionBuilder, @@ -17,7 +27,8 @@ use openvm_stark_backend::{ p3_field::{Field, FieldAlgebra, PrimeField32}, rap::BaseAirWithPublicValues, }; -use serde::{Deserialize, Serialize}; + +use crate::CASTF_MAX_BITS; // LIMB_BITS is the size of the limbs in bits. pub(crate) const LIMB_BITS: usize = 8; @@ -32,7 +43,7 @@ pub struct CastFCoreCols { pub is_valid: T, } -#[derive(Copy, Clone, Debug)] +#[derive(derive_new::new, Copy, Clone, Debug)] pub struct CastFCoreAir { pub bus: VariableRangeCheckerBus, /* to communicate with the range checker that checks that * all limbs are < 2^LIMB_BITS */ @@ -105,97 +116,213 @@ where } #[repr(C)] -#[derive(Debug, Serialize, Deserialize)] -pub struct CastFRecord { - pub in_val: F, - pub out_val: [u32; RV32_REGISTER_NUM_LIMBS], +#[derive(AlignedBytesBorrow, Debug)] +pub struct CastFCoreRecord { + pub val: u32, } -pub struct CastFCoreChip { - pub air: CastFCoreAir, +#[derive(derive_new::new)] +pub struct CastFCoreStep { + adapter: A, pub range_checker_chip: SharedVariableRangeCheckerChip, } -impl CastFCoreChip { - pub fn new(range_checker_chip: SharedVariableRangeCheckerChip) -> Self { - Self { - air: CastFCoreAir { - bus: range_checker_chip.bus(), - }, - range_checker_chip, - } - } -} - -impl> VmCoreChip for CastFCoreChip +impl TraceStep for CastFCoreStep where - I::Reads: Into<[[F; 1]; 1]>, - I::Writes: From<[[F; RV32_REGISTER_NUM_LIMBS]; 1]>, + F: PrimeField32, + A: 'static + + AdapterTraceStep, { - type Record = CastFRecord; - type Air = CastFCoreAir; + type RecordLayout = EmptyAdapterCoreLayout; + type RecordMut<'a> = (A::RecordMut<'a>, &'a mut CastFCoreRecord); - #[allow(clippy::type_complexity)] - fn execute_instruction( - &self, + fn get_opcode_name(&self, _opcode: usize) -> String { + format!("{:?}", CastfOpcode::CASTF) + } + + fn execute<'buf, RA>( + &mut self, + state: VmStateMut, CTX>, instruction: &Instruction, - _from_pc: u32, - reads: I::Reads, - ) -> Result<(AdapterRuntimeContext, Self::Record)> { - let Instruction { opcode, .. } = instruction; + arena: &'buf mut RA, + ) -> Result<()> + where + RA: RecordArena<'buf, Self::RecordLayout, Self::RecordMut<'buf>>, + { + let (mut adapter_record, core_record) = arena.alloc(EmptyAdapterCoreLayout::new()); - assert_eq!( - opcode.local_opcode_idx(CastfOpcode::CLASS_OFFSET), - CastfOpcode::CASTF as usize - ); + A::start(*state.pc, state.memory, &mut adapter_record); - let y = reads.into()[0][0]; - let x = CastF::solve(y.as_canonical_u32()); + core_record.val = self + .adapter + .read(state.memory, instruction, &mut adapter_record)[0] + .as_canonical_u32(); - let output = AdapterRuntimeContext { - to_pc: None, - writes: [x.map(F::from_canonical_u32)].into(), - }; + let x = run_castf(core_record.val); - let record = CastFRecord { - in_val: y, - out_val: x, - }; + self.adapter + .write(state.memory, instruction, x, &mut adapter_record); - Ok((output, record)) - } + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); - fn get_opcode_name(&self, _opcode: usize) -> String { - format!("{:?}", CastfOpcode::CASTF) + Ok(()) } +} - fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) { - for (i, limb) in record.out_val.iter().enumerate() { - if i == 3 { - self.range_checker_chip.add_count(*limb, FINAL_LIMB_BITS); +impl TraceFiller for CastFCoreStep +where + F: PrimeField32, + A: 'static + AdapterTraceFiller, +{ + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + self.adapter.fill_trace_row(mem_helper, adapter_row); + + let record: &CastFCoreRecord = unsafe { get_record_from_slice(&mut core_row, ()) }; + let core_row: &mut CastFCoreCols<_> = core_row.borrow_mut(); + + // Writing in reverse order to avoid overwriting the `record` + let out = run_castf(record.val); + for (i, &limb) in out.iter().enumerate() { + let limb_bits = if i == out.len() - 1 { + FINAL_LIMB_BITS } else { - self.range_checker_chip.add_count(*limb, LIMB_BITS); - } + LIMB_BITS + }; + self.range_checker_chip.add_count(limb as u32, limb_bits); } + core_row.is_valid = F::ONE; + core_row.out_val = out.map(F::from_canonical_u8); + core_row.in_val = F::from_canonical_u32(record.val); + } +} - let cols: &mut CastFCoreCols = row_slice.borrow_mut(); - cols.in_val = record.in_val; - cols.out_val = record.out_val.map(F::from_canonical_u32); - cols.is_valid = F::ONE; +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct CastFPreCompute { + a: u32, + b: u32, +} + +impl CastFCoreStep { + #[inline(always)] + fn pre_compute_impl( + &self, + pc: u32, + inst: &Instruction, + data: &mut CastFPreCompute, + ) -> Result<()> { + let Instruction { + a, b, d, e, opcode, .. + } = inst; + + if opcode.local_opcode_idx(CastfOpcode::CLASS_OFFSET) != CastfOpcode::CASTF as usize { + return Err(InvalidInstruction(pc)); + } + if d.as_canonical_u32() != RV32_MEMORY_AS { + return Err(InvalidInstruction(pc)); + } + if e.as_canonical_u32() != AS::Native as u32 { + return Err(InvalidInstruction(pc)); + } + + let a = a.as_canonical_u32(); + let b = b.as_canonical_u32(); + *data = CastFPreCompute { a, b }; + + Ok(()) } +} - fn air(&self) -> &Self::Air { - &self.air +impl StepExecutorE1 for CastFCoreStep +where + F: PrimeField32, +{ + #[inline(always)] + fn pre_compute_size(&self) -> usize { + size_of::() + } + + #[inline(always)] + fn pre_compute_e1( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result> { + let pre_compute: &mut CastFPreCompute = data.borrow_mut(); + + self.pre_compute_impl(pc, inst, pre_compute)?; + + let fn_ptr = execute_e1_impl::<_, _>; + + Ok(fn_ptr) } } -pub struct CastF; -impl CastF { - pub(super) fn solve(y: u32) -> [u32; RV32_REGISTER_NUM_LIMBS] { - let mut x = [0; 4]; - for (i, limb) in x.iter_mut().enumerate() { - *limb = (y >> (8 * i)) & 0xFF; - } - x +impl StepExecutorE2 for CastFCoreStep +where + F: PrimeField32, +{ + #[inline(always)] + fn e2_pre_compute_size(&self) -> usize { + size_of::>() + } + + #[inline(always)] + fn pre_compute_e2( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result> { + let pre_compute: &mut E2PreCompute = data.borrow_mut(); + pre_compute.chip_idx = chip_idx as u32; + + self.pre_compute_impl(pc, inst, &mut pre_compute.data)?; + + let fn_ptr = execute_e2_impl::<_, _>; + + Ok(fn_ptr) } } + +unsafe fn execute_e1_impl( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &CastFPreCompute = pre_compute.borrow(); + execute_e12_impl(pre_compute, vm_state); +} + +unsafe fn execute_e2_impl( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &E2PreCompute = pre_compute.borrow(); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, 1); + execute_e12_impl(&pre_compute.data, vm_state); +} + +#[inline(always)] +unsafe fn execute_e12_impl( + pre_compute: &CastFPreCompute, + vm_state: &mut VmSegmentState, +) { + let y = vm_state.vm_read::(AS::Native as u32, pre_compute.b)[0]; + let x = run_castf(y.as_canonical_u32()); + + vm_state.vm_write::(RV32_MEMORY_AS, pre_compute.a, &x); + + vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP); + vm_state.instret += 1; +} + +#[inline(always)] +pub(super) fn run_castf(y: u32) -> [u8; RV32_REGISTER_NUM_LIMBS] { + debug_assert!(y < 1 << CASTF_MAX_BITS); + y.to_le_bytes() +} diff --git a/extensions/native/circuit/src/castf/mod.rs b/extensions/native/circuit/src/castf/mod.rs index 9fbd77f245..b7b3141d39 100644 --- a/extensions/native/circuit/src/castf/mod.rs +++ b/extensions/native/circuit/src/castf/mod.rs @@ -1,12 +1,13 @@ -use openvm_circuit::arch::{VmAirWrapper, VmChipWrapper}; +use openvm_circuit::arch::{MatrixRecordArena, NewVmChipWrapper, VmAirWrapper}; -use super::adapters::convert_adapter::{ConvertAdapterAir, ConvertAdapterChip}; - -#[cfg(test)] -mod tests; +use crate::adapters::{ConvertAdapterAir, ConvertAdapterStep}; mod core; pub use core::*; +#[cfg(test)] +mod tests; + pub type CastFAir = VmAirWrapper, CastFCoreAir>; -pub type CastFChip = VmChipWrapper, CastFCoreChip>; +pub type CastFStep = CastFCoreStep>; +pub type CastFChip = NewVmChipWrapper>; diff --git a/extensions/native/circuit/src/castf/tests.rs b/extensions/native/circuit/src/castf/tests.rs index 9758e6b956..d99750acec 100644 --- a/extensions/native/circuit/src/castf/tests.rs +++ b/extensions/native/circuit/src/castf/tests.rs @@ -1,254 +1,222 @@ use std::borrow::BorrowMut; -use openvm_circuit::arch::testing::{memory::gen_pointer, VmChipTestBuilder}; -use openvm_instructions::{instruction::Instruction, LocalOpcode}; -use openvm_native_compiler::CastfOpcode; -use openvm_stark_backend::{ - p3_field::FieldAlgebra, utils::disable_debug_builder, verifier::VerificationError, Chip, +use openvm_circuit::arch::{ + testing::{memory::gen_pointer, VmChipTestBuilder}, + MemoryConfig, +}; +use openvm_instructions::{ + instruction::Instruction, + riscv::{RV32_MEMORY_AS, RV32_REGISTER_NUM_LIMBS}, + LocalOpcode, }; -use openvm_stark_sdk::{ - config::baby_bear_poseidon2::BabyBearPoseidon2Engine, engine::StarkFriEngine, - p3_baby_bear::BabyBear, utils::create_seeded_rng, +use openvm_native_compiler::{conversion::AS, CastfOpcode}; +use openvm_stark_backend::{ + p3_air::BaseAir, + p3_field::{FieldAlgebra, PrimeField32}, + p3_matrix::{ + dense::{DenseMatrix, RowMajorMatrix}, + Matrix, + }, + utils::disable_debug_builder, + verifier::VerificationError, }; +use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; use rand::{rngs::StdRng, Rng}; -use super::{ - super::adapters::convert_adapter::{ConvertAdapterChip, ConvertAdapterCols}, - CastF, CastFChip, CastFCoreChip, CastFCoreCols, FINAL_LIMB_BITS, LIMB_BITS, +use super::{CastFChip, CastFCoreAir, CastFCoreCols, CastFStep, LIMB_BITS}; +use crate::{ + adapters::{ConvertAdapterAir, ConvertAdapterCols, ConvertAdapterStep}, + castf::run_castf, + test_utils::write_native_array, + CastFAir, CASTF_MAX_BITS, }; + +const MAX_INS_CAPACITY: usize = 128; +const READ_SIZE: usize = 1; +const WRITE_SIZE: usize = 4; type F = BabyBear; -fn generate_uint_number(rng: &mut StdRng) -> u32 { - rng.gen_range(0..(1 << 30) - 1) +fn create_test_chip(tester: &VmChipTestBuilder) -> CastFChip { + let mut chip = CastFChip::::new( + CastFAir::new( + ConvertAdapterAir::new(tester.execution_bridge(), tester.memory_bridge()), + CastFCoreAir::new(tester.range_checker().bus()), + ), + CastFStep::new( + ConvertAdapterStep::::new(), + tester.range_checker().clone(), + ), + tester.memory_helper(), + ); + chip.set_trace_buffer_height(MAX_INS_CAPACITY); + + chip } -fn prepare_castf_rand_write_execute( +fn set_and_execute( tester: &mut VmChipTestBuilder, chip: &mut CastFChip, - y: u32, rng: &mut StdRng, + b: Option, ) { - let operand1 = y; - - let as_x = 2usize; // d - let as_y = 4usize; // e - let address_x = gen_pointer(rng, 32); // a - let address_y = gen_pointer(rng, 32); // b - - let operand1_f = F::from_canonical_u32(y); - - tester.write_cell(as_y, address_y, operand1_f); - let x = CastF::solve(operand1); + let b_val = b.unwrap_or(F::from_canonical_u32(rng.gen_range(0..1 << CASTF_MAX_BITS))); + let b_ptr = write_native_array(tester, rng, Some([b_val])).1; + let a = gen_pointer(rng, RV32_REGISTER_NUM_LIMBS); tester.execute( chip, &Instruction::from_usize( CastfOpcode::CASTF.global_opcode(), - [address_x, address_y, 0, as_x, as_y], + [a, b_ptr, 0, RV32_MEMORY_AS as usize, AS::Native as usize], ), ); - assert_eq!( - x.map(F::from_canonical_u32), - tester.read::<4>(as_x, address_x) - ); + let expected = run_castf(b_val.as_canonical_u32()); + let result = tester.read::(RV32_MEMORY_AS as usize, a); + assert_eq!(result.map(|x| x.as_canonical_u32() as u8), expected); } +/////////////////////////////////////////////////////////////////////////////////////// +/// POSITIVE TESTS +/// +/// Randomly generate computations and execute, ensuring that the generated trace +/// passes all constraints. +/////////////////////////////////////////////////////////////////////////////////////// + #[test] fn castf_rand_test() { let mut rng = create_seeded_rng(); - let mut tester = VmChipTestBuilder::default(); - let mut chip = CastFChip::::new( - ConvertAdapterChip::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - ), - CastFCoreChip::new(tester.range_checker()), - tester.offline_memory_mutex_arc(), - ); - let num_tests: usize = 3; + let mut tester = VmChipTestBuilder::volatile(MemoryConfig::default()); + let mut chip = create_test_chip(&tester); + let num_ops = 100; - for _ in 0..num_tests { - let y = generate_uint_number(&mut rng); - prepare_castf_rand_write_execute(&mut tester, &mut chip, y, &mut rng); + for _ in 0..num_ops { + set_and_execute(&mut tester, &mut chip, &mut rng, None); } + set_and_execute(&mut tester, &mut chip, &mut rng, Some(F::ZERO)); + let tester = tester.build().load(chip).finalize(); tester.simple_test().expect("Verification failed"); } -#[test] -fn negative_castf_overflow_test() { - let mut tester = VmChipTestBuilder::default(); - let range_checker_chip = tester.range_checker(); - let mut chip = CastFChip::::new( - ConvertAdapterChip::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - ), - CastFCoreChip::new(range_checker_chip.clone()), - tester.offline_memory_mutex_arc(), - ); +////////////////////////////////////////////////////////////////////////////////////// +// NEGATIVE TESTS +// +// Given a fake trace of a single operation, setup a chip and run the test. We replace +// part of the trace and check that the chip throws the expected error. +////////////////////////////////////////////////////////////////////////////////////// + +#[derive(Clone, Copy, Default)] +struct CastFPrankValues { + pub in_val: Option, + pub out_val: Option<[u32; 4]>, + pub a_pointer: Option, + pub b_pointer: Option, +} +fn run_negative_castf_test(prank_vals: CastFPrankValues, b: Option, error: VerificationError) { let mut rng = create_seeded_rng(); - let y = generate_uint_number(&mut rng); - prepare_castf_rand_write_execute(&mut tester, &mut chip, y, &mut rng); - tester.build(); - - let chip_air = chip.air(); - let mut chip_input = chip.generate_air_proof_input(); - let trace = chip_input.raw.common_main.as_mut().unwrap(); - let row = trace.row_mut(0); - let cols: &mut CastFCoreCols = row - .split_at_mut(ConvertAdapterCols::::width()) - .1 - .borrow_mut(); - cols.out_val[3] = F::from_canonical_u32(rng.gen_range(1 << FINAL_LIMB_BITS..1 << LIMB_BITS)); - - let rc_air = range_checker_chip.air(); - let rc_p_input = range_checker_chip.generate_air_proof_input(); + let mut tester = VmChipTestBuilder::volatile(MemoryConfig::default()); + + let mut chip = create_test_chip(&tester); + set_and_execute(&mut tester, &mut chip, &mut rng, b); + + let adapter_width = BaseAir::::width(&chip.air.adapter); + + let modify_trace = |trace: &mut DenseMatrix| { + let mut values = trace.row_slice(0).to_vec(); + let (adapter_row, core_row) = values.split_at_mut(adapter_width); + let core_cols: &mut CastFCoreCols = core_row.borrow_mut(); + let adapter_cols: &mut ConvertAdapterCols = + adapter_row.borrow_mut(); + + if let Some(in_val) = prank_vals.in_val { + // TODO: in_val is actually never used in the AIR, should remove it + core_cols.in_val = F::from_canonical_u32(in_val); + } + if let Some(out_val) = prank_vals.out_val { + core_cols.out_val = out_val.map(F::from_canonical_u32); + } + if let Some(a_pointer) = prank_vals.a_pointer { + adapter_cols.a_pointer = F::from_canonical_u32(a_pointer); + } + if let Some(b_pointer) = prank_vals.b_pointer { + adapter_cols.b_pointer = F::from_canonical_u32(b_pointer); + } + *trace = RowMajorMatrix::new(values, trace.width()); + }; disable_debug_builder(); - assert_eq!( - BabyBearPoseidon2Engine::run_test_fast( - vec![chip_air, rc_air], - vec![chip_input, rc_p_input] - ) - .err(), - Some(VerificationError::ChallengePhaseError), - "Expected verification to fail, but it didn't" - ); + let tester = tester + .build() + .load_and_prank_trace(chip, modify_trace) + .finalize(); + tester.simple_test_with_expected_error(error); } #[test] -fn negative_castf_memread_test() { - let mut tester = VmChipTestBuilder::default(); - let range_checker_chip = tester.memory_controller().borrow().range_checker.clone(); - let mut chip = CastFChip::::new( - ConvertAdapterChip::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - ), - CastFCoreChip::new(range_checker_chip.clone()), - tester.offline_memory_mutex_arc(), +fn casf_invalid_out_val_test() { + run_negative_castf_test( + CastFPrankValues { + out_val: Some([2 << LIMB_BITS, 0, 0, 0]), + ..Default::default() + }, + Some(F::from_canonical_u32(2 << LIMB_BITS)), + VerificationError::ChallengePhaseError, ); - let mut rng = create_seeded_rng(); - let y = generate_uint_number(&mut rng); - prepare_castf_rand_write_execute(&mut tester, &mut chip, y, &mut rng); - tester.build(); - - let chip_air = chip.air(); - let mut chip_input = chip.generate_air_proof_input(); - let trace = chip_input.raw.common_main.as_mut().unwrap(); - let row = trace.row_mut(0); - let cols: &mut ConvertAdapterCols = row - .split_at_mut(ConvertAdapterCols::::width()) - .0 - .borrow_mut(); - cols.b_pointer += F::ONE; - - let rc_air = range_checker_chip.air(); - let rc_p_input = range_checker_chip.generate_air_proof_input(); - - disable_debug_builder(); - assert_eq!( - BabyBearPoseidon2Engine::run_test_fast( - vec![chip_air, rc_air], - vec![chip_input, rc_p_input] - ) - .err(), - Some(VerificationError::ChallengePhaseError), - "Expected verification to fail, but it didn't" + let prime = F::NEG_ONE.as_canonical_u32() + 1; + run_negative_castf_test( + CastFPrankValues { + out_val: Some(prime.to_le_bytes().map(|x| x as u32)), + ..Default::default() + }, + Some(F::ZERO), + VerificationError::ChallengePhaseError, ); } #[test] -fn negative_castf_memwrite_test() { - let mut tester = VmChipTestBuilder::default(); - let range_checker_chip = tester.memory_controller().borrow().range_checker.clone(); - let mut chip = CastFChip::::new( - ConvertAdapterChip::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - ), - CastFCoreChip::new(range_checker_chip.clone()), - tester.offline_memory_mutex_arc(), +fn negative_convert_adapter_test() { + // overflowing the memory pointer + run_negative_castf_test( + CastFPrankValues { + b_pointer: Some(1 << 30), + ..Default::default() + }, + None, + VerificationError::ChallengePhaseError, ); - let mut rng = create_seeded_rng(); - let y = generate_uint_number(&mut rng); - prepare_castf_rand_write_execute(&mut tester, &mut chip, y, &mut rng); - tester.build(); - - let chip_air = chip.air(); - let mut chip_input = chip.generate_air_proof_input(); - let trace = chip_input.raw.common_main.as_mut().unwrap(); - let row = trace.row_mut(0); - let cols: &mut ConvertAdapterCols = row - .split_at_mut(ConvertAdapterCols::::width()) - .0 - .borrow_mut(); - cols.a_pointer += F::ONE; - - let rc_air = range_checker_chip.air(); - let rc_p_input = range_checker_chip.generate_air_proof_input(); - - disable_debug_builder(); - assert_eq!( - BabyBearPoseidon2Engine::run_test_fast( - vec![chip_air, rc_air], - vec![chip_input, rc_p_input] - ) - .err(), - Some(VerificationError::ChallengePhaseError), - "Expected verification to fail, but it didn't" + // Memory address space pointer has to be 4-byte aligned + run_negative_castf_test( + CastFPrankValues { + a_pointer: Some(1), + ..Default::default() + }, + None, + VerificationError::ChallengePhaseError, ); } +#[should_panic] #[test] -fn negative_castf_as_test() { - let mut tester = VmChipTestBuilder::default(); - let range_checker_chip = tester.memory_controller().borrow().range_checker.clone(); - let mut chip = CastFChip::::new( - ConvertAdapterChip::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - ), - CastFCoreChip::new(range_checker_chip.clone()), - tester.offline_memory_mutex_arc(), - ); - +fn castf_overflow_in_val_test() { let mut rng = create_seeded_rng(); - let y = generate_uint_number(&mut rng); - prepare_castf_rand_write_execute(&mut tester, &mut chip, y, &mut rng); - tester.build(); - - let chip_air = chip.air(); - let mut chip_input = chip.generate_air_proof_input(); - let trace = chip_input.raw.common_main.as_mut().unwrap(); - let row = trace.row_mut(0); - let cols: &mut ConvertAdapterCols = row - .split_at_mut(ConvertAdapterCols::::width()) - .0 - .borrow_mut(); - cols.a_pointer += F::ONE; - - let rc_air = range_checker_chip.air(); - let rc_p_input = range_checker_chip.generate_air_proof_input(); + let mut tester = VmChipTestBuilder::volatile(MemoryConfig::default()); + let mut chip = create_test_chip(&tester); + set_and_execute(&mut tester, &mut chip, &mut rng, Some(F::NEG_ONE)); +} - disable_debug_builder(); - assert_eq!( - BabyBearPoseidon2Engine::run_test_fast( - vec![chip_air, rc_air], - vec![chip_input, rc_p_input] - ) - .err(), - Some(VerificationError::ChallengePhaseError), - "Expected verification to fail, but it didn't" - ); +/////////////////////////////////////////////////////////////////////////////////////// +/// SANITY TESTS +/// +/// Ensure that solve functions produce the correct results. +/////////////////////////////////////////////////////////////////////////////////////// + +#[test] +fn castf_sanity_test() { + let b = 160558167; + let expected = [87, 236, 145, 9]; + assert_eq!(run_castf(b), expected); } diff --git a/extensions/native/circuit/src/extension.rs b/extensions/native/circuit/src/extension.rs index 385c9392ac..3f7e1c451b 100644 --- a/extensions/native/circuit/src/extension.rs +++ b/extensions/native/circuit/src/extension.rs @@ -1,17 +1,19 @@ -use air::VerifyBatchBus; -use alu_native_adapter::AluNativeAdapterChip; -use branch_native_adapter::BranchNativeAdapterChip; +use alu_native_adapter::{AluNativeAdapterAir, AluNativeAdapterStep}; +use branch_native_adapter::{BranchNativeAdapterAir, BranchNativeAdapterStep}; +use convert_adapter::{ConvertAdapterAir, ConvertAdapterStep}; use derive_more::derive::From; -use loadstore_native_adapter::NativeLoadStoreAdapterChip; -use native_vectorized_adapter::NativeVectorizedAdapterChip; +use fri::{FriReducedOpeningAir, FriReducedOpeningChip, FriReducedOpeningStep}; +use jal_rangecheck::{JalRangeCheckAir, JalRangeCheckChip, JalRangeCheckStep}; +use loadstore_native_adapter::{NativeLoadStoreAdapterAir, NativeLoadStoreAdapterStep}; +use native_vectorized_adapter::{NativeVectorizedAdapterAir, NativeVectorizedAdapterStep}; use openvm_circuit::{ arch::{ - ExecutionBridge, InitFileGenerator, MemoryConfig, SystemConfig, SystemPort, VmExtension, - VmInventory, VmInventoryBuilder, VmInventoryError, + ExecutionBridge, InitFileGenerator, MemoryConfig, SystemConfig, SystemPort, VmAirWrapper, + VmExtension, VmInventory, VmInventoryBuilder, VmInventoryError, }, system::phantom::PhantomChip, }; -use openvm_circuit_derive::{AnyEnum, InstructionExecutor, VmConfig}; +use openvm_circuit_derive::{AnyEnum, InsExecutorE1, InsExecutorE2, InstructionExecutor, VmConfig}; use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; use openvm_instructions::{program::DEFAULT_PC_STEP, LocalOpcode, PhantomDiscriminant}; use openvm_native_compiler::{ @@ -21,19 +23,14 @@ use openvm_native_compiler::{ }; use openvm_poseidon2_air::Poseidon2Config; use openvm_rv32im_circuit::{ - BranchEqualCoreChip, Rv32I, Rv32IExecutor, Rv32IPeriphery, Rv32Io, Rv32IoExecutor, + BranchEqualCoreAir, Rv32I, Rv32IExecutor, Rv32IPeriphery, Rv32Io, Rv32IoExecutor, Rv32IoPeriphery, Rv32M, Rv32MExecutor, Rv32MPeriphery, }; use openvm_stark_backend::p3_field::PrimeField32; use serde::{Deserialize, Serialize}; use strum::IntoEnumIterator; -use crate::{ - adapters::{convert_adapter::ConvertAdapterChip, *}, - chip::NativePoseidon2Chip, - phantom::*, - *, -}; +use crate::{adapters::*, air::VerifyBatchBus, phantom::*, *}; #[derive(Clone, Debug, Serialize, Deserialize, VmConfig, derive_new::new)] pub struct NativeConfig { @@ -48,10 +45,7 @@ impl NativeConfig { Self { system: SystemConfig::new( max_constraint_degree, - MemoryConfig { - max_access_adapter_n: 8, - ..Default::default() - }, + MemoryConfig::aggregation(), num_public_values, ) .with_max_segment_len((1 << 24) - 100), @@ -66,7 +60,9 @@ impl InitFileGenerator for NativeConfig {} #[derive(Clone, Copy, Debug, Default, Serialize, Deserialize)] pub struct Native; -#[derive(ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum)] +#[derive( + ChipUsageGetter, Chip, InstructionExecutor, InsExecutorE1, InsExecutorE2, From, AnyEnum, +)] pub enum NativeExecutor { LoadStore(NativeLoadStoreChip), BlockLoadStore(NativeLoadStoreChip), @@ -97,58 +93,75 @@ impl VmExtension for Native { program_bus, memory_bridge, } = builder.system_port(); - let offline_memory = builder.system_base().offline_memory(); - let mut load_store_chip = NativeLoadStoreChip::::new( - NativeLoadStoreAdapterChip::new( - execution_bus, - program_bus, - memory_bridge, + let range_checker = &builder.system_base().range_checker_chip; + + let load_store_chip = NativeLoadStoreChip::::new( + VmAirWrapper::new( + NativeLoadStoreAdapterAir::new( + memory_bridge, + ExecutionBridge::new(execution_bus, program_bus), + ), + NativeLoadStoreCoreAir::new(NativeLoadStoreOpcode::CLASS_OFFSET), + ), + NativeLoadStoreCoreStep::new( + NativeLoadStoreAdapterStep::new(NativeLoadStoreOpcode::CLASS_OFFSET), NativeLoadStoreOpcode::CLASS_OFFSET, ), - NativeLoadStoreCoreChip::new(NativeLoadStoreOpcode::CLASS_OFFSET), - offline_memory.clone(), + builder.system_base().memory_controller.helper(), ); - load_store_chip.core.set_streams(builder.streams().clone()); - inventory.add_executor( load_store_chip, NativeLoadStoreOpcode::iter().map(|x| x.global_opcode()), )?; - let mut block_load_store_chip = NativeLoadStoreChip::::new( - NativeLoadStoreAdapterChip::new( - execution_bus, - program_bus, - memory_bridge, + let block_load_store_chip = NativeLoadStoreChip::::new( + VmAirWrapper::new( + NativeLoadStoreAdapterAir::new( + memory_bridge, + ExecutionBridge::new(execution_bus, program_bus), + ), + NativeLoadStoreCoreAir::new(NativeLoadStore4Opcode::CLASS_OFFSET), + ), + NativeLoadStoreCoreStep::new( + NativeLoadStoreAdapterStep::new(NativeLoadStore4Opcode::CLASS_OFFSET), NativeLoadStore4Opcode::CLASS_OFFSET, ), - NativeLoadStoreCoreChip::new(NativeLoadStore4Opcode::CLASS_OFFSET), - offline_memory.clone(), + builder.system_base().memory_controller.helper(), ); - block_load_store_chip - .core - .set_streams(builder.streams().clone()); - inventory.add_executor( block_load_store_chip, NativeLoadStore4Opcode::iter().map(|x| x.global_opcode()), )?; let branch_equal_chip = NativeBranchEqChip::new( - BranchNativeAdapterChip::<_>::new(execution_bus, program_bus, memory_bridge), - BranchEqualCoreChip::new(NativeBranchEqualOpcode::CLASS_OFFSET, DEFAULT_PC_STEP), - offline_memory.clone(), + NativeBranchEqAir::new( + BranchNativeAdapterAir::new( + ExecutionBridge::new(execution_bus, program_bus), + memory_bridge, + ), + BranchEqualCoreAir::new(NativeBranchEqualOpcode::CLASS_OFFSET, DEFAULT_PC_STEP), + ), + NativeBranchEqStep::new( + BranchNativeAdapterStep::new(), + NativeBranchEqualOpcode::CLASS_OFFSET, + DEFAULT_PC_STEP, + ), + builder.system_base().memory_controller.helper(), ); inventory.add_executor( branch_equal_chip, NativeBranchEqualOpcode::iter().map(|x| x.global_opcode()), )?; - let jal_chip = JalRangeCheckChip::new( - ExecutionBridge::new(execution_bus, program_bus), - offline_memory.clone(), - builder.system_base().range_checker_chip.clone(), + let jal_chip = JalRangeCheckChip::::new( + JalRangeCheckAir::new( + ExecutionBridge::new(execution_bus, program_bus), + memory_bridge, + range_checker.bus(), + ), + JalRangeCheckStep::new(range_checker.clone()), + builder.system_base().memory_controller.helper(), ); inventory.add_executor( jal_chip, @@ -158,44 +171,57 @@ impl VmExtension for Native { ], )?; - let field_arithmetic_chip = FieldArithmeticChip::new( - AluNativeAdapterChip::::new(execution_bus, program_bus, memory_bridge), - FieldArithmeticCoreChip::new(), - offline_memory.clone(), + let field_arithmetic_chip = FieldArithmeticChip::::new( + VmAirWrapper::new( + AluNativeAdapterAir::new( + ExecutionBridge::new(execution_bus, program_bus), + memory_bridge, + ), + FieldArithmeticCoreAir::new(), + ), + FieldArithmeticStep::new(AluNativeAdapterStep::new()), + builder.system_base().memory_controller.helper(), ); inventory.add_executor( field_arithmetic_chip, FieldArithmeticOpcode::iter().map(|x| x.global_opcode()), )?; - let field_extension_chip = FieldExtensionChip::new( - NativeVectorizedAdapterChip::new(execution_bus, program_bus, memory_bridge), - FieldExtensionCoreChip::new(), - offline_memory.clone(), + let field_extension_chip = FieldExtensionChip::::new( + VmAirWrapper::new( + NativeVectorizedAdapterAir::new( + ExecutionBridge::new(execution_bus, program_bus), + memory_bridge, + ), + FieldExtensionCoreAir::new(), + ), + FieldExtensionStep::new(NativeVectorizedAdapterStep::new()), + builder.system_base().memory_controller.helper(), ); inventory.add_executor( field_extension_chip, FieldExtensionOpcode::iter().map(|x| x.global_opcode()), )?; - let fri_reduced_opening_chip = FriReducedOpeningChip::new( - execution_bus, - program_bus, - memory_bridge, - offline_memory.clone(), - builder.streams().clone(), + let fri_reduced_opening_chip = FriReducedOpeningChip::::new( + FriReducedOpeningAir::new( + ExecutionBridge::new(execution_bus, program_bus), + memory_bridge, + ), + FriReducedOpeningStep::new(), + builder.system_base().memory_controller.helper(), ); + inventory.add_executor( fri_reduced_opening_chip, FriOpcode::iter().map(|x| x.global_opcode()), )?; - let poseidon2_chip = NativePoseidon2Chip::new( + let poseidon2_chip = new_native_poseidon2_chip( builder.system_port(), - offline_memory.clone(), Poseidon2Config::default(), VerifyBatchBus::new(builder.new_bus_idx()), - builder.streams().clone(), + builder.system_base().memory_controller.helper(), ); inventory.add_executor( poseidon2_chip, @@ -239,10 +265,11 @@ pub(crate) mod phantom { use eyre::bail; use openvm_circuit::{ arch::{PhantomSubExecutor, Streams}, - system::memory::MemoryController, + system::memory::online::GuestMemory, }; use openvm_instructions::PhantomDiscriminant; use openvm_stark_backend::p3_field::{Field, PrimeField32}; + use rand::rngs::StdRng; pub struct NativeHintInputSubEx; pub struct NativeHintSliceSubEx; @@ -252,12 +279,13 @@ pub(crate) mod phantom { impl PhantomSubExecutor for NativeHintInputSubEx { fn phantom_execute( - &mut self, - _: &MemoryController, + &self, + _: &GuestMemory, streams: &mut Streams, + _: &mut StdRng, _: PhantomDiscriminant, - _: F, - _: F, + _: u32, + _: u32, _: u16, ) -> eyre::Result<()> { let hint = match streams.input_stream.pop_front() { @@ -277,12 +305,13 @@ pub(crate) mod phantom { impl PhantomSubExecutor for NativeHintSliceSubEx { fn phantom_execute( - &mut self, - _: &MemoryController, + &self, + _: &GuestMemory, streams: &mut Streams, + _: &mut StdRng, _: PhantomDiscriminant, - _: F, - _: F, + _: u32, + _: u32, _: u16, ) -> eyre::Result<()> { let hint = match streams.input_stream.pop_front() { @@ -300,36 +329,35 @@ pub(crate) mod phantom { impl PhantomSubExecutor for NativePrintSubEx { fn phantom_execute( - &mut self, - memory: &MemoryController, + &self, + memory: &GuestMemory, _: &mut Streams, + _: &mut StdRng, _: PhantomDiscriminant, - a: F, - _: F, + a: u32, + _: u32, c_upper: u16, ) -> eyre::Result<()> { - let addr_space = F::from_canonical_u16(c_upper); - let value = memory.unsafe_read_cell(addr_space, a); - println!("{}", value); + let [value] = unsafe { memory.read::(c_upper as u32, a) }; + println!("{value}"); Ok(()) } } impl PhantomSubExecutor for NativeHintBitsSubEx { fn phantom_execute( - &mut self, - memory: &MemoryController, + &self, + memory: &GuestMemory, streams: &mut Streams, + _: &mut StdRng, _: PhantomDiscriminant, - a: F, - b: F, + a: u32, + len: u32, c_upper: u16, ) -> eyre::Result<()> { - let addr_space = F::from_canonical_u16(c_upper); - let val = memory.unsafe_read_cell(addr_space, a); + let [val] = unsafe { memory.read::(c_upper as u32, a) }; let mut val = val.as_canonical_u32(); - let len = b.as_canonical_u32(); assert!(streams.hint_stream.is_empty()); for _ in 0..len { streams @@ -343,12 +371,13 @@ pub(crate) mod phantom { impl PhantomSubExecutor for NativeHintLoadSubEx { fn phantom_execute( - &mut self, - _: &MemoryController, + &self, + _: &GuestMemory, streams: &mut Streams, + _: &mut StdRng, _: PhantomDiscriminant, - _: F, - _: F, + _: u32, + _: u32, _: u16, ) -> eyre::Result<()> { let payload = match streams.input_stream.pop_front() { @@ -370,7 +399,9 @@ pub(crate) mod phantom { #[derive(Clone, Copy, Debug, Default, Serialize, Deserialize)] pub struct CastFExtension; -#[derive(ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum)] +#[derive( + ChipUsageGetter, Chip, InstructionExecutor, InsExecutorE1, InsExecutorE2, From, AnyEnum, +)] pub enum CastFExtensionExecutor { CastF(CastFChip), } @@ -394,13 +425,18 @@ impl VmExtension for CastFExtension { program_bus, memory_bridge, } = builder.system_port(); - let offline_memory = builder.system_base().offline_memory(); - let range_checker = builder.system_base().range_checker_chip.clone(); - - let castf_chip = CastFChip::new( - ConvertAdapterChip::new(execution_bus, program_bus, memory_bridge), - CastFCoreChip::new(range_checker.clone()), - offline_memory.clone(), + let range_checker = &builder.system_base().range_checker_chip; + + let castf_chip = CastFChip::::new( + VmAirWrapper::new( + ConvertAdapterAir::new( + ExecutionBridge::new(execution_bus, program_bus), + memory_bridge, + ), + CastFCoreAir::new(range_checker.bus()), + ), + CastFStep::new(ConvertAdapterStep::<1, 4>::new(), range_checker.clone()), + builder.system_base().memory_controller.helper(), ); inventory.add_executor(castf_chip, [CastfOpcode::CASTF.global_opcode()])?; @@ -439,3 +475,11 @@ impl Default for Rv32WithKernelsConfig { // Default implementation uses no init file impl InitFileGenerator for Rv32WithKernelsConfig {} + +// Pre-computed maximum trace heights for NativeConfig. Found by doubling +// the actual trace heights of kitchen-sink leaf verification (except for +// VariableRangeChecker, which has a fixed height). +pub const NATIVE_MAX_TRACE_HEIGHTS: &[u32] = &[ + 4194304, 4, 128, 2097152, 8388608, 4194304, 262144, 2097152, 16777216, 2097152, 8388608, + 262144, 2097152, 1048576, 4194304, 65536, 262144, +]; diff --git a/extensions/native/circuit/src/field_arithmetic/core.rs b/extensions/native/circuit/src/field_arithmetic/core.rs index c813f6a066..ca1302584d 100644 --- a/extensions/native/circuit/src/field_arithmetic/core.rs +++ b/extensions/native/circuit/src/field_arithmetic/core.rs @@ -1,20 +1,32 @@ use std::borrow::{Borrow, BorrowMut}; use itertools::izip; -use openvm_circuit::arch::{ - AdapterAirContext, AdapterRuntimeContext, MinimalInstruction, Result, VmAdapterInterface, - VmCoreAir, VmCoreChip, +use openvm_circuit::{ + arch::{ + execution_mode::{E1ExecutionCtx, E2ExecutionCtx}, + get_record_from_slice, AdapterAirContext, AdapterTraceFiller, AdapterTraceStep, + E2PreCompute, EmptyAdapterCoreLayout, ExecuteFunc, ExecutionError, MinimalInstruction, + RecordArena, Result, StepExecutorE1, StepExecutorE2, TraceFiller, TraceStep, + VmAdapterInterface, VmCoreAir, VmSegmentState, VmStateMut, + }, + system::memory::{online::TracingMemory, MemoryAuxColsFactory}, + utils::{transmute_field_to_u32, transmute_u32_to_field}, }; +use openvm_circuit_primitives::AlignedBytesBorrow; use openvm_circuit_primitives_derive::AlignedBorrow; -use openvm_instructions::{instruction::Instruction, LocalOpcode}; -use openvm_native_compiler::FieldArithmeticOpcode::{self, *}; +use openvm_instructions::{ + instruction::Instruction, program::DEFAULT_PC_STEP, riscv::RV32_IMM_AS, LocalOpcode, +}; +use openvm_native_compiler::{ + conversion::AS, + FieldArithmeticOpcode::{self, *}, +}; use openvm_stark_backend::{ interaction::InteractionBuilder, p3_air::BaseAir, p3_field::{Field, FieldAlgebra, PrimeField32}, rap::BaseAirWithPublicValues, }; -use serde::{Deserialize, Serialize}; #[repr(C)] #[derive(AlignedBorrow)] @@ -31,7 +43,7 @@ pub struct FieldArithmeticCoreCols { pub divisor_inv: T, } -#[derive(Copy, Clone, Debug)] +#[derive(derive_new::new, Copy, Clone, Debug)] pub struct FieldArithmeticCoreAir {} impl BaseAir for FieldArithmeticCoreAir { @@ -106,120 +118,396 @@ where } #[repr(C)] -#[derive(Debug, Serialize, Deserialize)] +#[derive(AlignedBytesBorrow, Debug)] pub struct FieldArithmeticRecord { - pub opcode: FieldArithmeticOpcode, - pub a: F, pub b: F, pub c: F, + pub local_opcode: u8, } -pub struct FieldArithmeticCoreChip { - pub air: FieldArithmeticCoreAir, +#[derive(derive_new::new)] +pub struct FieldArithmeticCoreStep { + adapter: A, } -impl FieldArithmeticCoreChip { - pub fn new() -> Self { - Self { - air: FieldArithmeticCoreAir {}, - } +impl TraceStep for FieldArithmeticCoreStep +where + F: PrimeField32, + A: 'static + AdapterTraceStep, +{ + type RecordLayout = EmptyAdapterCoreLayout; + type RecordMut<'a> = (A::RecordMut<'a>, &'a mut FieldArithmeticRecord); + + fn get_opcode_name(&self, opcode: usize) -> String { + format!( + "{:?}", + FieldArithmeticOpcode::from_usize(opcode - FieldArithmeticOpcode::CLASS_OFFSET) + ) } -} -impl Default for FieldArithmeticCoreChip { - fn default() -> Self { - Self::new() + fn execute<'buf, RA>( + &mut self, + state: VmStateMut, CTX>, + instruction: &Instruction, + arena: &'buf mut RA, + ) -> Result<()> + where + RA: RecordArena<'buf, Self::RecordLayout, Self::RecordMut<'buf>>, + { + let &Instruction { opcode, .. } = instruction; + let (mut adapter_record, core_record) = arena.alloc(EmptyAdapterCoreLayout::new()); + + A::start(*state.pc, state.memory, &mut adapter_record); + + [core_record.b, core_record.c] = + self.adapter + .read(state.memory, instruction, &mut adapter_record); + + core_record.local_opcode = + opcode.local_opcode_idx(FieldArithmeticOpcode::CLASS_OFFSET) as u8; + + let opcode = FieldArithmeticOpcode::from_usize(core_record.local_opcode as usize); + let a_val = run_field_arithmetic(opcode, core_record.b, core_record.c); + + self.adapter + .write(state.memory, instruction, [a_val], &mut adapter_record); + + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + + Ok(()) } } -impl> VmCoreChip for FieldArithmeticCoreChip +impl TraceFiller for FieldArithmeticCoreStep where - I::Reads: Into<[[F; 1]; 2]>, - I::Writes: From<[[F; 1]; 1]>, + F: PrimeField32, + A: 'static + AdapterTraceFiller, { - type Record = FieldArithmeticRecord; - type Air = FieldArithmeticCoreAir; + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + self.adapter.fill_trace_row(mem_helper, adapter_row); + let record: &FieldArithmeticRecord = unsafe { get_record_from_slice(&mut core_row, ()) }; + let core_row: &mut FieldArithmeticCoreCols<_> = core_row.borrow_mut(); + + let opcode = FieldArithmeticOpcode::from_usize(record.local_opcode as usize); + let result = run_field_arithmetic(opcode, record.b, record.c); - #[allow(clippy::type_complexity)] - fn execute_instruction( + // Writing in reverse order to avoid overwriting the `record` + core_row.divisor_inv = if opcode == FieldArithmeticOpcode::DIV { + record.c.inverse() + } else { + F::ZERO + }; + + core_row.is_div = F::from_bool(opcode == FieldArithmeticOpcode::DIV); + core_row.is_mul = F::from_bool(opcode == FieldArithmeticOpcode::MUL); + core_row.is_sub = F::from_bool(opcode == FieldArithmeticOpcode::SUB); + core_row.is_add = F::from_bool(opcode == FieldArithmeticOpcode::ADD); + + core_row.c = record.c; + core_row.b = record.b; + core_row.a = result; + } +} + +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct FieldArithmeticPreCompute { + a: u32, + b_or_imm: u32, + c_or_imm: u32, + e: u32, + f: u32, +} + +impl FieldArithmeticCoreStep { + #[inline(always)] + fn pre_compute_impl( &self, - instruction: &Instruction, - _from_pc: u32, - reads: I::Reads, - ) -> Result<(AdapterRuntimeContext, Self::Record)> { - let Instruction { opcode, .. } = instruction; + _pc: u32, + inst: &Instruction, + data: &mut FieldArithmeticPreCompute, + ) -> Result<(bool, bool, FieldArithmeticOpcode)> { + let &Instruction { + opcode, + a, + b, + c, + e, + f, + .. + } = inst; + let local_opcode = FieldArithmeticOpcode::from_usize( opcode.local_opcode_idx(FieldArithmeticOpcode::CLASS_OFFSET), ); - let data: [[F; 1]; 2] = reads.into(); - let b = data[0][0]; - let c = data[1][0]; - let a = FieldArithmetic::run_field_arithmetic(local_opcode, b, c).unwrap(); + let a = a.as_canonical_u32(); + let e = e.as_canonical_u32(); + let f = f.as_canonical_u32(); - let output: AdapterRuntimeContext = AdapterRuntimeContext { - to_pc: None, - writes: [[a]].into(), + let a_is_imm = e == RV32_IMM_AS; + let b_is_imm = f == RV32_IMM_AS; + + let b_or_imm = if a_is_imm { + transmute_field_to_u32(&b) + } else { + b.as_canonical_u32() + }; + let c_or_imm = if b_is_imm { + transmute_field_to_u32(&c) + } else { + c.as_canonical_u32() }; - let record = Self::Record { - opcode: local_opcode, + *data = FieldArithmeticPreCompute { a, - b, - c, + b_or_imm, + c_or_imm, + e, + f, }; - Ok((output, record)) + Ok((a_is_imm, b_is_imm, local_opcode)) } +} - fn get_opcode_name(&self, opcode: usize) -> String { - format!( - "{:?}", - FieldArithmeticOpcode::from_usize(opcode - FieldArithmeticOpcode::CLASS_OFFSET) - ) +impl StepExecutorE1 for FieldArithmeticCoreStep +where + F: PrimeField32, +{ + #[inline(always)] + fn pre_compute_size(&self) -> usize { + size_of::() } - fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) { - let FieldArithmeticRecord { opcode, a, b, c } = record; - let row_slice: &mut FieldArithmeticCoreCols<_> = row_slice.borrow_mut(); - row_slice.a = a; - row_slice.b = b; - row_slice.c = c; - - row_slice.is_add = F::from_bool(opcode == FieldArithmeticOpcode::ADD); - row_slice.is_sub = F::from_bool(opcode == FieldArithmeticOpcode::SUB); - row_slice.is_mul = F::from_bool(opcode == FieldArithmeticOpcode::MUL); - row_slice.is_div = F::from_bool(opcode == FieldArithmeticOpcode::DIV); - row_slice.divisor_inv = if opcode == FieldArithmeticOpcode::DIV { - c.inverse() - } else { - F::ZERO + #[inline(always)] + fn pre_compute_e1( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result> { + let pre_compute: &mut FieldArithmeticPreCompute = data.borrow_mut(); + + let (a_is_imm, b_is_imm, local_opcode) = self.pre_compute_impl(pc, inst, pre_compute)?; + + let fn_ptr = match (local_opcode, a_is_imm, b_is_imm) { + (FieldArithmeticOpcode::ADD, true, true) => { + execute_e1_impl::<_, _, true, true, { FieldArithmeticOpcode::ADD as u8 }> + } + (FieldArithmeticOpcode::ADD, true, false) => { + execute_e1_impl::<_, _, true, false, { FieldArithmeticOpcode::ADD as u8 }> + } + (FieldArithmeticOpcode::ADD, false, true) => { + execute_e1_impl::<_, _, false, true, { FieldArithmeticOpcode::ADD as u8 }> + } + (FieldArithmeticOpcode::ADD, false, false) => { + execute_e1_impl::<_, _, false, false, { FieldArithmeticOpcode::ADD as u8 }> + } + (FieldArithmeticOpcode::SUB, true, true) => { + execute_e1_impl::<_, _, true, true, { FieldArithmeticOpcode::SUB as u8 }> + } + (FieldArithmeticOpcode::SUB, true, false) => { + execute_e1_impl::<_, _, true, false, { FieldArithmeticOpcode::SUB as u8 }> + } + (FieldArithmeticOpcode::SUB, false, true) => { + execute_e1_impl::<_, _, false, true, { FieldArithmeticOpcode::SUB as u8 }> + } + (FieldArithmeticOpcode::SUB, false, false) => { + execute_e1_impl::<_, _, false, false, { FieldArithmeticOpcode::SUB as u8 }> + } + (FieldArithmeticOpcode::MUL, true, true) => { + execute_e1_impl::<_, _, true, true, { FieldArithmeticOpcode::MUL as u8 }> + } + (FieldArithmeticOpcode::MUL, true, false) => { + execute_e1_impl::<_, _, true, false, { FieldArithmeticOpcode::MUL as u8 }> + } + (FieldArithmeticOpcode::MUL, false, true) => { + execute_e1_impl::<_, _, false, true, { FieldArithmeticOpcode::MUL as u8 }> + } + (FieldArithmeticOpcode::MUL, false, false) => { + execute_e1_impl::<_, _, false, false, { FieldArithmeticOpcode::MUL as u8 }> + } + (FieldArithmeticOpcode::DIV, true, true) => { + execute_e1_impl::<_, _, true, true, { FieldArithmeticOpcode::DIV as u8 }> + } + (FieldArithmeticOpcode::DIV, true, false) => { + execute_e1_impl::<_, _, true, false, { FieldArithmeticOpcode::DIV as u8 }> + } + (FieldArithmeticOpcode::DIV, false, true) => { + execute_e1_impl::<_, _, false, true, { FieldArithmeticOpcode::DIV as u8 }> + } + (FieldArithmeticOpcode::DIV, false, false) => { + execute_e1_impl::<_, _, false, false, { FieldArithmeticOpcode::DIV as u8 }> + } }; + + Ok(fn_ptr) } +} - fn air(&self) -> &Self::Air { - &self.air +impl StepExecutorE2 for FieldArithmeticCoreStep +where + F: PrimeField32, +{ + #[inline(always)] + fn e2_pre_compute_size(&self) -> usize { + size_of::>() } + + #[inline(always)] + fn pre_compute_e2( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result> { + let pre_compute: &mut E2PreCompute = data.borrow_mut(); + pre_compute.chip_idx = chip_idx as u32; + + let (a_is_imm, b_is_imm, local_opcode) = + self.pre_compute_impl(pc, inst, &mut pre_compute.data)?; + + let fn_ptr = match (local_opcode, a_is_imm, b_is_imm) { + (FieldArithmeticOpcode::ADD, true, true) => { + execute_e2_impl::<_, _, true, true, { FieldArithmeticOpcode::ADD as u8 }> + } + (FieldArithmeticOpcode::ADD, true, false) => { + execute_e2_impl::<_, _, true, false, { FieldArithmeticOpcode::ADD as u8 }> + } + (FieldArithmeticOpcode::ADD, false, true) => { + execute_e2_impl::<_, _, false, true, { FieldArithmeticOpcode::ADD as u8 }> + } + (FieldArithmeticOpcode::ADD, false, false) => { + execute_e2_impl::<_, _, false, false, { FieldArithmeticOpcode::ADD as u8 }> + } + (FieldArithmeticOpcode::SUB, true, true) => { + execute_e2_impl::<_, _, true, true, { FieldArithmeticOpcode::SUB as u8 }> + } + (FieldArithmeticOpcode::SUB, true, false) => { + execute_e2_impl::<_, _, true, false, { FieldArithmeticOpcode::SUB as u8 }> + } + (FieldArithmeticOpcode::SUB, false, true) => { + execute_e2_impl::<_, _, false, true, { FieldArithmeticOpcode::SUB as u8 }> + } + (FieldArithmeticOpcode::SUB, false, false) => { + execute_e2_impl::<_, _, false, false, { FieldArithmeticOpcode::SUB as u8 }> + } + (FieldArithmeticOpcode::MUL, true, true) => { + execute_e2_impl::<_, _, true, true, { FieldArithmeticOpcode::MUL as u8 }> + } + (FieldArithmeticOpcode::MUL, true, false) => { + execute_e2_impl::<_, _, true, false, { FieldArithmeticOpcode::MUL as u8 }> + } + (FieldArithmeticOpcode::MUL, false, true) => { + execute_e2_impl::<_, _, false, true, { FieldArithmeticOpcode::MUL as u8 }> + } + (FieldArithmeticOpcode::MUL, false, false) => { + execute_e2_impl::<_, _, false, false, { FieldArithmeticOpcode::MUL as u8 }> + } + (FieldArithmeticOpcode::DIV, true, true) => { + execute_e2_impl::<_, _, true, true, { FieldArithmeticOpcode::DIV as u8 }> + } + (FieldArithmeticOpcode::DIV, true, false) => { + execute_e2_impl::<_, _, true, false, { FieldArithmeticOpcode::DIV as u8 }> + } + (FieldArithmeticOpcode::DIV, false, true) => { + execute_e2_impl::<_, _, false, true, { FieldArithmeticOpcode::DIV as u8 }> + } + (FieldArithmeticOpcode::DIV, false, false) => { + execute_e2_impl::<_, _, false, false, { FieldArithmeticOpcode::DIV as u8 }> + } + }; + + Ok(fn_ptr) + } +} + +unsafe fn execute_e1_impl< + F: PrimeField32, + CTX: E1ExecutionCtx, + const A_IS_IMM: bool, + const B_IS_IMM: bool, + const OPCODE: u8, +>( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &FieldArithmeticPreCompute = pre_compute.borrow(); + execute_e12_impl::(pre_compute, vm_state); +} + +unsafe fn execute_e2_impl< + F: PrimeField32, + CTX: E2ExecutionCtx, + const A_IS_IMM: bool, + const B_IS_IMM: bool, + const OPCODE: u8, +>( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &E2PreCompute = pre_compute.borrow(); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, 1); + execute_e12_impl::(&pre_compute.data, vm_state); } -pub struct FieldArithmetic; -impl FieldArithmetic { - pub(super) fn run_field_arithmetic( - opcode: FieldArithmeticOpcode, - b: F, - c: F, - ) -> Option { - match opcode { - FieldArithmeticOpcode::ADD => Some(b + c), - FieldArithmeticOpcode::SUB => Some(b - c), - FieldArithmeticOpcode::MUL => Some(b * c), - FieldArithmeticOpcode::DIV => { - if c.is_zero() { - None - } else { - Some(b * c.inverse()) - } +#[inline(always)] +unsafe fn execute_e12_impl< + F: PrimeField32, + CTX: E1ExecutionCtx, + const A_IS_IMM: bool, + const B_IS_IMM: bool, + const OPCODE: u8, +>( + pre_compute: &FieldArithmeticPreCompute, + vm_state: &mut VmSegmentState, +) { + // Read values based on the adapter logic + let b_val = if A_IS_IMM { + transmute_u32_to_field(&pre_compute.b_or_imm) + } else { + vm_state.vm_read::(pre_compute.e, pre_compute.b_or_imm)[0] + }; + let c_val = if B_IS_IMM { + transmute_u32_to_field(&pre_compute.c_or_imm) + } else { + vm_state.vm_read::(pre_compute.f, pre_compute.c_or_imm)[0] + }; + + let a_val = match OPCODE { + 0 => b_val + c_val, // ADD + 1 => b_val - c_val, // SUB + 2 => b_val * c_val, // MUL + 3 => { + // DIV + if c_val.is_zero() { + vm_state.exit_code = Err(ExecutionError::Fail { pc: vm_state.pc }); + return; } + b_val * c_val.inverse() + } + _ => panic!("Invalid field arithmetic opcode: {OPCODE}"), + }; + + vm_state.vm_write::(AS::Native as u32, pre_compute.a, &[a_val]); + + vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP); + vm_state.instret += 1; +} + +pub(super) fn run_field_arithmetic(opcode: FieldArithmeticOpcode, b: F, c: F) -> F { + match opcode { + FieldArithmeticOpcode::ADD => b + c, + FieldArithmeticOpcode::SUB => b - c, + FieldArithmeticOpcode::MUL => b * c, + FieldArithmeticOpcode::DIV => { + assert!(!c.is_zero(), "Division by zero"); + b * c.inverse() } } } diff --git a/extensions/native/circuit/src/field_arithmetic/mod.rs b/extensions/native/circuit/src/field_arithmetic/mod.rs index 865434cb37..421f93887a 100644 --- a/extensions/native/circuit/src/field_arithmetic/mod.rs +++ b/extensions/native/circuit/src/field_arithmetic/mod.rs @@ -1,6 +1,6 @@ -use openvm_circuit::arch::{VmAirWrapper, VmChipWrapper}; +use openvm_circuit::arch::{MatrixRecordArena, NewVmChipWrapper, VmAirWrapper}; -use crate::adapters::alu_native_adapter::{AluNativeAdapterAir, AluNativeAdapterChip}; +use crate::adapters::{AluNativeAdapterAir, AluNativeAdapterStep}; #[cfg(test)] mod tests; @@ -9,5 +9,6 @@ mod core; pub use core::*; pub type FieldArithmeticAir = VmAirWrapper; +pub type FieldArithmeticStep = FieldArithmeticCoreStep; pub type FieldArithmeticChip = - VmChipWrapper, FieldArithmeticCoreChip>; + NewVmChipWrapper>; diff --git a/extensions/native/circuit/src/field_arithmetic/tests.rs b/extensions/native/circuit/src/field_arithmetic/tests.rs index 8e69f8c44b..d2d597a46b 100644 --- a/extensions/native/circuit/src/field_arithmetic/tests.rs +++ b/extensions/native/circuit/src/field_arithmetic/tests.rs @@ -2,180 +2,240 @@ use std::borrow::BorrowMut; use openvm_circuit::arch::testing::{memory::gen_pointer, VmChipTestBuilder}; use openvm_instructions::{instruction::Instruction, LocalOpcode}; -use openvm_native_compiler::FieldArithmeticOpcode; +use openvm_native_compiler::{conversion::AS, FieldArithmeticOpcode}; use openvm_stark_backend::{ + p3_air::BaseAir, p3_field::{Field, FieldAlgebra, PrimeField32}, + p3_matrix::{ + dense::{DenseMatrix, RowMajorMatrix}, + Matrix, + }, utils::disable_debug_builder, verifier::VerificationError, - Chip, }; -use openvm_stark_sdk::{ - config::baby_bear_poseidon2::BabyBearPoseidon2Engine, engine::StarkFriEngine, - p3_baby_bear::BabyBear, utils::create_seeded_rng, -}; -use rand::Rng; -use strum::EnumCount; +use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; +use rand::{rngs::StdRng, Rng}; +use test_case::test_case; use super::{ - core::FieldArithmeticCoreChip, FieldArithmetic, FieldArithmeticChip, FieldArithmeticCoreCols, + FieldArithmeticChip, FieldArithmeticCoreAir, FieldArithmeticCoreCols, FieldArithmeticStep, +}; +use crate::{ + adapters::{AluNativeAdapterAir, AluNativeAdapterStep}, + field_arithmetic::{run_field_arithmetic, FieldArithmeticAir}, + test_utils::write_native_or_imm, }; -use crate::adapters::alu_native_adapter::{AluNativeAdapterChip, AluNativeAdapterCols}; -#[test] -fn new_field_arithmetic_air_test() { - let num_ops = 3; // non-power-of-2 to also test padding - let elem_range = || 1..=100; - let xy_address_space_range = || 0usize..=1; - - let mut tester = VmChipTestBuilder::default(); - let mut chip = FieldArithmeticChip::new( - AluNativeAdapterChip::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), +const MAX_INS_CAPACITY: usize = 128; +type F = BabyBear; + +fn create_test_chip(tester: &VmChipTestBuilder) -> FieldArithmeticChip { + let mut chip = FieldArithmeticChip::::new( + FieldArithmeticAir::new( + AluNativeAdapterAir::new(tester.execution_bridge(), tester.memory_bridge()), + FieldArithmeticCoreAir::new(), ), - FieldArithmeticCoreChip::new(), - tester.offline_memory_mutex_arc(), + FieldArithmeticStep::new(AluNativeAdapterStep::new()), + tester.memory_helper(), ); + chip.set_trace_buffer_height(MAX_INS_CAPACITY); - let mut rng = create_seeded_rng(); + chip +} - for _ in 0..num_ops { - let opcode = - FieldArithmeticOpcode::from_usize(rng.gen_range(0..FieldArithmeticOpcode::COUNT)); +#[allow(clippy::too_many_arguments)] +fn set_and_execute( + tester: &mut VmChipTestBuilder, + chip: &mut FieldArithmeticChip, + rng: &mut StdRng, + opcode: FieldArithmeticOpcode, + b: Option, + c: Option, +) { + let b_val = b.unwrap_or(rng.gen()); + let c_val = c.unwrap_or(if opcode == FieldArithmeticOpcode::DIV { + // If division, make sure c is not zero + F::from_canonical_u32(rng.gen_range(0..F::NEG_ONE.as_canonical_u32()) + 1) + } else { + rng.gen() + }); + assert!(!c_val.is_zero(), "Division by zero"); + let (b, b_as) = write_native_or_imm(tester, rng, b_val, None); + let (c, c_as) = write_native_or_imm(tester, rng, c_val, None); + let a = gen_pointer(rng, 1); - let operand1 = BabyBear::from_canonical_u32(rng.gen_range(elem_range())); - let operand2 = BabyBear::from_canonical_u32(rng.gen_range(elem_range())); + tester.execute( + chip, + &Instruction::new( + opcode.global_opcode(), + F::from_canonical_usize(a), + b, + c, + F::from_canonical_usize(AS::Native as usize), + F::from_canonical_usize(b_as), + F::from_canonical_usize(c_as), + F::ZERO, + ), + ); - if opcode == FieldArithmeticOpcode::DIV && operand2.is_zero() { - continue; - } + let expected = run_field_arithmetic(opcode, b_val, c_val); + let result = tester.read::<1>(AS::Native as usize, a)[0]; + assert_eq!(result, expected); +} - let result_as = 4usize; - let as1 = rng.gen_range(xy_address_space_range()) * 4; - let as2 = rng.gen_range(xy_address_space_range()) * 4; - let address1 = if as1 == 0 { - operand1.as_canonical_u32() as usize - } else { - gen_pointer(&mut rng, 1) - }; - let address2 = if as2 == 0 { - operand2.as_canonical_u32() as usize - } else { - gen_pointer(&mut rng, 1) - }; - assert_ne!(address1, address2); - let result_address = gen_pointer(&mut rng, 1); - - let result = FieldArithmetic::run_field_arithmetic(opcode, operand1, operand2).unwrap(); - tracing::debug!( - "{opcode:?} d = {}, e = {}, f = {}, result_addr = {}, addr1 = {}, addr2 = {}, z = {}, x = {}, y = {}", - result_as, as1, as2, result_address, address1, address2, result, operand1, operand2, - ); - - if as1 != 0 { - tester.write_cell(as1, address1, operand1); - } - if as2 != 0 { - tester.write_cell(as2, address2, operand2); - } - tester.execute( - &mut chip, - &Instruction::from_usize( - opcode.global_opcode(), - [result_address, address1, address2, result_as, as1, as2], - ), - ); - assert_eq!(result, tester.read_cell(result_as, result_address)); +////////////////////////////////////////////////////////////////////////////////////// +// POSITIVE TESTS +// +// Randomly generate computations and execute, ensuring that the generated trace +// passes all constraints. +////////////////////////////////////////////////////////////////////////////////////// +#[test_case(FieldArithmeticOpcode::ADD, 100)] +#[test_case(FieldArithmeticOpcode::SUB, 100)] +#[test_case(FieldArithmeticOpcode::MUL, 100)] +#[test_case(FieldArithmeticOpcode::DIV, 100)] +fn new_field_arithmetic_air_test(opcode: FieldArithmeticOpcode, num_ops: usize) { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default_native(); + let mut chip = create_test_chip(&tester); + + for _ in 0..num_ops { + set_and_execute(&mut tester, &mut chip, &mut rng, opcode, None, None); } - let mut tester = tester.build().load(chip).finalize(); + set_and_execute( + &mut tester, + &mut chip, + &mut rng, + opcode, + Some(F::ZERO), + None, + ); + + let tester = tester.build().load(chip).finalize(); tester.simple_test().expect("Verification failed"); +} - disable_debug_builder(); - // negative test pranking each IO value - for height in 0..num_ops { - // TODO: better way to modify existing traces in tester - let arith_trace = tester.air_proof_inputs[2] - .1 - .raw - .common_main - .as_mut() - .unwrap(); - let old_trace = arith_trace.clone(); - for width in 0..FieldArithmeticCoreCols::::width() { - let prank_value = BabyBear::from_canonical_u32(rng.gen_range(1..=100)); - arith_trace.row_mut(height)[width] = prank_value; +////////////////////////////////////////////////////////////////////////////////////// +// NEGATIVE TESTS +// +// Given a fake trace of a single operation, setup a chip and run the test. We replace +// part of the trace and check that the chip throws the expected error. +////////////////////////////////////////////////////////////////////////////////////// + +#[derive(Default)] +struct FieldExpressionPrankVals { + a: Option, + b: Option, + c: Option, + opcode_flags: Option<[bool; 4]>, + divisor_inv: Option, +} +#[allow(clippy::too_many_arguments)] +fn run_negative_field_arithmetic_test( + opcode: FieldArithmeticOpcode, + b: F, + c: F, + prank_vals: FieldExpressionPrankVals, + error: VerificationError, +) { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default_native(); + let mut chip = create_test_chip(&tester); + + set_and_execute(&mut tester, &mut chip, &mut rng, opcode, Some(b), Some(c)); + + let adapter_width = BaseAir::::width(&chip.air.adapter); + let modify_trace = |trace: &mut DenseMatrix| { + let mut values = trace.row_slice(0).to_vec(); + let cols: &mut FieldArithmeticCoreCols = + values.split_at_mut(adapter_width).1.borrow_mut(); + if let Some(a) = prank_vals.a { + cols.a = a; } + if let Some(b) = prank_vals.b { + cols.b = b; + } + if let Some(c) = prank_vals.c { + cols.c = c; + } + if let Some(opcode_flags) = prank_vals.opcode_flags { + [cols.is_add, cols.is_sub, cols.is_mul, cols.is_div] = opcode_flags.map(F::from_bool); + } + if let Some(divisor_inv) = prank_vals.divisor_inv { + cols.divisor_inv = divisor_inv; + } + *trace = RowMajorMatrix::new(values, trace.width()); + }; - // Run a test after pranking each row - assert_eq!( - tester.simple_test().err(), - Some(VerificationError::OodEvaluationMismatch), - "Expected constraint to fail" - ); - - tester.air_proof_inputs[2].1.raw.common_main = Some(old_trace); - } + disable_debug_builder(); + let tester = tester + .build() + .load_and_prank_trace(chip, modify_trace) + .finalize(); + tester.simple_test_with_expected_error(error); } #[test] -fn new_field_arithmetic_air_zero_div_zero() { - let mut tester = VmChipTestBuilder::default(); - let mut chip = FieldArithmeticChip::new( - AluNativeAdapterChip::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - ), - FieldArithmeticCoreChip::new(), - tester.offline_memory_mutex_arc(), +fn field_arithmetic_negative_zero_div_test() { + run_negative_field_arithmetic_test( + FieldArithmeticOpcode::DIV, + F::from_canonical_u32(111), + F::from_canonical_u32(222), + FieldExpressionPrankVals { + b: Some(F::ZERO), + ..Default::default() + }, + VerificationError::OodEvaluationMismatch, ); - tester.write_cell(4, 6, BabyBear::from_canonical_u32(111)); - tester.write_cell(4, 7, BabyBear::from_canonical_u32(222)); - tester.execute( - &mut chip, - &Instruction::from_usize( - FieldArithmeticOpcode::DIV.global_opcode(), - [5, 6, 7, 4, 4, 4], - ), + run_negative_field_arithmetic_test( + FieldArithmeticOpcode::DIV, + F::ZERO, + F::TWO, + FieldExpressionPrankVals { + c: Some(F::ZERO), + ..Default::default() + }, + VerificationError::OodEvaluationMismatch, ); - tester.build(); - - let chip_air = chip.air(); - let mut chip_input = chip.generate_air_proof_input(); - // set the value of [c]_f to zero, necessary to bypass trace gen checks - let row = chip_input.raw.common_main.as_mut().unwrap().row_mut(0); - let cols: &mut FieldArithmeticCoreCols = row - .split_at_mut(AluNativeAdapterCols::::width()) - .1 - .borrow_mut(); - cols.b = BabyBear::ZERO; - disable_debug_builder(); + run_negative_field_arithmetic_test( + FieldArithmeticOpcode::DIV, + F::ZERO, + F::TWO, + FieldExpressionPrankVals { + c: Some(F::ZERO), + opcode_flags: Some([false, false, true, false]), + ..Default::default() + }, + VerificationError::ChallengePhaseError, + ); +} - assert_eq!( - BabyBearPoseidon2Engine::run_test_fast(vec![chip_air], vec![chip_input]).err(), - Some(VerificationError::OodEvaluationMismatch), - "Expected constraint to fail" +#[test] +fn field_arithmetic_negative_rand() { + let mut rng = create_seeded_rng(); + run_negative_field_arithmetic_test( + FieldArithmeticOpcode::DIV, + F::from_canonical_u32(111), + F::from_canonical_u32(222), + FieldExpressionPrankVals { + a: Some(rng.gen()), + b: Some(rng.gen()), + c: Some(rng.gen()), + opcode_flags: Some([rng.gen(), rng.gen(), rng.gen(), rng.gen()]), + divisor_inv: Some(rng.gen()), + }, + VerificationError::OodEvaluationMismatch, ); } #[should_panic] #[test] fn new_field_arithmetic_air_test_panic() { - let mut tester = VmChipTestBuilder::default(); - let mut chip = FieldArithmeticChip::new( - AluNativeAdapterChip::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - ), - FieldArithmeticCoreChip::new(), - tester.offline_memory_mutex_arc(), - ); - tester.write_cell(4, 0, BabyBear::ZERO); + let mut tester = VmChipTestBuilder::default_native(); + let mut chip = create_test_chip(&tester); + tester.write(4, 0, [BabyBear::ZERO]); // should panic tester.execute( &mut chip, diff --git a/extensions/native/circuit/src/field_extension/core.rs b/extensions/native/circuit/src/field_extension/core.rs index d8c83fabdd..7ff729097d 100644 --- a/extensions/native/circuit/src/field_extension/core.rs +++ b/extensions/native/circuit/src/field_extension/core.rs @@ -5,20 +5,30 @@ use std::{ }; use itertools::izip; -use openvm_circuit::arch::{ - AdapterAirContext, AdapterRuntimeContext, MinimalInstruction, Result, VmAdapterInterface, - VmCoreAir, VmCoreChip, +use openvm_circuit::{ + arch::{ + execution_mode::{E1ExecutionCtx, E2ExecutionCtx}, + get_record_from_slice, AdapterAirContext, AdapterTraceFiller, AdapterTraceStep, + E2PreCompute, EmptyAdapterCoreLayout, ExecuteFunc, + ExecutionError::InvalidInstruction, + MinimalInstruction, RecordArena, Result, StepExecutorE1, StepExecutorE2, TraceFiller, + TraceStep, VmAdapterInterface, VmCoreAir, VmSegmentState, VmStateMut, + }, + system::memory::{online::TracingMemory, MemoryAuxColsFactory}, }; +use openvm_circuit_primitives::AlignedBytesBorrow; use openvm_circuit_primitives_derive::AlignedBorrow; -use openvm_instructions::{instruction::Instruction, LocalOpcode}; -use openvm_native_compiler::FieldExtensionOpcode::{self, *}; +use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode}; +use openvm_native_compiler::{ + conversion::AS, + FieldExtensionOpcode::{self, *}, +}; use openvm_stark_backend::{ interaction::InteractionBuilder, p3_air::BaseAir, p3_field::{Field, FieldAlgebra, PrimeField32}, rap::BaseAirWithPublicValues, }; -use serde::{Deserialize, Serialize}; pub const BETA: usize = 11; pub const EXT_DEG: usize = 4; @@ -34,11 +44,11 @@ pub struct FieldExtensionCoreCols { pub is_sub: T, pub is_mul: T, pub is_div: T, - /// `divisor_inv` is y.inverse() when opcode is FDIV and zero otherwise. + /// `divisor_inv` is z.inverse() when opcode is FDIV and zero otherwise. pub divisor_inv: [T; EXT_DEG], } -#[derive(Copy, Clone, Debug)] +#[derive(derive_new::new, Copy, Clone, Debug)] pub struct FieldExtensionCoreAir {} impl BaseAir for FieldExtensionCoreAir { @@ -78,8 +88,8 @@ where // - Each flag in `flags` is a boolean. // - Exactly one flag in `flags` is true. // - The inner product of the `flags` and `opcodes` equals `io.opcode`. - // - The inner product of the `flags` and `results[:,j]` equals `io.z[j]` for each `j`. - // - If `is_div` is true, then `aux.divisor_inv` correctly represents the inverse of `io.y`. + // - The inner product of the `flags` and `results[:,j]` equals `io.x[j]` for each `j`. + // - If `is_div` is true, then `aux.divisor_inv` correctly represents the inverse of `io.z`. let mut is_valid = AB::Expr::ZERO; let mut expected_opcode = AB::Expr::ZERO; @@ -133,116 +143,275 @@ where } #[repr(C)] -#[derive(Debug, Serialize, Deserialize)] +#[derive(AlignedBytesBorrow, Debug)] pub struct FieldExtensionRecord { - pub opcode: FieldExtensionOpcode, - pub x: [F; EXT_DEG], pub y: [F; EXT_DEG], pub z: [F; EXT_DEG], + pub local_opcode: u8, } -pub struct FieldExtensionCoreChip { - pub air: FieldExtensionCoreAir, +#[derive(derive_new::new)] +pub struct FieldExtensionCoreStep { + adapter: A, } -impl FieldExtensionCoreChip { - pub fn new() -> Self { - Self { - air: FieldExtensionCoreAir {}, - } +impl TraceStep for FieldExtensionCoreStep +where + F: PrimeField32, + A: 'static + AdapterTraceStep, +{ + type RecordLayout = EmptyAdapterCoreLayout; + type RecordMut<'a> = (A::RecordMut<'a>, &'a mut FieldExtensionRecord); + + fn get_opcode_name(&self, opcode: usize) -> String { + format!( + "{:?}", + FieldExtensionOpcode::from_usize(opcode - FieldExtensionOpcode::CLASS_OFFSET) + ) } -} -impl Default for FieldExtensionCoreChip { - fn default() -> Self { - Self::new() + fn execute<'buf, RA>( + &mut self, + state: VmStateMut, CTX>, + instruction: &Instruction, + arena: &'buf mut RA, + ) -> Result<()> + where + RA: RecordArena<'buf, Self::RecordLayout, Self::RecordMut<'buf>>, + { + let &Instruction { opcode, .. } = instruction; + + let (mut adapter_record, core_record) = arena.alloc(EmptyAdapterCoreLayout::new()); + + A::start(*state.pc, state.memory, &mut adapter_record); + + core_record.local_opcode = + opcode.local_opcode_idx(FieldExtensionOpcode::CLASS_OFFSET) as u8; + + [core_record.y, core_record.z] = + self.adapter + .read(state.memory, instruction, &mut adapter_record); + + let x = run_field_extension( + FieldExtensionOpcode::from_usize(core_record.local_opcode as usize), + core_record.y, + core_record.z, + ); + + self.adapter + .write(state.memory, instruction, x, &mut adapter_record); + + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + + Ok(()) } } -impl> VmCoreChip for FieldExtensionCoreChip +impl TraceFiller for FieldExtensionCoreStep where - I::Reads: Into<[[F; EXT_DEG]; 2]>, - I::Writes: From<[[F; EXT_DEG]; 1]>, + F: PrimeField32, + A: 'static + AdapterTraceFiller, { - type Record = FieldExtensionRecord; - type Air = FieldExtensionCoreAir; + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + self.adapter.fill_trace_row(mem_helper, adapter_row); + + let record: &FieldExtensionRecord = unsafe { get_record_from_slice(&mut core_row, ()) }; + let core_row: &mut FieldExtensionCoreCols<_> = core_row.borrow_mut(); + + // Writing in reverse order to avoid overwriting the `record` + let opcode = FieldExtensionOpcode::from_usize(record.local_opcode as usize); + if opcode == FieldExtensionOpcode::BBE4DIV { + core_row.divisor_inv = FieldExtension::invert(record.z); + } else { + core_row.divisor_inv = [F::ZERO; EXT_DEG]; + } + + core_row.is_div = F::from_bool(opcode == FieldExtensionOpcode::BBE4DIV); + core_row.is_mul = F::from_bool(opcode == FieldExtensionOpcode::BBE4MUL); + core_row.is_sub = F::from_bool(opcode == FieldExtensionOpcode::FE4SUB); + core_row.is_add = F::from_bool(opcode == FieldExtensionOpcode::FE4ADD); + + core_row.z = record.z; + core_row.y = record.y; + core_row.x = run_field_extension(opcode, core_row.y, core_row.z); + } +} + +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct FieldExtensionPreCompute { + a: u32, + b: u32, + c: u32, +} - #[allow(clippy::type_complexity)] - fn execute_instruction( +impl FieldExtensionCoreStep { + #[inline(always)] + fn pre_compute_impl( &self, - instruction: &Instruction, - _from_pc: u32, - reads: I::Reads, - ) -> Result<(AdapterRuntimeContext, Self::Record)> { - let Instruction { opcode, .. } = instruction; - let local_opcode_idx = opcode.local_opcode_idx(FieldExtensionOpcode::CLASS_OFFSET); + pc: u32, + inst: &Instruction, + data: &mut FieldExtensionPreCompute, + ) -> Result { + let &Instruction { + opcode, + a, + b, + c, + d, + e, + .. + } = inst; + + let local_opcode = FieldExtensionOpcode::from_usize( + opcode.local_opcode_idx(FieldExtensionOpcode::CLASS_OFFSET), + ); + + let a = a.as_canonical_u32(); + let b = b.as_canonical_u32(); + let c = c.as_canonical_u32(); + let d = d.as_canonical_u32(); + let e = e.as_canonical_u32(); + + if d != AS::Native as u32 { + return Err(InvalidInstruction(pc)); + } + if e != AS::Native as u32 { + return Err(InvalidInstruction(pc)); + } - let data: [[F; EXT_DEG]; 2] = reads.into(); - let y: [F; EXT_DEG] = data[0]; - let z: [F; EXT_DEG] = data[1]; + *data = FieldExtensionPreCompute { a, b, c }; - let x = FieldExtension::solve(FieldExtensionOpcode::from_usize(local_opcode_idx), y, z) - .unwrap(); + Ok(local_opcode as u8) + } +} - let output = AdapterRuntimeContext { - to_pc: None, - writes: [x].into(), - }; +impl StepExecutorE1 for FieldExtensionCoreStep +where + F: PrimeField32, +{ + #[inline(always)] + fn pre_compute_size(&self) -> usize { + size_of::() + } - let record = Self::Record { - opcode: FieldExtensionOpcode::from_usize(local_opcode_idx), - x, - y, - z, + #[inline(always)] + fn pre_compute_e1( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result> { + let pre_compute: &mut FieldExtensionPreCompute = data.borrow_mut(); + + let opcode = self.pre_compute_impl(pc, inst, pre_compute)?; + + let fn_ptr = match opcode { + 0 => execute_e1_impl::<_, _, 0>, // FE4ADD + 1 => execute_e1_impl::<_, _, 1>, // FE4SUB + 2 => execute_e1_impl::<_, _, 2>, // BBE4MUL + 3 => execute_e1_impl::<_, _, 3>, // BBE4DIV + _ => panic!("Invalid field extension opcode: {opcode}"), }; - Ok((output, record)) + Ok(fn_ptr) } +} - fn get_opcode_name(&self, opcode: usize) -> String { - format!( - "{:?}", - FieldExtensionOpcode::from_usize(opcode - FieldExtensionOpcode::CLASS_OFFSET) - ) +impl StepExecutorE2 for FieldExtensionCoreStep +where + F: PrimeField32, +{ + #[inline(always)] + fn e2_pre_compute_size(&self) -> usize { + size_of::>() } - fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) { - let FieldExtensionRecord { opcode, x, y, z } = record; - let cols: &mut FieldExtensionCoreCols<_> = row_slice.borrow_mut(); - cols.x = x; - cols.y = y; - cols.z = z; - cols.is_add = F::from_bool(opcode == FieldExtensionOpcode::FE4ADD); - cols.is_sub = F::from_bool(opcode == FieldExtensionOpcode::FE4SUB); - cols.is_mul = F::from_bool(opcode == FieldExtensionOpcode::BBE4MUL); - cols.is_div = F::from_bool(opcode == FieldExtensionOpcode::BBE4DIV); - cols.divisor_inv = if opcode == FieldExtensionOpcode::BBE4DIV { - FieldExtension::invert(z) - } else { - [F::ZERO; EXT_DEG] + #[inline(always)] + fn pre_compute_e2( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result> { + let pre_compute: &mut E2PreCompute = data.borrow_mut(); + pre_compute.chip_idx = chip_idx as u32; + + let opcode = self.pre_compute_impl(pc, inst, &mut pre_compute.data)?; + + let fn_ptr = match opcode { + 0 => execute_e2_impl::<_, _, 0>, // FE4ADD + 1 => execute_e2_impl::<_, _, 1>, // FE4SUB + 2 => execute_e2_impl::<_, _, 2>, // BBE4MUL + 3 => execute_e2_impl::<_, _, 3>, // BBE4DIV + _ => panic!("Invalid field extension opcode: {opcode}"), }; - } - fn air(&self) -> &Self::Air { - &self.air + Ok(fn_ptr) } } -pub struct FieldExtension; -impl FieldExtension { - pub(super) fn solve( - opcode: FieldExtensionOpcode, - x: [F; EXT_DEG], - y: [F; EXT_DEG], - ) -> Option<[F; EXT_DEG]> { - match opcode { - FieldExtensionOpcode::FE4ADD => Some(Self::add(x, y)), - FieldExtensionOpcode::FE4SUB => Some(Self::subtract(x, y)), - FieldExtensionOpcode::BBE4MUL => Some(Self::multiply(x, y)), - FieldExtensionOpcode::BBE4DIV => Some(Self::divide(x, y)), - } +unsafe fn execute_e1_impl( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &FieldExtensionPreCompute = pre_compute.borrow(); + execute_e12_impl::(pre_compute, vm_state); +} + +unsafe fn execute_e2_impl( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &E2PreCompute = pre_compute.borrow(); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, 1); + execute_e12_impl::(&pre_compute.data, vm_state); +} + +#[inline(always)] +unsafe fn execute_e12_impl( + pre_compute: &FieldExtensionPreCompute, + vm_state: &mut VmSegmentState, +) { + let y: [F; EXT_DEG] = vm_state.vm_read::(AS::Native as u32, pre_compute.b); + let z: [F; EXT_DEG] = vm_state.vm_read::(AS::Native as u32, pre_compute.c); + + let x = match OPCODE { + 0 => FieldExtension::add(y, z), // FE4ADD + 1 => FieldExtension::subtract(y, z), // FE4SUB + 2 => FieldExtension::multiply(y, z), // BBE4MUL + 3 => FieldExtension::divide(y, z), // BBE4DIV + _ => panic!("Invalid field extension opcode: {OPCODE}"), + }; + + vm_state.vm_write(AS::Native as u32, pre_compute.a, &x); + + vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP); + vm_state.instret += 1; +} + +// Returns the result of the field extension operation. +// Will panic if divide by zero. +pub(super) fn run_field_extension( + opcode: FieldExtensionOpcode, + y: [F; EXT_DEG], + z: [F; EXT_DEG], +) -> [F; EXT_DEG] { + match opcode { + FieldExtensionOpcode::FE4ADD => FieldExtension::add(y, z), + FieldExtensionOpcode::FE4SUB => FieldExtension::subtract(y, z), + FieldExtensionOpcode::BBE4MUL => FieldExtension::multiply(y, z), + FieldExtensionOpcode::BBE4DIV => FieldExtension::divide(y, z), } +} +pub(crate) struct FieldExtension; + +impl FieldExtension { pub(crate) fn add(x: [V; EXT_DEG], y: [V; EXT_DEG]) -> [E; EXT_DEG] where V: Copy, diff --git a/extensions/native/circuit/src/field_extension/mod.rs b/extensions/native/circuit/src/field_extension/mod.rs index d109deb528..1399731079 100644 --- a/extensions/native/circuit/src/field_extension/mod.rs +++ b/extensions/native/circuit/src/field_extension/mod.rs @@ -1,16 +1,15 @@ -use openvm_circuit::arch::{VmAirWrapper, VmChipWrapper}; +use openvm_circuit::arch::{MatrixRecordArena, NewVmChipWrapper, VmAirWrapper}; -use super::adapters::native_vectorized_adapter::{ - NativeVectorizedAdapterAir, NativeVectorizedAdapterChip, -}; - -#[cfg(test)] -mod tests; +use crate::adapters::{NativeVectorizedAdapterAir, NativeVectorizedAdapterStep}; mod core; pub use core::*; +#[cfg(test)] +mod tests; + pub type FieldExtensionAir = VmAirWrapper, FieldExtensionCoreAir>; +pub type FieldExtensionStep = FieldExtensionCoreStep>; pub type FieldExtensionChip = - VmChipWrapper, FieldExtensionCoreChip>; + NewVmChipWrapper>; diff --git a/extensions/native/circuit/src/field_extension/tests.rs b/extensions/native/circuit/src/field_extension/tests.rs index 66d6c94004..058842cff6 100644 --- a/extensions/native/circuit/src/field_extension/tests.rs +++ b/extensions/native/circuit/src/field_extension/tests.rs @@ -1,102 +1,225 @@ use std::{ array, + borrow::BorrowMut, ops::{Add, Div, Mul, Sub}, }; use openvm_circuit::arch::testing::{memory::gen_pointer, VmChipTestBuilder}; use openvm_instructions::{instruction::Instruction, LocalOpcode}; -use openvm_native_compiler::FieldExtensionOpcode; +use openvm_native_compiler::{conversion::AS, FieldExtensionOpcode}; use openvm_stark_backend::{ + p3_air::BaseAir, p3_field::{extension::BinomialExtensionField, FieldAlgebra, FieldExtensionAlgebra}, + p3_matrix::{ + dense::{DenseMatrix, RowMajorMatrix}, + Matrix, + }, utils::disable_debug_builder, verifier::VerificationError, - ChipUsageGetter, }; use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; -use rand::Rng; -use strum::EnumCount; +use rand::{rngs::StdRng, Rng}; +use test_case::test_case; -use super::{ - super::adapters::native_vectorized_adapter::NativeVectorizedAdapterChip, FieldExtension, - FieldExtensionChip, FieldExtensionCoreChip, +use crate::{ + adapters::{NativeVectorizedAdapterAir, NativeVectorizedAdapterStep}, + field_extension::run_field_extension, + test_utils::write_native_array, + FieldExtension, FieldExtensionAir, FieldExtensionChip, FieldExtensionCoreAir, + FieldExtensionCoreCols, FieldExtensionStep, EXT_DEG, }; -#[test] -fn new_field_extension_air_test() { - type F = BabyBear; +const MAX_INS_CAPACITY: usize = 128; +type F = BabyBear; - let mut tester = VmChipTestBuilder::default(); - let mut chip = FieldExtensionChip::new( - NativeVectorizedAdapterChip::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), +fn create_test_chip(tester: &VmChipTestBuilder) -> FieldExtensionChip { + let mut chip = FieldExtensionChip::::new( + FieldExtensionAir::new( + NativeVectorizedAdapterAir::new(tester.execution_bridge(), tester.memory_bridge()), + FieldExtensionCoreAir::new(), ), - FieldExtensionCoreChip::new(), - tester.offline_memory_mutex_arc(), + FieldExtensionStep::new(NativeVectorizedAdapterStep::new()), + tester.memory_helper(), ); - let trace_width = chip.trace_width(); + chip.set_trace_buffer_height(MAX_INS_CAPACITY); - let mut rng = create_seeded_rng(); - let num_ops: usize = 7; // test padding with dummy row + chip +} - for _ in 0..num_ops { - let opcode = - FieldExtensionOpcode::from_usize(rng.gen_range(0..FieldExtensionOpcode::COUNT)); +fn set_and_execute( + tester: &mut VmChipTestBuilder, + chip: &mut FieldExtensionChip, + rng: &mut StdRng, + opcode: FieldExtensionOpcode, + y: Option<[F; EXT_DEG]>, + z: Option<[F; EXT_DEG]>, +) { + let (y_val, y_ptr) = write_native_array(tester, rng, y); + let (z_val, z_ptr) = write_native_array(tester, rng, z); - let as_d = 4usize; - let as_e = 4usize; - let address1 = gen_pointer(&mut rng, 4); - let address2 = gen_pointer(&mut rng, 4); - let result_address = gen_pointer(&mut rng, 4); + let x_ptr = gen_pointer(rng, EXT_DEG); - let operand1 = array::from_fn(|_| rng.gen::()); - let operand2 = array::from_fn(|_| rng.gen::()); + tester.execute( + chip, + &Instruction::from_usize( + opcode.global_opcode(), + [ + x_ptr, + y_ptr, + z_ptr, + AS::Native as usize, + AS::Native as usize, + ], + ), + ); - assert!(address1.abs_diff(address2) >= 4); + let result = tester.read::(AS::Native as usize, x_ptr); + let expected = run_field_extension(opcode, y_val, z_val); + assert_eq!(result, expected); +} - tester.write(as_d, address1, operand1); - tester.write(as_e, address2, operand2); +/////////////////////////////////////////////////////////////////////////////////////// +/// POSITIVE TESTS +/// +/// Randomly generate computations and execute, ensuring that the generated trace +/// passes all constraints. +/////////////////////////////////////////////////////////////////////////////////////// - let result = FieldExtension::solve(opcode, operand1, operand2).unwrap(); +#[test_case(FieldExtensionOpcode::FE4ADD, 100)] +#[test_case(FieldExtensionOpcode::FE4SUB, 100)] +#[test_case(FieldExtensionOpcode::BBE4MUL, 100)] +#[test_case(FieldExtensionOpcode::BBE4DIV, 100)] +fn rand_field_extension_test(opcode: FieldExtensionOpcode, num_ops: usize) { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default_native(); + let mut chip = create_test_chip(&tester); - tester.execute( - &mut chip, - &Instruction::from_usize( - opcode.global_opcode(), - [result_address, address1, address2, as_d, as_e], - ), - ); - assert_eq!(result, tester.read(as_d, result_address)); + for _ in 0..num_ops { + set_and_execute(&mut tester, &mut chip, &mut rng, opcode, None, None); } - // positive test - let mut tester = tester.build().load(chip).finalize(); + let tester = tester.build().load(chip).finalize(); tester.simple_test().expect("Verification failed"); +} - disable_debug_builder(); - // negative test pranking each IO value - for height in [0, num_ops - 1] { - // TODO: better way to modify existing traces in tester - let extension_trace = tester.air_proof_inputs[2] - .1 - .raw - .common_main - .as_mut() - .unwrap(); - let original_trace = extension_trace.clone(); - for width in 0..trace_width { - let prank_value = BabyBear::from_canonical_u32(rng.gen_range(1..=100)); - extension_trace.row_mut(height)[width] = prank_value; +////////////////////////////////////////////////////////////////////////////////////// +// NEGATIVE TESTS +// +// Given a fake trace of a single operation, setup a chip and run the test. We replace +// part of the trace and check that the chip throws the expected error. +////////////////////////////////////////////////////////////////////////////////////// + +#[derive(Clone, Copy, Default)] +struct FieldExtensionPrankValues { + pub x: Option<[F; EXT_DEG]>, + pub y: Option<[F; EXT_DEG]>, + pub z: Option<[F; EXT_DEG]>, + pub opcode_flags: Option<[bool; 4]>, + pub divisor_inv: Option<[F; EXT_DEG]>, +} + +fn run_negative_field_extension_test( + opcode: FieldExtensionOpcode, + y: Option<[F; EXT_DEG]>, + z: Option<[F; EXT_DEG]>, + prank_vals: FieldExtensionPrankValues, + error: VerificationError, +) { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default_native(); + let mut chip = create_test_chip(&tester); + set_and_execute(&mut tester, &mut chip, &mut rng, opcode, y, z); + + let adapter_width = BaseAir::::width(&chip.air.adapter); + let modify_trace = |trace: &mut DenseMatrix| { + let mut values = trace.row_slice(0).to_vec(); + let core_cols: &mut FieldExtensionCoreCols = + values.split_at_mut(adapter_width).1.borrow_mut(); + + if let Some(x) = prank_vals.x { + core_cols.x = x; + } + if let Some(y) = prank_vals.y { + core_cols.y = y; + } + if let Some(z) = prank_vals.z { + core_cols.z = z; + } + if let Some(opcode_flags) = prank_vals.opcode_flags { + [ + core_cols.is_add, + core_cols.is_sub, + core_cols.is_mul, + core_cols.is_div, + ] = opcode_flags.map(F::from_bool); + } + if let Some(divisor_inv) = prank_vals.divisor_inv { + core_cols.divisor_inv = divisor_inv; } - assert_eq!( - tester.simple_test().err(), - Some(VerificationError::OodEvaluationMismatch), - "Expected constraint to fail" - ); - tester.air_proof_inputs[2].1.raw.common_main = Some(original_trace); - } + *trace = RowMajorMatrix::new(values, trace.width()); + }; + + disable_debug_builder(); + let tester = tester + .build() + .load_and_prank_trace(chip, modify_trace) + .finalize(); + tester.simple_test_with_expected_error(error); +} + +#[test] +fn rand_negative_field_extension_test() { + let mut rng = create_seeded_rng(); + run_negative_field_extension_test( + FieldExtensionOpcode::FE4ADD, + None, + None, + FieldExtensionPrankValues { + x: Some(array::from_fn(|_| rng.gen::())), + y: Some(array::from_fn(|_| rng.gen::())), + z: Some(array::from_fn(|_| rng.gen::())), + opcode_flags: Some(array::from_fn(|_| rng.gen_bool(0.5))), + divisor_inv: Some(array::from_fn(|_| rng.gen::())), + }, + VerificationError::OodEvaluationMismatch, + ); +} + +#[test] +fn field_extension_negative_tests() { + run_negative_field_extension_test( + FieldExtensionOpcode::BBE4DIV, + None, + None, + FieldExtensionPrankValues { + z: Some([F::ZERO; EXT_DEG]), + ..Default::default() + }, + VerificationError::OodEvaluationMismatch, + ); + + run_negative_field_extension_test( + FieldExtensionOpcode::BBE4DIV, + None, + None, + FieldExtensionPrankValues { + divisor_inv: Some([F::ZERO; EXT_DEG]), + ..Default::default() + }, + VerificationError::OodEvaluationMismatch, + ); + + run_negative_field_extension_test( + FieldExtensionOpcode::BBE4MUL, + Some([F::ZERO; EXT_DEG]), + None, + FieldExtensionPrankValues { + z: Some([F::ZERO; EXT_DEG]), + ..Default::default() + }, + VerificationError::ChallengePhaseError, + ); } #[test] diff --git a/extensions/native/circuit/src/fri/mod.rs b/extensions/native/circuit/src/fri/mod.rs index 7dbc3fd851..60e999cde3 100644 --- a/extensions/native/circuit/src/fri/mod.rs +++ b/extensions/native/circuit/src/fri/mod.rs @@ -2,38 +2,41 @@ use core::ops::Deref; use std::{ borrow::{Borrow, BorrowMut}, mem::offset_of, - sync::{Arc, Mutex}, }; -use itertools::{zip_eq, Itertools}; +use itertools::zip_eq; use openvm_circuit::{ arch::{ - ExecutionBridge, ExecutionBus, ExecutionError, ExecutionState, InstructionExecutor, Streams, + execution_mode::{E1ExecutionCtx, E2ExecutionCtx}, + get_record_from_slice, CustomBorrow, E2PreCompute, ExecuteFunc, ExecutionBridge, + ExecutionState, MatrixRecordArena, MultiRowLayout, MultiRowMetadata, NewVmChipWrapper, + RecordArena, Result, SizedRecord, StepExecutorE1, StepExecutorE2, TraceFiller, TraceStep, + VmSegmentState, VmStateMut, }, system::{ memory::{ - offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols}, - MemoryAddress, MemoryAuxColsFactory, MemoryController, OfflineMemory, RecordId, + offline_checker::{ + MemoryBridge, MemoryReadAuxCols, MemoryReadAuxRecord, MemoryWriteAuxCols, + MemoryWriteAuxRecord, + }, + online::TracingMemory, + MemoryAddress, MemoryAuxColsFactory, }, - program::ProgramBus, + native_adapter::util::{memory_read_native, tracing_read_native, tracing_write_native}, }, }; -use openvm_circuit_primitives::utils::next_power_of_two_or_zero; +use openvm_circuit_primitives::AlignedBytesBorrow; use openvm_circuit_primitives_derive::AlignedBorrow; use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode}; use openvm_native_compiler::{conversion::AS, FriOpcode::FRI_REDUCED_OPENING}; use openvm_stark_backend::{ - config::{StarkGenericConfig, Val}, interaction::InteractionBuilder, p3_air::{Air, AirBuilder, BaseAir}, p3_field::{Field, FieldAlgebra, PrimeField32}, p3_matrix::{dense::RowMajorMatrix, Matrix}, p3_maybe_rayon::prelude::*, - prover::types::AirProofInput, rap::{BaseAirWithPublicValues, PartitionedBaseAir}, - AirRef, Chip, ChipUsageGetter, }; -use serde::{Deserialize, Serialize}; use static_assertions::const_assert_eq; use crate::{ @@ -219,8 +222,8 @@ const INSTRUCTION_READS: usize = 5; /// it starts with a Workload row (T1) and ends with either a Disabled or Instruction2 row (T7). /// The other transition constraints then ensure the proper state transitions from Workload to /// Instruction2. -#[derive(Copy, Clone, Debug)] -struct FriReducedOpeningAir { +#[derive(Copy, Clone, Debug, derive_new::new)] +pub struct FriReducedOpeningAir { execution_bridge: ExecutionBridge, memory_bridge: MemoryBridge, } @@ -544,355 +547,742 @@ fn elem_to_ext(elem: F) -> [F; EXT_DEG] { ret } -#[derive(Serialize, Deserialize)] -#[serde(bound = "F: Field")] -pub struct FriReducedOpeningRecord { - pub pc: F, - pub start_timestamp: F, - pub instruction: Instruction, - pub alpha_read: RecordId, - pub length_read: RecordId, - pub a_ptr_read: RecordId, - pub is_init_read: RecordId, - pub b_ptr_read: RecordId, - pub a_rws: Vec, - pub b_reads: Vec, - pub result_write: RecordId, -} - -impl FriReducedOpeningRecord { - pub fn get_height(&self) -> usize { - // 2 for instruction rows - self.a_rws.len() + 2 +#[derive(Copy, Clone, Debug)] +pub struct FriReducedOpeningMetadata { + length: usize, + is_init: bool, +} + +impl MultiRowMetadata for FriReducedOpeningMetadata { + #[inline(always)] + fn get_num_rows(&self) -> usize { + // Allocates `length` workload rows + 1 Instruction1 row + 1 Instruction2 row + self.length + 2 } } -pub struct FriReducedOpeningChip { - air: FriReducedOpeningAir, - pub records: Vec>, - pub height: usize, - offline_memory: Arc>>, - streams: Arc>>, -} -impl FriReducedOpeningChip { - pub fn new( - execution_bus: ExecutionBus, - program_bus: ProgramBus, - memory_bridge: MemoryBridge, - offline_memory: Arc>>, - streams: Arc>>, - ) -> Self { - let air = FriReducedOpeningAir { - execution_bridge: ExecutionBridge::new(execution_bus, program_bus), - memory_bridge, +type FriReducedOpeningLayout = MultiRowLayout; + +// Header of record that is common for all trace rows for an instruction +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug)] +pub struct FriReducedOpeningHeaderRecord { + pub length: u32, + pub is_init: bool, +} + +// Part of record that is common for all trace rows for an instruction +// NOTE: Order for fields is important here to prevent overwriting. +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug)] +pub struct FriReducedOpeningCommonRecord { + pub timestamp: u32, + + pub a_ptr: u32, + + pub b_ptr: u32, + + pub alpha: [F; EXT_DEG], + + pub from_pc: u32, + + pub a_ptr_ptr: F, + pub a_ptr_aux: MemoryReadAuxRecord, + + pub b_ptr_ptr: F, + pub b_ptr_aux: MemoryReadAuxRecord, + + pub length_ptr: F, + pub length_aux: MemoryReadAuxRecord, + + pub alpha_ptr: F, + pub alpha_aux: MemoryReadAuxRecord, + + pub result_ptr: F, + pub result_aux: MemoryWriteAuxRecord, + + pub hint_id_ptr: F, + + pub is_init_ptr: F, + pub is_init_aux: MemoryReadAuxRecord, +} + +// Part of record for each workload row that calculates the partial `result` +// NOTE: Order for fields is important here to prevent overwriting. +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug)] +pub struct FriReducedOpeningWorkloadRowRecord { + pub a: F, + pub a_aux: MemoryReadAuxRecord, + // The result of this workload row + // b can be computed from a, alpha, result, and previous result: + // b = result + a - prev_result * alpha + pub result: [F; EXT_DEG], + pub b_aux: MemoryReadAuxRecord, +} + +// NOTE: Order for fields is important here to prevent overwriting. +#[derive(Debug)] +pub struct FriReducedOpeningRecordMut<'a, F> { + pub header: &'a mut FriReducedOpeningHeaderRecord, + pub workload: &'a mut [FriReducedOpeningWorkloadRowRecord], + // if is_init this will be an empty slice, otherwise it will be the previous data of writing + // `a`s + pub a_write_prev_data: &'a mut [F], + pub common: &'a mut FriReducedOpeningCommonRecord, +} + +impl<'a, F> CustomBorrow<'a, FriReducedOpeningRecordMut<'a, F>, FriReducedOpeningLayout> + for [u8] +{ + fn custom_borrow( + &'a mut self, + layout: FriReducedOpeningLayout, + ) -> FriReducedOpeningRecordMut<'a, F> { + let (header_buf, rest) = + unsafe { self.split_at_mut_unchecked(size_of::()) }; + let header: &mut FriReducedOpeningHeaderRecord = header_buf.borrow_mut(); + + let workload_size = + layout.metadata.length * size_of::>(); + + let (workload_buf, rest) = unsafe { rest.split_at_mut_unchecked(workload_size) }; + let a_prev_size = if layout.metadata.is_init { + 0 + } else { + layout.metadata.length * size_of::() }; + + let (a_prev_buf, common_buf) = unsafe { rest.split_at_mut_unchecked(a_prev_size) }; + + let (_, a_prev_records, _) = unsafe { a_prev_buf.align_to_mut::() }; + let (_, workload_records, _) = + unsafe { workload_buf.align_to_mut::>() }; + + let common: &mut FriReducedOpeningCommonRecord = common_buf.borrow_mut(); + + FriReducedOpeningRecordMut { + header, + workload: &mut workload_records[..layout.metadata.length], + a_write_prev_data: &mut a_prev_records[..], + common, + } + } + + unsafe fn extract_layout(&self) -> FriReducedOpeningLayout { + let header: &FriReducedOpeningHeaderRecord = self.borrow(); + FriReducedOpeningLayout::new(FriReducedOpeningMetadata { + length: header.length as usize, + is_init: header.is_init, + }) + } +} + +impl SizedRecord for FriReducedOpeningRecordMut<'_, F> { + fn size(layout: &FriReducedOpeningLayout) -> usize { + let mut total_len = size_of::(); + total_len += layout.metadata.length * size_of::>(); + total_len += size_of::>(); + total_len + } + + fn alignment(_layout: &FriReducedOpeningLayout) -> usize { + align_of::() + } +} + +pub struct FriReducedOpeningStep { + phantom: std::marker::PhantomData, +} + +impl Default for FriReducedOpeningStep { + fn default() -> Self { + Self::new() + } +} + +impl FriReducedOpeningStep { + pub fn new() -> Self { Self { - records: vec![], - air, - height: 0, - offline_memory, - streams, + phantom: std::marker::PhantomData, } } } -impl InstructionExecutor for FriReducedOpeningChip { - fn execute( + +impl TraceStep for FriReducedOpeningStep +where + F: PrimeField32, +{ + type RecordLayout = FriReducedOpeningLayout; + type RecordMut<'a> = FriReducedOpeningRecordMut<'a, F>; + + fn get_opcode_name(&self, opcode: usize) -> String { + assert_eq!(opcode, FRI_REDUCED_OPENING.global_opcode().as_usize()); + String::from("FRI_REDUCED_OPENING") + } + + fn execute<'buf, RA>( &mut self, - memory: &mut MemoryController, + state: VmStateMut, CTX>, instruction: &Instruction, - from_state: ExecutionState, - ) -> Result, ExecutionError> { + arena: &'buf mut RA, + ) -> Result<()> + where + RA: RecordArena<'buf, Self::RecordLayout, Self::RecordMut<'buf>>, + { let &Instruction { - a: a_ptr_ptr, - b: b_ptr_ptr, - c: length_ptr, - d: alpha_ptr, - e: result_ptr, - f: hint_id_ptr, - g: is_init_ptr, + a, + b, + c, + d, + e, + f, + g, .. } = instruction; - let addr_space = F::from_canonical_u32(AS::Native as u32); - let alpha_read = memory.read(addr_space, alpha_ptr); - let length_read = memory.read_cell(addr_space, length_ptr); - let a_ptr_read = memory.read_cell(addr_space, a_ptr_ptr); - let b_ptr_read = memory.read_cell(addr_space, b_ptr_ptr); - let is_init_read = memory.read_cell(addr_space, is_init_ptr); - let is_init = is_init_read.1.as_canonical_u32(); + let timestamp_start = state.memory.timestamp; + + // Read length from memory to allocate record + let length_ptr = c.as_canonical_u32(); + let [length]: [F; 1] = memory_read_native(&state.memory.data, length_ptr); + let length = length.as_canonical_u32(); + let is_init_ptr = g.as_canonical_u32(); + let [is_init]: [F; 1] = memory_read_native(&state.memory.data, is_init_ptr); + let is_init = is_init != F::ZERO; - let hint_id_f = memory.unsafe_read_cell(addr_space, hint_id_ptr); - let hint_id = hint_id_f.as_canonical_u32() as usize; + let metadata = FriReducedOpeningMetadata { + length: length as usize, + is_init, + }; + let record = arena.alloc(MultiRowLayout::new(metadata)); - let alpha = alpha_read.1; - let length = length_read.1.as_canonical_u32() as usize; - let a_ptr = a_ptr_read.1; - let b_ptr = b_ptr_read.1; + record.common.from_pc = *state.pc; + record.common.timestamp = timestamp_start; - let mut a_rws = Vec::with_capacity(length); - let mut b_reads = Vec::with_capacity(length); - let mut result = [F::ZERO; EXT_DEG]; + let alpha_ptr = d.as_canonical_u32(); + let alpha = tracing_read_native( + state.memory, + alpha_ptr, + &mut record.common.alpha_aux.prev_timestamp, + ); + record.common.alpha_ptr = d; + record.common.alpha = alpha; + + tracing_read_native::( + state.memory, + length_ptr, + &mut record.common.length_aux.prev_timestamp, + ); + record.common.length_ptr = c; + record.header.length = length; + + let a_ptr_ptr = a.as_canonical_u32(); + let [a_ptr]: [F; 1] = tracing_read_native( + state.memory, + a_ptr_ptr, + &mut record.common.a_ptr_aux.prev_timestamp, + ); + record.common.a_ptr_ptr = a; + record.common.a_ptr = a_ptr.as_canonical_u32(); + + let b_ptr_ptr = b.as_canonical_u32(); + let [b_ptr]: [F; 1] = tracing_read_native( + state.memory, + b_ptr_ptr, + &mut record.common.b_ptr_aux.prev_timestamp, + ); + record.common.b_ptr_ptr = b; + record.common.b_ptr = b_ptr.as_canonical_u32(); - let data = if is_init == 0 { - let mut streams = self.streams.lock().unwrap(); - let hint_steam = &mut streams.hint_space[hint_id]; + tracing_read_native::( + state.memory, + is_init_ptr, + &mut record.common.is_init_aux.prev_timestamp, + ); + record.common.is_init_ptr = g; + record.header.is_init = is_init; + + let hint_id_ptr = f.as_canonical_u32(); + let [hint_id]: [F; 1] = memory_read_native(state.memory.data(), hint_id_ptr); + let hint_id = hint_id.as_canonical_u32() as usize; + record.common.hint_id_ptr = f; + + let length = length as usize; + + let data = if !is_init { + let hint_steam = &mut state.streams.hint_space[hint_id]; hint_steam.drain(0..length).collect() } else { vec![] }; + + let mut as_and_bs = Vec::with_capacity(length); #[allow(clippy::needless_range_loop)] for i in 0..length { - let a_rw = if is_init == 0 { - let (record_id, _) = - memory.write_cell(addr_space, a_ptr + F::from_canonical_usize(i), data[i]); - (record_id, data[i]) + let workload_row = &mut record.workload[length - i - 1]; + + let a_ptr_i = record.common.a_ptr + i as u32; + let [a]: [F; 1] = if !is_init { + let mut prev = [F::ZERO; 1]; + tracing_write_native( + state.memory, + a_ptr_i, + [data[i]], + &mut workload_row.a_aux.prev_timestamp, + &mut prev, + ); + record.a_write_prev_data[length - i - 1] = prev[0]; + [data[i]] } else { - memory.read_cell(addr_space, a_ptr + F::from_canonical_usize(i)) + tracing_read_native( + state.memory, + a_ptr_i, + &mut workload_row.a_aux.prev_timestamp, + ) }; - let b_read = - memory.read::(addr_space, b_ptr + F::from_canonical_usize(EXT_DEG * i)); - a_rws.push(a_rw); - b_reads.push(b_read); + let b_ptr_i = record.common.b_ptr + (EXT_DEG * i) as u32; + let b = tracing_read_native::( + state.memory, + b_ptr_i, + &mut workload_row.b_aux.prev_timestamp, + ); + + as_and_bs.push((a, b)); } - for (a_rw, b_read) in a_rws.iter().rev().zip_eq(b_reads.iter().rev()) { - let a = a_rw.1; - let b = b_read.1; + let mut result = [F::ZERO; EXT_DEG]; + for (i, (a, b)) in as_and_bs.into_iter().rev().enumerate() { + let workload_row = &mut record.workload[i]; + // result = result * alpha + (b - a) result = FieldExtension::add( FieldExtension::multiply(result, alpha), FieldExtension::subtract(b, elem_to_ext(a)), ); + workload_row.a = a; + workload_row.result = result; } - let (result_write, _) = memory.write(addr_space, result_ptr, result); - - let record = FriReducedOpeningRecord { - pc: F::from_canonical_u32(from_state.pc), - start_timestamp: F::from_canonical_u32(from_state.timestamp), - instruction: instruction.clone(), - alpha_read: alpha_read.0, - length_read: length_read.0, - a_ptr_read: a_ptr_read.0, - is_init_read: is_init_read.0, - b_ptr_read: b_ptr_read.0, - a_rws: a_rws.into_iter().map(|r| r.0).collect(), - b_reads: b_reads.into_iter().map(|r| r.0).collect(), - result_write, - }; - self.height += record.get_height(); - self.records.push(record); + let result_ptr = e.as_canonical_u32(); + tracing_write_native( + state.memory, + result_ptr, + result, + &mut record.common.result_aux.prev_timestamp, + &mut record.common.result_aux.prev_data, + ); + record.common.result_ptr = e; - Ok(ExecutionState { - pc: from_state.pc + DEFAULT_PC_STEP, - timestamp: memory.timestamp(), - }) - } + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); - fn get_opcode_name(&self, opcode: usize) -> String { - assert_eq!(opcode, FRI_REDUCED_OPENING.global_opcode().as_usize()); - String::from("FRI_REDUCED_OPENING") + Ok(()) } } -fn record_to_rows( - record: FriReducedOpeningRecord, - aux_cols_factory: &MemoryAuxColsFactory, - slice: &mut [F], - memory: &OfflineMemory, -) { - let Instruction { - a: a_ptr_ptr, - b: b_ptr_ptr, - c: length_ptr, - d: alpha_ptr, - e: result_ptr, - f: hint_id_ptr, - g: is_init_ptr, - .. - } = record.instruction; - - let length_read = memory.record_by_id(record.length_read); - let alpha_read = memory.record_by_id(record.alpha_read); - let a_ptr_read = memory.record_by_id(record.a_ptr_read); - let b_ptr_read = memory.record_by_id(record.b_ptr_read); - let is_init_read = memory.record_by_id(record.is_init_read); - let is_init = is_init_read.data_at(0); - let write_a = F::ONE - is_init; - - let length = length_read.data_at(0).as_canonical_u32() as usize; - let alpha: [F; EXT_DEG] = alpha_read.data_slice().try_into().unwrap(); - let a_ptr = a_ptr_read.data_at(0); - let b_ptr = b_ptr_read.data_at(0); +impl TraceFiller for FriReducedOpeningStep +where + F: PrimeField32, +{ + fn fill_trace( + &self, + mem_helper: &MemoryAuxColsFactory, + trace: &mut RowMajorMatrix, + rows_used: usize, + ) { + if rows_used == 0 { + return; + } + debug_assert_eq!(trace.width, OVERALL_WIDTH); - let mut result = [F::ZERO; EXT_DEG]; + let mut remaining_trace = &mut trace.values[..OVERALL_WIDTH * rows_used]; + let mut chunks = Vec::with_capacity(rows_used); + while !remaining_trace.is_empty() { + let header: &FriReducedOpeningHeaderRecord = + unsafe { get_record_from_slice(&mut remaining_trace, ()) }; + let num_rows = header.length as usize + 2; + let chunk_size = OVERALL_WIDTH * num_rows; + let (chunk, rest) = remaining_trace.split_at_mut(chunk_size); + chunks.push((chunk, header.is_init)); + remaining_trace = rest; + } - let alpha_aux = aux_cols_factory.make_read_aux_cols(alpha_read); - let length_aux = aux_cols_factory.make_read_aux_cols(length_read); - let a_ptr_aux = aux_cols_factory.make_read_aux_cols(a_ptr_read); - let b_ptr_aux = aux_cols_factory.make_read_aux_cols(b_ptr_read); - let is_init_aux = aux_cols_factory.make_read_aux_cols(is_init_read); - - let result_aux = aux_cols_factory.make_write_aux_cols(memory.record_by_id(record.result_write)); - - // WorkloadCols - for (i, (&a_record_id, &b_record_id)) in record - .a_rws - .iter() - .rev() - .zip_eq(record.b_reads.iter().rev()) - .enumerate() - { - let a_rw = memory.record_by_id(a_record_id); - let b_read = memory.record_by_id(b_record_id); - let a = a_rw.data_at(0); - let b: [F; EXT_DEG] = b_read.data_slice().try_into().unwrap(); - - let start = i * OVERALL_WIDTH; - let cols: &mut WorkloadCols = slice[start..start + WL_WIDTH].borrow_mut(); - *cols = WorkloadCols { - prefix: PrefixCols { - general: GeneralCols { - is_workload_row: F::ONE, - is_ins_row: F::ZERO, - timestamp: record.start_timestamp + F::from_canonical_usize((length - i) * 2), - }, - a_or_is_first: a, - data: DataCols { - a_ptr: a_ptr + F::from_canonical_usize(length - i), - write_a, - b_ptr: b_ptr + F::from_canonical_usize((length - i) * EXT_DEG), - idx: F::from_canonical_usize(i), - result, - alpha, - }, - }, - // Generate write aux columns no matter `a` is read or written. When `a` is written, - // `prev_data` is not constrained. - a_aux: if a_rw.prev_data_slice().is_some() { - aux_cols_factory.make_write_aux_cols(a_rw) + chunks.into_par_iter().for_each(|(mut chunk, is_init)| { + let num_rows = chunk.len() / OVERALL_WIDTH; + let metadata = FriReducedOpeningMetadata { + length: num_rows - 2, + is_init, + }; + let record: FriReducedOpeningRecordMut = + unsafe { get_record_from_slice(&mut chunk, MultiRowLayout::new(metadata)) }; + + let timestamp = record.common.timestamp; + let length = record.header.length as usize; + let alpha = record.common.alpha; + let is_init = record.header.is_init; + let write_a = F::from_bool(!is_init); + + let a_ptr = record.common.a_ptr; + let b_ptr = record.common.b_ptr; + + let (workload_chunk, rest) = chunk.split_at_mut(length * OVERALL_WIDTH); + let (ins1_chunk, ins2_chunk) = rest.split_at_mut(OVERALL_WIDTH); + + { + // ins2 row + let cols: &mut Instruction2Cols = ins2_chunk[..INS_2_WIDTH].borrow_mut(); + + cols.write_a_x_is_first = F::ZERO; + + mem_helper.fill( + record.common.is_init_aux.prev_timestamp, + timestamp + 4, + cols.is_init_aux.as_mut(), + ); + cols.is_init_ptr = record.common.is_init_ptr; + + cols.hint_id_ptr = record.common.hint_id_ptr; + + cols.result_aux + .set_prev_data(record.common.result_aux.prev_data); + mem_helper.fill( + record.common.result_aux.prev_timestamp, + timestamp + 5 + 2 * length as u32, + cols.result_aux.as_mut(), + ); + cols.result_ptr = record.common.result_ptr; + + mem_helper.fill( + record.common.alpha_aux.prev_timestamp, + timestamp, + cols.alpha_aux.as_mut(), + ); + cols.alpha_ptr = record.common.alpha_ptr; + + mem_helper.fill( + record.common.length_aux.prev_timestamp, + timestamp + 1, + cols.length_aux.as_mut(), + ); + cols.length_ptr = record.common.length_ptr; + + cols.is_first = F::ZERO; + + cols.general.timestamp = F::from_canonical_u32(timestamp); + cols.general.is_ins_row = F::ONE; + cols.general.is_workload_row = F::ZERO; + + ins2_chunk[INS_2_WIDTH..OVERALL_WIDTH].fill(F::ZERO); + } + + { + // ins 1 row + let cols: &mut Instruction1Cols = ins1_chunk[..INS_1_WIDTH].borrow_mut(); + + cols.write_a_x_is_first = write_a; + + mem_helper.fill( + record.common.b_ptr_aux.prev_timestamp, + timestamp + 3, + cols.b_ptr_aux.as_mut(), + ); + cols.b_ptr_ptr = record.common.b_ptr_ptr; + + mem_helper.fill( + record.common.a_ptr_aux.prev_timestamp, + timestamp + 2, + cols.a_ptr_aux.as_mut(), + ); + cols.a_ptr_ptr = record.common.a_ptr_ptr; + + cols.pc = F::from_canonical_u32(record.common.from_pc); + + cols.prefix.data.alpha = alpha; + cols.prefix.data.result = record.workload.last().unwrap().result; + cols.prefix.data.idx = F::from_canonical_usize(length); + cols.prefix.data.b_ptr = F::from_canonical_u32(b_ptr); + cols.prefix.data.write_a = write_a; + cols.prefix.data.a_ptr = F::from_canonical_u32(a_ptr); + + cols.prefix.a_or_is_first = F::ONE; + + cols.prefix.general.timestamp = F::from_canonical_u32(timestamp); + cols.prefix.general.is_ins_row = F::ONE; + cols.prefix.general.is_workload_row = F::ZERO; + ins1_chunk[INS_1_WIDTH..OVERALL_WIDTH].fill(F::ZERO); + } + + // To fill the WorkloadRows we do 2 passes: + // - First, a serial pass to fill some of the records into the trace + // - Then, a parallel pass to fill the rest of the records into the trace + // Note, the first pass is done to avoid overwriting the records + + // Copy of `a_write_prev_data` to avoid overwriting it and to use it in the parallel + // pass + let a_prev_data = if !is_init { + let mut tmp = Vec::with_capacity(length); + tmp.extend_from_slice(record.a_write_prev_data); + tmp } else { - let read_aux = aux_cols_factory.make_read_aux_cols(a_rw); - MemoryWriteAuxCols::from_base(read_aux.get_base(), [F::ZERO]) - }, - b, - b_aux: aux_cols_factory.make_read_aux_cols(b_read), - }; - // result = result * alpha + (b - a) - result = FieldExtension::add( - FieldExtension::multiply(result, alpha), - FieldExtension::subtract(b, elem_to_ext(a)), - ); + vec![] + }; + + for (i, (workload_row, row_chunk)) in record + .workload + .iter() + .zip(workload_chunk.chunks_exact_mut(OVERALL_WIDTH)) + .enumerate() + .rev() + { + let cols: &mut WorkloadCols = row_chunk[..WL_WIDTH].borrow_mut(); + + let timestamp = timestamp + ((length - i) * 2) as u32; + + // fill in reverse order + mem_helper.fill( + workload_row.b_aux.prev_timestamp, + timestamp + 4, + cols.b_aux.as_mut(), + ); + + // We temporarily store the result here + // the correct value of b is computed during the serial pass below + cols.b = record.workload[i].result; + + mem_helper.fill( + workload_row.a_aux.prev_timestamp, + timestamp + 3, + cols.a_aux.as_mut(), + ); + cols.prefix.a_or_is_first = workload_row.a; + + if i > 0 { + cols.prefix.data.result = record.workload[i - 1].result; + } + } + + workload_chunk + .par_chunks_exact_mut(OVERALL_WIDTH) + .enumerate() + .for_each(|(i, row_chunk)| { + let cols: &mut WorkloadCols = row_chunk[..WL_WIDTH].borrow_mut(); + let timestamp = timestamp + ((length - i) * 2) as u32; + if is_init { + cols.a_aux.set_prev_data([F::ZERO; 1]); + } else { + cols.a_aux.set_prev_data([a_prev_data[i]]); + } + + // DataCols + cols.prefix.data.a_ptr = F::from_canonical_u32(a_ptr + (length - i) as u32); + cols.prefix.data.write_a = write_a; + cols.prefix.data.b_ptr = + F::from_canonical_u32(b_ptr + ((length - i) * EXT_DEG) as u32); + cols.prefix.data.idx = F::from_canonical_usize(i); + if i == 0 { + cols.prefix.data.result = [F::ZERO; EXT_DEG]; + } + cols.prefix.data.alpha = alpha; + + // GeneralCols + cols.prefix.general.is_workload_row = F::ONE; + cols.prefix.general.is_ins_row = F::ZERO; + + // WorkloadCols + cols.prefix.general.timestamp = F::from_canonical_u32(timestamp); + + cols.b = FieldExtension::subtract( + FieldExtension::add(cols.b, elem_to_ext(cols.prefix.a_or_is_first)), + FieldExtension::multiply(cols.prefix.data.result, alpha), + ); + row_chunk[WL_WIDTH..OVERALL_WIDTH].fill(F::ZERO); + }); + }); } - // Instruction1Cols - { - let start = length * OVERALL_WIDTH; - let cols: &mut Instruction1Cols = slice[start..start + INS_1_WIDTH].borrow_mut(); - *cols = Instruction1Cols { - prefix: PrefixCols { - general: GeneralCols { - is_workload_row: F::ZERO, - is_ins_row: F::ONE, - timestamp: record.start_timestamp, - }, - a_or_is_first: F::ONE, - data: DataCols { - a_ptr, - write_a, - b_ptr, - idx: F::from_canonical_usize(length), - result, - alpha, - }, - }, - pc: record.pc, +} + +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct FriReducedOpeningPreCompute { + a_ptr_ptr: u32, + b_ptr_ptr: u32, + length_ptr: u32, + alpha_ptr: u32, + result_ptr: u32, + hint_id_ptr: u32, + is_init_ptr: u32, +} + +impl FriReducedOpeningStep { + #[inline(always)] + fn pre_compute_impl( + &self, + _pc: u32, + inst: &Instruction, + data: &mut FriReducedOpeningPreCompute, + ) -> Result<()> { + let &Instruction { + a, + b, + c, + d, + e, + f, + g, + .. + } = inst; + + let a_ptr_ptr = a.as_canonical_u32(); + let b_ptr_ptr = b.as_canonical_u32(); + let length_ptr = c.as_canonical_u32(); + let alpha_ptr = d.as_canonical_u32(); + let result_ptr = e.as_canonical_u32(); + let hint_id_ptr = f.as_canonical_u32(); + let is_init_ptr = g.as_canonical_u32(); + + *data = FriReducedOpeningPreCompute { a_ptr_ptr, - a_ptr_aux, b_ptr_ptr, - b_ptr_aux, - write_a_x_is_first: write_a, - }; - } - // Instruction2Cols - { - let start = (length + 1) * OVERALL_WIDTH; - let cols: &mut Instruction2Cols = slice[start..start + INS_2_WIDTH].borrow_mut(); - *cols = Instruction2Cols { - general: GeneralCols { - is_workload_row: F::ZERO, - is_ins_row: F::ONE, - timestamp: record.start_timestamp, - }, - is_first: F::ZERO, length_ptr, - length_aux, alpha_ptr, - alpha_aux, result_ptr, - result_aux, hint_id_ptr, is_init_ptr, - is_init_aux, - write_a_x_is_first: F::ZERO, }; + + Ok(()) } } -impl ChipUsageGetter for FriReducedOpeningChip { - fn air_name(&self) -> String { - "FriReducedOpeningAir".to_string() +impl StepExecutorE1 for FriReducedOpeningStep +where + F: PrimeField32, +{ + #[inline(always)] + fn pre_compute_size(&self) -> usize { + size_of::() } - fn current_trace_height(&self) -> usize { - self.height - } + #[inline(always)] + fn pre_compute_e1( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result> { + let pre_compute: &mut FriReducedOpeningPreCompute = data.borrow_mut(); - fn trace_width(&self) -> usize { - OVERALL_WIDTH + self.pre_compute_impl(pc, inst, pre_compute)?; + + let fn_ptr = execute_e1_impl; + Ok(fn_ptr) } } -impl Chip for FriReducedOpeningChip> +impl StepExecutorE2 for FriReducedOpeningStep where - Val: PrimeField32, + F: PrimeField32, { - fn air(&self) -> AirRef { - Arc::new(self.air) + #[inline(always)] + fn e2_pre_compute_size(&self) -> usize { + size_of::>() } - fn generate_air_proof_input(self) -> AirProofInput { - let height = next_power_of_two_or_zero(self.height); - let mut flat_trace = Val::::zero_vec(OVERALL_WIDTH * height); - let chunked_trace = { - let sizes: Vec<_> = self - .records - .par_iter() - .map(|record| OVERALL_WIDTH * record.get_height()) - .collect(); - variable_chunks_mut(&mut flat_trace, &sizes) - }; - let memory = self.offline_memory.lock().unwrap(); - let aux_cols_factory = memory.aux_cols_factory(); + #[inline(always)] + fn pre_compute_e2( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result> { + let pre_compute: &mut E2PreCompute = data.borrow_mut(); + pre_compute.chip_idx = chip_idx as u32; - self.records - .into_par_iter() - .zip_eq(chunked_trace.into_par_iter()) - .for_each(|(record, slice)| { - record_to_rows(record, &aux_cols_factory, slice, &memory); - }); + self.pre_compute_impl(pc, inst, &mut pre_compute.data)?; - let matrix = RowMajorMatrix::new(flat_trace, OVERALL_WIDTH); - AirProofInput::simple_no_pis(matrix) + let fn_ptr = execute_e2_impl; + Ok(fn_ptr) } } -fn variable_chunks_mut<'a, T>(mut slice: &'a mut [T], sizes: &[usize]) -> Vec<&'a mut [T]> { - let mut result = Vec::with_capacity(sizes.len()); - for &size in sizes { - // split_at_mut guarantees disjoint slices - let (left, right) = slice.split_at_mut(size); - result.push(left); - slice = right; // move forward for the next chunk +unsafe fn execute_e1_impl( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &FriReducedOpeningPreCompute = pre_compute.borrow(); + execute_e12_impl(pre_compute, vm_state); +} + +unsafe fn execute_e2_impl( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &E2PreCompute = pre_compute.borrow(); + let height = execute_e12_impl(&pre_compute.data, vm_state); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, height); +} + +#[inline(always)] +unsafe fn execute_e12_impl( + pre_compute: &FriReducedOpeningPreCompute, + vm_state: &mut VmSegmentState, +) -> u32 { + let alpha = vm_state.vm_read(AS::Native as u32, pre_compute.alpha_ptr); + + let [length]: [F; 1] = vm_state.vm_read(AS::Native as u32, pre_compute.length_ptr); + let length = length.as_canonical_u32() as usize; + + let [a_ptr]: [F; 1] = vm_state.vm_read(AS::Native as u32, pre_compute.a_ptr_ptr); + let [b_ptr]: [F; 1] = vm_state.vm_read(AS::Native as u32, pre_compute.b_ptr_ptr); + + let [is_init_read]: [F; 1] = vm_state.vm_read(AS::Native as u32, pre_compute.is_init_ptr); + let is_init = is_init_read.as_canonical_u32(); + + let [hint_id_f]: [F; 1] = vm_state.host_read(AS::Native as u32, pre_compute.hint_id_ptr); + let hint_id = hint_id_f.as_canonical_u32() as usize; + + let data = if is_init == 0 { + let hint_steam = &mut vm_state.streams.hint_space[hint_id]; + hint_steam.drain(0..length).collect() + } else { + vec![] + }; + + let mut as_and_bs = Vec::with_capacity(length); + #[allow(clippy::needless_range_loop)] + for i in 0..length { + let a_ptr_i = (a_ptr + F::from_canonical_usize(i)).as_canonical_u32(); + let [a]: [F; 1] = if is_init == 0 { + vm_state.vm_write(AS::Native as u32, a_ptr_i, &[data[i]]); + [data[i]] + } else { + vm_state.vm_read(AS::Native as u32, a_ptr_i) + }; + let b_ptr_i = (b_ptr + F::from_canonical_usize(EXT_DEG * i)).as_canonical_u32(); + let b = vm_state.vm_read(AS::Native as u32, b_ptr_i); + + as_and_bs.push((a, b)); } - result + + let mut result = [F::ZERO; EXT_DEG]; + for (a, b) in as_and_bs.into_iter().rev() { + // result = result * alpha + (b - a) + result = FieldExtension::add( + FieldExtension::multiply(result, alpha), + FieldExtension::subtract(b, elem_to_ext(a)), + ); + } + + vm_state.vm_write(AS::Native as u32, pre_compute.result_ptr, &result); + + vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP); + vm_state.instret += 1; + + length as u32 + 2 } + +pub type FriReducedOpeningChip = + NewVmChipWrapper, MatrixRecordArena>; diff --git a/extensions/native/circuit/src/fri/tests.rs b/extensions/native/circuit/src/fri/tests.rs index 97dcdbc532..905bbe6af2 100644 --- a/extensions/native/circuit/src/fri/tests.rs +++ b/extensions/native/circuit/src/fri/tests.rs @@ -1,22 +1,43 @@ -use std::sync::{Arc, Mutex}; +use std::borrow::BorrowMut; use itertools::Itertools; -use openvm_circuit::arch::{ - testing::{memory::gen_pointer, VmChipTestBuilder}, - Streams, -}; +use openvm_circuit::arch::testing::{memory::gen_pointer, VmChipTestBuilder}; use openvm_instructions::{instruction::Instruction, LocalOpcode}; -use openvm_native_compiler::FriOpcode::FRI_REDUCED_OPENING; +use openvm_native_compiler::{conversion::AS, FriOpcode::FRI_REDUCED_OPENING}; use openvm_stark_backend::{ p3_field::{Field, FieldAlgebra}, + p3_matrix::{ + dense::{DenseMatrix, RowMajorMatrix}, + Matrix, + }, utils::disable_debug_builder, verifier::VerificationError, }; use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; -use rand::Rng; +use rand::{rngs::StdRng, Rng}; + +use super::{ + super::field_extension::FieldExtension, elem_to_ext, FriReducedOpeningAir, + FriReducedOpeningChip, FriReducedOpeningStep, EXT_DEG, +}; +use crate::{ + fri::{WorkloadCols, OVERALL_WIDTH, WL_WIDTH}, + write_native_array, +}; + +const MAX_INS_CAPACITY: usize = 1024; +type F = BabyBear; -use super::{super::field_extension::FieldExtension, elem_to_ext, FriReducedOpeningChip, EXT_DEG}; -use crate::OVERALL_WIDTH; +fn create_test_chip(tester: &VmChipTestBuilder) -> FriReducedOpeningChip { + let mut chip = FriReducedOpeningChip::::new( + FriReducedOpeningAir::new(tester.execution_bridge(), tester.memory_bridge()), + FriReducedOpeningStep::new(), + tester.memory_helper(), + ); + chip.set_trace_buffer_height(MAX_INS_CAPACITY); + + chip +} fn compute_fri_mat_opening( alpha: [F; EXT_DEG], @@ -35,146 +56,115 @@ fn compute_fri_mat_opening( result } -#[test] -fn fri_mat_opening_air_test() { - let num_ops = 14; // non-power-of-2 to also test padding - let elem_range = || 1..=100; - let length_range = || 1..=49; - - let mut tester = VmChipTestBuilder::default(); - - let streams = Arc::new(Mutex::new(Streams::default())); - let mut chip = FriReducedOpeningChip::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.offline_memory_mutex_arc(), - streams.clone(), +fn set_and_execute( + tester: &mut VmChipTestBuilder, + chip: &mut FriReducedOpeningChip, + rng: &mut StdRng, +) { + let len = rng.gen_range(1..=28); + let a_ptr = gen_pointer(rng, len); + let b_ptr = gen_pointer(rng, len); + let a_ptr_ptr = + write_native_array::(tester, rng, Some([F::from_canonical_usize(a_ptr)])).1; + let b_ptr_ptr = + write_native_array::(tester, rng, Some([F::from_canonical_usize(b_ptr)])).1; + + let len_ptr = write_native_array::(tester, rng, Some([F::from_canonical_usize(len)])).1; + let (alpha, alpha_ptr) = write_native_array::(tester, rng, None); + let out_ptr = gen_pointer(rng, EXT_DEG); + let is_init = true; + let is_init_ptr = write_native_array::(tester, rng, Some([F::from_bool(is_init)])).1; + + let mut vec_a = Vec::with_capacity(len); + let mut vec_b = Vec::with_capacity(len); + for i in 0..len { + let a = rng.gen(); + let b: [F; EXT_DEG] = std::array::from_fn(|_| rng.gen()); + vec_a.push(a); + vec_b.push(b); + if !is_init { + tester.streams.hint_space[0].push(a); + } else { + tester.write(AS::Native as usize, a_ptr + i, [a]); + } + tester.write(AS::Native as usize, b_ptr + (EXT_DEG * i), b); + } + + tester.execute( + chip, + &Instruction::from_usize( + FRI_REDUCED_OPENING.global_opcode(), + [ + a_ptr_ptr, + b_ptr_ptr, + len_ptr, + alpha_ptr, + out_ptr, + 0, // hint id, will just use 0 for testing + is_init_ptr, + ], + ), ); - let mut rng = create_seeded_rng(); + let expected_result = compute_fri_mat_opening(alpha, &vec_a, &vec_b); + assert_eq!(expected_result, tester.read(AS::Native as usize, out_ptr)); - macro_rules! gen_ext { - () => { - std::array::from_fn::<_, EXT_DEG, _>(|_| { - BabyBear::from_canonical_u32(rng.gen_range(elem_range())) - }) - }; + for (i, ai) in vec_a.iter().enumerate() { + let [found] = tester.read(AS::Native as usize, a_ptr + i); + assert_eq!(*ai, found); } +} - streams.lock().unwrap().hint_space = vec![vec![]]; - - for _ in 0..num_ops { - let alpha = gen_ext!(); - let length = rng.gen_range(length_range()); - let a = (0..length) - .map(|_| BabyBear::from_canonical_u32(rng.gen_range(elem_range()))) - .collect_vec(); - let b = (0..length).map(|_| gen_ext!()).collect_vec(); - - let result = compute_fri_mat_opening(alpha, &a, &b); - - let alpha_pointer = gen_pointer(&mut rng, 4); - let length_pointer = gen_pointer(&mut rng, 1); - let a_pointer_pointer = gen_pointer(&mut rng, 1); - let b_pointer_pointer = gen_pointer(&mut rng, 1); - let result_pointer = gen_pointer(&mut rng, 4); - let a_pointer = gen_pointer(&mut rng, 1); - let b_pointer = gen_pointer(&mut rng, 4); - let is_init_ptr = gen_pointer(&mut rng, 1); - - let address_space = 4usize; - - /*tracing::debug!( - "{opcode:?} d = {}, e = {}, f = {}, result_addr = {}, addr1 = {}, addr2 = {}, z = {}, x = {}, y = {}", - result_as, as1, as2, result_pointer, address1, address2, result, operand1, operand2, - );*/ - - tester.write(address_space, alpha_pointer, alpha); - tester.write_cell( - address_space, - length_pointer, - BabyBear::from_canonical_usize(length), - ); - tester.write_cell( - address_space, - a_pointer_pointer, - BabyBear::from_canonical_usize(a_pointer), - ); - tester.write_cell( - address_space, - b_pointer_pointer, - BabyBear::from_canonical_usize(b_pointer), - ); - let is_init = rng.gen_range(0..2); - tester.write_cell( - address_space, - is_init_ptr, - BabyBear::from_canonical_u32(is_init), - ); +/////////////////////////////////////////////////////////////////////////////////////// +/// POSITIVE TESTS +/// +/// Randomly generate computations and execute, ensuring that the generated trace +/// passes all constraints. +/////////////////////////////////////////////////////////////////////////////////////// - if is_init == 0 { - streams.lock().unwrap().hint_space[0].extend_from_slice(&a); - } else { - for (i, ai) in a.iter().enumerate() { - tester.write_cell(address_space, a_pointer + i, *ai); - } - } - for (i, bi) in b.iter().enumerate() { - tester.write(address_space, b_pointer + (4 * i), *bi); - } +#[test] +fn fri_mat_opening_air_test() { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default_native(); + let mut chip = create_test_chip(&tester); - tester.execute( - &mut chip, - &Instruction::from_usize( - FRI_REDUCED_OPENING.global_opcode(), - [ - a_pointer_pointer, - b_pointer_pointer, - length_pointer, - alpha_pointer, - result_pointer, - 0, // hint id - is_init_ptr, - ], - ), - ); - assert_eq!(result, tester.read(address_space, result_pointer)); - // Check that `a` was populated. - for (i, ai) in a.iter().enumerate() { - let found = tester.read_cell(address_space, a_pointer + i); - assert_eq!(*ai, found); - } + let num_ops = 28; // non-power-of-2 to also test padding + for _ in 0..num_ops { + set_and_execute(&mut tester, &mut chip, &mut rng); } - let mut tester = tester.build().load(chip).finalize(); + let tester = tester.build().load(chip).finalize(); tester.simple_test().expect("Verification failed"); +} - disable_debug_builder(); - // negative test pranking each value - for height in 0..num_ops { - // TODO: better way to modify existing traces in tester - let trace = tester.air_proof_inputs[2] - .1 - .raw - .common_main - .as_mut() - .unwrap(); - let old_trace = trace.clone(); - for width in 0..OVERALL_WIDTH - /* num operands */ - { - let prank_value = BabyBear::from_canonical_u32(rng.gen_range(1..=100)); - trace.row_mut(height)[width] = prank_value; - } +////////////////////////////////////////////////////////////////////////////////////// +// NEGATIVE TESTS +// +// Given a fake trace of a single operation, setup a chip and run the test. We replace +// part of the trace and check that the chip throws the expected error. +////////////////////////////////////////////////////////////////////////////////////// - // Run a test after pranking each row - assert_eq!( - tester.simple_test().err(), - Some(VerificationError::OodEvaluationMismatch), - "Expected constraint to fail" - ); +#[test] +fn run_negative_fri_mat_opening_test() { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default_native(); + let mut chip = create_test_chip(&tester); - tester.air_proof_inputs[2].1.raw.common_main = Some(old_trace); - } + set_and_execute(&mut tester, &mut chip, &mut rng); + + let modify_trace = |trace: &mut DenseMatrix| { + let mut values = trace.row_slice(0).to_vec(); + let cols: &mut WorkloadCols = values[..WL_WIDTH].borrow_mut(); + + cols.prefix.a_or_is_first = F::from_canonical_u32(42); + + *trace = RowMajorMatrix::new(values, OVERALL_WIDTH); + }; + + disable_debug_builder(); + let tester = tester + .build() + .load_and_prank_trace(chip, modify_trace) + .finalize(); + tester.simple_test_with_expected_error(VerificationError::OodEvaluationMismatch); } diff --git a/extensions/native/circuit/src/jal/mod.rs b/extensions/native/circuit/src/jal/mod.rs deleted file mode 100644 index 28322834a2..0000000000 --- a/extensions/native/circuit/src/jal/mod.rs +++ /dev/null @@ -1,342 +0,0 @@ -use std::{ - borrow::{Borrow, BorrowMut}, - ops::Deref, - sync::{Arc, Mutex}, -}; - -use openvm_circuit::{ - arch::{ExecutionBridge, ExecutionError, ExecutionState, InstructionExecutor, PcIncOrSet}, - system::memory::{ - offline_checker::{MemoryBridge, MemoryWriteAuxCols}, - MemoryAddress, MemoryAuxColsFactory, MemoryController, OfflineMemory, RecordId, - }, -}; -use openvm_circuit_primitives::{ - utils::next_power_of_two_or_zero, - var_range::{ - SharedVariableRangeCheckerChip, VariableRangeCheckerBus, VariableRangeCheckerChip, - }, -}; -use openvm_circuit_primitives_derive::AlignedBorrow; -use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode}; -use openvm_native_compiler::{conversion::AS, NativeJalOpcode, NativeRangeCheckOpcode}; -use openvm_stark_backend::{ - config::{StarkGenericConfig, Val}, - interaction::InteractionBuilder, - p3_air::{Air, AirBuilder, BaseAir}, - p3_field::{Field, FieldAlgebra, PrimeField32}, - p3_matrix::{dense::RowMajorMatrix, Matrix}, - p3_maybe_rayon::prelude::*, - prover::types::AirProofInput, - rap::{BaseAirWithPublicValues, PartitionedBaseAir}, - AirRef, Chip, ChipUsageGetter, -}; -use serde::{Deserialize, Serialize}; -use static_assertions::const_assert_eq; -use AS::Native; - -#[cfg(test)] -mod tests; - -#[repr(C)] -#[derive(AlignedBorrow)] -struct JalRangeCheckCols { - is_jal: T, - is_range_check: T, - a_pointer: T, - state: ExecutionState, - // Write when is_jal, read when is_range_check. - writes_aux: MemoryWriteAuxCols, - b: T, - // Only used by range check. - c: T, - // Only used by range check. - y: T, -} - -const OVERALL_WIDTH: usize = JalRangeCheckCols::::width(); -const_assert_eq!(OVERALL_WIDTH, 12); - -#[derive(Copy, Clone, Debug)] -pub struct JalRangeCheckAir { - execution_bridge: ExecutionBridge, - memory_bridge: MemoryBridge, - range_bus: VariableRangeCheckerBus, -} - -impl BaseAir for JalRangeCheckAir { - fn width(&self) -> usize { - OVERALL_WIDTH - } -} - -impl BaseAirWithPublicValues for JalRangeCheckAir {} -impl PartitionedBaseAir for JalRangeCheckAir {} -impl Air for JalRangeCheckAir -where - AB::F: PrimeField32, -{ - fn eval(&self, builder: &mut AB) { - let main = builder.main(); - let local = main.row_slice(0); - let local_slice = local.deref(); - let local: &JalRangeCheckCols = local_slice.borrow(); - builder.assert_bool(local.is_jal); - builder.assert_bool(local.is_range_check); - let is_valid = local.is_jal + local.is_range_check; - builder.assert_bool(is_valid.clone()); - - let d = AB::Expr::from_canonical_u32(Native as u32); - let a_val = local.writes_aux.prev_data()[0]; - // if is_jal, write pc + DEFAULT_PC_STEP, else if is_range_check, read a_val. - let write_val = local.is_jal - * (local.state.pc + AB::Expr::from_canonical_u32(DEFAULT_PC_STEP)) - + local.is_range_check * a_val; - self.memory_bridge - .write( - MemoryAddress::new(d.clone(), local.a_pointer), - [write_val], - local.state.timestamp, - &local.writes_aux, - ) - .eval(builder, is_valid.clone()); - - let opcode = local.is_jal - * AB::F::from_canonical_usize(NativeJalOpcode::JAL.global_opcode().as_usize()) - + local.is_range_check - * AB::F::from_canonical_usize( - NativeRangeCheckOpcode::RANGE_CHECK - .global_opcode() - .as_usize(), - ); - // Increment pc by b if is_jal, else by DEFAULT_PC_STEP if is_range_check. - let pc_inc = local.is_jal * local.b - + local.is_range_check * AB::F::from_canonical_u32(DEFAULT_PC_STEP); - builder.when(local.is_jal).assert_zero(local.c); - self.execution_bridge - .execute_and_increment_or_set_pc( - opcode, - [local.a_pointer.into(), local.b.into(), local.c.into(), d], - local.state, - AB::F::ONE, - PcIncOrSet::Inc(pc_inc), - ) - .eval(builder, is_valid); - - // Range check specific: - // a_val = x + y * (1 << 16) - let x = a_val - local.y * AB::Expr::from_canonical_u32(1 << 16); - self.range_bus - .send(x.clone(), local.b) - .eval(builder, local.is_range_check); - // Assert y < (1 << c), where c <= 14. - self.range_bus - .send(local.y, local.c) - .eval(builder, local.is_range_check); - } -} - -impl JalRangeCheckAir { - fn new( - execution_bridge: ExecutionBridge, - memory_bridge: MemoryBridge, - range_bus: VariableRangeCheckerBus, - ) -> Self { - Self { - execution_bridge, - memory_bridge, - range_bus, - } - } -} - -#[repr(C)] -#[derive(Serialize, Deserialize)] -pub struct JalRangeCheckRecord { - pub state: ExecutionState, - pub a_rw: RecordId, - pub b: u32, - pub c: u8, - pub is_jal: bool, -} - -/// Chip for JAL and RANGE_CHECK. These opcodes are logically irrelevant. Putting these opcodes into -/// the same chip is just to save columns. -pub struct JalRangeCheckChip { - air: JalRangeCheckAir, - pub records: Vec, - offline_memory: Arc>>, - range_checker_chip: SharedVariableRangeCheckerChip, - /// If true, ignore execution errors. - debug: bool, -} - -impl JalRangeCheckChip { - pub fn new( - execution_bridge: ExecutionBridge, - offline_memory: Arc>>, - range_checker_chip: SharedVariableRangeCheckerChip, - ) -> Self { - let memory_bridge = offline_memory.lock().unwrap().memory_bridge(); - let air = JalRangeCheckAir::new(execution_bridge, memory_bridge, range_checker_chip.bus()); - Self { - air, - records: vec![], - offline_memory, - range_checker_chip, - debug: false, - } - } - pub fn with_debug(mut self) -> Self { - self.debug = true; - self - } -} - -impl InstructionExecutor for JalRangeCheckChip { - fn execute( - &mut self, - memory: &mut MemoryController, - instruction: &Instruction, - from_state: ExecutionState, - ) -> Result, ExecutionError> { - if instruction.opcode == NativeJalOpcode::JAL.global_opcode() { - let (record_id, _) = memory.write( - F::from_canonical_u32(AS::Native as u32), - instruction.a, - [F::from_canonical_u32(from_state.pc + DEFAULT_PC_STEP)], - ); - let b = instruction.b.as_canonical_u32(); - self.records.push(JalRangeCheckRecord { - state: from_state, - a_rw: record_id, - b, - c: 0, - is_jal: true, - }); - return Ok(ExecutionState { - pc: (F::from_canonical_u32(from_state.pc) + instruction.b).as_canonical_u32(), - timestamp: memory.timestamp(), - }); - } else if instruction.opcode == NativeRangeCheckOpcode::RANGE_CHECK.global_opcode() { - let d = F::from_canonical_u32(AS::Native as u32); - // This is a read, but we make the record have prev_data - let a_val = memory.unsafe_read_cell(d, instruction.a); - let (record_id, _) = memory.write(d, instruction.a, [a_val]); - let a_val = a_val.as_canonical_u32(); - let b = instruction.b.as_canonical_u32(); - let c = instruction.c.as_canonical_u32(); - debug_assert!(!self.debug || b <= 16); - debug_assert!(!self.debug || c <= 14); - let x = a_val & ((1 << 16) - 1); - if !self.debug && x >= 1 << b { - return Err(ExecutionError::Fail { pc: from_state.pc }); - } - let y = a_val >> 16; - if !self.debug && y >= 1 << c { - return Err(ExecutionError::Fail { pc: from_state.pc }); - } - self.records.push(JalRangeCheckRecord { - state: from_state, - a_rw: record_id, - b, - c: c as u8, - is_jal: false, - }); - return Ok(ExecutionState { - pc: from_state.pc + DEFAULT_PC_STEP, - timestamp: memory.timestamp(), - }); - } - panic!("Unknown opcode {}", instruction.opcode); - } - - fn get_opcode_name(&self, opcode: usize) -> String { - let jal_opcode = NativeJalOpcode::JAL.global_opcode().as_usize(); - let range_check_opcode = NativeRangeCheckOpcode::RANGE_CHECK - .global_opcode() - .as_usize(); - if opcode == jal_opcode { - return String::from("JAL"); - } - if opcode == range_check_opcode { - return String::from("RANGE_CHECK"); - } - panic!("Unknown opcode {}", opcode); - } -} - -impl ChipUsageGetter for JalRangeCheckChip { - fn air_name(&self) -> String { - "JalRangeCheck".to_string() - } - - fn current_trace_height(&self) -> usize { - self.records.len() - } - - fn trace_width(&self) -> usize { - OVERALL_WIDTH - } -} - -impl Chip for JalRangeCheckChip> -where - Val: PrimeField32, -{ - fn air(&self) -> AirRef { - Arc::new(self.air) - } - fn generate_air_proof_input(self) -> AirProofInput { - let height = next_power_of_two_or_zero(self.records.len()); - let mut flat_trace = Val::::zero_vec(OVERALL_WIDTH * height); - let memory = self.offline_memory.lock().unwrap(); - let aux_cols_factory = memory.aux_cols_factory(); - - self.records - .into_par_iter() - .zip(flat_trace.par_chunks_mut(OVERALL_WIDTH)) - .for_each(|(record, slice)| { - record_to_row( - record, - &aux_cols_factory, - self.range_checker_chip.as_ref(), - slice, - &memory, - ); - }); - - let matrix = RowMajorMatrix::new(flat_trace, OVERALL_WIDTH); - AirProofInput::simple_no_pis(matrix) - } -} - -fn record_to_row( - record: JalRangeCheckRecord, - aux_cols_factory: &MemoryAuxColsFactory, - range_checker_chip: &VariableRangeCheckerChip, - slice: &mut [F], - memory: &OfflineMemory, -) { - let a_record = memory.record_by_id(record.a_rw); - let col: &mut JalRangeCheckCols<_> = slice.borrow_mut(); - col.is_jal = F::from_bool(record.is_jal); - col.is_range_check = F::from_bool(!record.is_jal); - col.a_pointer = a_record.pointer; - col.state = ExecutionState { - pc: F::from_canonical_u32(record.state.pc), - timestamp: F::from_canonical_u32(record.state.timestamp), - }; - aux_cols_factory.generate_write_aux(a_record, &mut col.writes_aux); - col.b = F::from_canonical_u32(record.b); - if !record.is_jal { - let a_val = a_record.data_at(0); - let a_val_u32 = a_val.as_canonical_u32(); - let y = a_val_u32 >> 16; - let x = a_val_u32 & ((1 << 16) - 1); - range_checker_chip.add_count(x, record.b as usize); - range_checker_chip.add_count(y, record.c as usize); - col.c = F::from_canonical_u32(record.c as u32); - col.y = F::from_canonical_u32(y); - } -} diff --git a/extensions/native/circuit/src/jal/tests.rs b/extensions/native/circuit/src/jal/tests.rs deleted file mode 100644 index dd56b73c8f..0000000000 --- a/extensions/native/circuit/src/jal/tests.rs +++ /dev/null @@ -1,198 +0,0 @@ -use std::borrow::BorrowMut; - -use openvm_circuit::arch::{testing::VmChipTestBuilder, ExecutionBridge}; -use openvm_instructions::{ - instruction::Instruction, - program::{DEFAULT_PC_STEP, PC_BITS}, - LocalOpcode, -}; -use openvm_native_compiler::{NativeJalOpcode::*, NativeRangeCheckOpcode::RANGE_CHECK}; -use openvm_stark_backend::{ - p3_field::{FieldAlgebra, PrimeField32}, - utils::disable_debug_builder, - verifier::VerificationError, - Chip, -}; -use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; -use rand::{rngs::StdRng, Rng}; - -use crate::{jal::JalRangeCheckCols, JalRangeCheckChip}; -type F = BabyBear; - -fn set_and_execute( - tester: &mut VmChipTestBuilder, - chip: &mut JalRangeCheckChip, - rng: &mut StdRng, - initial_imm: Option, - initial_pc: Option, -) { - let imm = initial_imm.unwrap_or(rng.gen_range(0..20)); - let a = rng.gen_range(0..32) << 2; - let d = 4usize; - - tester.execute_with_pc( - chip, - &Instruction::from_usize(JAL.global_opcode(), [a, imm as usize, 0, d, 0, 0, 0]), - initial_pc.unwrap_or(rng.gen_range(0..(1 << PC_BITS))), - ); - let initial_pc = tester.execution.last_from_pc().as_canonical_u32(); - let final_pc = tester.execution.last_to_pc().as_canonical_u32(); - - let next_pc = initial_pc + imm; - let rd_data = initial_pc + DEFAULT_PC_STEP; - - assert_eq!(next_pc, final_pc); - assert_eq!(rd_data, tester.read::<1>(d, a)[0].as_canonical_u32()); -} - -struct RangeCheckTestCase { - val: u32, - x_bit: u32, - y_bit: u32, -} - -fn set_and_execute_range_check( - tester: &mut VmChipTestBuilder, - chip: &mut JalRangeCheckChip, - rng: &mut StdRng, - test_cases: Vec, -) { - let a = rng.gen_range(0..32) << 2; - for RangeCheckTestCase { val, x_bit, y_bit } in test_cases { - let d = 4usize; - - tester.write_cell(d, a, F::from_canonical_u32(val)); - tester.execute_with_pc( - chip, - &Instruction::from_usize( - RANGE_CHECK.global_opcode(), - [a, x_bit as usize, y_bit as usize, d, 0, 0, 0], - ), - rng.gen_range(0..(1 << PC_BITS)), - ); - } -} - -fn setup() -> (StdRng, VmChipTestBuilder, JalRangeCheckChip) { - let rng = create_seeded_rng(); - let tester = VmChipTestBuilder::default(); - let execution_bridge = ExecutionBridge::new(tester.execution_bus(), tester.program_bus()); - let offline_memory = tester.offline_memory_mutex_arc(); - let range_checker = tester.range_checker(); - let chip = JalRangeCheckChip::::new(execution_bridge, offline_memory, range_checker); - (rng, tester, chip) -} - -#[test] -fn rand_jal_test() { - let (mut rng, mut tester, mut chip) = setup(); - let num_tests: usize = 100; - for _ in 0..num_tests { - set_and_execute(&mut tester, &mut chip, &mut rng, None, None); - } - - let tester = tester.build().load(chip).finalize(); - tester.simple_test().expect("Verification failed"); -} - -#[test] -fn rand_range_check_test() { - let (mut rng, mut tester, mut chip) = setup(); - let f = |x: u32, y: u32| RangeCheckTestCase { - val: x + y * (1 << 16), - x_bit: 32 - x.leading_zeros(), - y_bit: 32 - y.leading_zeros(), - }; - let mut test_cases: Vec<_> = (0..10) - .map(|_| { - let x = 0; - let y = rng.gen_range(0..1 << 14); - f(x, y) - }) - .collect(); - test_cases.extend((0..10).map(|_| { - let x = rng.gen_range(0..1 << 16); - let y = 0; - f(x, y) - })); - test_cases.extend((0..10).map(|_| { - let x = rng.gen_range(0..1 << 16); - let y = rng.gen_range(0..1 << 14); - f(x, y) - })); - test_cases.push(f((1 << 16) - 1, (1 << 14) - 1)); - set_and_execute_range_check(&mut tester, &mut chip, &mut rng, test_cases); - let tester = tester.build().load(chip).finalize(); - tester.simple_test().expect("Verification failed"); -} - -#[test] -fn negative_range_check_test() { - { - let (mut rng, mut tester, chip) = setup(); - let mut chip = chip.with_debug(); - set_and_execute_range_check( - &mut tester, - &mut chip, - &mut rng, - vec![RangeCheckTestCase { - x_bit: 1, - y_bit: 1, - val: 2, - }], - ); - let tester = tester.build().load(chip).finalize(); - disable_debug_builder(); - let result = tester.simple_test(); - assert!(result.is_err()); - } - { - let (mut rng, mut tester, chip) = setup(); - let mut chip = chip.with_debug(); - set_and_execute_range_check( - &mut tester, - &mut chip, - &mut rng, - vec![RangeCheckTestCase { - x_bit: 1, - y_bit: 0, - val: 1 << 16, - }], - ); - let tester = tester.build().load(chip).finalize(); - disable_debug_builder(); - let result = tester.simple_test(); - assert!(result.is_err()); - } -} - -#[test] -fn negative_jal_test() { - let (mut rng, mut tester, mut chip) = setup(); - set_and_execute(&mut tester, &mut chip, &mut rng, None, None); - - let tester = tester.build(); - - let chip_air = chip.air(); - let mut chip_input = chip.generate_air_proof_input(); - let jal_trace = chip_input.raw.common_main.as_mut().unwrap(); - { - let col: &mut JalRangeCheckCols<_> = jal_trace.row_mut(0).borrow_mut(); - col.b = F::from_canonical_u32(rng.gen_range(1 << 11..1 << 12)); - } - disable_debug_builder(); - let tester = tester - .load_air_proof_input((chip_air, chip_input)) - .finalize(); - let msg = format!( - "Expected verification to fail with {:?}, but it didn't", - VerificationError::ChallengePhaseError - ); - let result = tester.simple_test(); - assert_eq!( - result.err(), - Some(VerificationError::ChallengePhaseError), - "{}", - msg - ); -} diff --git a/extensions/native/circuit/src/jal_rangecheck/mod.rs b/extensions/native/circuit/src/jal_rangecheck/mod.rs new file mode 100644 index 0000000000..678638da67 --- /dev/null +++ b/extensions/native/circuit/src/jal_rangecheck/mod.rs @@ -0,0 +1,515 @@ +use std::{ + borrow::{Borrow, BorrowMut}, + ops::Deref, +}; + +use openvm_circuit::{ + arch::{ + execution_mode::{E1ExecutionCtx, E2ExecutionCtx}, + get_record_from_slice, E2PreCompute, EmptyMultiRowLayout, ExecuteFunc, ExecutionBridge, + ExecutionError::{self, InvalidInstruction}, + ExecutionState, MatrixRecordArena, NewVmChipWrapper, PcIncOrSet, RecordArena, Result, + StepExecutorE1, StepExecutorE2, TraceFiller, TraceStep, VmSegmentState, VmStateMut, + }, + system::{ + memory::{ + offline_checker::{MemoryBridge, MemoryWriteAuxCols, MemoryWriteAuxRecord}, + online::TracingMemory, + MemoryAddress, MemoryAuxColsFactory, + }, + native_adapter::util::{memory_read_native, tracing_write_native}, + }, +}; +use openvm_circuit_primitives::{ + var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus}, + AlignedBytesBorrow, +}; +use openvm_circuit_primitives_derive::AlignedBorrow; +use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode}; +use openvm_native_compiler::{conversion::AS, NativeJalOpcode, NativeRangeCheckOpcode}; +use openvm_stark_backend::{ + interaction::InteractionBuilder, + p3_air::{Air, AirBuilder, BaseAir}, + p3_field::{Field, FieldAlgebra, PrimeField32}, + p3_matrix::Matrix, + rap::{BaseAirWithPublicValues, PartitionedBaseAir}, +}; +use static_assertions::const_assert_eq; +use AS::Native; + +#[cfg(test)] +mod tests; + +#[repr(C)] +#[derive(AlignedBorrow)] +struct JalRangeCheckCols { + is_jal: T, + is_range_check: T, + a_pointer: T, + state: ExecutionState, + // Write when is_jal, read when is_range_check. + writes_aux: MemoryWriteAuxCols, + b: T, + // Only used by range check. + c: T, + // Only used by range check. + y: T, +} + +const OVERALL_WIDTH: usize = JalRangeCheckCols::::width(); +const_assert_eq!(OVERALL_WIDTH, 12); + +#[derive(Copy, Clone, Debug, derive_new::new)] +pub struct JalRangeCheckAir { + execution_bridge: ExecutionBridge, + memory_bridge: MemoryBridge, + range_bus: VariableRangeCheckerBus, +} + +impl BaseAir for JalRangeCheckAir { + fn width(&self) -> usize { + OVERALL_WIDTH + } +} + +impl BaseAirWithPublicValues for JalRangeCheckAir {} +impl PartitionedBaseAir for JalRangeCheckAir {} +impl Air for JalRangeCheckAir +where + AB::F: PrimeField32, +{ + fn eval(&self, builder: &mut AB) { + let main = builder.main(); + let local = main.row_slice(0); + let local_slice = local.deref(); + let local: &JalRangeCheckCols = local_slice.borrow(); + builder.assert_bool(local.is_jal); + builder.assert_bool(local.is_range_check); + let is_valid = local.is_jal + local.is_range_check; + builder.assert_bool(is_valid.clone()); + + let d = AB::Expr::from_canonical_u32(Native as u32); + let a_val = local.writes_aux.prev_data()[0]; + // if is_jal, write pc + DEFAULT_PC_STEP, else if is_range_check, read a_val. + let write_val = local.is_jal + * (local.state.pc + AB::Expr::from_canonical_u32(DEFAULT_PC_STEP)) + + local.is_range_check * a_val; + self.memory_bridge + .write( + MemoryAddress::new(d.clone(), local.a_pointer), + [write_val], + local.state.timestamp, + &local.writes_aux, + ) + .eval(builder, is_valid.clone()); + + let opcode = local.is_jal + * AB::F::from_canonical_usize(NativeJalOpcode::JAL.global_opcode().as_usize()) + + local.is_range_check + * AB::F::from_canonical_usize( + NativeRangeCheckOpcode::RANGE_CHECK + .global_opcode() + .as_usize(), + ); + // Increment pc by b if is_jal, else by DEFAULT_PC_STEP if is_range_check. + let pc_inc = local.is_jal * local.b + + local.is_range_check * AB::F::from_canonical_u32(DEFAULT_PC_STEP); + builder.when(local.is_jal).assert_zero(local.c); + self.execution_bridge + .execute_and_increment_or_set_pc( + opcode, + [local.a_pointer.into(), local.b.into(), local.c.into(), d], + local.state, + AB::F::ONE, + PcIncOrSet::Inc(pc_inc), + ) + .eval(builder, is_valid); + + // Range check specific: + // a_val = x + y * (1 << 16) + let x = a_val - local.y * AB::Expr::from_canonical_u32(1 << 16); + self.range_bus + .send(x.clone(), local.b) + .eval(builder, local.is_range_check); + // Assert y < (1 << c), where c <= 14. + self.range_bus + .send(local.y, local.c) + .eval(builder, local.is_range_check); + } +} + +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug)] +pub struct JalRangeCheckRecord { + pub is_jal: bool, + pub a: F, + pub from_pc: u32, + pub from_timestamp: u32, + pub write: MemoryWriteAuxRecord, + pub b: F, + pub c: F, +} + +/// Chip for JAL and RANGE_CHECK. These opcodes are logically irrelevant. Putting these opcodes into +/// the same chip is just to save columns. +#[derive(derive_new::new)] +pub struct JalRangeCheckStep { + range_checker_chip: SharedVariableRangeCheckerChip, +} + +impl TraceStep for JalRangeCheckStep +where + F: PrimeField32, +{ + type RecordLayout = EmptyMultiRowLayout; + type RecordMut<'a> = &'a mut JalRangeCheckRecord; + + fn get_opcode_name(&self, opcode: usize) -> String { + let jal_opcode = NativeJalOpcode::JAL.global_opcode().as_usize(); + let range_check_opcode = NativeRangeCheckOpcode::RANGE_CHECK + .global_opcode() + .as_usize(); + if opcode == jal_opcode { + return String::from("JAL"); + } + if opcode == range_check_opcode { + return String::from("RANGE_CHECK"); + } + panic!("Unknown opcode {opcode}"); + } + + fn execute<'buf, RA>( + &mut self, + state: VmStateMut, CTX>, + instruction: &Instruction, + arena: &'buf mut RA, + ) -> Result<()> + where + RA: RecordArena<'buf, Self::RecordLayout, Self::RecordMut<'buf>>, + { + let &Instruction { + opcode, a, b, c, .. + } = instruction; + + debug_assert!( + opcode == NativeJalOpcode::JAL.global_opcode() + || opcode == NativeRangeCheckOpcode::RANGE_CHECK.global_opcode() + ); + + let record = arena.alloc(EmptyMultiRowLayout::default()); + + record.from_pc = *state.pc; + record.from_timestamp = state.memory.timestamp; + + record.a = a; + record.b = b; + + if opcode == NativeJalOpcode::JAL.global_opcode() { + record.is_jal = true; + record.c = F::ZERO; + + tracing_write_native( + state.memory, + a.as_canonical_u32(), + [F::from_canonical_u32( + state.pc.wrapping_add(DEFAULT_PC_STEP), + )], + &mut record.write.prev_timestamp, + &mut record.write.prev_data, + ); + *state.pc = (F::from_canonical_u32(*state.pc) + b).as_canonical_u32(); + } else if opcode == NativeRangeCheckOpcode::RANGE_CHECK.global_opcode() { + record.is_jal = false; + record.c = c; + + let a_ptr = a.as_canonical_u32(); + let [a_val]: [F; 1] = memory_read_native(state.memory.data(), a_ptr); + tracing_write_native( + state.memory, + a_ptr, + [a_val], + &mut record.write.prev_timestamp, + &mut record.write.prev_data, + ); + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + } + + Ok(()) + } +} + +impl TraceFiller for JalRangeCheckStep { + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, mut row_slice: &mut [F]) { + let record: &mut JalRangeCheckRecord = + unsafe { get_record_from_slice(&mut row_slice, ()) }; + let cols: &mut JalRangeCheckCols = row_slice.borrow_mut(); + + // Writing in reverse order to avoid overwriting the `record` + if record.is_jal { + cols.y = F::ZERO; + cols.c = F::ZERO; + cols.b = record.b; + cols.writes_aux.set_prev_data(record.write.prev_data); + mem_helper.fill( + record.write.prev_timestamp, + record.from_timestamp, + cols.writes_aux.as_mut(), + ); + cols.state.timestamp = F::from_canonical_u32(record.from_timestamp); + cols.state.pc = F::from_canonical_u32(record.from_pc); + cols.a_pointer = record.a; + cols.is_range_check = F::ZERO; + cols.is_jal = F::ONE; + } else { + let a_val = record.write.prev_data[0].as_canonical_u32(); + let b = record.b.as_canonical_u32(); + let c = record.c.as_canonical_u32(); + let x = a_val & 0xffff; + let y = a_val >> 16; + #[cfg(debug_assertions)] + { + assert!(b <= 16); + assert!(c <= 14); + assert!(x < (1 << b)); + assert!(y < (1 << c)); + } + + self.range_checker_chip.add_count(x, b as usize); + self.range_checker_chip.add_count(y, c as usize); + + cols.y = F::from_canonical_u32(y); + cols.c = record.c; + cols.b = record.b; + cols.writes_aux.set_prev_data(record.write.prev_data); + mem_helper.fill( + record.write.prev_timestamp, + record.from_timestamp, + cols.writes_aux.as_mut(), + ); + cols.state.timestamp = F::from_canonical_u32(record.from_timestamp); + cols.state.pc = F::from_canonical_u32(record.from_pc); + cols.a_pointer = record.a; + cols.is_range_check = F::ONE; + cols.is_jal = F::ZERO; + } + } +} + +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct JalPreCompute { + a: u32, + b: F, + return_pc: F, +} + +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct RangeCheckPreCompute { + a: u32, + b: u8, + c: u8, +} + +impl JalRangeCheckStep { + #[inline(always)] + fn pre_compute_jal_impl( + &self, + pc: u32, + inst: &Instruction, + jal_data: &mut JalPreCompute, + ) -> Result<()> { + let &Instruction { opcode, a, b, .. } = inst; + + if opcode != NativeJalOpcode::JAL.global_opcode() { + return Err(InvalidInstruction(pc)); + } + + let a = a.as_canonical_u32(); + let return_pc = F::from_canonical_u32(pc.wrapping_add(DEFAULT_PC_STEP)); + + *jal_data = JalPreCompute { a, b, return_pc }; + Ok(()) + } + + #[inline(always)] + fn pre_compute_range_check_impl( + &self, + pc: u32, + inst: &Instruction, + range_check_data: &mut RangeCheckPreCompute, + ) -> Result<()> { + let &Instruction { + opcode, a, b, c, .. + } = inst; + + if opcode != NativeRangeCheckOpcode::RANGE_CHECK.global_opcode() { + return Err(InvalidInstruction(pc)); + } + + let a = a.as_canonical_u32(); + let b = b.as_canonical_u32(); + let c = c.as_canonical_u32(); + if b > 16 || c > 14 { + return Err(InvalidInstruction(pc)); + } + + *range_check_data = RangeCheckPreCompute { + a, + b: b as u8, + c: c as u8, + }; + Ok(()) + } +} + +impl StepExecutorE1 for JalRangeCheckStep +where + F: PrimeField32, +{ + #[inline(always)] + fn pre_compute_size(&self) -> usize { + std::cmp::max( + size_of::>(), + size_of::(), + ) + } + + #[inline(always)] + fn pre_compute_e1( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result> { + let &Instruction { opcode, .. } = inst; + + let is_jal = opcode == NativeJalOpcode::JAL.global_opcode(); + + if is_jal { + let jal_data: &mut JalPreCompute = data.borrow_mut(); + self.pre_compute_jal_impl(pc, inst, jal_data)?; + Ok(execute_jal_e1_impl) + } else { + let range_check_data: &mut RangeCheckPreCompute = data.borrow_mut(); + self.pre_compute_range_check_impl(pc, inst, range_check_data)?; + Ok(execute_range_check_e1_impl) + } + } +} + +impl StepExecutorE2 for JalRangeCheckStep +where + F: PrimeField32, +{ + #[inline(always)] + fn e2_pre_compute_size(&self) -> usize { + std::cmp::max( + size_of::>>(), + size_of::>(), + ) + } + + #[inline(always)] + fn pre_compute_e2( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result> { + let &Instruction { opcode, .. } = inst; + + let is_jal = opcode == NativeJalOpcode::JAL.global_opcode(); + + if is_jal { + let pre_compute: &mut E2PreCompute> = data.borrow_mut(); + pre_compute.chip_idx = chip_idx as u32; + + self.pre_compute_jal_impl(pc, inst, &mut pre_compute.data)?; + Ok(execute_jal_e2_impl) + } else { + let pre_compute: &mut E2PreCompute = data.borrow_mut(); + pre_compute.chip_idx = chip_idx as u32; + + self.pre_compute_range_check_impl(pc, inst, &mut pre_compute.data)?; + Ok(execute_range_check_e2_impl) + } + } +} + +unsafe fn execute_jal_e1_impl( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &JalPreCompute = pre_compute.borrow(); + execute_jal_e12_impl(pre_compute, vm_state); +} + +unsafe fn execute_jal_e2_impl( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &E2PreCompute> = pre_compute.borrow(); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, 1); + execute_jal_e12_impl(&pre_compute.data, vm_state); +} + +unsafe fn execute_range_check_e1_impl( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &RangeCheckPreCompute = pre_compute.borrow(); + execute_range_check_e12_impl(pre_compute, vm_state); +} + +unsafe fn execute_range_check_e2_impl( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &E2PreCompute = pre_compute.borrow(); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, 1); + execute_range_check_e12_impl(&pre_compute.data, vm_state); +} + +#[inline(always)] +unsafe fn execute_jal_e12_impl( + pre_compute: &JalPreCompute, + vm_state: &mut VmSegmentState, +) { + vm_state.vm_write(AS::Native as u32, pre_compute.a, &[pre_compute.return_pc]); + // TODO(ayush): better way to do this + vm_state.pc = (F::from_canonical_u32(vm_state.pc) + pre_compute.b).as_canonical_u32(); +} + +#[inline(always)] +unsafe fn execute_range_check_e12_impl( + pre_compute: &RangeCheckPreCompute, + vm_state: &mut VmSegmentState, +) { + let [a_val]: [F; 1] = vm_state.host_read(AS::Native as u32, pre_compute.a); + + vm_state.vm_write(AS::Native as u32, pre_compute.a, &[a_val]); + { + let a_val = a_val.as_canonical_u32(); + let b = pre_compute.b; + let c = pre_compute.c; + let x = a_val & 0xffff; + let y = a_val >> 16; + + // The range of `b`,`c` had already been checked in `pre_compute_e1`. + if !(x < (1 << b) && y < (1 << c)) { + vm_state.exit_code = Err(ExecutionError::Fail { pc: vm_state.pc }); + return; + } + } + vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP); + vm_state.instret += 1; +} + +pub type JalRangeCheckChip = + NewVmChipWrapper>; diff --git a/extensions/native/circuit/src/jal_rangecheck/tests.rs b/extensions/native/circuit/src/jal_rangecheck/tests.rs new file mode 100644 index 0000000000..53b6ff4445 --- /dev/null +++ b/extensions/native/circuit/src/jal_rangecheck/tests.rs @@ -0,0 +1,295 @@ +use std::borrow::BorrowMut; + +use openvm_circuit::arch::testing::{memory::gen_pointer, VmChipTestBuilder}; +use openvm_instructions::{ + instruction::Instruction, + program::{DEFAULT_PC_STEP, PC_BITS}, + LocalOpcode, VmOpcode, +}; +use openvm_native_compiler::{ + conversion::AS, NativeJalOpcode::*, NativeRangeCheckOpcode::RANGE_CHECK, +}; +use openvm_stark_backend::{ + p3_field::{FieldAlgebra, PrimeField32}, + p3_matrix::{ + dense::{DenseMatrix, RowMajorMatrix}, + Matrix, + }, + utils::disable_debug_builder, + verifier::VerificationError, +}; +use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; +use rand::{rngs::StdRng, Rng}; +use test_case::test_case; + +use super::{JalRangeCheckAir, JalRangeCheckStep}; +use crate::{ + jal_rangecheck::{JalRangeCheckChip, JalRangeCheckCols}, + test_utils::write_native_array, +}; + +const MAX_INS_CAPACITY: usize = 128; +type F = BabyBear; + +fn create_test_chip(tester: &VmChipTestBuilder) -> JalRangeCheckChip { + let mut chip = JalRangeCheckChip::::new( + JalRangeCheckAir::new( + tester.execution_bridge(), + tester.memory_bridge(), + tester.range_checker().bus(), + ), + JalRangeCheckStep::new(tester.range_checker().clone()), + tester.memory_helper(), + ); + chip.set_trace_buffer_height(MAX_INS_CAPACITY); + + chip +} + +// `a_val` and `c` will be disregarded if opcode is JAL +fn set_and_execute( + tester: &mut VmChipTestBuilder, + chip: &mut JalRangeCheckChip, + rng: &mut StdRng, + opcode: VmOpcode, + a_val: Option, + b: Option, + c: Option, +) { + if opcode == JAL.global_opcode() { + let initial_pc = rng.gen_range(0..(1 << PC_BITS)); + let a = gen_pointer(rng, 1); + let final_pc = F::from_canonical_u32(rng.gen_range(0..(1 << PC_BITS))); + let b = b.unwrap_or((final_pc - F::from_canonical_u32(initial_pc)).as_canonical_u32()); + tester.execute_with_pc( + chip, + &Instruction::from_usize(opcode, [a, b as usize, 0, AS::Native as usize, 0, 0, 0]), + initial_pc, + ); + + let final_pc = tester.execution.last_to_pc(); + let expected_final_pc = F::from_canonical_u32(initial_pc) + F::from_canonical_u32(b); + assert_eq!(final_pc, expected_final_pc); + let result_a_val = tester.read::<1>(AS::Native as usize, a)[0].as_canonical_u32(); + let expected_a_val = initial_pc + DEFAULT_PC_STEP; + assert_eq!(result_a_val, expected_a_val); + } else { + let a_val = a_val.unwrap_or(rng.gen_range(0..(1 << 30))); + let a = write_native_array(tester, rng, Some([F::from_canonical_u32(a_val)])).1; + let x = a_val & 0xffff; + let y = a_val >> 16; + + let min_b = 32 - x.leading_zeros(); + let min_c = 32 - y.leading_zeros(); + let b = b.unwrap_or(rng.gen_range(min_b..=16)); + let c = c.unwrap_or(rng.gen_range(min_c..=14)); + tester.execute( + chip, + &Instruction::from_usize( + opcode, + [a, b as usize, c as usize, AS::Native as usize, 0, 0, 0], + ), + ); + // There is nothing to assert for range check since it doesn't write to the memory + }; +} + +/////////////////////////////////////////////////////////////////////////////////////// +/// POSITIVE TESTS +/// +/// Randomly generate computations and execute, ensuring that the generated trace +/// passes all constraints. +/////////////////////////////////////////////////////////////////////////////////////// + +#[test_case(JAL.global_opcode(), 100)] +#[test_case(RANGE_CHECK.global_opcode(), 100)] +fn rand_jal_range_check_test(opcode: VmOpcode, num_ops: usize) { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default_native(); + let mut chip = create_test_chip(&tester); + + for _ in 0..num_ops { + set_and_execute(&mut tester, &mut chip, &mut rng, opcode, None, None, None); + } + let tester = tester.build().load(chip).finalize(); + tester.simple_test().expect("Verification failed"); +} + +#[test] +fn range_check_edge_cases_test() { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default_native(); + let mut chip = create_test_chip(&tester); + + set_and_execute( + &mut tester, + &mut chip, + &mut rng, + RANGE_CHECK.global_opcode(), + Some(0), + None, + None, + ); + set_and_execute( + &mut tester, + &mut chip, + &mut rng, + RANGE_CHECK.global_opcode(), + Some((1 << 30) - 1), + None, + None, + ); + + // x = 0 + let a = rng.gen_range(0..(1 << 14)) << 16; + set_and_execute( + &mut tester, + &mut chip, + &mut rng, + RANGE_CHECK.global_opcode(), + Some(a), + None, + None, + ); + + // y = 0 + let a = rng.gen_range(0..(1 << 16)); + set_and_execute( + &mut tester, + &mut chip, + &mut rng, + RANGE_CHECK.global_opcode(), + Some(a), + None, + None, + ); + + let tester = tester.build().load(chip).finalize(); + tester.simple_test().expect("Verification failed"); +} + +////////////////////////////////////////////////////////////////////////////////////// +// NEGATIVE TESTS +// +// Given a fake trace of a single operation, setup a chip and run the test. We replace +// part of the trace and check that the chip throws the expected error. +////////////////////////////////////////////////////////////////////////////////////// + +#[derive(Clone, Copy, Default)] +struct JalRangeCheckPrankValues { + pub flags: Option<[bool; 2]>, + pub a_val: Option, + pub b: Option, + pub c: Option, + pub y: Option, +} + +fn run_negative_jal_range_check_test( + opcode: VmOpcode, + a_val: Option, + b: Option, + c: Option, + prank_vals: JalRangeCheckPrankValues, + error: VerificationError, +) { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default_native(); + let mut chip = create_test_chip(&tester); + set_and_execute(&mut tester, &mut chip, &mut rng, opcode, a_val, b, c); + + let modify_trace = |trace: &mut DenseMatrix| { + let mut values = trace.row_slice(0).to_vec(); + let cols: &mut JalRangeCheckCols = values[..].borrow_mut(); + + if let Some(flags) = prank_vals.flags { + cols.is_jal = F::from_bool(flags[0]); + cols.is_range_check = F::from_bool(flags[1]); + } + if let Some(a_val) = prank_vals.a_val { + cols.writes_aux + .set_prev_data([F::from_canonical_u32(a_val)]); + } + + if let Some(b) = prank_vals.b { + cols.b = F::from_canonical_u32(b); + } + if let Some(c) = prank_vals.c { + cols.c = F::from_canonical_u32(c); + } + if let Some(y) = prank_vals.y { + cols.y = F::from_canonical_u32(y); + } + + *trace = RowMajorMatrix::new(values, trace.width()); + }; + + disable_debug_builder(); + let tester = tester + .build() + .load_and_prank_trace(chip, modify_trace) + .finalize(); + tester.simple_test_with_expected_error(error); +} + +#[test] +fn negative_range_check_test() { + run_negative_jal_range_check_test( + RANGE_CHECK.global_opcode(), + Some(2), + Some(2), + Some(1), + JalRangeCheckPrankValues { + b: Some(1), + ..Default::default() + }, + VerificationError::ChallengePhaseError, + ); + run_negative_jal_range_check_test( + RANGE_CHECK.global_opcode(), + Some(1 << 16), + None, + None, + JalRangeCheckPrankValues { + c: Some(0), + ..Default::default() + }, + VerificationError::ChallengePhaseError, + ); + run_negative_jal_range_check_test( + RANGE_CHECK.global_opcode(), + Some((1 << 30) - 1), + None, + None, + JalRangeCheckPrankValues { + a_val: Some(1 << 30), + ..Default::default() + }, + VerificationError::ChallengePhaseError, + ); + run_negative_jal_range_check_test( + RANGE_CHECK.global_opcode(), + Some(1 << 17), + None, + None, + JalRangeCheckPrankValues { + y: Some(1), + ..Default::default() + }, + VerificationError::ChallengePhaseError, + ); +} + +#[test] +fn negative_jal_test() { + run_negative_jal_range_check_test( + JAL.global_opcode(), + None, + None, + None, + JalRangeCheckPrankValues { + b: Some(0), + ..Default::default() + }, + VerificationError::ChallengePhaseError, + ); +} diff --git a/extensions/native/circuit/src/lib.rs b/extensions/native/circuit/src/lib.rs index 46c6bc890f..78f3cee4c3 100644 --- a/extensions/native/circuit/src/lib.rs +++ b/extensions/native/circuit/src/lib.rs @@ -5,7 +5,7 @@ mod castf; mod field_arithmetic; mod field_extension; mod fri; -mod jal; +mod jal_rangecheck; mod loadstore; mod poseidon2; @@ -14,7 +14,7 @@ pub use castf::*; pub use field_arithmetic::*; pub use field_extension::*; pub use fri::*; -pub use jal::*; +pub use jal_rangecheck::*; pub use loadstore::*; pub use poseidon2::*; @@ -22,4 +22,6 @@ mod extension; pub use extension::*; mod utils; +#[cfg(any(test, feature = "test-utils"))] +pub use utils::test_utils::*; pub use utils::*; diff --git a/extensions/native/circuit/src/loadstore/core.rs b/extensions/native/circuit/src/loadstore/core.rs index 094a57dccc..4860dd4493 100644 --- a/extensions/native/circuit/src/loadstore/core.rs +++ b/extensions/native/circuit/src/loadstore/core.rs @@ -1,27 +1,34 @@ use std::{ array, borrow::{Borrow, BorrowMut}, - sync::{Arc, Mutex, OnceLock}, }; -use openvm_circuit::arch::{ - instructions::LocalOpcode, AdapterAirContext, AdapterRuntimeContext, ExecutionError, Result, - Streams, VmAdapterInterface, VmCoreAir, VmCoreChip, +use openvm_circuit::{ + arch::{ + execution_mode::{E1ExecutionCtx, E2ExecutionCtx}, + get_record_from_slice, + instructions::LocalOpcode, + AdapterAirContext, AdapterTraceFiller, AdapterTraceStep, E2PreCompute, + EmptyAdapterCoreLayout, ExecuteFunc, + ExecutionError::{self, InvalidInstruction}, + RecordArena, Result, StepExecutorE1, StepExecutorE2, TraceFiller, TraceStep, + VmAdapterInterface, VmCoreAir, VmSegmentState, VmStateMut, + }, + system::memory::{online::TracingMemory, MemoryAuxColsFactory}, }; +use openvm_circuit_primitives::AlignedBytesBorrow; use openvm_circuit_primitives_derive::AlignedBorrow; -use openvm_instructions::instruction::Instruction; -use openvm_native_compiler::NativeLoadStoreOpcode; +use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP}; +use openvm_native_compiler::{conversion::AS, NativeLoadStoreOpcode}; use openvm_stark_backend::{ interaction::InteractionBuilder, p3_air::BaseAir, p3_field::{Field, FieldAlgebra, PrimeField32}, rap::BaseAirWithPublicValues, }; -use serde::{Deserialize, Serialize}; -use serde_big_array::BigArray; use strum::IntoEnumIterator; -use super::super::adapters::loadstore_native_adapter::NativeLoadStoreInstruction; +use crate::adapters::NativeLoadStoreInstruction; #[repr(C)] #[derive(AlignedBorrow)] @@ -34,17 +41,7 @@ pub struct NativeLoadStoreCoreCols { pub data: [T; NUM_CELLS], } -#[repr(C)] -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct NativeLoadStoreCoreRecord { - pub opcode: NativeLoadStoreOpcode, - - pub pointer_read: F, - #[serde(with = "BigArray")] - pub data: [F; NUM_CELLS], -} - -#[derive(Clone, Debug)] +#[derive(Clone, Debug, derive_new::new)] pub struct NativeLoadStoreCoreAir { pub offset: usize, } @@ -113,89 +110,323 @@ where } } -pub struct NativeLoadStoreCoreChip { - pub air: NativeLoadStoreCoreAir, - pub streams: OnceLock>>>, +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug)] +pub struct NativeLoadStoreCoreRecord { + pub pointer_read: F, + pub data: [F; NUM_CELLS], + pub local_opcode: u8, } -impl NativeLoadStoreCoreChip { - pub fn new(offset: usize) -> Self { - Self { - air: NativeLoadStoreCoreAir:: { offset }, - streams: OnceLock::new(), - } - } - pub fn set_streams(&mut self, streams: Arc>>) { - self.streams - .set(streams) - .map_err(|_| "streams have already been set.") - .unwrap(); - } +#[derive(Debug)] +pub struct NativeLoadStoreCoreStep { + adapter: A, + offset: usize, } -impl Default for NativeLoadStoreCoreChip { - fn default() -> Self { - Self::new(NativeLoadStoreOpcode::CLASS_OFFSET) +impl NativeLoadStoreCoreStep { + pub fn new(adapter: A, offset: usize) -> Self { + Self { adapter, offset } } } -impl, const NUM_CELLS: usize> VmCoreChip - for NativeLoadStoreCoreChip +impl TraceStep for NativeLoadStoreCoreStep where - I::Reads: Into<(F, [F; NUM_CELLS])>, - I::Writes: From<[F; NUM_CELLS]>, + F: PrimeField32, + A: 'static + + AdapterTraceStep, { - type Record = NativeLoadStoreCoreRecord; - type Air = NativeLoadStoreCoreAir; + type RecordLayout = EmptyAdapterCoreLayout; + type RecordMut<'a> = ( + A::RecordMut<'a>, + &'a mut NativeLoadStoreCoreRecord, + ); - fn execute_instruction( - &self, + fn get_opcode_name(&self, opcode: usize) -> String { + format!( + "{:?}", + NativeLoadStoreOpcode::from_usize(opcode - self.offset) + ) + } + + fn execute<'buf, RA>( + &mut self, + state: VmStateMut, CTX>, instruction: &Instruction, - from_pc: u32, - reads: I::Reads, - ) -> Result<(AdapterRuntimeContext, Self::Record)> { - let Instruction { opcode, .. } = *instruction; - let local_opcode = - NativeLoadStoreOpcode::from_usize(opcode.local_opcode_idx(self.air.offset)); - let (pointer_read, data_read) = reads.into(); - - let data = if local_opcode == NativeLoadStoreOpcode::HINT_STOREW { - let mut streams = self.streams.get().unwrap().lock().unwrap(); - if streams.hint_stream.len() < NUM_CELLS { - return Err(ExecutionError::HintOutOfBounds { pc: from_pc }); + arena: &'buf mut RA, + ) -> Result<()> + where + RA: RecordArena<'buf, Self::RecordLayout, Self::RecordMut<'buf>>, + { + let &Instruction { opcode, .. } = instruction; + + let (mut adapter_record, core_record) = arena.alloc(EmptyAdapterCoreLayout::new()); + + A::start(*state.pc, state.memory, &mut adapter_record); + + let (pointer_read, data_read) = + self.adapter + .read(state.memory, instruction, &mut adapter_record); + + core_record.local_opcode = opcode.local_opcode_idx(self.offset) as u8; + let opcode = NativeLoadStoreOpcode::from_usize(core_record.local_opcode as usize); + + let data = if opcode == NativeLoadStoreOpcode::HINT_STOREW { + if state.streams.hint_stream.len() < NUM_CELLS { + return Err(ExecutionError::HintOutOfBounds { pc: *state.pc }); } - array::from_fn(|_| streams.hint_stream.pop_front().unwrap()) + array::from_fn(|_| state.streams.hint_stream.pop_front().unwrap()) } else { data_read }; - let output = AdapterRuntimeContext::without_pc(data); - let record = NativeLoadStoreCoreRecord { - opcode: NativeLoadStoreOpcode::from_usize(opcode.local_opcode_idx(self.air.offset)), - pointer_read, - data, + self.adapter + .write(state.memory, instruction, data, &mut adapter_record); + + core_record.pointer_read = pointer_read; + core_record.data = data; + + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + + Ok(()) + } +} + +impl TraceFiller + for NativeLoadStoreCoreStep +where + F: PrimeField32, + A: 'static + AdapterTraceFiller, +{ + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + self.adapter.fill_trace_row(mem_helper, adapter_row); + + let record: &NativeLoadStoreCoreRecord = + unsafe { get_record_from_slice(&mut core_row, ()) }; + let core_row: &mut NativeLoadStoreCoreCols = core_row.borrow_mut(); + + let opcode = NativeLoadStoreOpcode::from_usize(record.local_opcode as usize); + + // Writing in reverse order to avoid overwriting the `record` + core_row.data = record.data; + core_row.pointer_read = record.pointer_read; + core_row.is_hint_storew = F::from_bool(opcode == NativeLoadStoreOpcode::HINT_STOREW); + core_row.is_storew = F::from_bool(opcode == NativeLoadStoreOpcode::STOREW); + core_row.is_loadw = F::from_bool(opcode == NativeLoadStoreOpcode::LOADW); + } +} + +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct NativeLoadStorePreCompute { + a: u32, + b: F, + c: u32, +} + +impl NativeLoadStoreCoreStep { + #[inline(always)] + fn pre_compute_impl( + &self, + pc: u32, + inst: &Instruction, + data: &mut NativeLoadStorePreCompute, + ) -> Result { + let &Instruction { + opcode, + a, + b, + c, + d, + e, + .. + } = inst; + + let local_opcode = NativeLoadStoreOpcode::from_usize(opcode.local_opcode_idx(self.offset)); + + let a = a.as_canonical_u32(); + let c = c.as_canonical_u32(); + let d = d.as_canonical_u32(); + let e = e.as_canonical_u32(); + + if d != AS::Native as u32 || e != AS::Native as u32 { + return Err(InvalidInstruction(pc)); + } + + *data = NativeLoadStorePreCompute { a, b, c }; + + Ok(local_opcode) + } +} + +impl StepExecutorE1 for NativeLoadStoreCoreStep +where + F: PrimeField32, +{ + #[inline(always)] + fn pre_compute_size(&self) -> usize { + size_of::>() + } + + #[inline(always)] + fn pre_compute_e1( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result> { + let pre_compute: &mut NativeLoadStorePreCompute = data.borrow_mut(); + + let local_opcode = self.pre_compute_impl(pc, inst, pre_compute)?; + + let fn_ptr = match local_opcode { + NativeLoadStoreOpcode::LOADW => execute_e1_loadw::, + NativeLoadStoreOpcode::STOREW => execute_e1_storew::, + NativeLoadStoreOpcode::HINT_STOREW => execute_e1_hint_storew::, }; - Ok((output, record)) + + Ok(fn_ptr) } +} - fn get_opcode_name(&self, opcode: usize) -> String { - format!( - "{:?}", - NativeLoadStoreOpcode::from_usize(opcode - self.air.offset) - ) +impl StepExecutorE2 for NativeLoadStoreCoreStep +where + F: PrimeField32, +{ + #[inline(always)] + fn e2_pre_compute_size(&self) -> usize { + size_of::>>() } - fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) { - let cols: &mut NativeLoadStoreCoreCols<_, NUM_CELLS> = row_slice.borrow_mut(); - cols.is_loadw = F::from_bool(record.opcode == NativeLoadStoreOpcode::LOADW); - cols.is_storew = F::from_bool(record.opcode == NativeLoadStoreOpcode::STOREW); - cols.is_hint_storew = F::from_bool(record.opcode == NativeLoadStoreOpcode::HINT_STOREW); + #[inline(always)] + fn pre_compute_e2( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result> { + let pre_compute: &mut E2PreCompute> = data.borrow_mut(); + pre_compute.chip_idx = chip_idx as u32; + + let local_opcode = self.pre_compute_impl(pc, inst, &mut pre_compute.data)?; - cols.pointer_read = record.pointer_read; - cols.data = record.data; + let fn_ptr = match local_opcode { + NativeLoadStoreOpcode::LOADW => execute_e2_loadw::, + NativeLoadStoreOpcode::STOREW => execute_e2_storew::, + NativeLoadStoreOpcode::HINT_STOREW => execute_e2_hint_storew::, + }; + + Ok(fn_ptr) } +} - fn air(&self) -> &Self::Air { - &self.air +unsafe fn execute_e1_loadw( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &NativeLoadStorePreCompute = pre_compute.borrow(); + execute_e12_loadw::<_, _, NUM_CELLS>(pre_compute, vm_state); +} + +unsafe fn execute_e1_storew( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &NativeLoadStorePreCompute = pre_compute.borrow(); + execute_e12_storew::<_, _, NUM_CELLS>(pre_compute, vm_state); +} + +unsafe fn execute_e1_hint_storew( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &NativeLoadStorePreCompute = pre_compute.borrow(); + execute_e12_hint_storew::<_, _, NUM_CELLS>(pre_compute, vm_state); +} + +unsafe fn execute_e2_loadw( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &E2PreCompute> = pre_compute.borrow(); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, 1); + execute_e12_loadw::<_, _, NUM_CELLS>(&pre_compute.data, vm_state); +} + +unsafe fn execute_e2_storew( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &E2PreCompute> = pre_compute.borrow(); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, 1); + execute_e12_storew::<_, _, NUM_CELLS>(&pre_compute.data, vm_state); +} + +unsafe fn execute_e2_hint_storew( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &E2PreCompute> = pre_compute.borrow(); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, 1); + execute_e12_hint_storew::<_, _, NUM_CELLS>(&pre_compute.data, vm_state); +} + +#[inline(always)] +unsafe fn execute_e12_loadw( + pre_compute: &NativeLoadStorePreCompute, + vm_state: &mut VmSegmentState, +) { + let [read_cell]: [F; 1] = vm_state.vm_read(AS::Native as u32, pre_compute.c); + + let data_read_ptr = (read_cell + pre_compute.b).as_canonical_u32(); + let data_read: [F; NUM_CELLS] = vm_state.vm_read(AS::Native as u32, data_read_ptr); + + vm_state.vm_write(AS::Native as u32, pre_compute.a, &data_read); + + vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP); + vm_state.instret += 1; +} + +#[inline(always)] +unsafe fn execute_e12_storew( + pre_compute: &NativeLoadStorePreCompute, + vm_state: &mut VmSegmentState, +) { + let [read_cell]: [F; 1] = vm_state.vm_read(AS::Native as u32, pre_compute.c); + let data_read: [F; NUM_CELLS] = vm_state.vm_read(AS::Native as u32, pre_compute.a); + + let data_write_ptr = (read_cell + pre_compute.b).as_canonical_u32(); + vm_state.vm_write(AS::Native as u32, data_write_ptr, &data_read); + + vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP); + vm_state.instret += 1; +} + +#[inline(always)] +unsafe fn execute_e12_hint_storew( + pre_compute: &NativeLoadStorePreCompute, + vm_state: &mut VmSegmentState, +) { + let [read_cell]: [F; 1] = vm_state.vm_read(AS::Native as u32, pre_compute.c); + + if vm_state.streams.hint_stream.len() < NUM_CELLS { + vm_state.exit_code = Err(ExecutionError::HintOutOfBounds { pc: vm_state.pc }); + return; } + let data: [F; NUM_CELLS] = + array::from_fn(|_| vm_state.streams.hint_stream.pop_front().unwrap()); + + let data_write_ptr = (read_cell + pre_compute.b).as_canonical_u32(); + vm_state.vm_write(AS::Native as u32, data_write_ptr, &data); + + vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP); + vm_state.instret += 1; } diff --git a/extensions/native/circuit/src/loadstore/mod.rs b/extensions/native/circuit/src/loadstore/mod.rs index 3dd51113a9..c075e0e0f2 100644 --- a/extensions/native/circuit/src/loadstore/mod.rs +++ b/extensions/native/circuit/src/loadstore/mod.rs @@ -1,19 +1,20 @@ -use openvm_circuit::arch::{VmAirWrapper, VmChipWrapper}; +use openvm_circuit::arch::{MatrixRecordArena, NewVmChipWrapper, VmAirWrapper}; -#[cfg(test)] -mod tests; +use crate::adapters::{NativeLoadStoreAdapterAir, NativeLoadStoreAdapterStep}; mod core; pub use core::*; -use super::adapters::loadstore_native_adapter::{ - NativeLoadStoreAdapterAir, NativeLoadStoreAdapterChip, -}; +#[cfg(test)] +mod tests; pub type NativeLoadStoreAir = VmAirWrapper, NativeLoadStoreCoreAir>; -pub type NativeLoadStoreChip = VmChipWrapper< +pub type NativeLoadStoreStep = + NativeLoadStoreCoreStep, NUM_CELLS>; +pub type NativeLoadStoreChip = NewVmChipWrapper< F, - NativeLoadStoreAdapterChip, - NativeLoadStoreCoreChip, + NativeLoadStoreAir, + NativeLoadStoreStep, + MatrixRecordArena, >; diff --git a/extensions/native/circuit/src/loadstore/tests.rs b/extensions/native/circuit/src/loadstore/tests.rs index cd653c2fc0..9bd8000441 100644 --- a/extensions/native/circuit/src/loadstore/tests.rs +++ b/extensions/native/circuit/src/loadstore/tests.rs @@ -1,175 +1,234 @@ -use std::sync::{Arc, Mutex}; +use std::{array, borrow::BorrowMut}; -use openvm_circuit::arch::{testing::VmChipTestBuilder, Streams}; +use openvm_circuit::arch::testing::{memory::gen_pointer, VmChipTestBuilder}; use openvm_instructions::{instruction::Instruction, LocalOpcode}; -use openvm_native_compiler::NativeLoadStoreOpcode::{self, *}; -use openvm_stark_backend::p3_field::{FieldAlgebra, PrimeField32}; -use openvm_stark_sdk::{config::setup_tracing, p3_baby_bear::BabyBear, utils::create_seeded_rng}; +use openvm_native_compiler::{ + conversion::AS, + NativeLoadStoreOpcode::{self, *}, +}; +use openvm_stark_backend::{ + p3_air::BaseAir, + p3_field::{FieldAlgebra, PrimeField32}, + p3_matrix::{ + dense::{DenseMatrix, RowMajorMatrix}, + Matrix, + }, + utils::disable_debug_builder, + verifier::VerificationError, +}; +use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; use rand::{rngs::StdRng, Rng}; +use test_case::test_case; -use super::{ - super::adapters::loadstore_native_adapter::NativeLoadStoreAdapterChip, NativeLoadStoreChip, - NativeLoadStoreCoreChip, +use super::{NativeLoadStoreChip, NativeLoadStoreCoreAir}; +use crate::{ + adapters::{NativeLoadStoreAdapterAir, NativeLoadStoreAdapterCols, NativeLoadStoreAdapterStep}, + test_utils::write_native_array, + NativeLoadStoreAir, NativeLoadStoreCoreCols, NativeLoadStoreStep, }; +const MAX_INS_CAPACITY: usize = 128; +const NUM_CELLS: usize = 1; type F = BabyBear; -#[derive(Debug)] -struct TestData { - a: F, - b: F, - c: F, - d: F, - e: F, - ad_val: F, - cd_val: F, - data_val: F, - is_load: bool, - is_hint: bool, +fn create_test_chip(tester: &VmChipTestBuilder) -> NativeLoadStoreChip { + let mut chip = NativeLoadStoreChip::::new( + NativeLoadStoreAir::new( + NativeLoadStoreAdapterAir::new(tester.memory_bridge(), tester.execution_bridge()), + NativeLoadStoreCoreAir::new(NativeLoadStoreOpcode::CLASS_OFFSET), + ), + NativeLoadStoreStep::new( + NativeLoadStoreAdapterStep::new(NativeLoadStoreOpcode::CLASS_OFFSET), + NativeLoadStoreOpcode::CLASS_OFFSET, + ), + tester.memory_helper(), + ); + chip.set_trace_buffer_height(MAX_INS_CAPACITY); + + chip } -fn setup() -> (StdRng, VmChipTestBuilder, NativeLoadStoreChip) { - let rng = create_seeded_rng(); - let tester = VmChipTestBuilder::default(); +fn set_and_execute( + tester: &mut VmChipTestBuilder, + chip: &mut NativeLoadStoreChip, + rng: &mut StdRng, + opcode: NativeLoadStoreOpcode, +) { + let a = gen_pointer(rng, NUM_CELLS); + let ([c_val], c) = write_native_array(tester, rng, None); + + let mem_ptr = gen_pointer(rng, NUM_CELLS); + let b = F::from_canonical_usize(mem_ptr) - c_val; + let data: [F; NUM_CELLS] = array::from_fn(|_| rng.gen()); + + match opcode { + LOADW => { + tester.write(AS::Native as usize, mem_ptr, data); + } + STOREW => { + tester.write(AS::Native as usize, a, data); + } + HINT_STOREW => { + tester.streams.hint_stream.extend(data); + } + } - let adapter = NativeLoadStoreAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - NativeLoadStoreOpcode::CLASS_OFFSET, + tester.execute( + chip, + &Instruction::from_usize( + opcode.global_opcode(), + [ + a, + b.as_canonical_u32() as usize, + c, + AS::Native as usize, + AS::Native as usize, + ], + ), ); - let mut inner = NativeLoadStoreCoreChip::new(NativeLoadStoreOpcode::CLASS_OFFSET); - inner.set_streams(Arc::new(Mutex::new(Streams::default()))); - let chip = NativeLoadStoreChip::::new(adapter, inner, tester.offline_memory_mutex_arc()); - (rng, tester, chip) + + let result = match opcode { + STOREW | HINT_STOREW => tester.read(AS::Native as usize, mem_ptr), + LOADW => tester.read(AS::Native as usize, a), + }; + assert_eq!(result, data); } -fn gen_test_data(rng: &mut StdRng, opcode: NativeLoadStoreOpcode) -> TestData { - let is_load = matches!(opcode, NativeLoadStoreOpcode::LOADW); - - let a = rng.gen_range(0..1 << 20); - let b = rng.gen_range(0..1 << 20); - let c = rng.gen_range(0..1 << 20); - let d = F::from_canonical_u32(4u32); - let e = F::from_canonical_u32(4u32); - - TestData { - a: F::from_canonical_u32(a), - b: F::from_canonical_u32(b), - c: F::from_canonical_u32(c), - d, - e, - ad_val: F::from_canonical_u32(111), - cd_val: F::from_canonical_u32(222), - data_val: F::from_canonical_u32(444), - is_load, - is_hint: matches!(opcode, NativeLoadStoreOpcode::HINT_STOREW), +/////////////////////////////////////////////////////////////////////////////////////// +/// POSITIVE TESTS +/// +/// Randomly generate computations and execute, ensuring that the generated trace +/// passes all constraints. +/////////////////////////////////////////////////////////////////////////////////////// + +#[test_case(STOREW, 100)] +#[test_case(HINT_STOREW, 100)] +#[test_case(LOADW, 100)] +fn rand_native_loadstore_test(opcode: NativeLoadStoreOpcode, num_ops: usize) { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default_native(); + let mut chip = create_test_chip(&tester); + + for _ in 0..num_ops { + set_and_execute(&mut tester, &mut chip, &mut rng, opcode); } + let tester = tester.build().load(chip).finalize(); + tester.simple_test().expect("Verification failed"); } -fn get_data_pointer(data: &TestData) -> F { - if data.d != F::ZERO { - data.cd_val + data.b - } else { - data.c + data.b - } +////////////////////////////////////////////////////////////////////////////////////// +// NEGATIVE TESTS +// +// Given a fake trace of a single operation, setup a chip and run the test. We replace +// part of the trace and check that the chip throws the expected error. +////////////////////////////////////////////////////////////////////////////////////// + +#[derive(Clone, Copy, Default)] +struct NativeLoadStorePrankValues { + // Core cols + pub data: Option<[F; NUM_CELLS]>, + pub opcode_flags: Option<[bool; 3]>, + pub pointer_read: Option, + // Adapter cols + pub data_write_pointer: Option, } -fn set_values( - tester: &mut VmChipTestBuilder, - chip: &mut NativeLoadStoreChip, - data: &TestData, +fn run_negative_native_loadstore_test( + opcode: NativeLoadStoreOpcode, + prank_vals: NativeLoadStorePrankValues, + error: VerificationError, ) { - if data.d != F::ZERO { - tester.write( - data.d.as_canonical_u32() as usize, - data.a.as_canonical_u32() as usize, - [data.ad_val], - ); - tester.write( - data.d.as_canonical_u32() as usize, - data.c.as_canonical_u32() as usize, - [data.cd_val], - ); - } - if data.is_load { - let data_pointer = get_data_pointer(data); - tester.write( - data.e.as_canonical_u32() as usize, - data_pointer.as_canonical_u32() as usize, - [data.data_val], - ); - } - if data.is_hint { - for _ in 0..data.e.as_canonical_u32() { - chip.core - .streams - .get() - .unwrap() - .lock() - .unwrap() - .hint_stream - .push_back(data.data_val); - } - } -} + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default_native(); + let mut chip = create_test_chip(&tester); -fn check_values(tester: &mut VmChipTestBuilder, data: &TestData) { - let data_pointer = get_data_pointer(data); - - let written_data_val = if data.is_load { - tester.read::<1>( - data.d.as_canonical_u32() as usize, - data.a.as_canonical_u32() as usize, - )[0] - } else { - tester.read::<1>( - data.e.as_canonical_u32() as usize, - data_pointer.as_canonical_u32() as usize, - )[0] - }; + set_and_execute(&mut tester, &mut chip, &mut rng, opcode); - let correct_data_val = if data.is_load || data.is_hint { - data.data_val - } else if data.d != F::ZERO { - data.ad_val - } else { - data.a + let adapter_width = BaseAir::::width(&chip.air.adapter); + let modify_trace = |trace: &mut DenseMatrix| { + let mut values = trace.row_slice(0).to_vec(); + let (adapter_row, core_row) = values.split_at_mut(adapter_width); + let adapter_cols: &mut NativeLoadStoreAdapterCols = adapter_row.borrow_mut(); + let core_cols: &mut NativeLoadStoreCoreCols = core_row.borrow_mut(); + + if let Some(data) = prank_vals.data { + core_cols.data = data; + } + if let Some(pointer_read) = prank_vals.pointer_read { + core_cols.pointer_read = pointer_read; + } + if let Some(opcode_flags) = prank_vals.opcode_flags { + [ + core_cols.is_loadw, + core_cols.is_storew, + core_cols.is_hint_storew, + ] = opcode_flags.map(F::from_bool); + } + if let Some(data_write_pointer) = prank_vals.data_write_pointer { + adapter_cols.data_write_pointer = data_write_pointer; + } + + *trace = RowMajorMatrix::new(values, trace.width()); }; - assert_eq!(written_data_val, correct_data_val, "{:?}", data); + disable_debug_builder(); + let tester = tester + .build() + .load_and_prank_trace(chip, modify_trace) + .finalize(); + tester.simple_test_with_expected_error(error); } -fn set_and_execute( - tester: &mut VmChipTestBuilder, - chip: &mut NativeLoadStoreChip, - rng: &mut StdRng, - opcode: NativeLoadStoreOpcode, -) { - let data = gen_test_data(rng, opcode); - set_values(tester, chip, &data); +#[test] +fn negative_native_loadstore_tests() { + run_negative_native_loadstore_test( + STOREW, + NativeLoadStorePrankValues { + data_write_pointer: Some(F::ZERO), + ..Default::default() + }, + VerificationError::OodEvaluationMismatch, + ); - tester.execute_with_pc( - chip, - &Instruction::from_usize( - opcode.global_opcode(), - [data.a, data.b, data.c, data.d, data.e].map(|x| x.as_canonical_u32() as usize), - ), - 0u32, + run_negative_native_loadstore_test( + LOADW, + NativeLoadStorePrankValues { + data_write_pointer: Some(F::ZERO), + ..Default::default() + }, + VerificationError::OodEvaluationMismatch, ); +} - check_values(tester, &data); +#[test] +fn invalid_flags_native_loadstore_tests() { + run_negative_native_loadstore_test( + HINT_STOREW, + NativeLoadStorePrankValues { + opcode_flags: Some([false, false, false]), + ..Default::default() + }, + VerificationError::ChallengePhaseError, + ); + + run_negative_native_loadstore_test( + LOADW, + NativeLoadStorePrankValues { + opcode_flags: Some([false, false, true]), + ..Default::default() + }, + VerificationError::OodEvaluationMismatch, + ); } #[test] -fn rand_native_loadstore_test() { - setup_tracing(); - let (mut rng, mut tester, mut chip) = setup(); - for _ in 0..20 { - set_and_execute(&mut tester, &mut chip, &mut rng, STOREW); - set_and_execute(&mut tester, &mut chip, &mut rng, HINT_STOREW); - set_and_execute(&mut tester, &mut chip, &mut rng, LOADW); - } - let tester = tester.build().load(chip).finalize(); - tester.simple_test().expect("Verification failed"); +fn invalid_data_native_loadstore_tests() { + run_negative_native_loadstore_test( + LOADW, + NativeLoadStorePrankValues { + data: Some([F::ZERO; NUM_CELLS]), + ..Default::default() + }, + VerificationError::ChallengePhaseError, + ); } diff --git a/extensions/native/circuit/src/poseidon2/air.rs b/extensions/native/circuit/src/poseidon2/air.rs index 5ed28abd60..9d24966fe7 100644 --- a/extensions/native/circuit/src/poseidon2/air.rs +++ b/extensions/native/circuit/src/poseidon2/air.rs @@ -20,15 +20,13 @@ use openvm_stark_backend::{ rap::{BaseAirWithPublicValues, PartitionedBaseAir}, }; -use crate::{ +use crate::poseidon2::{ chip::{NUM_INITIAL_READS, NUM_SIMPLE_ACCESSES}, - poseidon2::{ - columns::{ - InsideRowSpecificCols, NativePoseidon2Cols, SimplePoseidonSpecificCols, - TopLevelSpecificCols, - }, - CHUNK, + columns::{ + InsideRowSpecificCols, NativePoseidon2Cols, SimplePoseidonSpecificCols, + TopLevelSpecificCols, }, + CHUNK, }; #[derive(Clone, Debug)] diff --git a/extensions/native/circuit/src/poseidon2/chip.rs b/extensions/native/circuit/src/poseidon2/chip.rs index 426b089a9c..b652d42256 100644 --- a/extensions/native/circuit/src/poseidon2/chip.rs +++ b/extensions/native/circuit/src/poseidon2/chip.rs @@ -1,189 +1,152 @@ -use std::sync::{Arc, Mutex}; +use std::borrow::{Borrow, BorrowMut}; use openvm_circuit::{ arch::{ - ExecutionBridge, ExecutionError, ExecutionState, InstructionExecutor, Streams, SystemPort, + execution_mode::{E1ExecutionCtx, E2ExecutionCtx}, + CustomBorrow, E2PreCompute, ExecuteFunc, + ExecutionError::InvalidInstruction, + MultiRowLayout, MultiRowMetadata, RecordArena, Result, SizedRecord, StepExecutorE1, + StepExecutorE2, TraceFiller, TraceStep, VmSegmentState, VmStateMut, + }, + system::{ + memory::{offline_checker::MemoryBaseAuxCols, online::TracingMemory, MemoryAuxColsFactory}, + native_adapter::util::{ + memory_read_native, tracing_read_native, tracing_write_native_inplace, + }, }, - system::memory::{MemoryController, OfflineMemory, RecordId}, }; +use openvm_circuit_primitives::AlignedBytesBorrow; use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode}; use openvm_native_compiler::{ conversion::AS, Poseidon2Opcode::{COMP_POS2, PERM_POS2}, VerifyBatchOpcode::VERIFY_BATCH, }; -use openvm_poseidon2_air::{Poseidon2Config, Poseidon2SubAir, Poseidon2SubChip}; +use openvm_poseidon2_air::{Poseidon2Config, Poseidon2SubChip}; use openvm_stark_backend::{ + p3_air::BaseAir, p3_field::{Field, PrimeField32}, p3_maybe_rayon::prelude::{ParallelIterator, ParallelSlice}, }; -use serde::{Deserialize, Serialize}; use crate::poseidon2::{ - air::{NativePoseidon2Air, VerifyBatchBus}, + columns::{ + InsideRowSpecificCols, NativePoseidon2Cols, SimplePoseidonSpecificCols, + TopLevelSpecificCols, + }, CHUNK, }; -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(bound = "F: Field")] -pub struct VerifyBatchRecord { - pub from_state: ExecutionState, - pub instruction: Instruction, - - pub dim_base_pointer: F, - pub opened_base_pointer: F, - pub opened_length: usize, - pub index_base_pointer: F, - pub commit_pointer: F, - - pub dim_base_pointer_read: RecordId, - pub opened_base_pointer_read: RecordId, - pub opened_length_read: RecordId, - pub index_base_pointer_read: RecordId, - pub commit_pointer_read: RecordId, - - pub commit_read: RecordId, - pub initial_log_height: usize, - pub top_level: Vec>, +pub struct NativePoseidon2Step { + // pre-computed Poseidon2 sub cols for dummy rows. + empty_poseidon2_sub_cols: Vec, + pub(super) subchip: Poseidon2SubChip, } -impl VerifyBatchRecord { - pub fn opened_element_size_inv(&self) -> F { - self.instruction.g +impl NativePoseidon2Step { + pub fn new(poseidon2_config: Poseidon2Config) -> Self { + let subchip = Poseidon2SubChip::new(poseidon2_config.constants); + let empty_poseidon2_sub_cols = subchip.generate_trace(vec![[F::ZERO; CHUNK * 2]]).values; + Self { + empty_poseidon2_sub_cols, + subchip, + } } } -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(bound = "F: Field")] -pub struct TopLevelRecord { - // must be present in first record - pub incorporate_row: Option>, - // must be present in all bust last record - pub incorporate_sibling: Option>, +fn compress( + subchip: &Poseidon2SubChip, + left: [F; CHUNK], + right: [F; CHUNK], +) -> ([F; 2 * CHUNK], [F; CHUNK]) { + let concatenated = std::array::from_fn(|i| if i < CHUNK { left[i] } else { right[i - CHUNK] }); + let permuted = subchip.permute(concatenated); + (concatenated, std::array::from_fn(|i| permuted[i])) } -#[repr(C)] -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(bound = "F: Field")] -pub struct IncorporateSiblingRecord { - pub read_sibling_is_on_right: RecordId, - pub sibling_is_on_right: bool, - pub p2_input: [F; 2 * CHUNK], -} +pub(super) const NUM_INITIAL_READS: usize = 6; +pub(super) const NUM_SIMPLE_ACCESSES: u32 = 7; -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(bound = "F: Field")] -pub struct IncorporateRowRecord { - pub chunks: Vec>, - pub initial_opened_index: usize, - pub final_opened_index: usize, - pub initial_height_read: RecordId, - pub final_height_read: RecordId, - pub p2_input: [F; 2 * CHUNK], +#[derive(Debug, Clone, Default)] +pub struct NativePoseidon2Metadata { + num_rows: usize, } -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(bound = "F: Field")] -pub struct InsideRowRecord { - pub cells: Vec, - pub p2_input: [F; 2 * CHUNK], +impl MultiRowMetadata for NativePoseidon2Metadata { + #[inline(always)] + fn get_num_rows(&self) -> usize { + self.num_rows + } } -#[repr(C)] -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct CellRecord { - pub read: RecordId, - pub opened_index: usize, - pub read_row_pointer_and_length: Option, - pub row_pointer: usize, - pub row_end: usize, -} +type NativePoseidon2RecordLayout = MultiRowLayout; -#[repr(C)] -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(bound = "F: Field")] -pub struct SimplePoseidonRecord { - pub from_state: ExecutionState, - pub instruction: Instruction, - - pub read_input_pointer_1: RecordId, - pub read_input_pointer_2: Option, - pub read_output_pointer: RecordId, - pub read_data_1: RecordId, - pub read_data_2: RecordId, - pub write_data_1: RecordId, - pub write_data_2: Option, - - pub input_pointer_1: F, - pub input_pointer_2: F, - pub output_pointer: F, - pub p2_input: [F; 2 * CHUNK], -} +pub struct NativePoseidon2RecordMut<'a, F, const SBOX_REGISTERS: usize>( + &'a mut [NativePoseidon2Cols], +); -#[derive(Debug, Clone, Serialize, Deserialize, Default)] -#[serde(bound = "F: Field")] -pub struct NativePoseidon2RecordSet { - pub verify_batch_records: Vec>, - pub simple_permute_records: Vec>, -} +impl<'a, F: PrimeField32, const SBOX_REGISTERS: usize> + CustomBorrow<'a, NativePoseidon2RecordMut<'a, F, SBOX_REGISTERS>, NativePoseidon2RecordLayout> + for [u8] +{ + fn custom_borrow( + &'a mut self, + layout: NativePoseidon2RecordLayout, + ) -> NativePoseidon2RecordMut<'a, F, SBOX_REGISTERS> { + let arr = unsafe { + self.align_to_mut::>() + .1 + }; + NativePoseidon2RecordMut(&mut arr[..layout.metadata.num_rows]) + } -pub struct NativePoseidon2Chip { - pub(super) air: NativePoseidon2Air, - pub record_set: NativePoseidon2RecordSet, - pub height: usize, - pub(super) offline_memory: Arc>>, - pub(super) subchip: Poseidon2SubChip, - pub(super) streams: Arc>>, + unsafe fn extract_layout(&self) -> NativePoseidon2RecordLayout { + // Each instruction record consists solely of some number of contiguously + // stored NativePoseidon2Cols<...> structs, each of which corresponds to a + // single trace row. Trace fillers don't actually need to know how many rows + // each instruction uses, and can thus treat each NativePoseidon2Cols<...> + // as a single record. + NativePoseidon2RecordLayout { + metadata: NativePoseidon2Metadata { num_rows: 1 }, + } + } } -impl NativePoseidon2Chip { - pub fn new( - port: SystemPort, - offline_memory: Arc>>, - poseidon2_config: Poseidon2Config, - verify_batch_bus: VerifyBatchBus, - streams: Arc>>, - ) -> Self { - let air = NativePoseidon2Air { - execution_bridge: ExecutionBridge::new(port.execution_bus, port.program_bus), - memory_bridge: port.memory_bridge, - internal_bus: verify_batch_bus, - subair: Arc::new(Poseidon2SubAir::new(poseidon2_config.constants.into())), - address_space: F::from_canonical_u32(AS::Native as u32), - }; - Self { - record_set: Default::default(), - air, - height: 0, - offline_memory, - subchip: Poseidon2SubChip::new(poseidon2_config.constants), - streams, - } +impl SizedRecord + for NativePoseidon2RecordMut<'_, F, SBOX_REGISTERS> +{ + fn size(layout: &NativePoseidon2RecordLayout) -> usize { + layout.metadata.num_rows * size_of::>() } - fn compress(&self, left: [F; CHUNK], right: [F; CHUNK]) -> ([F; 2 * CHUNK], [F; CHUNK]) { - let concatenated = - std::array::from_fn(|i| if i < CHUNK { left[i] } else { right[i - CHUNK] }); - let permuted = self.subchip.permute(concatenated); - (concatenated, std::array::from_fn(|i| permuted[i])) + fn alignment(_layout: &NativePoseidon2RecordLayout) -> usize { + align_of::>() } } -pub(super) const NUM_INITIAL_READS: usize = 6; -pub(super) const NUM_SIMPLE_ACCESSES: u32 = 7; - -impl InstructionExecutor - for NativePoseidon2Chip +impl TraceStep + for NativePoseidon2Step { - fn execute( + type RecordLayout = MultiRowLayout; + type RecordMut<'a> = NativePoseidon2RecordMut<'a, F, SBOX_REGISTERS>; + fn execute<'buf, RA>( &mut self, - memory: &mut MemoryController, + state: VmStateMut, CTX>, instruction: &Instruction, - from_state: ExecutionState, - ) -> Result, ExecutionError> { + arena: &'buf mut RA, + ) -> openvm_circuit::arch::Result<()> + where + RA: RecordArena<'buf, Self::RecordLayout, Self::RecordMut<'buf>>, + { + let init_timestamp_u32 = state.memory.timestamp; if instruction.opcode == PERM_POS2.global_opcode() || instruction.opcode == COMP_POS2.global_opcode() { + let cols = &mut arena + .alloc(MultiRowLayout::new(NativePoseidon2Metadata { num_rows: 1 })) + .0[0]; + let simple_cols: &mut SimplePoseidonSpecificCols = + cols.specific[..SimplePoseidonSpecificCols::::width()].borrow_mut(); let &Instruction { a: output_register, b: input_register_1, @@ -192,22 +155,45 @@ impl InstructionExecutor e: data_address_space, .. } = instruction; + debug_assert_eq!( + register_address_space, + F::from_canonical_u32(AS::Native as u32) + ); + debug_assert_eq!(data_address_space, F::from_canonical_u32(AS::Native as u32)); + let [output_pointer]: [F; 1] = tracing_read_native_helper( + state.memory, + output_register.as_canonical_u32(), + simple_cols.read_output_pointer.as_mut(), + ); + let output_pointer_u32 = output_pointer.as_canonical_u32(); + let [input_pointer_1]: [F; 1] = tracing_read_native_helper( + state.memory, + input_register_1.as_canonical_u32(), + simple_cols.read_input_pointer_1.as_mut(), + ); + let input_pointer_1_u32 = input_pointer_1.as_canonical_u32(); + let [input_pointer_2]: [F; 1] = if instruction.opcode == PERM_POS2.global_opcode() { + state.memory.increment_timestamp(); + [input_pointer_1 + F::from_canonical_usize(CHUNK)] + } else { + tracing_read_native_helper( + state.memory, + input_register_2.as_canonical_u32(), + simple_cols.read_input_pointer_2.as_mut(), + ) + }; + let input_pointer_2_u32 = input_pointer_2.as_canonical_u32(); + let data_1: [F; CHUNK] = tracing_read_native_helper( + state.memory, + input_pointer_1_u32, + simple_cols.read_data_1.as_mut(), + ); + let data_2: [F; CHUNK] = tracing_read_native_helper( + state.memory, + input_pointer_2_u32, + simple_cols.read_data_2.as_mut(), + ); - let (read_output_pointer, output_pointer) = - memory.read_cell(register_address_space, output_register); - let (read_input_pointer_1, input_pointer_1) = - memory.read_cell(register_address_space, input_register_1); - let (read_input_pointer_2, input_pointer_2) = - if instruction.opcode == PERM_POS2.global_opcode() { - memory.increment_timestamp(); - (None, input_pointer_1 + F::from_canonical_usize(CHUNK)) - } else { - let (read_input_pointer_2, input_pointer_2) = - memory.read_cell(register_address_space, input_register_2); - (Some(read_input_pointer_2), input_pointer_2) - }; - let (read_data_1, data_1) = memory.read::(data_address_space, input_pointer_1); - let (read_data_2, data_2) = memory.read::(data_address_space, input_pointer_2); let p2_input = std::array::from_fn(|i| { if i < CHUNK { data_1[i] @@ -216,50 +202,51 @@ impl InstructionExecutor } }); let output = self.subchip.permute(p2_input); - let (write_data_1, _) = memory.write::( - data_address_space, - output_pointer, + tracing_write_native_inplace( + state.memory, + output_pointer_u32, std::array::from_fn(|i| output[i]), + &mut simple_cols.write_data_1, ); - let write_data_2 = if instruction.opcode == PERM_POS2.global_opcode() { - Some( - memory - .write::( - data_address_space, - output_pointer + F::from_canonical_usize(CHUNK), - std::array::from_fn(|i| output[CHUNK + i]), - ) - .0, - ) + if instruction.opcode == PERM_POS2.global_opcode() { + tracing_write_native_inplace( + state.memory, + output_pointer_u32 + CHUNK as u32, + std::array::from_fn(|i| output[i + CHUNK]), + &mut simple_cols.write_data_2, + ); } else { - memory.increment_timestamp(); - None - }; - - assert_eq!( - memory.timestamp(), - from_state.timestamp + NUM_SIMPLE_ACCESSES + state.memory.increment_timestamp(); + } + debug_assert_eq!( + state.memory.timestamp, + init_timestamp_u32 + NUM_SIMPLE_ACCESSES ); + cols.incorporate_row = F::ZERO; + cols.incorporate_sibling = F::ZERO; + cols.inside_row = F::ZERO; + cols.simple = F::ONE; + cols.end_inside_row = F::ZERO; + cols.end_top_level = F::ZERO; + cols.is_exhausted = [F::ZERO; CHUNK - 1]; + cols.start_timestamp = F::from_canonical_u32(init_timestamp_u32); - self.record_set - .simple_permute_records - .push(SimplePoseidonRecord { - from_state, - instruction: instruction.clone(), - read_input_pointer_1, - read_input_pointer_2, - read_output_pointer, - read_data_1, - read_data_2, - write_data_1, - write_data_2, - input_pointer_1, - input_pointer_2, - output_pointer, - p2_input, - }); - self.height += 1; + cols.inner.inputs = p2_input; + simple_cols.pc = F::from_canonical_u32(*state.pc); + simple_cols.is_compress = F::from_bool(instruction.opcode == COMP_POS2.global_opcode()); + simple_cols.output_register = output_register; + simple_cols.input_register_1 = input_register_1; + simple_cols.input_register_2 = input_register_2; + simple_cols.output_pointer = output_pointer; + simple_cols.input_pointer_1 = input_pointer_1; + simple_cols.input_pointer_2 = input_pointer_2; } else if instruction.opcode == VERIFY_BATCH.global_opcode() { + let init_timestamp = F::from_canonical_u32(init_timestamp_u32); + let mut col_buffer = vec![F::ZERO; NativePoseidon2Cols::::width()]; + let last_top_level_cols: &mut NativePoseidon2Cols = + col_buffer.as_mut_slice().borrow_mut(); + let ltl_specific_cols: &mut TopLevelSpecificCols = + last_top_level_cols.specific[..TopLevelSpecificCols::::width()].borrow_mut(); let &Instruction { a: dim_register, b: opened_register, @@ -270,58 +257,123 @@ impl InstructionExecutor g: opened_element_size_inv, .. } = instruction; - let address_space = self.air.address_space; // calc inverse fast assuming opened_element_size in {1, 4} let mut opened_element_size = F::ONE; while opened_element_size * opened_element_size_inv != F::ONE { opened_element_size += F::ONE; } - let proof_id = memory.unsafe_read_cell(address_space, proof_id_ptr); - let (dim_base_pointer_read, dim_base_pointer) = - memory.read_cell(address_space, dim_register); - let (opened_base_pointer_read, opened_base_pointer) = - memory.read_cell(address_space, opened_register); - let (opened_length_read, opened_length) = - memory.read_cell(address_space, opened_length_register); - let (index_base_pointer_read, index_base_pointer) = - memory.read_cell(address_space, index_register); - let (commit_pointer_read, commit_pointer) = - memory.read_cell(address_space, commit_register); - let (commit_read, commit) = memory.read(address_space, commit_pointer); + let [proof_id]: [F; 1] = + memory_read_native(state.memory.data(), proof_id_ptr.as_canonical_u32()); + let [dim_base_pointer]: [F; 1] = tracing_read_native_helper( + state.memory, + dim_register.as_canonical_u32(), + ltl_specific_cols.dim_base_pointer_read.as_mut(), + ); + let dim_base_pointer_u32 = dim_base_pointer.as_canonical_u32(); + let [opened_base_pointer]: [F; 1] = tracing_read_native_helper( + state.memory, + opened_register.as_canonical_u32(), + ltl_specific_cols.opened_base_pointer_read.as_mut(), + ); + let opened_base_pointer_u32 = opened_base_pointer.as_canonical_u32(); + let [opened_length]: [F; 1] = tracing_read_native_helper( + state.memory, + opened_length_register.as_canonical_u32(), + ltl_specific_cols.opened_length_read.as_mut(), + ); + let [index_base_pointer]: [F; 1] = tracing_read_native_helper( + state.memory, + index_register.as_canonical_u32(), + ltl_specific_cols.index_base_pointer_read.as_mut(), + ); + let index_base_pointer_u32 = index_base_pointer.as_canonical_u32(); + let [commit_pointer]: [F; 1] = tracing_read_native_helper( + state.memory, + commit_register.as_canonical_u32(), + ltl_specific_cols.commit_pointer_read.as_mut(), + ); + let commit = tracing_read_native_helper( + state.memory, + commit_pointer.as_canonical_u32(), + ltl_specific_cols.commit_read.as_mut(), + ); let opened_length = opened_length.as_canonical_u32() as usize; + let [initial_log_height]: [F; 1] = + memory_read_native(state.memory.data(), dim_base_pointer_u32); + let initial_log_height_u32 = initial_log_height.as_canonical_u32(); + let mut log_height = initial_log_height_u32 as i32; - let initial_log_height = memory - .unsafe_read_cell(address_space, dim_base_pointer) - .as_canonical_u32(); - let mut log_height = initial_log_height as i32; - let mut sibling_index = 0; + // Number of non-inside rows, this is used to compute the offset of the inside row + // section. + let (num_inside_rows, num_non_inside_rows) = { + let opened_element_size_u32 = opened_element_size.as_canonical_u32(); + let mut num_non_inside_rows = initial_log_height_u32 as usize; + let mut num_inside_rows = 0; + let mut log_height = initial_log_height_u32; + let mut opened_index = 0; + loop { + let mut total_len = 0; + while opened_index < opened_length { + let [height]: [F; 1] = memory_read_native( + state.memory.data(), + dim_base_pointer_u32 + opened_index as u32, + ); + if height.as_canonical_u32() != log_height { + break; + } + let [row_len]: [F; 1] = memory_read_native( + state.memory.data(), + opened_base_pointer_u32 + 2 * opened_index as u32 + 1, + ); + total_len += row_len.as_canonical_u32() * opened_element_size_u32; + opened_index += 1; + } + if total_len != 0 { + num_non_inside_rows += 1; + num_inside_rows += (total_len as usize).div_ceil(CHUNK); + } + if log_height == 0 { + break; + } + log_height -= 1; + } + (num_inside_rows, num_non_inside_rows) + }; + let mut proof_index = 0; let mut opened_index = 0; - let mut top_level = vec![]; let mut root = [F::ZERO; CHUNK]; let sibling_proof: Vec<[F; CHUNK]> = { - let streams = self.streams.lock().unwrap(); let proof_idx = proof_id.as_canonical_u32() as usize; - streams.hint_space[proof_idx] + state.streams.hint_space[proof_idx] .par_chunks(CHUNK) .map(|c| c.try_into().unwrap()) .collect() }; + let total_num_row = num_inside_rows + num_non_inside_rows; + let allocated_rows = arena + .alloc(MultiRowLayout::new(NativePoseidon2Metadata { + num_rows: total_num_row, + })) + .0; + let mut inside_row_idx = num_non_inside_rows; + let mut non_inside_row_idx = 0; + while log_height >= 0 { - let incorporate_row = if opened_index < opened_length - && memory.unsafe_read_cell( - address_space, - dim_base_pointer + F::from_canonical_usize(opened_index), - ) == F::from_canonical_u32(log_height as u32) + if opened_index < opened_length + && memory_read_native::( + state.memory.data(), + dim_base_pointer_u32 + opened_index as u32, + )[0] == F::from_canonical_u32(log_height as u32) { + state + .memory + .increment_timestamp_by(NUM_INITIAL_READS as u32); + let incorporate_start_timestamp = state.memory.timestamp; let initial_opened_index = opened_index; - for _ in 0..NUM_INITIAL_READS { - memory.increment_timestamp(); - } - let mut chunks = vec![]; let mut row_pointer = 0; let mut row_end = 0; @@ -332,166 +384,248 @@ impl InstructionExecutor let mut is_first_in_segment = true; loop { - let mut cells = vec![]; + if inside_row_idx == total_num_row { + opened_index += 1; + break; + } + let inside_cols = &mut allocated_rows[inside_row_idx]; + let inside_specific_cols: &mut InsideRowSpecificCols = inside_cols + .specific[..InsideRowSpecificCols::::width()] + .borrow_mut(); + let start_timestamp_u32 = state.memory.timestamp; + + let mut cells_idx = 0; for chunk_elem in rolling_hash.iter_mut().take(CHUNK) { - let read_row_pointer_and_length = if is_first_in_segment - || row_pointer == row_end - { + let cell_cols = &mut inside_specific_cols.cells[cells_idx]; + if is_first_in_segment || row_pointer == row_end { if is_first_in_segment { is_first_in_segment = false; } else { opened_index += 1; if opened_index == opened_length - || memory.unsafe_read_cell( - address_space, - dim_base_pointer - + F::from_canonical_usize(opened_index), - ) != F::from_canonical_u32(log_height as u32) + || memory_read_native::( + state.memory.data(), + dim_base_pointer_u32 + opened_index as u32, + )[0] != F::from_canonical_u32(log_height as u32) { break; } } - let (result, [new_row_pointer, row_len]) = memory.read( - address_space, - opened_base_pointer + F::from_canonical_usize(2 * opened_index), + let [new_row_pointer, row_len]: [F; 2] = tracing_read_native_helper( + state.memory, + opened_base_pointer_u32 + 2 * opened_index as u32, + cell_cols.read_row_pointer_and_length.as_mut(), ); row_pointer = new_row_pointer.as_canonical_u32() as usize; row_end = row_pointer + (opened_element_size * row_len).as_canonical_u32() as usize; - Some(result) + cell_cols.is_first_in_row = F::ONE; } else { - memory.increment_timestamp(); - None - }; - let (read, value) = memory - .read_cell(address_space, F::from_canonical_usize(row_pointer)); - cells.push(CellRecord { - read, - opened_index, - read_row_pointer_and_length, - row_pointer, - row_end, - }); + state.memory.increment_timestamp(); + } + let [value]: [F; 1] = tracing_read_native_helper( + state.memory, + row_pointer as u32, + cell_cols.read.as_mut(), + ); + + cell_cols.opened_index = F::from_canonical_usize(opened_index); + cell_cols.row_pointer = F::from_canonical_usize(row_pointer); + cell_cols.row_end = F::from_canonical_usize(row_end); + *chunk_elem = value; row_pointer += 1; + cells_idx += 1; } - if cells.is_empty() { + if cells_idx == 0 { break; } - let cells_len = cells.len(); - chunks.push(InsideRowRecord { - cells, - p2_input: rolling_hash, - }); - self.height += 1; + let p2_input = rolling_hash; prev_rolling_hash = Some(rolling_hash); self.subchip.permute_mut(&mut rolling_hash); - if cells_len < CHUNK { - for _ in 0..CHUNK - cells_len { - memory.increment_timestamp(); - memory.increment_timestamp(); + if cells_idx < CHUNK { + state + .memory + .increment_timestamp_by(2 * (CHUNK - cells_idx) as u32); + } + + inside_row_idx += 1; + inside_cols.inner.inputs = p2_input; + inside_cols.incorporate_row = F::ZERO; + inside_cols.incorporate_sibling = F::ZERO; + inside_cols.inside_row = F::ONE; + inside_cols.simple = F::ZERO; + // `end_inside_row` of the last row will be set to 1 after this loop. + inside_cols.end_inside_row = F::ZERO; + inside_cols.end_top_level = F::ZERO; + inside_cols.opened_element_size_inv = opened_element_size_inv; + inside_cols.very_first_timestamp = + F::from_canonical_u32(incorporate_start_timestamp); + inside_cols.start_timestamp = F::from_canonical_u32(start_timestamp_u32); + + inside_cols.initial_opened_index = + F::from_canonical_usize(initial_opened_index); + inside_cols.opened_base_pointer = opened_base_pointer; + if cells_idx < CHUNK { + let exhausted_opened_idx = F::from_canonical_usize(opened_index - 1); + for exhausted_idx in cells_idx..CHUNK { + inside_cols.is_exhausted[exhausted_idx - 1] = F::ONE; + inside_specific_cols.cells[exhausted_idx].opened_index = + exhausted_opened_idx; } break; } } + { + let inside_cols = &mut allocated_rows[inside_row_idx - 1]; + inside_cols.end_inside_row = F::ONE; + } + + let incorporate_cols = &mut allocated_rows[non_inside_row_idx]; + let top_level_specific_cols: &mut TopLevelSpecificCols = incorporate_cols + .specific[..TopLevelSpecificCols::::width()] + .borrow_mut(); + let final_opened_index = opened_index - 1; - let (initial_height_read, height_check) = memory.read_cell( - address_space, - dim_base_pointer + F::from_canonical_usize(initial_opened_index), + let [height_check]: [F; 1] = tracing_read_native_helper( + state.memory, + dim_base_pointer_u32 + initial_opened_index as u32, + top_level_specific_cols + .read_initial_height_or_sibling_is_on_right + .as_mut(), ); assert_eq!(height_check, F::from_canonical_u32(log_height as u32)); - let (final_height_read, height_check) = memory.read_cell( - address_space, - dim_base_pointer + F::from_canonical_usize(final_opened_index), + let final_height_read_timestamp = state.memory.timestamp; + let [height_check]: [F; 1] = tracing_read_native_helper( + state.memory, + dim_base_pointer_u32 + final_opened_index as u32, + top_level_specific_cols.read_final_height.as_mut(), ); assert_eq!(height_check, F::from_canonical_u32(log_height as u32)); let hash: [F; CHUNK] = std::array::from_fn(|i| rolling_hash[i]); - - let (p2_input, new_root) = if log_height as u32 == initial_log_height { + let (p2_input, new_root) = if log_height as u32 == initial_log_height_u32 { (prev_rolling_hash.unwrap(), hash) } else { - self.compress(root, hash) + compress(&self.subchip, root, hash) }; root = new_root; + non_inside_row_idx += 1; - self.height += 1; - Some(IncorporateRowRecord { - chunks, - initial_opened_index, - final_opened_index, - initial_height_read, - final_height_read, - p2_input, - }) - } else { - None - }; + incorporate_cols.incorporate_row = F::ONE; + incorporate_cols.incorporate_sibling = F::ZERO; + incorporate_cols.inside_row = F::ZERO; + incorporate_cols.simple = F::ZERO; + incorporate_cols.end_inside_row = F::ZERO; + incorporate_cols.end_top_level = F::ZERO; + incorporate_cols.start_top_level = F::from_bool(proof_index == 0); + incorporate_cols.opened_element_size_inv = opened_element_size_inv; + incorporate_cols.very_first_timestamp = init_timestamp; + incorporate_cols.start_timestamp = F::from_canonical_u32( + incorporate_start_timestamp - NUM_INITIAL_READS as u32, + ); + top_level_specific_cols.end_timestamp = + F::from_canonical_u32(final_height_read_timestamp + 1); - let incorporate_sibling = if log_height == 0 { - None - } else { - for _ in 0..NUM_INITIAL_READS { - memory.increment_timestamp(); - } + incorporate_cols.inner.inputs = p2_input; + incorporate_cols.initial_opened_index = + F::from_canonical_usize(initial_opened_index); + top_level_specific_cols.final_opened_index = + F::from_canonical_usize(final_opened_index); + top_level_specific_cols.log_height = F::from_canonical_u32(log_height as u32); + top_level_specific_cols.opened_length = F::from_canonical_usize(opened_length); + top_level_specific_cols.dim_base_pointer = dim_base_pointer; + incorporate_cols.opened_base_pointer = opened_base_pointer; + top_level_specific_cols.index_base_pointer = index_base_pointer; + top_level_specific_cols.proof_index = F::from_canonical_usize(proof_index); + } + + if log_height != 0 { + let row_start_timestamp = state.memory.timestamp; + state + .memory + .increment_timestamp_by(NUM_INITIAL_READS as u32); + + let sibling_cols = &mut allocated_rows[non_inside_row_idx]; + let top_level_specific_cols: &mut TopLevelSpecificCols = + sibling_cols.specific[..TopLevelSpecificCols::::width()].borrow_mut(); - let (read_sibling_is_on_right, sibling_is_on_right) = memory.read_cell( - address_space, - index_base_pointer + F::from_canonical_usize(sibling_index), + let read_sibling_is_on_right_timestamp = state.memory.timestamp; + let [sibling_is_on_right]: [F; 1] = tracing_read_native_helper( + state.memory, + index_base_pointer_u32 + proof_index as u32, + top_level_specific_cols + .read_initial_height_or_sibling_is_on_right + .as_mut(), ); - let sibling_is_on_right = sibling_is_on_right == F::ONE; - let sibling = sibling_proof[sibling_index]; - let (p2_input, new_root) = if sibling_is_on_right { - self.compress(sibling, root) + let sibling = sibling_proof[proof_index]; + let (p2_input, new_root) = if sibling_is_on_right == F::ONE { + compress(&self.subchip, sibling, root) } else { - self.compress(root, sibling) + compress(&self.subchip, root, sibling) }; root = new_root; - self.height += 1; - Some(IncorporateSiblingRecord { - read_sibling_is_on_right, - sibling_is_on_right, - p2_input, - }) - }; + non_inside_row_idx += 1; + + sibling_cols.inner.inputs = p2_input; - top_level.push(TopLevelRecord { - incorporate_row, - incorporate_sibling, - }); + sibling_cols.incorporate_row = F::ZERO; + sibling_cols.incorporate_sibling = F::ONE; + sibling_cols.inside_row = F::ZERO; + sibling_cols.simple = F::ZERO; + sibling_cols.end_inside_row = F::ZERO; + sibling_cols.end_top_level = F::ZERO; + sibling_cols.start_top_level = F::ZERO; + sibling_cols.opened_element_size_inv = opened_element_size_inv; + sibling_cols.very_first_timestamp = init_timestamp; + sibling_cols.start_timestamp = F::from_canonical_u32(row_start_timestamp); + + top_level_specific_cols.end_timestamp = + F::from_canonical_u32(read_sibling_is_on_right_timestamp + 1); + sibling_cols.initial_opened_index = F::from_canonical_usize(opened_index); + top_level_specific_cols.final_opened_index = + F::from_canonical_usize(opened_index - 1); + top_level_specific_cols.log_height = F::from_canonical_u32(log_height as u32); + top_level_specific_cols.opened_length = F::from_canonical_usize(opened_length); + top_level_specific_cols.dim_base_pointer = dim_base_pointer; + sibling_cols.opened_base_pointer = opened_base_pointer; + top_level_specific_cols.index_base_pointer = index_base_pointer; + + top_level_specific_cols.proof_index = F::from_canonical_usize(proof_index); + top_level_specific_cols.sibling_is_on_right = sibling_is_on_right; + }; log_height -= 1; - sibling_index += 1; + proof_index += 1; } - + let ltl_trace_cols = &mut allocated_rows[non_inside_row_idx - 1]; + let ltl_trace_specific_cols: &mut TopLevelSpecificCols = + ltl_trace_cols.specific[..TopLevelSpecificCols::::width()].borrow_mut(); + ltl_trace_cols.end_top_level = F::ONE; + ltl_trace_specific_cols.pc = F::from_canonical_u32(*state.pc); + ltl_trace_specific_cols.dim_register = dim_register; + ltl_trace_specific_cols.opened_register = opened_register; + ltl_trace_specific_cols.opened_length_register = opened_length_register; + ltl_trace_specific_cols.proof_id = proof_id_ptr; + ltl_trace_specific_cols.index_register = index_register; + ltl_trace_specific_cols.commit_register = commit_register; + ltl_trace_specific_cols.commit_pointer = commit_pointer; + ltl_trace_specific_cols.dim_base_pointer_read = ltl_specific_cols.dim_base_pointer_read; + ltl_trace_specific_cols.opened_base_pointer_read = + ltl_specific_cols.opened_base_pointer_read; + ltl_trace_specific_cols.opened_length_read = ltl_specific_cols.opened_length_read; + ltl_trace_specific_cols.index_base_pointer_read = + ltl_specific_cols.index_base_pointer_read; + ltl_trace_specific_cols.commit_pointer_read = ltl_specific_cols.commit_pointer_read; + ltl_trace_specific_cols.commit_read = ltl_specific_cols.commit_read; assert_eq!(commit, root); - self.record_set - .verify_batch_records - .push(VerifyBatchRecord { - from_state, - instruction: instruction.clone(), - dim_base_pointer, - opened_base_pointer, - opened_length, - index_base_pointer, - commit_pointer, - dim_base_pointer_read, - opened_base_pointer_read, - opened_length_read, - index_base_pointer_read, - commit_pointer_read, - commit_read, - initial_log_height: initial_log_height as usize, - top_level, - }); } else { unreachable!() } - Ok(ExecutionState { - pc: from_state.pc + DEFAULT_PC_STEP, - timestamp: memory.timestamp(), - }) + + *state.pc += DEFAULT_PC_STEP; + Ok(()) } fn get_opcode_name(&self, opcode: usize) -> String { @@ -506,3 +640,645 @@ impl InstructionExecutor } } } + +impl TraceFiller + for NativePoseidon2Step +{ + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + let inner_cols = { + let cols: &NativePoseidon2Cols = row_slice.as_ref().borrow(); + &self.subchip.generate_trace(vec![cols.inner.inputs]).values + }; + let inner_width = self.subchip.air.width(); + row_slice[..inner_width].copy_from_slice(inner_cols); + let cols: &mut NativePoseidon2Cols = row_slice.borrow_mut(); + + // Simple poseidon2 row + if cols.simple.is_one() { + let simple_cols: &mut SimplePoseidonSpecificCols = + cols.specific[..SimplePoseidonSpecificCols::::width()].borrow_mut(); + let start_timestamp_u32 = cols.start_timestamp.as_canonical_u32(); + mem_fill_helper( + mem_helper, + start_timestamp_u32, + simple_cols.read_output_pointer.as_mut(), + ); + mem_fill_helper( + mem_helper, + start_timestamp_u32 + 1, + simple_cols.read_input_pointer_1.as_mut(), + ); + if simple_cols.is_compress.is_one() { + mem_fill_helper( + mem_helper, + start_timestamp_u32 + 2, + simple_cols.read_input_pointer_2.as_mut(), + ); + } + mem_fill_helper( + mem_helper, + start_timestamp_u32 + 3, + simple_cols.read_data_1.as_mut(), + ); + mem_fill_helper( + mem_helper, + start_timestamp_u32 + 4, + simple_cols.read_data_2.as_mut(), + ); + mem_fill_helper( + mem_helper, + start_timestamp_u32 + 5, + simple_cols.write_data_1.as_mut(), + ); + if simple_cols.is_compress.is_zero() { + mem_fill_helper( + mem_helper, + start_timestamp_u32 + 6, + simple_cols.write_data_2.as_mut(), + ); + } + } else if cols.inside_row.is_one() { + let inside_row_specific_cols: &mut InsideRowSpecificCols = + cols.specific[..InsideRowSpecificCols::::width()].borrow_mut(); + let start_timestamp_u32 = cols.start_timestamp.as_canonical_u32(); + for (i, cell) in inside_row_specific_cols.cells.iter_mut().enumerate() { + if i > 0 && cols.is_exhausted[i - 1].is_one() { + break; + } + if cell.is_first_in_row.is_one() { + mem_fill_helper( + mem_helper, + start_timestamp_u32 + 2 * i as u32, + cell.read_row_pointer_and_length.as_mut(), + ); + } + mem_fill_helper( + mem_helper, + start_timestamp_u32 + 2 * i as u32 + 1, + cell.read.as_mut(), + ); + } + } else { + let top_level_specific_cols: &mut TopLevelSpecificCols = + cols.specific[..TopLevelSpecificCols::::width()].borrow_mut(); + let start_timestamp_u32 = cols.start_timestamp.as_canonical_u32(); + if cols.end_top_level.is_one() { + let very_start_timestamp_u32 = cols.very_first_timestamp.as_canonical_u32(); + mem_fill_helper( + mem_helper, + very_start_timestamp_u32, + top_level_specific_cols.dim_base_pointer_read.as_mut(), + ); + mem_fill_helper( + mem_helper, + very_start_timestamp_u32 + 1, + top_level_specific_cols.opened_base_pointer_read.as_mut(), + ); + mem_fill_helper( + mem_helper, + very_start_timestamp_u32 + 2, + top_level_specific_cols.opened_length_read.as_mut(), + ); + mem_fill_helper( + mem_helper, + very_start_timestamp_u32 + 3, + top_level_specific_cols.index_base_pointer_read.as_mut(), + ); + mem_fill_helper( + mem_helper, + very_start_timestamp_u32 + 4, + top_level_specific_cols.commit_pointer_read.as_mut(), + ); + mem_fill_helper( + mem_helper, + very_start_timestamp_u32 + 5, + top_level_specific_cols.commit_read.as_mut(), + ); + } + if cols.incorporate_row.is_one() { + let end_timestamp = top_level_specific_cols.end_timestamp.as_canonical_u32(); + mem_fill_helper( + mem_helper, + end_timestamp - 2, + top_level_specific_cols + .read_initial_height_or_sibling_is_on_right + .as_mut(), + ); + mem_fill_helper( + mem_helper, + end_timestamp - 1, + top_level_specific_cols.read_final_height.as_mut(), + ); + } else if cols.incorporate_sibling.is_one() { + mem_fill_helper( + mem_helper, + start_timestamp_u32 + NUM_INITIAL_READS as u32, + top_level_specific_cols + .read_initial_height_or_sibling_is_on_right + .as_mut(), + ); + } else { + unreachable!() + } + } + } + + fn fill_dummy_trace_row(&self, _mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + let width = self.subchip.air.width(); + row_slice[..width].copy_from_slice(&self.empty_poseidon2_sub_cols); + } +} + +fn tracing_read_native_helper( + memory: &mut TracingMemory, + ptr: u32, + base_aux: &mut MemoryBaseAuxCols, +) -> [F; BLOCK_SIZE] { + let mut prev_ts = 0; + let ret = tracing_read_native(memory, ptr, &mut prev_ts); + base_aux.set_prev(F::from_canonical_u32(prev_ts)); + ret +} + +/// Fill `MemoryBaseAuxCols`, assuming that the `prev_timestamp` is already set in `base_aux`. +fn mem_fill_helper( + mem_helper: &MemoryAuxColsFactory, + timestamp: u32, + base_aux: &mut MemoryBaseAuxCols, +) { + let prev_ts = base_aux.prev_timestamp.as_canonical_u32(); + mem_helper.fill(prev_ts, timestamp, base_aux); +} + +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct Pos2PreCompute<'a, F: Field, const SBOX_REGISTERS: usize> { + subchip: &'a Poseidon2SubChip, + output_register: u32, + input_register_1: u32, + input_register_2: u32, +} + +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct VerifyBatchPreCompute<'a, F: Field, const SBOX_REGISTERS: usize> { + subchip: &'a Poseidon2SubChip, + dim_register: u32, + opened_register: u32, + opened_length_register: u32, + proof_id_ptr: u32, + index_register: u32, + commit_register: u32, + opened_element_size: F, +} + +impl<'a, F: PrimeField32, const SBOX_REGISTERS: usize> NativePoseidon2Step { + #[inline(always)] + fn pre_compute_pos2_impl( + &'a self, + pc: u32, + inst: &Instruction, + pos2_data: &mut Pos2PreCompute<'a, F, SBOX_REGISTERS>, + ) -> Result<()> { + let &Instruction { + opcode, + a, + b, + c, + d, + e, + .. + } = inst; + + if opcode != PERM_POS2.global_opcode() && opcode != COMP_POS2.global_opcode() { + return Err(InvalidInstruction(pc)); + } + + let a = a.as_canonical_u32(); + let b = b.as_canonical_u32(); + let c = c.as_canonical_u32(); + let d = d.as_canonical_u32(); + let e = e.as_canonical_u32(); + + if d != AS::Native as u32 { + return Err(InvalidInstruction(pc)); + } + if e != AS::Native as u32 { + return Err(InvalidInstruction(pc)); + } + + *pos2_data = Pos2PreCompute { + subchip: &self.subchip, + output_register: a, + input_register_1: b, + input_register_2: c, + }; + + Ok(()) + } + + #[inline(always)] + fn pre_compute_verify_batch_impl( + &'a self, + pc: u32, + inst: &Instruction, + verify_batch_data: &mut VerifyBatchPreCompute<'a, F, SBOX_REGISTERS>, + ) -> Result<()> { + let &Instruction { + opcode, + a, + b, + c, + d, + e, + f, + g, + .. + } = inst; + + if opcode != VERIFY_BATCH.global_opcode() { + return Err(InvalidInstruction(pc)); + } + + let a = a.as_canonical_u32(); + let b = b.as_canonical_u32(); + let c = c.as_canonical_u32(); + let d = d.as_canonical_u32(); + let e = e.as_canonical_u32(); + let f = f.as_canonical_u32(); + + let opened_element_size_inv = g; + // calc inverse fast assuming opened_element_size in {1, 4} + let mut opened_element_size = F::ONE; + while opened_element_size * opened_element_size_inv != F::ONE { + opened_element_size += F::ONE; + } + + *verify_batch_data = VerifyBatchPreCompute { + subchip: &self.subchip, + dim_register: a, + opened_register: b, + opened_length_register: c, + proof_id_ptr: d, + index_register: e, + commit_register: f, + opened_element_size, + }; + + Ok(()) + } +} + +impl StepExecutorE1 + for NativePoseidon2Step +{ + #[inline(always)] + fn pre_compute_size(&self) -> usize { + std::cmp::max( + size_of::>(), + size_of::>(), + ) + } + + #[inline(always)] + fn pre_compute_e1( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result> { + let &Instruction { opcode, .. } = inst; + + let is_pos2 = opcode == PERM_POS2.global_opcode() || opcode == COMP_POS2.global_opcode(); + + if is_pos2 { + let pos2_data: &mut Pos2PreCompute = data.borrow_mut(); + self.pre_compute_pos2_impl(pc, inst, pos2_data)?; + if opcode == PERM_POS2.global_opcode() { + Ok(execute_pos2_e1_impl::<_, _, SBOX_REGISTERS, true>) + } else { + Ok(execute_pos2_e1_impl::<_, _, SBOX_REGISTERS, false>) + } + } else { + let verify_batch_data: &mut VerifyBatchPreCompute = + data.borrow_mut(); + self.pre_compute_verify_batch_impl(pc, inst, verify_batch_data)?; + Ok(execute_verify_batch_e1_impl::<_, _, SBOX_REGISTERS>) + } + } +} + +impl StepExecutorE2 + for NativePoseidon2Step +{ + #[inline(always)] + fn e2_pre_compute_size(&self) -> usize { + std::cmp::max( + size_of::>>(), + size_of::>>(), + ) + } + + #[inline(always)] + fn pre_compute_e2( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result> { + let &Instruction { opcode, .. } = inst; + + let is_pos2 = opcode == PERM_POS2.global_opcode() || opcode == COMP_POS2.global_opcode(); + + if is_pos2 { + let pre_compute: &mut E2PreCompute> = + data.borrow_mut(); + pre_compute.chip_idx = chip_idx as u32; + + self.pre_compute_pos2_impl(pc, inst, &mut pre_compute.data)?; + if opcode == PERM_POS2.global_opcode() { + Ok(execute_pos2_e2_impl::<_, _, SBOX_REGISTERS, true>) + } else { + Ok(execute_pos2_e2_impl::<_, _, SBOX_REGISTERS, false>) + } + } else { + let pre_compute: &mut E2PreCompute> = + data.borrow_mut(); + pre_compute.chip_idx = chip_idx as u32; + + self.pre_compute_verify_batch_impl(pc, inst, &mut pre_compute.data)?; + Ok(execute_verify_batch_e2_impl::<_, _, SBOX_REGISTERS>) + } + } +} + +unsafe fn execute_pos2_e1_impl< + F: PrimeField32, + CTX: E1ExecutionCtx, + const SBOX_REGISTERS: usize, + const IS_PERM: bool, +>( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &Pos2PreCompute = pre_compute.borrow(); + execute_pos2_e12_impl::<_, _, SBOX_REGISTERS, IS_PERM>(pre_compute, vm_state); +} + +unsafe fn execute_pos2_e2_impl< + F: PrimeField32, + CTX: E2ExecutionCtx, + const SBOX_REGISTERS: usize, + const IS_PERM: bool, +>( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &E2PreCompute> = pre_compute.borrow(); + let height = + execute_pos2_e12_impl::<_, _, SBOX_REGISTERS, IS_PERM>(&pre_compute.data, vm_state); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, height); +} + +unsafe fn execute_verify_batch_e1_impl< + F: PrimeField32, + CTX: E1ExecutionCtx, + const SBOX_REGISTERS: usize, +>( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &VerifyBatchPreCompute = pre_compute.borrow(); + execute_verify_batch_e12_impl::<_, _, SBOX_REGISTERS>(pre_compute, vm_state); +} + +unsafe fn execute_verify_batch_e2_impl< + F: PrimeField32, + CTX: E2ExecutionCtx, + const SBOX_REGISTERS: usize, +>( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &E2PreCompute> = pre_compute.borrow(); + let height = execute_verify_batch_e12_impl::<_, _, SBOX_REGISTERS>(&pre_compute.data, vm_state); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, height); +} + +#[inline(always)] +unsafe fn execute_pos2_e12_impl< + F: PrimeField32, + CTX: E1ExecutionCtx, + const SBOX_REGISTERS: usize, + const IS_PERM: bool, +>( + pre_compute: &Pos2PreCompute, + vm_state: &mut VmSegmentState, +) -> u32 { + let subchip = pre_compute.subchip; + + let [output_pointer]: [F; 1] = vm_state.vm_read(AS::Native as u32, pre_compute.output_register); + let [input_pointer_1]: [F; 1] = + vm_state.vm_read(AS::Native as u32, pre_compute.input_register_1); + let [input_pointer_2] = if IS_PERM { + [input_pointer_1 + F::from_canonical_usize(CHUNK)] + } else { + vm_state.vm_read(AS::Native as u32, pre_compute.input_register_2) + }; + + let data_1: [F; CHUNK] = + vm_state.vm_read(AS::Native as u32, input_pointer_1.as_canonical_u32()); + let data_2: [F; CHUNK] = + vm_state.vm_read(AS::Native as u32, input_pointer_2.as_canonical_u32()); + + let p2_input = std::array::from_fn(|i| { + if i < CHUNK { + data_1[i] + } else { + data_2[i - CHUNK] + } + }); + let output = subchip.permute(p2_input); + let output_pointer_u32 = output_pointer.as_canonical_u32(); + + vm_state.vm_write::( + AS::Native as u32, + output_pointer_u32, + &std::array::from_fn(|i| output[i]), + ); + if IS_PERM { + vm_state.vm_write::( + AS::Native as u32, + output_pointer_u32 + CHUNK as u32, + &std::array::from_fn(|i| output[i + CHUNK]), + ); + } + + vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP); + vm_state.instret += 1; + + 1 +} + +#[inline(always)] +unsafe fn execute_verify_batch_e12_impl< + F: PrimeField32, + CTX: E1ExecutionCtx, + const SBOX_REGISTERS: usize, +>( + pre_compute: &VerifyBatchPreCompute, + vm_state: &mut VmSegmentState, +) -> u32 { + // TODO: Add a flag `optimistic_execution`. When the flag is true, we trust all inputs + // and skip all input validation computation during E1 execution. + + let subchip = pre_compute.subchip; + let opened_element_size = pre_compute.opened_element_size; + + let [proof_id]: [F; 1] = vm_state.host_read(AS::Native as u32, pre_compute.proof_id_ptr); + let [dim_base_pointer]: [F; 1] = vm_state.vm_read(AS::Native as u32, pre_compute.dim_register); + let dim_base_pointer_u32 = dim_base_pointer.as_canonical_u32(); + let [opened_base_pointer]: [F; 1] = + vm_state.vm_read(AS::Native as u32, pre_compute.opened_register); + let opened_base_pointer_u32 = opened_base_pointer.as_canonical_u32(); + let [opened_length]: [F; 1] = + vm_state.vm_read(AS::Native as u32, pre_compute.opened_length_register); + let [index_base_pointer]: [F; 1] = + vm_state.vm_read(AS::Native as u32, pre_compute.index_register); + let index_base_pointer_u32 = index_base_pointer.as_canonical_u32(); + let [commit_pointer]: [F; 1] = vm_state.vm_read(AS::Native as u32, pre_compute.commit_register); + let commit: [F; CHUNK] = vm_state.vm_read(AS::Native as u32, commit_pointer.as_canonical_u32()); + + let opened_length = opened_length.as_canonical_u32() as usize; + + let initial_log_height = { + let [height]: [F; 1] = vm_state.host_read(AS::Native as u32, dim_base_pointer_u32); + height.as_canonical_u32() + }; + + let mut log_height = initial_log_height as i32; + let mut sibling_index = 0; + let mut opened_index = 0; + let mut height = 0; + + let mut root = [F::ZERO; CHUNK]; + let sibling_proof: Vec<[F; CHUNK]> = { + let proof_idx = proof_id.as_canonical_u32() as usize; + vm_state.streams.hint_space[proof_idx] + .par_chunks(CHUNK) + .map(|c| c.try_into().unwrap()) + .collect() + }; + + while log_height >= 0 { + if opened_index < opened_length + && vm_state.host_read::( + AS::Native as u32, + dim_base_pointer_u32 + opened_index as u32, + )[0] == F::from_canonical_u32(log_height as u32) + { + let initial_opened_index = opened_index; + + let mut row_pointer = 0; + let mut row_end = 0; + + let mut rolling_hash = [F::ZERO; 2 * CHUNK]; + + let mut is_first_in_segment = true; + + loop { + let mut cells_len = 0; + for chunk_elem in rolling_hash.iter_mut().take(CHUNK) { + if is_first_in_segment || row_pointer == row_end { + if is_first_in_segment { + is_first_in_segment = false; + } else { + opened_index += 1; + if opened_index == opened_length + || vm_state.host_read::( + AS::Native as u32, + dim_base_pointer_u32 + opened_index as u32, + )[0] != F::from_canonical_u32(log_height as u32) + { + break; + } + } + let [new_row_pointer, row_len]: [F; 2] = vm_state.vm_read( + AS::Native as u32, + opened_base_pointer_u32 + 2 * opened_index as u32, + ); + row_pointer = new_row_pointer.as_canonical_u32() as usize; + row_end = row_pointer + + (opened_element_size * row_len).as_canonical_u32() as usize; + } + let [value]: [F; 1] = vm_state.vm_read(AS::Native as u32, row_pointer as u32); + cells_len += 1; + *chunk_elem = value; + row_pointer += 1; + } + if cells_len == 0 { + break; + } + height += 1; + subchip.permute_mut(&mut rolling_hash); + if cells_len < CHUNK { + break; + } + } + + let final_opened_index = opened_index - 1; + let [height_check]: [F; 1] = vm_state.host_read( + AS::Native as u32, + dim_base_pointer_u32 + initial_opened_index as u32, + ); + assert_eq!(height_check, F::from_canonical_u32(log_height as u32)); + let [height_check]: [F; 1] = vm_state.host_read( + AS::Native as u32, + dim_base_pointer_u32 + final_opened_index as u32, + ); + assert_eq!(height_check, F::from_canonical_u32(log_height as u32)); + + let hash: [F; CHUNK] = std::array::from_fn(|i| rolling_hash[i]); + + let new_root = if log_height as u32 == initial_log_height { + hash + } else { + let (_, new_root) = compress(subchip, root, hash); + new_root + }; + root = new_root; + height += 1; + } + + if log_height != 0 { + let [sibling_is_on_right]: [F; 1] = vm_state.vm_read( + AS::Native as u32, + index_base_pointer_u32 + sibling_index as u32, + ); + let sibling_is_on_right = sibling_is_on_right == F::ONE; + let sibling = sibling_proof[sibling_index]; + let (_, new_root) = if sibling_is_on_right { + compress(subchip, sibling, root) + } else { + compress(subchip, root, sibling) + }; + root = new_root; + height += 1; + } + + log_height -= 1; + sibling_index += 1; + } + + assert_eq!(commit, root); + + vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP); + vm_state.instret += 1; + + height +} diff --git a/extensions/native/circuit/src/poseidon2/mod.rs b/extensions/native/circuit/src/poseidon2/mod.rs index af503e20f4..7e9b3783e0 100644 --- a/extensions/native/circuit/src/poseidon2/mod.rs +++ b/extensions/native/circuit/src/poseidon2/mod.rs @@ -1,8 +1,47 @@ +use std::sync::Arc; + +use openvm_circuit::{ + arch::{ExecutionBridge, MatrixRecordArena, NewVmChipWrapper, SystemPort}, + system::memory::SharedMemoryHelper, +}; +use openvm_native_compiler::conversion::AS; +use openvm_poseidon2_air::{Poseidon2Config, Poseidon2SubAir}; +use openvm_stark_backend::p3_field::PrimeField32; + +use crate::poseidon2::{ + air::{NativePoseidon2Air, VerifyBatchBus}, + chip::NativePoseidon2Step, +}; + pub mod air; pub mod chip; mod columns; #[cfg(test)] mod tests; -mod trace; const CHUNK: usize = 8; +pub type NativePoseidon2Chip = NewVmChipWrapper< + F, + NativePoseidon2Air, + NativePoseidon2Step, + MatrixRecordArena, +>; + +pub fn new_native_poseidon2_chip( + port: SystemPort, + poseidon2_config: Poseidon2Config, + verify_batch_bus: VerifyBatchBus, + mem_helper: SharedMemoryHelper, +) -> NativePoseidon2Chip { + NativePoseidon2Chip::::new( + NativePoseidon2Air { + execution_bridge: ExecutionBridge::new(port.execution_bus, port.program_bus), + memory_bridge: port.memory_bridge, + internal_bus: verify_batch_bus, + subair: Arc::new(Poseidon2SubAir::new(poseidon2_config.constants.into())), + address_space: F::from_canonical_u32(AS::Native as u32), + }, + NativePoseidon2Step::new(poseidon2_config), + mem_helper, + ) +} diff --git a/extensions/native/circuit/src/poseidon2/tests.rs b/extensions/native/circuit/src/poseidon2/tests.rs index 32a0e483a3..8e0a211b03 100644 --- a/extensions/native/circuit/src/poseidon2/tests.rs +++ b/extensions/native/circuit/src/poseidon2/tests.rs @@ -1,11 +1,8 @@ -use std::{ - cmp::min, - sync::{Arc, Mutex}, -}; +use std::cmp::min; use openvm_circuit::arch::{ testing::{memory::gen_pointer, VmChipTestBuilder, VmChipTester}, - verify_single, Streams, VirtualMachine, + verify_single, VirtualMachine, }; use openvm_instructions::{instruction::Instruction, program::Program, LocalOpcode, SystemOpcode}; use openvm_native_compiler::{ @@ -34,11 +31,12 @@ use rand::{rngs::StdRng, Rng}; use super::air::VerifyBatchBus; use crate::{ - poseidon2::{chip::NativePoseidon2Chip, CHUNK}, + poseidon2::{new_native_poseidon2_chip, CHUNK}, NativeConfig, }; const VERIFY_BATCH_BUS: VerifyBatchBus = VerifyBatchBus::new(7); +const MAX_INS_CAPACITY: usize = 1 << 15; fn compute_commit( dim: &[usize], @@ -153,15 +151,14 @@ fn test(cases: [Case; N]) { // single op let address_space = AS::Native as usize; - let mut tester = VmChipTestBuilder::default(); - let streams = Arc::new(Mutex::new(Streams::default())); - let mut chip = NativePoseidon2Chip::::new( + let mut tester = VmChipTestBuilder::default_native(); + let mut chip = new_native_poseidon2_chip::( tester.system_port(), - tester.offline_memory_mutex_arc(), Poseidon2Config::default(), VERIFY_BATCH_BUS, - streams.clone(), + tester.memory_helper(), ); + chip.set_trace_buffer_height(MAX_INS_CAPACITY); let mut rng = create_seeded_rng(); for Case { @@ -169,12 +166,11 @@ fn test(cases: [Case; N]) { opened_element_size, } in cases { - let mut streams = streams.lock().unwrap(); let instance = random_instance(&mut rng, row_lengths, opened_element_size, |left, right| { let concatenated = std::array::from_fn(|i| if i < CHUNK { left[i] } else { right[i - CHUNK] }); - let permuted = chip.subchip.permute(concatenated); + let permuted = chip.step.subchip.permute(concatenated); ( std::array::from_fn(|i| permuted[i]), std::array::from_fn(|i| permuted[i + CHUNK]), @@ -203,7 +199,7 @@ fn test(cases: [Case; N]) { tester.write_usize(address_space, dim_register, [dim_base_pointer]); tester.write_usize(address_space, opened_register, [opened_base_pointer]); tester.write_usize(address_space, opened_length_register, [opened.len()]); - tester.write_usize(address_space, proof_id, [streams.hint_space.len()]); + tester.write_usize(address_space, proof_id, [tester.streams.hint_space.len()]); tester.write_usize(address_space, index_register, [index_base_pointer]); tester.write_usize(address_space, commit_register, [commit_pointer]); @@ -218,15 +214,15 @@ fn test(cases: [Case; N]) { [row_pointer, opened_row.len() / opened_element_size], ); for (j, &opened_value) in opened_row.iter().enumerate() { - tester.write_cell(address_space, row_pointer + j, opened_value); + tester.write(address_space, row_pointer + j, [opened_value]); } } - streams + tester + .streams .hint_space .push(proof.iter().flatten().copied().collect()); - drop(streams); for (i, &bit) in sibling_is_on_right.iter().enumerate() { - tester.write_cell(address_space, index_base_pointer + i, F::from_bool(bit)); + tester.write(address_space, index_base_pointer + i, [F::from_bool(bit)]); } tester.write(address_space, commit_pointer, commit); @@ -383,15 +379,14 @@ fn random_instructions(num_ops: usize) -> Vec> { fn tester_with_random_poseidon2_ops(num_ops: usize) -> VmChipTester { let elem_range = || 1..=100; - let mut tester = VmChipTestBuilder::default(); - let streams = Arc::new(Mutex::new(Streams::default())); - let mut chip = NativePoseidon2Chip::::new( + let mut tester = VmChipTestBuilder::default_native(); + let mut chip = new_native_poseidon2_chip::( tester.system_port(), - tester.offline_memory_mutex_arc(), Poseidon2Config::default(), VERIFY_BATCH_BUS, - streams.clone(), + tester.memory_helper(), ); + chip.set_trace_buffer_height(MAX_INS_CAPACITY); let mut rng = create_seeded_rng(); @@ -417,23 +412,24 @@ fn tester_with_random_poseidon2_ops(num_ops: usize) -> VmChipTester { - let data_left: [_; CHUNK] = std::array::from_fn(|i| data[i]); - let data_right: [_; CHUNK] = std::array::from_fn(|i| data[CHUNK + i]); tester.write(e, lhs, data_left); tester.write(e, rhs, data_right); } PERM_POS2 => { - tester.write(e, lhs, data); + tester.write(e, lhs, data_left); + tester.write(e, lhs + CHUNK, data_right); } } @@ -446,8 +442,10 @@ fn tester_with_random_poseidon2_ops(num_ops: usize) -> VmChipTester { - let actual = tester.read::<{ 2 * CHUNK }>(e, dst); - assert_eq!(hash, actual); + let actual_0 = tester.read::<{ CHUNK }>(e, dst); + let actual_1 = tester.read::<{ CHUNK }>(e, dst + CHUNK); + let actual = [actual_0, actual_1].concat(); + assert_eq!(&hash, &actual[..]); } } } @@ -497,7 +495,11 @@ fn air_test_with_compress_poseidon2( let vm = VirtualMachine::new(engine, config); let pk = vm.keygen(); - let result = vm.execute_and_generate(program, vec![]).unwrap(); + let vk = pk.get_vk(); + let segments = vm + .execute_metered(program.clone(), vec![], &vk.num_interactions()) + .unwrap(); + let result = vm.execute_and_generate(program, vec![], &segments).unwrap(); let proofs = vm.prove(&pk, result); for proof in proofs { verify_single(&vm.engine, &pk.get_vk(), &proof).expect("Verification failed"); diff --git a/extensions/native/circuit/src/poseidon2/trace.rs b/extensions/native/circuit/src/poseidon2/trace.rs deleted file mode 100644 index df8547767f..0000000000 --- a/extensions/native/circuit/src/poseidon2/trace.rs +++ /dev/null @@ -1,485 +0,0 @@ -use std::{borrow::BorrowMut, sync::Arc}; - -use openvm_circuit::system::memory::{MemoryAuxColsFactory, OfflineMemory}; -use openvm_circuit_primitives::utils::next_power_of_two_or_zero; -use openvm_instructions::{instruction::Instruction, LocalOpcode}; -use openvm_native_compiler::Poseidon2Opcode::COMP_POS2; -use openvm_stark_backend::{ - config::{StarkGenericConfig, Val}, - p3_air::BaseAir, - p3_field::{Field, PrimeField32}, - p3_matrix::dense::RowMajorMatrix, - p3_maybe_rayon::prelude::*, - prover::types::AirProofInput, - AirRef, Chip, ChipUsageGetter, -}; - -use crate::{ - chip::{SimplePoseidonRecord, NUM_INITIAL_READS}, - poseidon2::{ - chip::{ - CellRecord, IncorporateRowRecord, IncorporateSiblingRecord, InsideRowRecord, - NativePoseidon2Chip, VerifyBatchRecord, - }, - columns::{ - InsideRowSpecificCols, NativePoseidon2Cols, SimplePoseidonSpecificCols, - TopLevelSpecificCols, - }, - CHUNK, - }, -}; -impl ChipUsageGetter - for NativePoseidon2Chip -{ - fn air_name(&self) -> String { - "VerifyBatchAir".to_string() - } - - fn current_trace_height(&self) -> usize { - self.height - } - - fn trace_width(&self) -> usize { - NativePoseidon2Cols::::width() - } -} - -impl NativePoseidon2Chip { - fn generate_subair_cols(&self, input: [F; 2 * CHUNK], cols: &mut [F]) { - let inner_trace = self.subchip.generate_trace(vec![input]); - let inner_width = self.air.subair.width(); - cols[..inner_width].copy_from_slice(inner_trace.values.as_slice()); - } - #[allow(clippy::too_many_arguments)] - fn incorporate_sibling_record_to_row( - &self, - record: &IncorporateSiblingRecord, - aux_cols_factory: &MemoryAuxColsFactory, - slice: &mut [F], - memory: &OfflineMemory, - parent: &VerifyBatchRecord, - proof_index: usize, - opened_index: usize, - log_height: usize, - ) { - let &IncorporateSiblingRecord { - read_sibling_is_on_right, - sibling_is_on_right, - p2_input, - } = record; - - let read_sibling_is_on_right = memory.record_by_id(read_sibling_is_on_right); - - self.generate_subair_cols(p2_input, slice); - let cols: &mut NativePoseidon2Cols = slice.borrow_mut(); - cols.incorporate_row = F::ZERO; - cols.incorporate_sibling = F::ONE; - cols.inside_row = F::ZERO; - cols.simple = F::ZERO; - cols.end_inside_row = F::ZERO; - cols.end_top_level = F::ZERO; - cols.start_top_level = F::ZERO; - cols.opened_element_size_inv = parent.opened_element_size_inv(); - cols.very_first_timestamp = F::from_canonical_u32(parent.from_state.timestamp); - cols.start_timestamp = - F::from_canonical_u32(read_sibling_is_on_right.timestamp - NUM_INITIAL_READS as u32); - - let specific: &mut TopLevelSpecificCols = - cols.specific[..TopLevelSpecificCols::::width()].borrow_mut(); - - specific.end_timestamp = - F::from_canonical_usize(read_sibling_is_on_right.timestamp as usize + 1); - cols.initial_opened_index = F::from_canonical_usize(opened_index); - specific.final_opened_index = F::from_canonical_usize(opened_index - 1); - specific.log_height = F::from_canonical_usize(log_height); - specific.opened_length = F::from_canonical_usize(parent.opened_length); - specific.dim_base_pointer = parent.dim_base_pointer; - cols.opened_base_pointer = parent.opened_base_pointer; - specific.index_base_pointer = parent.index_base_pointer; - - specific.proof_index = F::from_canonical_usize(proof_index); - aux_cols_factory.generate_read_aux( - read_sibling_is_on_right, - &mut specific.read_initial_height_or_sibling_is_on_right, - ); - specific.sibling_is_on_right = F::from_bool(sibling_is_on_right); - } - fn correct_last_top_level_row( - &self, - record: &VerifyBatchRecord, - aux_cols_factory: &MemoryAuxColsFactory, - slice: &mut [F], - memory: &OfflineMemory, - ) { - let &VerifyBatchRecord { - from_state, - commit_pointer, - dim_base_pointer_read, - opened_base_pointer_read, - opened_length_read, - index_base_pointer_read, - commit_pointer_read, - commit_read, - .. - } = record; - let instruction = &record.instruction; - let cols: &mut NativePoseidon2Cols = slice.borrow_mut(); - cols.end_top_level = F::ONE; - - let specific: &mut TopLevelSpecificCols = - cols.specific[..TopLevelSpecificCols::::width()].borrow_mut(); - - specific.pc = F::from_canonical_u32(from_state.pc); - specific.dim_register = instruction.a; - specific.opened_register = instruction.b; - specific.opened_length_register = instruction.c; - specific.proof_id = instruction.d; - specific.index_register = instruction.e; - specific.commit_register = instruction.f; - specific.commit_pointer = commit_pointer; - aux_cols_factory.generate_read_aux( - memory.record_by_id(dim_base_pointer_read), - &mut specific.dim_base_pointer_read, - ); - aux_cols_factory.generate_read_aux( - memory.record_by_id(opened_base_pointer_read), - &mut specific.opened_base_pointer_read, - ); - aux_cols_factory.generate_read_aux( - memory.record_by_id(opened_length_read), - &mut specific.opened_length_read, - ); - aux_cols_factory.generate_read_aux( - memory.record_by_id(index_base_pointer_read), - &mut specific.index_base_pointer_read, - ); - aux_cols_factory.generate_read_aux( - memory.record_by_id(commit_pointer_read), - &mut specific.commit_pointer_read, - ); - aux_cols_factory - .generate_read_aux(memory.record_by_id(commit_read), &mut specific.commit_read); - } - #[allow(clippy::too_many_arguments)] - fn incorporate_row_record_to_row( - &self, - record: &IncorporateRowRecord, - aux_cols_factory: &MemoryAuxColsFactory, - slice: &mut [F], - memory: &OfflineMemory, - parent: &VerifyBatchRecord, - proof_index: usize, - log_height: usize, - ) { - let &IncorporateRowRecord { - initial_opened_index, - final_opened_index, - initial_height_read, - final_height_read, - p2_input, - .. - } = record; - - let initial_height_read = memory.record_by_id(initial_height_read); - let final_height_read = memory.record_by_id(final_height_read); - - self.generate_subair_cols(p2_input, slice); - let cols: &mut NativePoseidon2Cols = slice.borrow_mut(); - cols.incorporate_row = F::ONE; - cols.incorporate_sibling = F::ZERO; - cols.inside_row = F::ZERO; - cols.simple = F::ZERO; - cols.end_inside_row = F::ZERO; - cols.end_top_level = F::ZERO; - cols.start_top_level = F::from_bool(proof_index == 0); - cols.opened_element_size_inv = parent.opened_element_size_inv(); - cols.very_first_timestamp = F::from_canonical_u32(parent.from_state.timestamp); - cols.start_timestamp = F::from_canonical_u32( - memory - .record_by_id( - record.chunks[0].cells[0] - .read_row_pointer_and_length - .unwrap(), - ) - .timestamp - - NUM_INITIAL_READS as u32, - ); - let specific: &mut TopLevelSpecificCols = - cols.specific[..TopLevelSpecificCols::::width()].borrow_mut(); - - specific.end_timestamp = F::from_canonical_u32(final_height_read.timestamp + 1); - - cols.initial_opened_index = F::from_canonical_usize(initial_opened_index); - specific.final_opened_index = F::from_canonical_usize(final_opened_index); - specific.log_height = F::from_canonical_usize(log_height); - specific.opened_length = F::from_canonical_usize(parent.opened_length); - specific.dim_base_pointer = parent.dim_base_pointer; - cols.opened_base_pointer = parent.opened_base_pointer; - specific.index_base_pointer = parent.index_base_pointer; - - specific.proof_index = F::from_canonical_usize(proof_index); - aux_cols_factory.generate_read_aux( - initial_height_read, - &mut specific.read_initial_height_or_sibling_is_on_right, - ); - aux_cols_factory.generate_read_aux(final_height_read, &mut specific.read_final_height); - } - #[allow(clippy::too_many_arguments)] - fn inside_row_record_to_row( - &self, - record: &InsideRowRecord, - aux_cols_factory: &MemoryAuxColsFactory, - slice: &mut [F], - memory: &OfflineMemory, - parent: &IncorporateRowRecord, - grandparent: &VerifyBatchRecord, - is_last: bool, - ) { - let InsideRowRecord { cells, p2_input } = record; - - self.generate_subair_cols(*p2_input, slice); - let cols: &mut NativePoseidon2Cols = slice.borrow_mut(); - cols.incorporate_row = F::ZERO; - cols.incorporate_sibling = F::ZERO; - cols.inside_row = F::ONE; - cols.simple = F::ZERO; - cols.end_inside_row = F::from_bool(is_last); - cols.end_top_level = F::ZERO; - cols.opened_element_size_inv = grandparent.opened_element_size_inv(); - cols.very_first_timestamp = F::from_canonical_u32( - memory - .record_by_id( - parent.chunks[0].cells[0] - .read_row_pointer_and_length - .unwrap(), - ) - .timestamp, - ); - cols.start_timestamp = - F::from_canonical_u32(memory.record_by_id(cells[0].read).timestamp - 1); - let specific: &mut InsideRowSpecificCols = - cols.specific[..InsideRowSpecificCols::::width()].borrow_mut(); - - for (record, cell) in cells.iter().zip(specific.cells.iter_mut()) { - let &CellRecord { - read, - opened_index, - read_row_pointer_and_length, - row_pointer, - row_end, - } = record; - aux_cols_factory.generate_read_aux(memory.record_by_id(read), &mut cell.read); - cell.opened_index = F::from_canonical_usize(opened_index); - if let Some(read_row_pointer_and_length) = read_row_pointer_and_length { - aux_cols_factory.generate_read_aux( - memory.record_by_id(read_row_pointer_and_length), - &mut cell.read_row_pointer_and_length, - ); - } - cell.row_pointer = F::from_canonical_usize(row_pointer); - cell.row_end = F::from_canonical_usize(row_end); - cell.is_first_in_row = F::from_bool(read_row_pointer_and_length.is_some()); - } - - for cell in specific.cells.iter_mut().skip(cells.len()) { - cell.opened_index = F::from_canonical_usize(parent.final_opened_index); - } - - cols.is_exhausted = std::array::from_fn(|i| F::from_bool(i + 1 >= cells.len())); - - cols.initial_opened_index = F::from_canonical_usize(parent.initial_opened_index); - cols.opened_base_pointer = grandparent.opened_base_pointer; - } - // returns number of used cells - fn verify_batch_record_to_rows( - &self, - record: &VerifyBatchRecord, - aux_cols_factory: &MemoryAuxColsFactory, - slice: &mut [F], - memory: &OfflineMemory, - ) -> usize { - let width = NativePoseidon2Cols::::width(); - let mut used_cells = 0; - - let mut opened_index = 0; - for (proof_index, top_level) in record.top_level.iter().enumerate() { - let log_height = record.initial_log_height - proof_index; - if let Some(incorporate_row) = &top_level.incorporate_row { - self.incorporate_row_record_to_row( - incorporate_row, - aux_cols_factory, - &mut slice[used_cells..used_cells + width], - memory, - record, - proof_index, - log_height, - ); - opened_index = incorporate_row.final_opened_index + 1; - used_cells += width; - } - if let Some(incorporate_sibling) = &top_level.incorporate_sibling { - self.incorporate_sibling_record_to_row( - incorporate_sibling, - aux_cols_factory, - &mut slice[used_cells..used_cells + width], - memory, - record, - proof_index, - opened_index, - log_height, - ); - used_cells += width; - } - } - self.correct_last_top_level_row( - record, - aux_cols_factory, - &mut slice[used_cells - width..used_cells], - memory, - ); - - for top_level in record.top_level.iter() { - if let Some(incorporate_row) = &top_level.incorporate_row { - for (i, chunk) in incorporate_row.chunks.iter().enumerate() { - self.inside_row_record_to_row( - chunk, - aux_cols_factory, - &mut slice[used_cells..used_cells + width], - memory, - incorporate_row, - record, - i == incorporate_row.chunks.len() - 1, - ); - used_cells += width; - } - } - } - - used_cells - } - fn simple_record_to_row( - &self, - record: &SimplePoseidonRecord, - aux_cols_factory: &MemoryAuxColsFactory, - slice: &mut [F], - memory: &OfflineMemory, - ) { - let &SimplePoseidonRecord { - from_state, - instruction: - Instruction { - opcode, - a: output_register, - b: input_register_1, - c: input_register_2, - .. - }, - read_input_pointer_1, - read_input_pointer_2, - read_output_pointer, - read_data_1, - read_data_2, - write_data_1, - write_data_2, - input_pointer_1, - input_pointer_2, - output_pointer, - p2_input, - } = record; - - let read_input_pointer_1 = memory.record_by_id(read_input_pointer_1); - let read_output_pointer = memory.record_by_id(read_output_pointer); - let read_data_1 = memory.record_by_id(read_data_1); - let read_data_2 = memory.record_by_id(read_data_2); - let write_data_1 = memory.record_by_id(write_data_1); - - self.generate_subair_cols(p2_input, slice); - let cols: &mut NativePoseidon2Cols = slice.borrow_mut(); - cols.incorporate_row = F::ZERO; - cols.incorporate_sibling = F::ZERO; - cols.inside_row = F::ZERO; - cols.simple = F::ONE; - cols.end_inside_row = F::ZERO; - cols.end_top_level = F::ZERO; - cols.is_exhausted = [F::ZERO; CHUNK - 1]; - - cols.start_timestamp = F::from_canonical_u32(from_state.timestamp); - let specific: &mut SimplePoseidonSpecificCols = - cols.specific[..SimplePoseidonSpecificCols::::width()].borrow_mut(); - - specific.pc = F::from_canonical_u32(from_state.pc); - specific.is_compress = F::from_bool(opcode == COMP_POS2.global_opcode()); - specific.output_register = output_register; - specific.input_register_1 = input_register_1; - specific.input_register_2 = input_register_2; - specific.output_pointer = output_pointer; - specific.input_pointer_1 = input_pointer_1; - specific.input_pointer_2 = input_pointer_2; - aux_cols_factory.generate_read_aux(read_output_pointer, &mut specific.read_output_pointer); - aux_cols_factory - .generate_read_aux(read_input_pointer_1, &mut specific.read_input_pointer_1); - aux_cols_factory.generate_read_aux(read_data_1, &mut specific.read_data_1); - aux_cols_factory.generate_read_aux(read_data_2, &mut specific.read_data_2); - aux_cols_factory.generate_write_aux(write_data_1, &mut specific.write_data_1); - - if opcode == COMP_POS2.global_opcode() { - let read_input_pointer_2 = memory.record_by_id(read_input_pointer_2.unwrap()); - aux_cols_factory - .generate_read_aux(read_input_pointer_2, &mut specific.read_input_pointer_2); - } else { - let write_data_2 = memory.record_by_id(write_data_2.unwrap()); - aux_cols_factory.generate_write_aux(write_data_2, &mut specific.write_data_2); - } - } - - fn generate_trace(self) -> RowMajorMatrix { - let width = self.trace_width(); - let height = next_power_of_two_or_zero(self.height); - let mut flat_trace = F::zero_vec(width * height); - - let memory = self.offline_memory.lock().unwrap(); - - let aux_cols_factory = memory.aux_cols_factory(); - - let mut used_cells = 0; - for record in self.record_set.verify_batch_records.iter() { - used_cells += self.verify_batch_record_to_rows( - record, - &aux_cols_factory, - &mut flat_trace[used_cells..], - &memory, - ); - } - for record in self.record_set.simple_permute_records.iter() { - self.simple_record_to_row( - record, - &aux_cols_factory, - &mut flat_trace[used_cells..used_cells + width], - &memory, - ); - used_cells += width; - } - // poseidon2 constraints are always checked - // following can be optimized to only hash [0; _] once - flat_trace[used_cells..] - .par_chunks_mut(width) - .for_each(|row| { - self.generate_subair_cols([F::ZERO; 2 * CHUNK], row); - }); - - RowMajorMatrix::new(flat_trace, width) - } -} - -impl Chip - for NativePoseidon2Chip, SBOX_REGISTERS> -where - Val: PrimeField32, -{ - fn air(&self) -> AirRef { - Arc::new(self.air.clone()) - } - fn generate_air_proof_input(self) -> AirProofInput { - AirProofInput::simple_no_pis(self.generate_trace()) - } -} diff --git a/extensions/native/circuit/src/utils.rs b/extensions/native/circuit/src/utils.rs index 2815427336..99fb5fa54f 100644 --- a/extensions/native/circuit/src/utils.rs +++ b/extensions/native/circuit/src/utils.rs @@ -1,19 +1,130 @@ -use openvm_circuit::arch::{Streams, SystemConfig, VmExecutor}; +use openvm_circuit::arch::{Streams, SystemConfig, VirtualMachine}; use openvm_instructions::program::Program; -use openvm_stark_sdk::p3_baby_bear::BabyBear; +use openvm_stark_sdk::{config::baby_bear_poseidon2::default_engine, p3_baby_bear::BabyBear}; use crate::{Native, NativeConfig}; +pub(crate) const CASTF_MAX_BITS: usize = 30; + +pub fn execute_program_with_system_config( + program: Program, + input_stream: impl Into>, + system_config: SystemConfig, +) { + let config = NativeConfig::new(system_config, Native); + let input = input_stream.into(); + + let vm = VirtualMachine::new(default_engine(), config); + let pk = vm.keygen(); + let vk = pk.get_vk(); + let segments = vm + .executor + .execute_metered(program.clone(), input.clone(), &vk.num_interactions()) + .unwrap(); + vm.execute(program, input, &segments).unwrap(); +} + pub fn execute_program(program: Program, input_stream: impl Into>) { let system_config = SystemConfig::default() .with_public_values(4) .with_max_segment_len((1 << 25) - 100); - let config = NativeConfig::new(system_config, Native); - let executor = VmExecutor::::new(config); - - executor.execute(program, input_stream).unwrap(); + execute_program_with_system_config(program, input_stream, system_config); } pub(crate) const fn const_max(a: usize, b: usize) -> usize { [a, b][(a < b) as usize] } + +/// Testing framework +#[cfg(any(test, feature = "test-utils"))] +pub mod test_utils { + use std::array; + + use openvm_circuit::{ + arch::{ + testing::{memory::gen_pointer, VmChipTestBuilder}, + Streams, + }, + utils::test_system_config, + }; + use openvm_instructions::{ + program::Program, + riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS}, + }; + use openvm_native_compiler::conversion::AS; + use openvm_stark_backend::p3_field::PrimeField32; + use openvm_stark_sdk::p3_baby_bear::BabyBear; + use rand::{distributions::Standard, prelude::Distribution, rngs::StdRng, Rng}; + + use super::execute_program_with_system_config; + use crate::{extension::NativeConfig, Rv32WithKernelsConfig}; + + // If immediate, returns (value, AS::Immediate). Otherwise, writes to native memory and returns + // (ptr, AS::Native). If is_imm is None, randomizes it. + pub fn write_native_or_imm( + tester: &mut VmChipTestBuilder, + rng: &mut StdRng, + value: F, + is_imm: Option, + ) -> (F, usize) { + let is_imm = is_imm.unwrap_or(rng.gen_bool(0.5)); + if is_imm { + (value, AS::Immediate as usize) + } else { + let ptr = gen_pointer(rng, 1); + tester.write::<1>(AS::Native as usize, ptr, [value]); + (F::from_canonical_usize(ptr), AS::Native as usize) + } + } + + // Writes value to native memory and returns a pointer to the first element together with the + // value If `value` is None, randomizes it. + pub fn write_native_array( + tester: &mut VmChipTestBuilder, + rng: &mut StdRng, + value: Option<[F; N]>, + ) -> ([F; N], usize) + where + Standard: Distribution, // Needed for `rng.gen` + { + let value = value.unwrap_or(array::from_fn(|_| rng.gen())); + let ptr = gen_pointer(rng, N); + tester.write::(AS::Native as usize, ptr, value); + (value, ptr) + } + + pub fn test_execute_program( + program: Program, + input_stream: impl Into>, + ) { + let system_config = test_native_config() + .system + .with_public_values(4) + .with_max_segment_len((1 << 25) - 100); + execute_program_with_system_config(program, input_stream, system_config); + } + + pub fn test_native_config() -> NativeConfig { + let mut system = test_system_config(); + system.memory_config.addr_space_sizes[RV32_REGISTER_AS as usize] = 0; + system.memory_config.addr_space_sizes[RV32_MEMORY_AS as usize] = 0; + NativeConfig { + system, + native: Default::default(), + } + } + + pub fn test_native_continuations_config() -> NativeConfig { + NativeConfig { + system: test_system_config().with_continuations(), + native: Default::default(), + } + } + + pub fn test_rv32_with_kernels_config() -> Rv32WithKernelsConfig { + Rv32WithKernelsConfig { + system: test_system_config().with_continuations(), + ..Default::default() + } + } +} diff --git a/extensions/native/compiler/Cargo.toml b/extensions/native/compiler/Cargo.toml index cb41c17f63..3b020d76fb 100644 --- a/extensions/native/compiler/Cargo.toml +++ b/extensions/native/compiler/Cargo.toml @@ -34,7 +34,7 @@ strum = { workspace = true } [dev-dependencies] p3-symmetric = { workspace = true } openvm-circuit = { workspace = true, features = ["test-utils"] } -openvm-native-circuit = { workspace = true } +openvm-native-circuit = { workspace = true, features = ["test-utils"]} openvm-stark-sdk = { workspace = true } rand.workspace = true diff --git a/extensions/native/compiler/tests/arithmetic.rs b/extensions/native/compiler/tests/arithmetic.rs index cd68fab563..3dc4c7c7fd 100644 --- a/extensions/native/compiler/tests/arithmetic.rs +++ b/extensions/native/compiler/tests/arithmetic.rs @@ -1,5 +1,5 @@ -use openvm_circuit::arch::{ExecutionError, VmExecutor}; -use openvm_native_circuit::{execute_program, NativeConfig}; +use openvm_circuit::arch::{ExecutionError, VirtualMachine}; +use openvm_native_circuit::{test_execute_program, NativeConfig}; use openvm_native_compiler::{ asm::{AsmBuilder, AsmCompiler, AsmConfig}, conversion::{convert_program, CompilerOptions}, @@ -8,7 +8,7 @@ use openvm_native_compiler::{ use openvm_stark_backend::p3_field::{ extension::BinomialExtensionField, Field, FieldAlgebra, FieldExtensionAlgebra, }; -use openvm_stark_sdk::p3_baby_bear::BabyBear; +use openvm_stark_sdk::{config::baby_bear_poseidon2::default_engine, p3_baby_bear::BabyBear}; use rand::{thread_rng, Rng}; const WORD_SIZE: usize = 1; @@ -93,7 +93,7 @@ fn test_compiler_arithmetic() { builder.halt(); let program = builder.clone().compile_isa(); - execute_program(program, vec![]); + test_execute_program(program, vec![]); } #[test] @@ -116,7 +116,7 @@ fn test_compiler_arithmetic_2() { builder.halt(); let program = builder.clone().compile_isa(); - execute_program(program, vec![]); + test_execute_program(program, vec![]); } #[test] @@ -152,7 +152,7 @@ fn test_in_place_arithmetic() { builder.halt(); let program = builder.clone().compile_isa(); - execute_program(program, vec![]); + test_execute_program(program, vec![]); } #[test] @@ -177,7 +177,7 @@ fn test_field_immediate() { builder.halt(); let program = builder.compile_isa(); - execute_program(program, vec![]); + test_execute_program(program, vec![]); } #[test] @@ -249,10 +249,10 @@ fn test_ext_immediate() { builder.halt(); let program = builder.clone().compile_isa(); - execute_program(program, vec![]); + test_execute_program(program, vec![]); let program = builder.compile_isa(); - execute_program(program, vec![]); + test_execute_program(program, vec![]); } #[test] @@ -302,10 +302,10 @@ fn test_ext_felt_arithmetic() { builder.halt(); let program = builder.clone().compile_isa(); - execute_program(program, vec![]); + test_execute_program(program, vec![]); let program = builder.compile_isa(); - execute_program(program, vec![]); + test_execute_program(program, vec![]); } #[test] @@ -332,7 +332,7 @@ fn test_felt_equality() { println!("{}", asm_code); let program = convert_program::(asm_code, CompilerOptions::default()); - execute_program(program, vec![]); + test_execute_program(program, vec![]); } #[test] @@ -369,7 +369,7 @@ fn test_ext_equality() { builder.halt(); let program = builder.compile_isa(); - execute_program(program, vec![]); + test_execute_program(program, vec![]); } #[test] @@ -392,7 +392,12 @@ fn assert_failed_assertion( ) { let program = builder.compile_isa(); - let executor = VmExecutor::::new(NativeConfig::aggregation(4, 3)); - let result = executor.execute(program, vec![]); + let config = NativeConfig::aggregation(4, 3); + let vm = VirtualMachine::new(default_engine(), config); + + let vm_pk = vm.keygen(); + let vm_vk = vm_pk.get_vk(); + + let result = vm.execute_metered(program, vec![], &vm_vk.num_interactions()); assert!(matches!(result, Err(ExecutionError::Fail { .. }))); } diff --git a/extensions/native/compiler/tests/array.rs b/extensions/native/compiler/tests/array.rs index 9ef5eca8ea..cd26b3ce61 100644 --- a/extensions/native/compiler/tests/array.rs +++ b/extensions/native/compiler/tests/array.rs @@ -1,4 +1,4 @@ -use openvm_native_circuit::execute_program; +use openvm_native_circuit::test_execute_program; use openvm_native_compiler::{ asm::{AsmBuilder, AsmConfig}, ir::{Array, Config, Ext, Felt, RVar, Usize, Var}, @@ -104,7 +104,7 @@ fn test_array_eq() { builder.halt(); let program = builder.compile_isa(); - execute_program(program, vec![]); + test_execute_program(program, vec![]); } #[should_panic] @@ -125,7 +125,7 @@ fn test_array_eq_neg() { builder.halt(); let program = builder.compile_isa(); - execute_program(program, vec![]); + test_execute_program(program, vec![]); } #[test] @@ -161,7 +161,7 @@ fn test_slice_variable_impl_happy_path() { builder.halt(); let program = builder.compile_isa(); - execute_program(program, vec![]); + test_execute_program(program, vec![]); } #[test] @@ -183,5 +183,5 @@ fn test_slice_assert_eq_neg() { builder.halt(); let program = builder.compile_isa(); - execute_program(program, vec![]); + test_execute_program(program, vec![]); } diff --git a/extensions/native/compiler/tests/conditionals.rs b/extensions/native/compiler/tests/conditionals.rs index 29fa85386a..b6ab8cf8ad 100644 --- a/extensions/native/compiler/tests/conditionals.rs +++ b/extensions/native/compiler/tests/conditionals.rs @@ -1,4 +1,4 @@ -use openvm_native_circuit::execute_program; +use openvm_native_circuit::test_execute_program; use openvm_native_compiler::{asm::AsmBuilder, ir::Var}; use openvm_stark_backend::p3_field::{extension::BinomialExtensionField, FieldAlgebra}; use openvm_stark_sdk::p3_baby_bear::BabyBear; @@ -50,7 +50,7 @@ fn test_compiler_conditionals() { builder.halt(); let program = builder.compile_isa(); - execute_program(program, vec![]); + test_execute_program(program, vec![]); } #[test] @@ -79,7 +79,7 @@ fn test_compiler_conditionals_v2() { builder.halt(); let program = builder.compile_isa(); - execute_program(program, vec![]); + test_execute_program(program, vec![]); } #[test] diff --git a/extensions/native/compiler/tests/cycle_tracker.rs b/extensions/native/compiler/tests/cycle_tracker.rs index 3561dfd2ec..e2a00feab3 100644 --- a/extensions/native/compiler/tests/cycle_tracker.rs +++ b/extensions/native/compiler/tests/cycle_tracker.rs @@ -1,4 +1,4 @@ -use openvm_native_circuit::execute_program; +use openvm_native_circuit::test_execute_program; use openvm_native_compiler::{asm::AsmBuilder, conversion::CompilerOptions, ir::Var}; use openvm_stark_backend::p3_field::{extension::BinomialExtensionField, FieldAlgebra}; use openvm_stark_sdk::p3_baby_bear::BabyBear; @@ -48,5 +48,5 @@ fn test_cycle_tracker() { } println!("{}", program); - execute_program(program, vec![]); + test_execute_program(program, vec![]); } diff --git a/extensions/native/compiler/tests/ext.rs b/extensions/native/compiler/tests/ext.rs index 5da70cb53b..70494bb6bd 100644 --- a/extensions/native/compiler/tests/ext.rs +++ b/extensions/native/compiler/tests/ext.rs @@ -1,4 +1,4 @@ -use openvm_native_circuit::execute_program; +use openvm_native_circuit::test_execute_program; use openvm_native_compiler::{ asm::AsmBuilder, ir::{Ext, Felt}, @@ -31,7 +31,7 @@ fn test_ext2felt() { let program = builder.compile_isa(); println!("{}", program); - execute_program(program, vec![]); + test_execute_program(program, vec![]); } #[test] @@ -60,5 +60,5 @@ fn test_ext_from_base_slice() { let program = builder.compile_isa(); println!("{}", program); - execute_program(program, vec![]); + test_execute_program(program, vec![]); } diff --git a/extensions/native/compiler/tests/for_loops.rs b/extensions/native/compiler/tests/for_loops.rs index 123a416cdb..709105ee32 100644 --- a/extensions/native/compiler/tests/for_loops.rs +++ b/extensions/native/compiler/tests/for_loops.rs @@ -1,4 +1,4 @@ -use openvm_native_circuit::execute_program; +use openvm_native_circuit::test_execute_program; use openvm_native_compiler::{ asm::{AsmBuilder, AsmConfig}, ir::{Array, Var}, @@ -46,7 +46,7 @@ fn test_compiler_for_loops() { builder.halt(); let program = builder.compile_isa(); - execute_program(program, vec![]); + test_execute_program(program, vec![]); } #[test] @@ -83,7 +83,7 @@ fn test_compiler_zip_fixed() { builder.halt(); let program = builder.compile_isa(); - execute_program(program, vec![]); + test_execute_program(program, vec![]); } #[test] @@ -125,7 +125,7 @@ fn test_compiler_zip_dyn() { builder.halt(); let program = builder.compile_isa(); - execute_program(program, vec![]); + test_execute_program(program, vec![]); } #[test] @@ -162,7 +162,7 @@ fn test_compiler_nested_array_loop() { builder.halt(); let program = builder.compile_isa(); - execute_program(program, vec![]); + test_execute_program(program, vec![]); } #[test] @@ -182,5 +182,5 @@ fn test_compiler_bneinc() { builder.halt(); let program = builder.compile_isa(); - execute_program(program, vec![]); + test_execute_program(program, vec![]); } diff --git a/extensions/native/compiler/tests/fri_ro_eval.rs b/extensions/native/compiler/tests/fri_ro_eval.rs index 6f332d22b6..dcba950a08 100644 --- a/extensions/native/compiler/tests/fri_ro_eval.rs +++ b/extensions/native/compiler/tests/fri_ro_eval.rs @@ -1,4 +1,4 @@ -use openvm_native_circuit::execute_program; +use openvm_native_circuit::test_execute_program; use openvm_native_compiler::{ asm::{AsmBuilder, AsmCompiler}, conversion::{convert_program, CompilerOptions}, @@ -89,5 +89,5 @@ fn test_single_reduced_opening_eval() { let asm_code = compiler.code(); let program = convert_program::(asm_code, CompilerOptions::default()); - execute_program(program, vec![mat_opening]); + test_execute_program(program, vec![mat_opening]); } diff --git a/extensions/native/compiler/tests/hint.rs b/extensions/native/compiler/tests/hint.rs index 05aff8a390..5ac1722577 100644 --- a/extensions/native/compiler/tests/hint.rs +++ b/extensions/native/compiler/tests/hint.rs @@ -1,4 +1,4 @@ -use openvm_native_circuit::execute_program; +use openvm_native_circuit::test_execute_program; use openvm_native_compiler::{asm::AsmBuilder, ir::Felt}; use openvm_stark_backend::p3_field::{extension::BinomialExtensionField, Field, FieldAlgebra}; use openvm_stark_sdk::p3_baby_bear::BabyBear; @@ -29,5 +29,5 @@ fn test_hint_bits_felt() { let program = builder.compile_isa(); println!("{}", program); - execute_program(program, vec![]); + test_execute_program(program, vec![]); } diff --git a/extensions/native/compiler/tests/io.rs b/extensions/native/compiler/tests/io.rs index 58b2c46b8f..ab16bc276a 100644 --- a/extensions/native/compiler/tests/io.rs +++ b/extensions/native/compiler/tests/io.rs @@ -1,4 +1,4 @@ -use openvm_native_circuit::execute_program; +use openvm_native_circuit::test_execute_program; use openvm_native_compiler::{ asm::{AsmBuilder, AsmCompiler}, conversion::{convert_program, CompilerOptions}, @@ -61,5 +61,5 @@ fn test_io() { println!("{}", asm_code); let program = convert_program::(asm_code, CompilerOptions::default()); - execute_program(program, witness_stream); + test_execute_program(program, witness_stream); } diff --git a/extensions/native/compiler/tests/poseidon2.rs b/extensions/native/compiler/tests/poseidon2.rs index de7dbc6e9b..b3dee9f4c1 100644 --- a/extensions/native/compiler/tests/poseidon2.rs +++ b/extensions/native/compiler/tests/poseidon2.rs @@ -1,4 +1,4 @@ -use openvm_native_circuit::execute_program; +use openvm_native_circuit::test_execute_program; use openvm_native_compiler::{ asm::AsmBuilder, ir::{Array, Var, PERMUTATION_WIDTH}, @@ -49,7 +49,7 @@ fn test_compiler_poseidon2_permute() { builder.halt(); let program = builder.compile_isa(); - execute_program(program, vec![]); + test_execute_program(program, vec![]); } #[test] @@ -80,5 +80,5 @@ fn test_compiler_poseidon2_hash_1() { builder.halt(); let program = builder.compile_isa(); - execute_program(program, vec![]); + test_execute_program(program, vec![]); } diff --git a/extensions/native/compiler/tests/public_values.rs b/extensions/native/compiler/tests/public_values.rs index 7c7abe3bc6..ba98a030d9 100644 --- a/extensions/native/compiler/tests/public_values.rs +++ b/extensions/native/compiler/tests/public_values.rs @@ -1,8 +1,8 @@ -use openvm_circuit::arch::{SingleSegmentVmExecutor, SystemConfig}; -use openvm_native_circuit::{execute_program, Native, NativeConfig}; +use openvm_circuit::arch::{SingleSegmentVmExecutor, VirtualMachine}; +use openvm_native_circuit::{test_execute_program, test_native_config}; use openvm_native_compiler::{asm::AsmBuilder, prelude::*}; use openvm_stark_backend::p3_field::{extension::BinomialExtensionField, FieldAlgebra}; -use openvm_stark_sdk::p3_baby_bear::BabyBear; +use openvm_stark_sdk::{config::baby_bear_poseidon2::default_engine, p3_baby_bear::BabyBear}; type F = BabyBear; type EF = BinomialExtensionField; @@ -28,13 +28,20 @@ fn test_compiler_public_values() { } let program = builder.compile_isa(); - let executor = SingleSegmentVmExecutor::new(NativeConfig::new( - SystemConfig::default().with_public_values(2), - Native, - )); + let config = test_native_config(); + + let vm = VirtualMachine::new(default_engine(), config.clone()); + let vm_pk = vm.keygen(); + let vm_vk = vm_pk.get_vk(); + + let executor = SingleSegmentVmExecutor::new(config); + + let max_trace_heights = executor + .execute_metered(program.clone().into(), vec![], &vm_vk.num_interactions()) + .unwrap(); let exe_result = executor - .execute_and_compute_heights(program, vec![]) + .execute_and_compute_heights(program, vec![], &max_trace_heights) .unwrap(); assert_eq!( exe_result @@ -66,5 +73,5 @@ fn test_compiler_public_values_no_initial() { builder.halt(); let program = builder.compile_isa(); - execute_program(program, vec![]); + test_execute_program(program, vec![]); } diff --git a/extensions/native/compiler/tests/range_check.rs b/extensions/native/compiler/tests/range_check.rs index 959f2bae9f..f1e7d9948f 100644 --- a/extensions/native/compiler/tests/range_check.rs +++ b/extensions/native/compiler/tests/range_check.rs @@ -1,4 +1,4 @@ -use openvm_native_circuit::execute_program; +use openvm_native_circuit::test_execute_program; use openvm_native_compiler::{asm::AsmBuilder, prelude::*}; use openvm_stark_backend::p3_field::{extension::BinomialExtensionField, FieldAlgebra}; use openvm_stark_sdk::p3_baby_bear::BabyBear; @@ -24,7 +24,7 @@ fn test_range_check_v() { builder.halt(); let program = builder.compile_isa(); - execute_program(program, vec![]); + test_execute_program(program, vec![]); } #[test] @@ -38,5 +38,5 @@ fn test_range_check_v_neg() { builder.halt(); let program = builder.compile_isa(); - execute_program(program, vec![]); + test_execute_program(program, vec![]); } diff --git a/extensions/native/recursion/Cargo.toml b/extensions/native/recursion/Cargo.toml index c799671a55..dd20f833d3 100644 --- a/extensions/native/recursion/Cargo.toml +++ b/extensions/native/recursion/Cargo.toml @@ -8,7 +8,7 @@ repository.workspace = true [dependencies] openvm-stark-backend = { workspace = true } -openvm-native-circuit = { workspace = true } +openvm-native-circuit = { workspace = true, features = ["test-utils"] } openvm-native-compiler = { workspace = true } openvm-native-compiler-derive = { workspace = true } openvm-stark-sdk = { workspace = true } diff --git a/extensions/native/recursion/src/challenger/duplex.rs b/extensions/native/recursion/src/challenger/duplex.rs index 7c0cd4dd88..309a95f766 100644 --- a/extensions/native/recursion/src/challenger/duplex.rs +++ b/extensions/native/recursion/src/challenger/duplex.rs @@ -189,7 +189,7 @@ impl ChallengerVariable for DuplexChallengerVariable { #[cfg(test)] mod tests { - use openvm_native_circuit::execute_program; + use openvm_native_circuit::test_execute_program; use openvm_native_compiler::{ asm::{AsmBuilder, AsmConfig}, ir::Felt, @@ -241,7 +241,7 @@ mod tests { builder.halt(); let program = builder.compile_isa(); - execute_program(program, vec![]); + test_execute_program(program, vec![]); } #[test] diff --git a/extensions/native/recursion/src/fri/domain.rs b/extensions/native/recursion/src/fri/domain.rs index cdc8fc242c..8b6aceefb7 100644 --- a/extensions/native/recursion/src/fri/domain.rs +++ b/extensions/native/recursion/src/fri/domain.rs @@ -156,7 +156,7 @@ where #[cfg(test)] pub(crate) mod tests { - use openvm_native_circuit::execute_program; + use openvm_native_circuit::test_execute_program; use openvm_native_compiler::asm::AsmBuilder; use openvm_stark_backend::{ config::{Domain, StarkGenericConfig, Val}, @@ -276,7 +276,7 @@ pub(crate) mod tests { builder.halt(); let program = builder.compile_isa(); - execute_program(program, vec![]); + test_execute_program(program, vec![]); } #[test] fn test_domain_static() { diff --git a/extensions/native/recursion/src/fri/two_adic_pcs.rs b/extensions/native/recursion/src/fri/two_adic_pcs.rs index 676da7493f..e8e1c296be 100644 --- a/extensions/native/recursion/src/fri/two_adic_pcs.rs +++ b/extensions/native/recursion/src/fri/two_adic_pcs.rs @@ -745,6 +745,6 @@ pub mod tests { #[test] fn test_two_adic_fri_pcs_single_batch() { let (program, witness) = build_test_fri_with_cols_and_log2_rows(10, 10); - openvm_native_circuit::execute_program(program, witness); + openvm_native_circuit::test_execute_program(program, witness); } } diff --git a/extensions/native/recursion/src/hints.rs b/extensions/native/recursion/src/hints.rs index b65f6ba647..642d3d5379 100644 --- a/extensions/native/recursion/src/hints.rs +++ b/extensions/native/recursion/src/hints.rs @@ -446,7 +446,7 @@ impl Hintable for Commitments> { #[cfg(test)] mod test { - use openvm_native_circuit::execute_program; + use openvm_native_circuit::test_execute_program; use openvm_native_compiler::{ asm::AsmBuilder, ir::{Ext, Felt, Var}, @@ -480,7 +480,7 @@ mod test { builder.halt(); let program = builder.compile_isa(); - execute_program(program, stream); + test_execute_program(program, stream); } #[test] @@ -527,6 +527,6 @@ mod test { builder.halt(); let program = builder.compile_isa(); - execute_program(program, stream); + test_execute_program(program, stream); } } diff --git a/extensions/native/recursion/src/testing_utils.rs b/extensions/native/recursion/src/testing_utils.rs index 380b2aa9a3..9e8d725b35 100644 --- a/extensions/native/recursion/src/testing_utils.rs +++ b/extensions/native/recursion/src/testing_utils.rs @@ -1,6 +1,11 @@ use inner::build_verification_program; -use openvm_circuit::{arch::instructions::program::Program, utils::execute_and_prove_program}; -use openvm_native_circuit::NativeConfig; +use openvm_circuit::{ + arch::{ + execution_mode::e1::E1Ctx, instructions::program::Program, interpreter::InterpretedInstance, + }, + utils::execute_and_prove_program, +}; +use openvm_native_circuit::{test_native_config, NativeConfig}; use openvm_native_compiler::conversion::CompilerOptions; use openvm_stark_backend::{ config::{Com, Domain, PcsProof, PcsProverData, StarkGenericConfig}, @@ -20,7 +25,6 @@ use crate::hints::InnerVal; type InnerSC = BabyBearPoseidon2Config; pub mod inner { - use openvm_native_circuit::NativeConfig; use openvm_native_compiler::conversion::CompilerOptions; use openvm_stark_sdk::{ config::{ @@ -75,7 +79,7 @@ pub mod inner { recursive_stark_test( vparams, CompilerOptions::default(), - NativeConfig::aggregation(4, 7), + test_native_config(), &BabyBearPoseidon2Engine::new(fri_params), ) .unwrap(); @@ -103,5 +107,9 @@ where { let (program, witness_stream) = build_verification_program(vparams, compiler_options); + let interpreter = InterpretedInstance::new(vm_config.clone(), program.clone()); + interpreter + .execute(E1Ctx::new(None), witness_stream.clone()) + .unwrap(); execute_and_prove_program(program, witness_stream, vm_config, engine) } diff --git a/extensions/native/recursion/src/tests.rs b/extensions/native/recursion/src/tests.rs index 4077ee6f1d..ce94bd0201 100644 --- a/extensions/native/recursion/src/tests.rs +++ b/extensions/native/recursion/src/tests.rs @@ -1,7 +1,7 @@ use std::{panic::catch_unwind, sync::Arc}; use openvm_circuit::utils::gen_vm_program_test_proof_input; -use openvm_native_circuit::NativeConfig; +use openvm_native_circuit::{test_native_config, NativeConfig}; use openvm_stark_backend::{ config::{StarkGenericConfig, Val}, interaction::BusIndex, @@ -148,7 +148,7 @@ fn test_optional_air() { let pk = keygen_builder.generate_pk(); let m_advice = new_from_inner_multi_vk(&pk.get_vk()); - let vm_config = NativeConfig::aggregation(4, 7); + let vm_config = test_native_config(); let program = VerifierProgram::build(m_advice, &fri_params); // Case 1: All AIRs are present. @@ -184,11 +184,11 @@ fn test_optional_air() { .verify(&pk.get_vk(), &proof) .expect("Verification failed"); // The VM program will panic when the program cannot verify the proof. - gen_vm_program_test_proof_input::( - program.clone(), - proof.write(), - vm_config.clone(), - ); + gen_vm_program_test_proof_input::< + BabyBearPoseidon2Config, + NativeConfig, + BabyBearPoseidon2Engine, + >(program.clone(), proof.write(), vm_config.clone()); } // Case 2: The second AIR is not presented. { @@ -215,11 +215,11 @@ fn test_optional_air() { .verify(&pk.get_vk(), &proof) .expect("Verification failed"); // The VM program will panic when the program cannot verify the proof. - gen_vm_program_test_proof_input::( - program.clone(), - proof.write(), - vm_config.clone(), - ); + gen_vm_program_test_proof_input::< + BabyBearPoseidon2Config, + NativeConfig, + BabyBearPoseidon2Engine, + >(program.clone(), proof.write(), vm_config.clone()); } // Case 3: Negative - unbalanced interactions. { @@ -238,11 +238,11 @@ fn test_optional_air() { assert!(engine.verify(&pk.get_vk(), &proof).is_err()); // The VM program should panic when the proof cannot be verified. let unwind_res = catch_unwind(|| { - gen_vm_program_test_proof_input::( - program.clone(), - proof.write(), - vm_config, - ) + gen_vm_program_test_proof_input::< + BabyBearPoseidon2Config, + NativeConfig, + BabyBearPoseidon2Engine, + >(program.clone(), proof.write(), vm_config) }); assert!(unwind_res.is_err()); } diff --git a/extensions/native/recursion/tests/recursion.rs b/extensions/native/recursion/tests/recursion.rs index 8f354f3316..4ef33a5cc1 100644 --- a/extensions/native/recursion/tests/recursion.rs +++ b/extensions/native/recursion/tests/recursion.rs @@ -1,4 +1,6 @@ -use openvm_circuit::arch::{instructions::program::Program, SystemConfig, VmConfig, VmExecutor}; +use openvm_circuit::arch::{ + instructions::program::Program, SystemConfig, VirtualMachine, VmConfig, +}; use openvm_native_circuit::{Native, NativeConfig}; use openvm_native_compiler::{asm::AsmBuilder, ir::Felt}; use openvm_native_recursion::testing_utils::inner::run_recursive_test; @@ -7,7 +9,11 @@ use openvm_stark_backend::{ p3_commit::PolynomialSpace, p3_field::{extension::BinomialExtensionField, FieldAlgebra}, }; -use openvm_stark_sdk::{config::FriParameters, p3_baby_bear::BabyBear, utils::ProofInputForTest}; +use openvm_stark_sdk::{ + config::{baby_bear_poseidon2::default_engine, FriParameters}, + p3_baby_bear::BabyBear, + utils::ProofInputForTest, +}; fn fibonacci_program(a: u32, b: u32, n: u32) -> Program { type F = BabyBear; @@ -47,9 +53,18 @@ where let vm_config = NativeConfig::new(SystemConfig::default().with_public_values(3), Native); let airs = vm_config.create_chip_complex().unwrap().airs(); - let executor = VmExecutor::::new(vm_config); + let vm = VirtualMachine::new(default_engine(), vm_config); + let pk = vm.keygen(); + let vk = pk.get_vk(); + let segments = vm + .executor + .execute_metered(fib_program.clone(), vec![], &vk.num_interactions()) + .unwrap(); - let mut result = executor.execute_and_generate(fib_program, vec![]).unwrap(); + let mut result = vm + .executor + .execute_and_generate(fib_program, vec![], &segments) + .unwrap(); assert_eq!(result.per_segment.len(), 1, "unexpected continuation"); let proof_input = result.per_segment.remove(0); // Filter out unused AIRS (where trace is empty) diff --git a/extensions/pairing/circuit/Cargo.toml b/extensions/pairing/circuit/Cargo.toml index af16f7eeab..ebc1dc6c75 100644 --- a/extensions/pairing/circuit/Cargo.toml +++ b/extensions/pairing/circuit/Cargo.toml @@ -23,7 +23,6 @@ openvm-mod-circuit-builder = { workspace = true } openvm-stark-backend = { workspace = true } openvm-rv32im-circuit = { workspace = true } openvm-algebra-circuit = { workspace = true } -openvm-rv32-adapters = { workspace = true } openvm-ecc-circuit = { workspace = true } openvm-pairing-transpiler = { workspace = true } @@ -33,7 +32,6 @@ strum = { workspace = true } derive_more = { workspace = true } derive-new = { workspace = true } rand = { workspace = true } -itertools = { workspace = true } eyre = { workspace = true } serde = { workspace = true, features = ["derive", "std"] } halo2curves-axiom = { workspace = true } @@ -45,7 +43,6 @@ openvm-pairing-guest = { workspace = true } openvm-stark-sdk = { workspace = true } openvm-mod-circuit-builder = { workspace = true, features = ["test-utils"] } openvm-circuit = { workspace = true, features = ["test-utils"] } -openvm-rv32-adapters = { workspace = true, features = ["test-utils"] } halo2curves-axiom = { workspace = true } openvm-ecc-guest = { workspace = true } openvm-pairing-guest = { workspace = true, features = [ diff --git a/extensions/pairing/circuit/src/fp12_chip/add.rs b/extensions/pairing/circuit/src/fp12_chip/add.rs deleted file mode 100644 index 643c68ef27..0000000000 --- a/extensions/pairing/circuit/src/fp12_chip/add.rs +++ /dev/null @@ -1,20 +0,0 @@ -use std::{cell::RefCell, rc::Rc}; - -use openvm_circuit_primitives::var_range::VariableRangeCheckerBus; -use openvm_mod_circuit_builder::{ExprBuilder, ExprBuilderConfig, FieldExpr}; - -use crate::Fp12; - -pub fn fp12_add_expr(config: ExprBuilderConfig, range_bus: VariableRangeCheckerBus) -> FieldExpr { - config.check_valid(); - let builder = ExprBuilder::new(config, range_bus.range_max_bits); - let builder = Rc::new(RefCell::new(builder)); - - let mut x = Fp12::new(builder.clone()); - let mut y = Fp12::new(builder.clone()); - let mut res = x.add(&mut y); - res.save_output(); - - let builder = builder.borrow().clone(); - FieldExpr::new(builder, range_bus, false) -} diff --git a/extensions/pairing/circuit/src/fp12_chip/mod.rs b/extensions/pairing/circuit/src/fp12_chip/mod.rs deleted file mode 100644 index c6894d0d27..0000000000 --- a/extensions/pairing/circuit/src/fp12_chip/mod.rs +++ /dev/null @@ -1,10 +0,0 @@ -mod add; -mod mul; -mod sub; - -pub use add::*; -pub use mul::*; -pub use sub::*; - -#[cfg(test)] -mod tests; diff --git a/extensions/pairing/circuit/src/fp12_chip/mul.rs b/extensions/pairing/circuit/src/fp12_chip/mul.rs deleted file mode 100644 index 0736981de7..0000000000 --- a/extensions/pairing/circuit/src/fp12_chip/mul.rs +++ /dev/null @@ -1,175 +0,0 @@ -use std::{ - cell::RefCell, - rc::Rc, - sync::{Arc, Mutex}, -}; - -use openvm_circuit::{arch::VmChipWrapper, system::memory::OfflineMemory}; -use openvm_circuit_derive::InstructionExecutor; -use openvm_circuit_primitives::var_range::{ - SharedVariableRangeCheckerChip, VariableRangeCheckerBus, -}; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; -use openvm_mod_circuit_builder::{ - ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreChip, -}; -use openvm_pairing_transpiler::Fp12Opcode; -use openvm_rv32_adapters::Rv32VecHeapAdapterChip; -use openvm_stark_backend::p3_field::PrimeField32; - -use crate::Fp12; -// Input: Fp12 * 2 -// Output: Fp12 -#[derive(Chip, ChipUsageGetter, InstructionExecutor)] -pub struct Fp12MulChip( - pub VmChipWrapper< - F, - Rv32VecHeapAdapterChip, - FieldExpressionCoreChip, - >, -); - -impl - Fp12MulChip -{ - pub fn new( - adapter: Rv32VecHeapAdapterChip, - config: ExprBuilderConfig, - xi: [isize; 2], - offset: usize, - range_checker: SharedVariableRangeCheckerChip, - offline_memory: Arc>>, - ) -> Self { - let expr = fp12_mul_expr(config, range_checker.bus(), xi); - let core = FieldExpressionCoreChip::new( - expr, - offset, - vec![Fp12Opcode::MUL as usize], - vec![], - range_checker, - "Fp12Mul", - false, - ); - Self(VmChipWrapper::new(adapter, core, offline_memory)) - } -} - -pub fn fp12_mul_expr( - config: ExprBuilderConfig, - range_bus: VariableRangeCheckerBus, - xi: [isize; 2], -) -> FieldExpr { - config.check_valid(); - let builder = ExprBuilder::new(config, range_bus.range_max_bits); - let builder = Rc::new(RefCell::new(builder)); - - let mut x = Fp12::new(builder.clone()); - let mut y = Fp12::new(builder.clone()); - let mut res = x.mul(&mut y, xi); - res.save_output(); - - let builder = builder.borrow().clone(); - FieldExpr::new(builder, range_bus, false) -} - -#[cfg(test)] -mod tests { - use halo2curves_axiom::{bn256::Fq12, ff::Field}; - use itertools::Itertools; - use openvm_circuit::arch::testing::{VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}; - use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, - }; - use openvm_ecc_guest::algebra::field::FieldExtension; - use openvm_instructions::{riscv::RV32_CELL_BITS, LocalOpcode}; - use openvm_mod_circuit_builder::{ - test_utils::{biguint_to_limbs, bn254_fq12_to_biguint_vec, bn254_fq2_to_biguint_vec}, - ExprBuilderConfig, - }; - use openvm_pairing_guest::bn254::{BN254_MODULUS, BN254_XI_ISIZE}; - use openvm_rv32_adapters::rv32_write_heap_default_with_increment; - use openvm_stark_backend::p3_field::FieldAlgebra; - use openvm_stark_sdk::p3_baby_bear::BabyBear; - use rand::{rngs::StdRng, SeedableRng}; - - use super::*; - - const LIMB_BITS: usize = 8; - type F = BabyBear; - - #[test] - fn test_fp12_mul_bn254() { - const NUM_LIMBS: usize = 32; - const BLOCK_SIZE: usize = 32; - - let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let config = ExprBuilderConfig { - modulus: BN254_MODULUS.clone(), - num_limbs: NUM_LIMBS, - limb_bits: LIMB_BITS, - }; - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let adapter = Rv32VecHeapAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - bitwise_chip.clone(), - ); - - let mut chip = Fp12MulChip::new( - adapter, - config, - BN254_XI_ISIZE, - Fp12Opcode::CLASS_OFFSET, - tester.range_checker(), - tester.offline_memory_mutex_arc(), - ); - - let mut rng = StdRng::seed_from_u64(64); - let x = Fq12::random(&mut rng); - let y = Fq12::random(&mut rng); - let inputs = [x.to_coeffs(), y.to_coeffs()] - .concat() - .iter() - .flat_map(|&x| bn254_fq2_to_biguint_vec(x)) - .collect::>(); - - let cmp = bn254_fq12_to_biguint_vec(x * y); - let res = chip - .0 - .core - .expr() - .execute_with_output(inputs.clone(), vec![true]); - assert_eq!(res.len(), cmp.len()); - for i in 0..res.len() { - assert_eq!(res[i], cmp[i]); - } - - let x_limbs = inputs[..12] - .iter() - .map(|x| { - biguint_to_limbs::(x.clone(), LIMB_BITS) - .map(BabyBear::from_canonical_u32) - }) - .collect_vec(); - let y_limbs = inputs[12..] - .iter() - .map(|y| { - biguint_to_limbs::(y.clone(), LIMB_BITS) - .map(BabyBear::from_canonical_u32) - }) - .collect_vec(); - let instruction = rv32_write_heap_default_with_increment( - &mut tester, - x_limbs, - y_limbs, - 512, - chip.0.core.air.offset + Fp12Opcode::MUL as usize, - ); - tester.execute(&mut chip, &instruction); - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); - tester.simple_test().expect("Verification failed"); - } -} diff --git a/extensions/pairing/circuit/src/fp12_chip/sub.rs b/extensions/pairing/circuit/src/fp12_chip/sub.rs deleted file mode 100644 index 470e700910..0000000000 --- a/extensions/pairing/circuit/src/fp12_chip/sub.rs +++ /dev/null @@ -1,20 +0,0 @@ -use std::{cell::RefCell, rc::Rc}; - -use openvm_circuit_primitives::var_range::VariableRangeCheckerBus; -use openvm_mod_circuit_builder::{ExprBuilder, ExprBuilderConfig, FieldExpr}; - -use crate::Fp12; - -pub fn fp12_sub_expr(config: ExprBuilderConfig, range_bus: VariableRangeCheckerBus) -> FieldExpr { - config.check_valid(); - let builder = ExprBuilder::new(config, range_bus.range_max_bits); - let builder = Rc::new(RefCell::new(builder)); - - let mut x = Fp12::new(builder.clone()); - let mut y = Fp12::new(builder.clone()); - let mut res = x.sub(&mut y); - res.save_output(); - - let builder = builder.borrow().clone(); - FieldExpr::new(builder, range_bus, false) -} diff --git a/extensions/pairing/circuit/src/fp12_chip/tests.rs b/extensions/pairing/circuit/src/fp12_chip/tests.rs deleted file mode 100644 index a9f6b235d5..0000000000 --- a/extensions/pairing/circuit/src/fp12_chip/tests.rs +++ /dev/null @@ -1,271 +0,0 @@ -use num_bigint::BigUint; -use openvm_circuit::arch::{ - testing::{VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}, - VmChipWrapper, -}; -use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, -}; -use openvm_instructions::{riscv::RV32_CELL_BITS, LocalOpcode}; -use openvm_mod_circuit_builder::{ - test_utils::{ - biguint_to_limbs, bls12381_fq12_random, bn254_fq12_random, bn254_fq12_to_biguint_vec, - }, - ExprBuilderConfig, FieldExpr, FieldExpressionCoreChip, -}; -use openvm_pairing_guest::{ - bls12_381::{ - BLS12_381_BLOCK_SIZE, BLS12_381_LIMB_BITS, BLS12_381_MODULUS, BLS12_381_NUM_LIMBS, - BLS12_381_XI_ISIZE, - }, - bn254::{BN254_BLOCK_SIZE, BN254_LIMB_BITS, BN254_MODULUS, BN254_NUM_LIMBS, BN254_XI_ISIZE}, -}; -use openvm_pairing_transpiler::{Bls12381Fp12Opcode, Bn254Fp12Opcode, Fp12Opcode}; -use openvm_rv32_adapters::{rv32_write_heap_default, Rv32VecHeapAdapterChip}; -use openvm_stark_backend::p3_field::FieldAlgebra; -use openvm_stark_sdk::p3_baby_bear::BabyBear; - -use super::{fp12_add_expr, fp12_mul_expr, fp12_sub_expr}; - -type F = BabyBear; - -#[allow(clippy::too_many_arguments)] -fn test_fp12_fn< - const INPUT_SIZE: usize, - const NUM_LIMBS: usize, - const LIMB_BITS: usize, - const BLOCK_SIZE: usize, ->( - mut tester: VmChipTestBuilder, - expr: FieldExpr, - offset: usize, - local_opcode_idx: usize, - name: &str, - x: Vec, - y: Vec, - var_len: usize, -) { - let core = FieldExpressionCoreChip::new( - expr, - offset, - vec![local_opcode_idx], - vec![], - tester.memory_controller().borrow().range_checker.clone(), - name, - false, - ); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - - let adapter = - Rv32VecHeapAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - bitwise_chip.clone(), - ); - - let x_limbs = x - .iter() - .map(|x| { - biguint_to_limbs::(x.clone(), LIMB_BITS).map(BabyBear::from_canonical_u32) - }) - .collect::>(); - let y_limbs = y - .iter() - .map(|y| { - biguint_to_limbs::(y.clone(), LIMB_BITS).map(BabyBear::from_canonical_u32) - }) - .collect::>(); - let mut chip = VmChipWrapper::new(adapter, core, tester.offline_memory_mutex_arc()); - - let res = chip.core.air.expr.execute([x, y].concat(), vec![]); - assert_eq!(res.len(), var_len); - - let instruction = rv32_write_heap_default( - &mut tester, - x_limbs, - y_limbs, - chip.core.air.offset + local_opcode_idx, - ); - tester.execute(&mut chip, &instruction); - - let run_tester = tester.build().load(chip).load(bitwise_chip).finalize(); - run_tester.simple_test().expect("Verification failed"); -} - -#[test] -fn test_fp12_add_bn254() { - let tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let config = ExprBuilderConfig { - modulus: BN254_MODULUS.clone(), - num_limbs: BN254_NUM_LIMBS, - limb_bits: BN254_LIMB_BITS, - }; - let expr = fp12_add_expr( - config, - tester.memory_controller().borrow().range_checker.bus(), - ); - - let x = bn254_fq12_to_biguint_vec(bn254_fq12_random(1)); - let y = bn254_fq12_to_biguint_vec(bn254_fq12_random(2)); - - test_fp12_fn::<12, BN254_NUM_LIMBS, BN254_LIMB_BITS, BN254_BLOCK_SIZE>( - tester, - expr, - Bn254Fp12Opcode::CLASS_OFFSET, - Fp12Opcode::ADD as usize, - "Bn254Fp12Add", - x, - y, - 12, - ); -} - -#[test] -fn test_fp12_sub_bn254() { - let tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let config = ExprBuilderConfig { - modulus: BN254_MODULUS.clone(), - num_limbs: BN254_NUM_LIMBS, - limb_bits: BN254_LIMB_BITS, - }; - let expr = fp12_sub_expr( - config, - tester.memory_controller().borrow().range_checker.bus(), - ); - - let x = bn254_fq12_to_biguint_vec(bn254_fq12_random(59)); - let y = bn254_fq12_to_biguint_vec(bn254_fq12_random(3)); - - test_fp12_fn::<12, BN254_NUM_LIMBS, BN254_LIMB_BITS, BN254_BLOCK_SIZE>( - tester, - expr, - Bn254Fp12Opcode::CLASS_OFFSET, - Fp12Opcode::SUB as usize, - "Bn254Fp12Sub", - x, - y, - 12, - ); -} - -#[test] -fn test_fp12_mul_bn254() { - let tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let config = ExprBuilderConfig { - modulus: BN254_MODULUS.clone(), - num_limbs: BN254_NUM_LIMBS, - limb_bits: BN254_LIMB_BITS, - }; - let xi = BN254_XI_ISIZE; - let expr = fp12_mul_expr( - config, - tester.memory_controller().borrow().range_checker.bus(), - xi, - ); - - let x = bn254_fq12_to_biguint_vec(bn254_fq12_random(5)); - let y = bn254_fq12_to_biguint_vec(bn254_fq12_random(25)); - - test_fp12_fn::<12, BN254_NUM_LIMBS, BN254_LIMB_BITS, BN254_BLOCK_SIZE>( - tester, - expr, - Bn254Fp12Opcode::CLASS_OFFSET, - Fp12Opcode::MUL as usize, - "Bn254Fp12Mul", - x, - y, - 33, - ); -} - -#[test] -fn test_fp12_add_bls12381() { - let tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let config = ExprBuilderConfig { - modulus: BLS12_381_MODULUS.clone(), - num_limbs: BLS12_381_NUM_LIMBS, - limb_bits: BLS12_381_LIMB_BITS, - }; - let expr = fp12_add_expr( - config, - tester.memory_controller().borrow().range_checker.bus(), - ); - - let x = bls12381_fq12_random(3); - let y = bls12381_fq12_random(99); - - test_fp12_fn::<36, BLS12_381_NUM_LIMBS, BLS12_381_LIMB_BITS, BLS12_381_BLOCK_SIZE>( - tester, - expr, - Bls12381Fp12Opcode::CLASS_OFFSET, - Fp12Opcode::ADD as usize, - "Bls12381Fp12Add", - x, - y, - 12, - ); -} - -#[test] -fn test_fp12_sub_bls12381() { - let tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let config = ExprBuilderConfig { - modulus: BLS12_381_MODULUS.clone(), - num_limbs: BLS12_381_NUM_LIMBS, - limb_bits: BLS12_381_LIMB_BITS, - }; - let expr = fp12_sub_expr( - config, - tester.memory_controller().borrow().range_checker.bus(), - ); - - let x = bls12381_fq12_random(8); - let y = bls12381_fq12_random(9); - - test_fp12_fn::<36, BLS12_381_NUM_LIMBS, BLS12_381_LIMB_BITS, BLS12_381_BLOCK_SIZE>( - tester, - expr, - Bls12381Fp12Opcode::CLASS_OFFSET, - Fp12Opcode::SUB as usize, - "Bls12381Fp12Sub", - x, - y, - 12, - ); -} - -// NOTE[yj]: This test requires RUST_MIN_STACK=8388608 to run without overflowing the stack, so it -// is ignored by the test runner for now -#[test] -#[ignore] -fn test_fp12_mul_bls12381() { - let tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let config = ExprBuilderConfig { - modulus: BLS12_381_MODULUS.clone(), - num_limbs: BLS12_381_NUM_LIMBS, - limb_bits: BLS12_381_LIMB_BITS, - }; - let xi = BLS12_381_XI_ISIZE; - let expr = fp12_mul_expr( - config, - tester.memory_controller().borrow().range_checker.bus(), - xi, - ); - - let x = bls12381_fq12_random(5); - let y = bls12381_fq12_random(25); - - test_fp12_fn::<36, BLS12_381_NUM_LIMBS, BLS12_381_LIMB_BITS, BLS12_381_BLOCK_SIZE>( - tester, - expr, - Bls12381Fp12Opcode::CLASS_OFFSET, - Fp12Opcode::MUL as usize, - "Bls12381Fp12Mul", - x, - y, - 46, - ); -} diff --git a/extensions/pairing/circuit/src/lib.rs b/extensions/pairing/circuit/src/lib.rs index b2b962b7f7..f96d126555 100644 --- a/extensions/pairing/circuit/src/lib.rs +++ b/extensions/pairing/circuit/src/lib.rs @@ -1,11 +1,7 @@ mod config; mod fp12; -mod fp12_chip; -mod pairing_chip; mod pairing_extension; pub use config::*; pub use fp12::*; -pub use fp12_chip::*; -pub use pairing_chip::*; pub use pairing_extension::*; diff --git a/extensions/pairing/circuit/src/pairing_chip/line/d_type/mod.rs b/extensions/pairing/circuit/src/pairing_chip/line/d_type/mod.rs deleted file mode 100644 index 08857995f3..0000000000 --- a/extensions/pairing/circuit/src/pairing_chip/line/d_type/mod.rs +++ /dev/null @@ -1,8 +0,0 @@ -mod mul_013_by_013; -mod mul_by_01234; - -pub use mul_013_by_013::*; -pub use mul_by_01234::*; - -#[cfg(test)] -mod tests; diff --git a/extensions/pairing/circuit/src/pairing_chip/line/d_type/mul_013_by_013.rs b/extensions/pairing/circuit/src/pairing_chip/line/d_type/mul_013_by_013.rs deleted file mode 100644 index 36d1012e9b..0000000000 --- a/extensions/pairing/circuit/src/pairing_chip/line/d_type/mul_013_by_013.rs +++ /dev/null @@ -1,101 +0,0 @@ -use std::{ - cell::RefCell, - rc::Rc, - sync::{Arc, Mutex}, -}; - -use openvm_algebra_circuit::Fp2; -use openvm_circuit::{arch::VmChipWrapper, system::memory::OfflineMemory}; -use openvm_circuit_derive::InstructionExecutor; -use openvm_circuit_primitives::var_range::{ - SharedVariableRangeCheckerChip, VariableRangeCheckerBus, -}; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; -use openvm_mod_circuit_builder::{ - ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreChip, -}; -use openvm_pairing_transpiler::PairingOpcode; -use openvm_rv32_adapters::Rv32VecHeapAdapterChip; -use openvm_stark_backend::p3_field::PrimeField32; - -// Input: line0.b, line0.c, line1.b, line1.c : 2 x 4 field elements -// Output: 5 Fp2 coefficients -> 10 field elements -#[derive(Chip, ChipUsageGetter, InstructionExecutor)] -pub struct EcLineMul013By013Chip< - F: PrimeField32, - const INPUT_BLOCKS: usize, - const OUTPUT_BLOCKS: usize, - const BLOCK_SIZE: usize, ->( - pub VmChipWrapper< - F, - Rv32VecHeapAdapterChip, - FieldExpressionCoreChip, - >, -); - -impl< - F: PrimeField32, - const INPUT_BLOCKS: usize, - const OUTPUT_BLOCKS: usize, - const BLOCK_SIZE: usize, - > EcLineMul013By013Chip -{ - pub fn new( - adapter: Rv32VecHeapAdapterChip, - range_checker: SharedVariableRangeCheckerChip, - config: ExprBuilderConfig, - xi: [isize; 2], - offset: usize, - offline_memory: Arc>>, - ) -> Self { - assert!( - xi[0].unsigned_abs() < 1 << config.limb_bits, - "expect xi to be small" - ); // not a hard rule, but we expect xi to be small - assert!( - xi[1].unsigned_abs() < 1 << config.limb_bits, - "expect xi to be small" - ); - let expr = mul_013_by_013_expr(config, range_checker.bus(), xi); - let core = FieldExpressionCoreChip::new( - expr, - offset, - vec![PairingOpcode::MUL_013_BY_013 as usize], - vec![], - range_checker, - "Mul013By013", - true, - ); - Self(VmChipWrapper::new(adapter, core, offline_memory)) - } -} - -pub fn mul_013_by_013_expr( - config: ExprBuilderConfig, - range_bus: VariableRangeCheckerBus, - xi: [isize; 2], -) -> FieldExpr { - config.check_valid(); - let builder = ExprBuilder::new(config.clone(), range_bus.range_max_bits); - let builder = Rc::new(RefCell::new(builder)); - - let mut b0 = Fp2::new(builder.clone()); - let mut c0 = Fp2::new(builder.clone()); - let mut b1 = Fp2::new(builder.clone()); - let mut c1 = Fp2::new(builder.clone()); - - // where w⁶ = xi - // l0 * l1 = 1 + (b0 + b1)w + (b0b1)w² + (c0 + c1)w³ + (b0c1 + b1c0)w⁴ + (c0c1)w⁶ - // = (1 + c0c1 * xi) + (b0 + b1)w + (b0b1)w² + (c0 + c1)w³ + (b0c1 + b1c0)w⁴ - let l0 = c0.mul(&mut c1).int_mul(xi).int_add([1, 0]); - let l1 = b0.add(&mut b1); - let l2 = b0.mul(&mut b1); - let l3 = c0.add(&mut c1); - let l4 = b0.mul(&mut c1).add(&mut b1.mul(&mut c0)); - - [l0, l1, l2, l3, l4].map(|mut l| l.save_output()); - - let builder = builder.borrow().clone(); - FieldExpr::new(builder, range_bus, false) -} diff --git a/extensions/pairing/circuit/src/pairing_chip/line/d_type/mul_by_01234.rs b/extensions/pairing/circuit/src/pairing_chip/line/d_type/mul_by_01234.rs deleted file mode 100644 index 996372e994..0000000000 --- a/extensions/pairing/circuit/src/pairing_chip/line/d_type/mul_by_01234.rs +++ /dev/null @@ -1,113 +0,0 @@ -use std::{ - cell::RefCell, - rc::Rc, - sync::{Arc, Mutex}, -}; - -use openvm_algebra_circuit::Fp2; -use openvm_circuit::{arch::VmChipWrapper, system::memory::OfflineMemory}; -use openvm_circuit_derive::InstructionExecutor; -use openvm_circuit_primitives::var_range::{ - SharedVariableRangeCheckerChip, VariableRangeCheckerBus, -}; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; -use openvm_mod_circuit_builder::{ - ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreChip, -}; -use openvm_pairing_transpiler::PairingOpcode; -use openvm_rv32_adapters::Rv32VecHeapTwoReadsAdapterChip; -use openvm_stark_backend::p3_field::PrimeField32; - -use crate::Fp12; - -// Input: Fp12 (12 field elements), [Fp2; 5] (5 x 2 field elements) -// Output: Fp12 (12 field elements) -#[derive(Chip, ChipUsageGetter, InstructionExecutor)] -pub struct EcLineMulBy01234Chip< - F: PrimeField32, - const INPUT_BLOCKS1: usize, - const INPUT_BLOCKS2: usize, - const OUTPUT_BLOCKS: usize, - const BLOCK_SIZE: usize, ->( - pub VmChipWrapper< - F, - Rv32VecHeapTwoReadsAdapterChip< - F, - INPUT_BLOCKS1, - INPUT_BLOCKS2, - OUTPUT_BLOCKS, - BLOCK_SIZE, - BLOCK_SIZE, - >, - FieldExpressionCoreChip, - >, -); - -impl< - F: PrimeField32, - const INPUT_BLOCKS1: usize, - const INPUT_BLOCKS2: usize, - const OUTPUT_BLOCKS: usize, - const BLOCK_SIZE: usize, - > EcLineMulBy01234Chip -{ - pub fn new( - adapter: Rv32VecHeapTwoReadsAdapterChip< - F, - INPUT_BLOCKS1, - INPUT_BLOCKS2, - OUTPUT_BLOCKS, - BLOCK_SIZE, - BLOCK_SIZE, - >, - config: ExprBuilderConfig, - xi: [isize; 2], - offset: usize, - range_checker: SharedVariableRangeCheckerChip, - offline_memory: Arc>>, - ) -> Self { - assert!( - xi[0].unsigned_abs() < 1 << config.limb_bits, - "expect xi to be small" - ); // not a hard rule, but we expect xi to be small - assert!( - xi[1].unsigned_abs() < 1 << config.limb_bits, - "expect xi to be small" - ); - let expr = mul_by_01234_expr(config, range_checker.bus(), xi); - let core = FieldExpressionCoreChip::new( - expr, - offset, - vec![PairingOpcode::MUL_BY_01234 as usize], - vec![], - range_checker.clone(), - "MulBy01234", - false, - ); - Self(VmChipWrapper::new(adapter, core, offline_memory)) - } -} - -pub fn mul_by_01234_expr( - config: ExprBuilderConfig, - range_bus: VariableRangeCheckerBus, - xi: [isize; 2], -) -> FieldExpr { - config.check_valid(); - let builder = ExprBuilder::new(config.clone(), range_bus.range_max_bits); - let builder = Rc::new(RefCell::new(builder)); - - let mut f = Fp12::new(builder.clone()); - let mut x0 = Fp2::new(builder.clone()); - let mut x1 = Fp2::new(builder.clone()); - let mut x2 = Fp2::new(builder.clone()); - let mut x3 = Fp2::new(builder.clone()); - let mut x4 = Fp2::new(builder.clone()); - - let mut r = f.mul_by_01234(&mut x0, &mut x1, &mut x2, &mut x3, &mut x4, xi); - r.save_output(); - - let builder = builder.borrow().clone(); - FieldExpr::new(builder, range_bus, false) -} diff --git a/extensions/pairing/circuit/src/pairing_chip/line/d_type/tests.rs b/extensions/pairing/circuit/src/pairing_chip/line/d_type/tests.rs deleted file mode 100644 index 81da3169fa..0000000000 --- a/extensions/pairing/circuit/src/pairing_chip/line/d_type/tests.rs +++ /dev/null @@ -1,287 +0,0 @@ -use halo2curves_axiom::{ - bn256::{Fq, Fq12, Fq2, G1Affine}, - ff::Field, -}; -use openvm_circuit::arch::testing::{VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}; -use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, -}; -use openvm_ecc_guest::AffinePoint; -use openvm_instructions::{riscv::RV32_CELL_BITS, LocalOpcode}; -use openvm_mod_circuit_builder::{ - test_utils::{ - biguint_to_limbs, bn254_fq12_to_biguint_vec, bn254_fq2_to_biguint_vec, bn254_fq_to_biguint, - }, - ExprBuilderConfig, -}; -use openvm_pairing_guest::{ - bn254::{BN254_LIMB_BITS, BN254_MODULUS, BN254_NUM_LIMBS, BN254_XI_ISIZE}, - halo2curves_shims::bn254::{tangent_line_013, Bn254}, - pairing::{Evaluatable, LineMulDType, UnevaluatedLine}, -}; -use openvm_pairing_transpiler::PairingOpcode; -use openvm_rv32_adapters::{ - rv32_write_heap_default, rv32_write_heap_default_with_increment, Rv32VecHeapAdapterChip, - Rv32VecHeapTwoReadsAdapterChip, -}; -use openvm_stark_backend::p3_field::FieldAlgebra; -use openvm_stark_sdk::p3_baby_bear::BabyBear; -use rand::{rngs::StdRng, SeedableRng}; - -use super::{super::EvaluateLineChip, *}; - -type F = BabyBear; -const NUM_LIMBS: usize = 32; -const LIMB_BITS: usize = 8; -const BLOCK_SIZE: usize = 32; - -#[test] -fn test_mul_013_by_013() { - let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let adapter = Rv32VecHeapAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - bitwise_chip.clone(), - ); - let mut chip = EcLineMul013By013Chip::new( - adapter, - tester.memory_controller().borrow().range_checker.clone(), - ExprBuilderConfig { - modulus: BN254_MODULUS.clone(), - num_limbs: NUM_LIMBS, - limb_bits: LIMB_BITS, - }, - BN254_XI_ISIZE, - PairingOpcode::CLASS_OFFSET, - tester.offline_memory_mutex_arc(), - ); - - let mut rng0 = StdRng::seed_from_u64(8); - let mut rng1 = StdRng::seed_from_u64(95); - let rnd_pt_0 = G1Affine::random(&mut rng0); - let rnd_pt_1 = G1Affine::random(&mut rng1); - let ec_pt_0 = AffinePoint:: { - x: rnd_pt_0.x, - y: rnd_pt_0.y, - }; - let ec_pt_1 = AffinePoint:: { - x: rnd_pt_1.x, - y: rnd_pt_1.y, - }; - let line0 = tangent_line_013::(ec_pt_0); - let line1 = tangent_line_013::(ec_pt_1); - let input_line0 = [ - bn254_fq2_to_biguint_vec(line0.b), - bn254_fq2_to_biguint_vec(line0.c), - ] - .concat(); - let input_line1 = [ - bn254_fq2_to_biguint_vec(line1.b), - bn254_fq2_to_biguint_vec(line1.c), - ] - .concat(); - - let vars = chip - .0 - .core - .expr() - .execute([input_line0.clone(), input_line1.clone()].concat(), vec![]); - let output_indices = chip.0.core.expr().builder.output_indices.clone(); - let output = output_indices - .iter() - .map(|i| vars[*i].clone()) - .collect::>(); - assert_eq!(output.len(), 10); - - let r_cmp = Bn254::mul_013_by_013(&line0, &line1); - let r_cmp_bigint = r_cmp - .map(|x| [bn254_fq_to_biguint(x.c0), bn254_fq_to_biguint(x.c1)]) - .concat(); - - for i in 0..10 { - assert_eq!(output[i], r_cmp_bigint[i]); - } - - let input_line0_limbs = input_line0 - .iter() - .map(|x| { - biguint_to_limbs::(x.clone(), LIMB_BITS).map(BabyBear::from_canonical_u32) - }) - .collect::>(); - let input_line1_limbs = input_line1 - .iter() - .map(|x| { - biguint_to_limbs::(x.clone(), LIMB_BITS).map(BabyBear::from_canonical_u32) - }) - .collect::>(); - - let instruction = rv32_write_heap_default( - &mut tester, - input_line0_limbs, - input_line1_limbs, - chip.0.core.air.offset + PairingOpcode::MUL_013_BY_013 as usize, - ); - - tester.execute(&mut chip, &instruction); - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); - tester.simple_test().expect("Verification failed"); -} - -#[test] -fn test_mul_by_01234() { - let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let adapter = Rv32VecHeapTwoReadsAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - bitwise_chip.clone(), - ); - let mut chip = EcLineMulBy01234Chip::new( - adapter, - ExprBuilderConfig { - modulus: BN254_MODULUS.clone(), - num_limbs: NUM_LIMBS, - limb_bits: LIMB_BITS, - }, - BN254_XI_ISIZE, - PairingOpcode::CLASS_OFFSET, - tester.range_checker(), - tester.offline_memory_mutex_arc(), - ); - - let mut rng = StdRng::seed_from_u64(8); - let f = Fq12::random(&mut rng); - let x0 = Fq2::random(&mut rng); - let x1 = Fq2::random(&mut rng); - let x2 = Fq2::random(&mut rng); - let x3 = Fq2::random(&mut rng); - let x4 = Fq2::random(&mut rng); - - let input_f = bn254_fq12_to_biguint_vec(f); - let input_x = [ - bn254_fq2_to_biguint_vec(x0), - bn254_fq2_to_biguint_vec(x1), - bn254_fq2_to_biguint_vec(x2), - bn254_fq2_to_biguint_vec(x3), - bn254_fq2_to_biguint_vec(x4), - ] - .concat(); - - let vars = chip - .0 - .core - .expr() - .execute([input_f.clone(), input_x.clone()].concat(), vec![]); - let output_indices = chip.0.core.expr().builder.output_indices.clone(); - let output = output_indices - .iter() - .map(|i| vars[*i].clone()) - .collect::>(); - assert_eq!(output.len(), 12); - - let r_cmp = Bn254::mul_by_01234(&f, &[x0, x1, x2, x3, x4]); - let r_cmp_bigint = bn254_fq12_to_biguint_vec(r_cmp); - - for i in 0..12 { - assert_eq!(output[i], r_cmp_bigint[i]); - } - - let input_f_limbs = input_f - .iter() - .map(|x| { - biguint_to_limbs::(x.clone(), LIMB_BITS).map(BabyBear::from_canonical_u32) - }) - .collect::>(); - let input_x_limbs = input_x - .iter() - .map(|x| { - biguint_to_limbs::(x.clone(), LIMB_BITS).map(BabyBear::from_canonical_u32) - }) - .collect::>(); - - let instruction = rv32_write_heap_default_with_increment( - &mut tester, - input_f_limbs, - input_x_limbs, - 512, - chip.0.core.air.offset + PairingOpcode::MUL_BY_01234 as usize, - ); - - tester.execute(&mut chip, &instruction); - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); - tester.simple_test().expect("Verification failed"); -} - -#[test] -fn test_evaluate_line() { - let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let config = ExprBuilderConfig { - modulus: BN254_MODULUS.clone(), - limb_bits: BN254_LIMB_BITS, - num_limbs: BN254_NUM_LIMBS, - }; - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let adapter = Rv32VecHeapTwoReadsAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - bitwise_chip.clone(), - ); - let mut chip = EvaluateLineChip::new( - adapter, - config, - PairingOpcode::CLASS_OFFSET, - tester.range_checker(), - tester.offline_memory_mutex_arc(), - ); - - let mut rng = StdRng::seed_from_u64(42); - let uneval_b = Fq2::random(&mut rng); - let uneval_c = Fq2::random(&mut rng); - let x_over_y = Fq::random(&mut rng); - let y_inv = Fq::random(&mut rng); - let mut inputs = vec![]; - inputs.extend(bn254_fq2_to_biguint_vec(uneval_b)); - inputs.extend(bn254_fq2_to_biguint_vec(uneval_c)); - inputs.push(bn254_fq_to_biguint(x_over_y)); - inputs.push(bn254_fq_to_biguint(y_inv)); - let input_limbs = inputs - .iter() - .map(|x| { - biguint_to_limbs::(x.clone(), LIMB_BITS).map(BabyBear::from_canonical_u32) - }) - .collect(); - - let uneval: UnevaluatedLine = UnevaluatedLine { - b: uneval_b, - c: uneval_c, - }; - let evaluated = uneval.evaluate(&(x_over_y, y_inv)); - - let result = chip.0.core.expr().execute_with_output(inputs, vec![]); - assert_eq!(result.len(), 4); - assert_eq!(result[0], bn254_fq_to_biguint(evaluated.b.c0)); - assert_eq!(result[1], bn254_fq_to_biguint(evaluated.b.c1)); - assert_eq!(result[2], bn254_fq_to_biguint(evaluated.c.c0)); - assert_eq!(result[3], bn254_fq_to_biguint(evaluated.c.c1)); - - let instruction = rv32_write_heap_default( - &mut tester, - input_limbs, - vec![], - chip.0.core.air.offset + PairingOpcode::EVALUATE_LINE as usize, - ); - - tester.execute(&mut chip, &instruction); - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); - tester.simple_test().expect("Verification failed"); -} diff --git a/extensions/pairing/circuit/src/pairing_chip/line/evaluate_line.rs b/extensions/pairing/circuit/src/pairing_chip/line/evaluate_line.rs deleted file mode 100644 index dc0a8cdfe1..0000000000 --- a/extensions/pairing/circuit/src/pairing_chip/line/evaluate_line.rs +++ /dev/null @@ -1,102 +0,0 @@ -use std::{ - cell::RefCell, - rc::Rc, - sync::{Arc, Mutex}, -}; - -use openvm_algebra_circuit::Fp2; -use openvm_circuit::{arch::VmChipWrapper, system::memory::OfflineMemory}; -use openvm_circuit_derive::InstructionExecutor; -use openvm_circuit_primitives::var_range::{ - SharedVariableRangeCheckerChip, VariableRangeCheckerBus, -}; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; -use openvm_mod_circuit_builder::{ - ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreChip, -}; -use openvm_pairing_transpiler::PairingOpcode; -use openvm_rv32_adapters::Rv32VecHeapTwoReadsAdapterChip; -use openvm_stark_backend::p3_field::PrimeField32; - -// Input: UnevaluatedLine, (Fp, Fp) -// Output: EvaluatedLine -#[derive(Chip, ChipUsageGetter, InstructionExecutor)] -pub struct EvaluateLineChip< - F: PrimeField32, - const INPUT_BLOCKS1: usize, - const INPUT_BLOCKS2: usize, - const OUTPUT_BLOCKS: usize, - const BLOCK_SIZE: usize, ->( - pub VmChipWrapper< - F, - Rv32VecHeapTwoReadsAdapterChip< - F, - INPUT_BLOCKS1, - INPUT_BLOCKS2, - OUTPUT_BLOCKS, - BLOCK_SIZE, - BLOCK_SIZE, - >, - FieldExpressionCoreChip, - >, -); - -impl< - F: PrimeField32, - const INPUT_BLOCKS1: usize, - const INPUT_BLOCKS2: usize, - const OUTPUT_BLOCKS: usize, - const BLOCK_SIZE: usize, - > EvaluateLineChip -{ - pub fn new( - adapter: Rv32VecHeapTwoReadsAdapterChip< - F, - INPUT_BLOCKS1, - INPUT_BLOCKS2, - OUTPUT_BLOCKS, - BLOCK_SIZE, - BLOCK_SIZE, - >, - config: ExprBuilderConfig, - offset: usize, - range_checker: SharedVariableRangeCheckerChip, - offline_memory: Arc>>, - ) -> Self { - let expr = evaluate_line_expr(config, range_checker.bus()); - let core = FieldExpressionCoreChip::new( - expr, - offset, - vec![PairingOpcode::EVALUATE_LINE as usize], - vec![], - range_checker, - "EvaluateLine", - false, - ); - Self(VmChipWrapper::new(adapter, core, offline_memory)) - } -} - -pub fn evaluate_line_expr( - config: ExprBuilderConfig, - range_bus: VariableRangeCheckerBus, -) -> FieldExpr { - config.check_valid(); - let builder = ExprBuilder::new(config, range_bus.range_max_bits); - let builder = Rc::new(RefCell::new(builder)); - - let mut uneval_b = Fp2::new(builder.clone()); - let mut uneval_c = Fp2::new(builder.clone()); - - let mut x_over_y = ExprBuilder::new_input(builder.clone()); - let mut y_inv = ExprBuilder::new_input(builder.clone()); - - let mut b = uneval_b.scalar_mul(&mut x_over_y); - let mut c = uneval_c.scalar_mul(&mut y_inv); - b.save_output(); - c.save_output(); - - let builder = builder.borrow().clone(); - FieldExpr::new(builder, range_bus, false) -} diff --git a/extensions/pairing/circuit/src/pairing_chip/line/m_type/mod.rs b/extensions/pairing/circuit/src/pairing_chip/line/m_type/mod.rs deleted file mode 100644 index b454d260ce..0000000000 --- a/extensions/pairing/circuit/src/pairing_chip/line/m_type/mod.rs +++ /dev/null @@ -1,8 +0,0 @@ -mod mul_023_by_023; -mod mul_by_02345; - -pub use mul_023_by_023::*; -pub use mul_by_02345::*; - -#[cfg(test)] -mod tests; diff --git a/extensions/pairing/circuit/src/pairing_chip/line/m_type/mul_023_by_023.rs b/extensions/pairing/circuit/src/pairing_chip/line/m_type/mul_023_by_023.rs deleted file mode 100644 index 0d760b886e..0000000000 --- a/extensions/pairing/circuit/src/pairing_chip/line/m_type/mul_023_by_023.rs +++ /dev/null @@ -1,101 +0,0 @@ -use std::{ - cell::RefCell, - rc::Rc, - sync::{Arc, Mutex}, -}; - -use openvm_algebra_circuit::Fp2; -use openvm_circuit::{arch::VmChipWrapper, system::memory::OfflineMemory}; -use openvm_circuit_derive::InstructionExecutor; -use openvm_circuit_primitives::var_range::{ - SharedVariableRangeCheckerChip, VariableRangeCheckerBus, -}; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; -use openvm_mod_circuit_builder::{ - ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreChip, -}; -use openvm_pairing_transpiler::PairingOpcode; -use openvm_rv32_adapters::Rv32VecHeapAdapterChip; -use openvm_stark_backend::p3_field::PrimeField32; - -// Input: line0.b, line0.c, line1.b, line1.c : 2 x 4 field elements -// Output: 5 Fp2 coefficients -> 10 field elements -#[derive(Chip, ChipUsageGetter, InstructionExecutor)] -pub struct EcLineMul023By023Chip< - F: PrimeField32, - const INPUT_BLOCKS: usize, - const OUTPUT_BLOCKS: usize, - const BLOCK_SIZE: usize, ->( - pub VmChipWrapper< - F, - Rv32VecHeapAdapterChip, - FieldExpressionCoreChip, - >, -); - -impl< - F: PrimeField32, - const INPUT_BLOCKS: usize, - const OUTPUT_BLOCKS: usize, - const BLOCK_SIZE: usize, - > EcLineMul023By023Chip -{ - pub fn new( - adapter: Rv32VecHeapAdapterChip, - range_checker: SharedVariableRangeCheckerChip, - config: ExprBuilderConfig, - xi: [isize; 2], - offset: usize, - offline_memory: Arc>>, - ) -> Self { - assert!( - xi[0].unsigned_abs() < 1 << config.limb_bits, - "expect xi to be small" - ); // not a hard rule, but we expect xi to be small - assert!( - xi[1].unsigned_abs() < 1 << config.limb_bits, - "expect xi to be small" - ); - let expr = mul_023_by_023_expr(config, range_checker.bus(), xi); - let core = FieldExpressionCoreChip::new( - expr, - offset, - vec![PairingOpcode::MUL_023_BY_023 as usize], - vec![], - range_checker, - "Mul023By023", - true, - ); - Self(VmChipWrapper::new(adapter, core, offline_memory)) - } -} - -pub fn mul_023_by_023_expr( - config: ExprBuilderConfig, - range_bus: VariableRangeCheckerBus, - xi: [isize; 2], -) -> FieldExpr { - config.check_valid(); - let builder = ExprBuilder::new(config.clone(), range_bus.range_max_bits); - let builder = Rc::new(RefCell::new(builder)); - - let mut b0 = Fp2::new(builder.clone()); // x2 - let mut c0 = Fp2::new(builder.clone()); // x3 - let mut b1 = Fp2::new(builder.clone()); // y2 - let mut c1 = Fp2::new(builder.clone()); // y3 - - // where w⁶ = xi - // l0 * l1 = c0c1 + (c0b1 + c1b0)w² + (c0 + c1)w³ + (b0b1)w⁴ + (b0 +b1)w⁵ + w⁶ - // = (c0c1 + xi) + (c0b1 + c1b0)w² + (c0 + c1)w³ + (b0b1)w⁴ + (b0 + b1)w⁵ - let l0 = c0.mul(&mut c1).int_add(xi); - let l2 = c0.mul(&mut b1).add(&mut c1.mul(&mut b0)); - let l3 = c0.add(&mut c1); - let l4 = b0.mul(&mut b1); - let l5 = b0.add(&mut b1); - - [l0, l2, l3, l4, l5].map(|mut l| l.save_output()); - - let builder = builder.borrow().clone(); - FieldExpr::new(builder, range_bus, false) -} diff --git a/extensions/pairing/circuit/src/pairing_chip/line/m_type/mul_by_02345.rs b/extensions/pairing/circuit/src/pairing_chip/line/m_type/mul_by_02345.rs deleted file mode 100644 index ad0e91e7bd..0000000000 --- a/extensions/pairing/circuit/src/pairing_chip/line/m_type/mul_by_02345.rs +++ /dev/null @@ -1,113 +0,0 @@ -use std::{ - cell::RefCell, - rc::Rc, - sync::{Arc, Mutex}, -}; - -use openvm_algebra_circuit::Fp2; -use openvm_circuit::{arch::VmChipWrapper, system::memory::OfflineMemory}; -use openvm_circuit_derive::InstructionExecutor; -use openvm_circuit_primitives::var_range::{ - SharedVariableRangeCheckerChip, VariableRangeCheckerBus, -}; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; -use openvm_mod_circuit_builder::{ - ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreChip, -}; -use openvm_pairing_transpiler::PairingOpcode; -use openvm_rv32_adapters::Rv32VecHeapTwoReadsAdapterChip; -use openvm_stark_backend::p3_field::PrimeField32; - -use crate::Fp12; - -// Input: 2 Fp12: 2 x 12 field elements -// Output: Fp12 -> 12 field elements -#[derive(Chip, ChipUsageGetter, InstructionExecutor)] -pub struct EcLineMulBy02345Chip< - F: PrimeField32, - const INPUT_BLOCKS1: usize, - const INPUT_BLOCKS2: usize, - const OUTPUT_BLOCKS: usize, - const BLOCK_SIZE: usize, ->( - pub VmChipWrapper< - F, - Rv32VecHeapTwoReadsAdapterChip< - F, - INPUT_BLOCKS1, - INPUT_BLOCKS2, - OUTPUT_BLOCKS, - BLOCK_SIZE, - BLOCK_SIZE, - >, - FieldExpressionCoreChip, - >, -); - -impl< - F: PrimeField32, - const INPUT_BLOCKS1: usize, - const INPUT_BLOCKS2: usize, - const OUTPUT_BLOCKS: usize, - const BLOCK_SIZE: usize, - > EcLineMulBy02345Chip -{ - pub fn new( - adapter: Rv32VecHeapTwoReadsAdapterChip< - F, - INPUT_BLOCKS1, - INPUT_BLOCKS2, - OUTPUT_BLOCKS, - BLOCK_SIZE, - BLOCK_SIZE, - >, - range_checker: SharedVariableRangeCheckerChip, - config: ExprBuilderConfig, - xi: [isize; 2], - offset: usize, - offline_memory: Arc>>, - ) -> Self { - assert!( - xi[0].unsigned_abs() < 1 << config.limb_bits, - "expect xi to be small" - ); // not a hard rule, but we expect xi to be small - assert!( - xi[1].unsigned_abs() < 1 << config.limb_bits, - "expect xi to be small" - ); - let expr = mul_by_02345_expr(config, range_checker.bus(), xi); - let core = FieldExpressionCoreChip::new( - expr, - offset, - vec![PairingOpcode::MUL_BY_02345 as usize], - vec![], - range_checker, - "MulBy02345", - false, - ); - Self(VmChipWrapper::new(adapter, core, offline_memory)) - } -} - -pub fn mul_by_02345_expr( - config: ExprBuilderConfig, - range_bus: VariableRangeCheckerBus, - xi: [isize; 2], -) -> FieldExpr { - config.check_valid(); - let builder = ExprBuilder::new(config.clone(), range_bus.range_max_bits); - let builder = Rc::new(RefCell::new(builder)); - - let mut f = Fp12::new(builder.clone()); - let mut x0 = Fp2::new(builder.clone()); - let mut x2 = Fp2::new(builder.clone()); - let mut x3 = Fp2::new(builder.clone()); - let mut x4 = Fp2::new(builder.clone()); - let mut x5 = Fp2::new(builder.clone()); - - let mut r = f.mul_by_02345(&mut x0, &mut x2, &mut x3, &mut x4, &mut x5, xi); - r.save_output(); - - let builder = builder.borrow().clone(); - FieldExpr::new(builder, range_bus, false) -} diff --git a/extensions/pairing/circuit/src/pairing_chip/line/m_type/tests.rs b/extensions/pairing/circuit/src/pairing_chip/line/m_type/tests.rs deleted file mode 100644 index 4331d2278e..0000000000 --- a/extensions/pairing/circuit/src/pairing_chip/line/m_type/tests.rs +++ /dev/null @@ -1,217 +0,0 @@ -use halo2curves_axiom::{ - bls12_381::{Fq, Fq12, Fq2, G1Affine}, - ff::Field, -}; -use openvm_circuit::arch::testing::{VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}; -use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, -}; -use openvm_ecc_guest::AffinePoint; -use openvm_instructions::{riscv::RV32_CELL_BITS, LocalOpcode}; -use openvm_mod_circuit_builder::{test_utils::*, ExprBuilderConfig}; -use openvm_pairing_guest::{ - bls12_381::{BLS12_381_LIMB_BITS, BLS12_381_MODULUS, BLS12_381_NUM_LIMBS, BLS12_381_XI_ISIZE}, - halo2curves_shims::bls12_381::{tangent_line_023, Bls12_381}, - pairing::LineMulMType, -}; -use openvm_pairing_transpiler::PairingOpcode; -use openvm_rv32_adapters::{ - rv32_write_heap_default_with_increment, Rv32VecHeapAdapterChip, Rv32VecHeapTwoReadsAdapterChip, -}; -use openvm_stark_backend::p3_field::FieldAlgebra; -use openvm_stark_sdk::p3_baby_bear::BabyBear; -use rand::{rngs::StdRng, SeedableRng}; - -use super::*; - -type F = BabyBear; -const NUM_LIMBS: usize = 48; -const LIMB_BITS: usize = 8; -const BLOCK_SIZE: usize = 16; - -#[test] -fn test_mul_023_by_023() { - let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let adapter = Rv32VecHeapAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - bitwise_chip.clone(), - ); - let mut chip = EcLineMul023By023Chip::new( - adapter, - tester.memory_controller().borrow().range_checker.clone(), - ExprBuilderConfig { - modulus: BLS12_381_MODULUS.clone(), - num_limbs: BLS12_381_NUM_LIMBS, - limb_bits: BLS12_381_LIMB_BITS, - }, - BLS12_381_XI_ISIZE, - PairingOpcode::CLASS_OFFSET, - tester.offline_memory_mutex_arc(), - ); - - let mut rng0 = StdRng::seed_from_u64(15); - let mut rng1 = StdRng::seed_from_u64(95); - let rnd_pt_0 = G1Affine::random(&mut rng0); - let rnd_pt_1 = G1Affine::random(&mut rng1); - let ec_pt_0 = AffinePoint:: { - x: rnd_pt_0.x, - y: rnd_pt_0.y, - }; - let ec_pt_1 = AffinePoint:: { - x: rnd_pt_1.x, - y: rnd_pt_1.y, - }; - let line0 = tangent_line_023::(ec_pt_0); - let line1 = tangent_line_023::(ec_pt_1); - let input_line0 = [ - bls12381_fq2_to_biguint_vec(line0.b), - bls12381_fq2_to_biguint_vec(line0.c), - ] - .concat(); - let input_line1 = [ - bls12381_fq2_to_biguint_vec(line1.b), - bls12381_fq2_to_biguint_vec(line1.c), - ] - .concat(); - - let vars = chip - .0 - .core - .expr() - .execute([input_line0.clone(), input_line1.clone()].concat(), vec![]); - let output_indices = chip.0.core.expr().builder.output_indices.clone(); - let output = output_indices - .iter() - .map(|i| vars[*i].clone()) - .collect::>(); - assert_eq!(output.len(), 10); - - let r_cmp = Bls12_381::mul_023_by_023(&line0, &line1); - let r_cmp_bigint = r_cmp - .map(|x| [bls12381_fq_to_biguint(x.c0), bls12381_fq_to_biguint(x.c1)]) - .concat(); - - for i in 0..10 { - assert_eq!(output[i], r_cmp_bigint[i]); - } - - let input_line0_limbs = input_line0 - .iter() - .map(|x| { - biguint_to_limbs::(x.clone(), LIMB_BITS).map(BabyBear::from_canonical_u32) - }) - .collect::>(); - let input_line1_limbs = input_line1 - .iter() - .map(|x| { - biguint_to_limbs::(x.clone(), LIMB_BITS).map(BabyBear::from_canonical_u32) - }) - .collect::>(); - - let instruction = rv32_write_heap_default_with_increment( - &mut tester, - input_line0_limbs, - input_line1_limbs, - 512, - chip.0.core.air.offset + PairingOpcode::MUL_023_BY_023 as usize, - ); - - tester.execute(&mut chip, &instruction); - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); - tester.simple_test().expect("Verification failed"); -} - -// NOTE[yj]: this test requires `RUST_MIN_STACK=8388608` to run otherwise it will overflow the stack -#[test] -#[ignore] -fn test_mul_by_02345() { - let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let adapter = Rv32VecHeapTwoReadsAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - bitwise_chip.clone(), - ); - let mut chip = EcLineMulBy02345Chip::new( - adapter, - tester.memory_controller().borrow().range_checker.clone(), - ExprBuilderConfig { - modulus: BLS12_381_MODULUS.clone(), - num_limbs: BLS12_381_NUM_LIMBS, - limb_bits: BLS12_381_LIMB_BITS, - }, - BLS12_381_XI_ISIZE, - PairingOpcode::CLASS_OFFSET, - tester.offline_memory_mutex_arc(), - ); - - let mut rng = StdRng::seed_from_u64(19); - let f = Fq12::random(&mut rng); - let x0 = Fq2::random(&mut rng); - let x2 = Fq2::random(&mut rng); - let x3 = Fq2::random(&mut rng); - let x4 = Fq2::random(&mut rng); - let x5 = Fq2::random(&mut rng); - - let input_f = bls12381_fq12_to_biguint_vec(f); - let input_x = [ - bls12381_fq2_to_biguint_vec(x0), - bls12381_fq2_to_biguint_vec(x2), - bls12381_fq2_to_biguint_vec(x3), - bls12381_fq2_to_biguint_vec(x4), - bls12381_fq2_to_biguint_vec(x5), - ] - .concat(); - - let vars = chip - .0 - .core - .expr() - .execute([input_f.clone(), input_x.clone()].concat(), vec![]); - let output_indices = chip.0.core.expr().builder.output_indices.clone(); - let output = output_indices - .iter() - .map(|i| vars[*i].clone()) - .collect::>(); - assert_eq!(output.len(), 12); - - let r_cmp = Bls12_381::mul_by_02345(&f, &[x0, x2, x3, x4, x5]); - let r_cmp_bigint = bls12381_fq12_to_biguint_vec(r_cmp); - - for i in 0..12 { - assert_eq!(output[i], r_cmp_bigint[i]); - } - - let input_f_limbs = input_f - .iter() - .map(|x| { - biguint_to_limbs::(x.clone(), LIMB_BITS).map(BabyBear::from_canonical_u32) - }) - .collect::>(); - let input_x_limbs = input_x - .iter() - .map(|x| { - biguint_to_limbs::(x.clone(), LIMB_BITS).map(BabyBear::from_canonical_u32) - }) - .collect::>(); - - let instruction = rv32_write_heap_default_with_increment( - &mut tester, - input_f_limbs, - input_x_limbs, - 1024, - chip.0.core.air.offset + PairingOpcode::MUL_BY_02345 as usize, - ); - - tester.execute(&mut chip, &instruction); - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); - tester.simple_test().expect("Verification failed"); -} diff --git a/extensions/pairing/circuit/src/pairing_chip/line/mod.rs b/extensions/pairing/circuit/src/pairing_chip/line/mod.rs deleted file mode 100644 index acf02c72be..0000000000 --- a/extensions/pairing/circuit/src/pairing_chip/line/mod.rs +++ /dev/null @@ -1,7 +0,0 @@ -mod d_type; -mod evaluate_line; -mod m_type; - -pub use d_type::*; -pub use evaluate_line::*; -pub use m_type::*; diff --git a/extensions/pairing/circuit/src/pairing_chip/miller_double_and_add_step.rs b/extensions/pairing/circuit/src/pairing_chip/miller_double_and_add_step.rs deleted file mode 100644 index 77084428c9..0000000000 --- a/extensions/pairing/circuit/src/pairing_chip/miller_double_and_add_step.rs +++ /dev/null @@ -1,215 +0,0 @@ -use std::{ - cell::RefCell, - rc::Rc, - sync::{Arc, Mutex}, -}; - -use openvm_algebra_circuit::Fp2; -use openvm_circuit::{arch::VmChipWrapper, system::memory::OfflineMemory}; -use openvm_circuit_derive::InstructionExecutor; -use openvm_circuit_primitives::var_range::{ - SharedVariableRangeCheckerChip, VariableRangeCheckerBus, -}; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; -use openvm_mod_circuit_builder::{ - ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreChip, -}; -use openvm_pairing_transpiler::PairingOpcode; -use openvm_rv32_adapters::Rv32VecHeapAdapterChip; -use openvm_stark_backend::p3_field::PrimeField32; - -// Input: two AffinePoint: 4 field elements each -// Output: (AffinePoint, UnevaluatedLine, UnevaluatedLine) -> 2*2 + 2*2 + 2*2 = 12 -// field elements -#[derive(Chip, ChipUsageGetter, InstructionExecutor)] -pub struct MillerDoubleAndAddStepChip< - F: PrimeField32, - const INPUT_BLOCKS: usize, - const OUTPUT_BLOCKS: usize, - const BLOCK_SIZE: usize, ->( - pub VmChipWrapper< - F, - Rv32VecHeapAdapterChip, - FieldExpressionCoreChip, - >, -); - -impl< - F: PrimeField32, - const INPUT_BLOCKS: usize, - const OUTPUT_BLOCKS: usize, - const BLOCK_SIZE: usize, - > MillerDoubleAndAddStepChip -{ - pub fn new( - adapter: Rv32VecHeapAdapterChip, - config: ExprBuilderConfig, - offset: usize, - range_checker: SharedVariableRangeCheckerChip, - offline_memory: Arc>>, - ) -> Self { - let expr = miller_double_and_add_step_expr(config, range_checker.bus()); - let core = FieldExpressionCoreChip::new( - expr, - offset, - vec![PairingOpcode::MILLER_DOUBLE_AND_ADD_STEP as usize], - vec![], - range_checker, - "MillerDoubleAndAddStep", - false, - ); - Self(VmChipWrapper::new(adapter, core, offline_memory)) - } -} - -// Ref: openvm_pairing_guest::miller_step -pub fn miller_double_and_add_step_expr( - config: ExprBuilderConfig, - range_bus: VariableRangeCheckerBus, -) -> FieldExpr { - config.check_valid(); - let builder = ExprBuilder::new(config, range_bus.range_max_bits); - let builder = Rc::new(RefCell::new(builder)); - - let mut x_s = Fp2::new(builder.clone()); - let mut y_s = Fp2::new(builder.clone()); - let mut x_q = Fp2::new(builder.clone()); - let mut y_q = Fp2::new(builder.clone()); - - // λ1 = (y_s - y_q) / (x_s - x_q) - let mut lambda1 = y_s.sub(&mut y_q).div(&mut x_s.sub(&mut x_q)); - let mut x_sq = lambda1.square().sub(&mut x_s).sub(&mut x_q); - // λ2 = -λ1 - 2y_s / (x_{s+q} - x_s) - let mut lambda2 = lambda1 - .neg() - .sub(&mut y_s.int_mul([2, 0]).div(&mut x_sq.sub(&mut x_s))); - let mut x_sqs = lambda2.square().sub(&mut x_s).sub(&mut x_sq); - let mut y_sqs = lambda2.mul(&mut (x_s.sub(&mut x_sqs))).sub(&mut y_s); - - x_sqs.save_output(); - y_sqs.save_output(); - - let mut b0 = lambda1.neg(); - let mut c0 = lambda1.mul(&mut x_s).sub(&mut y_s); - b0.save_output(); - c0.save_output(); - - let mut b1 = lambda2.neg(); - let mut c1 = lambda2.mul(&mut x_s).sub(&mut y_s); - b1.save_output(); - c1.save_output(); - - let builder = builder.borrow().clone(); - FieldExpr::new(builder, range_bus, false) -} - -#[cfg(test)] -mod tests { - use halo2curves_axiom::bn256::G2Affine; - use openvm_circuit::arch::testing::{VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}; - use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, - }; - use openvm_ecc_guest::AffinePoint; - use openvm_instructions::{riscv::RV32_CELL_BITS, LocalOpcode}; - use openvm_mod_circuit_builder::test_utils::{biguint_to_limbs, bn254_fq_to_biguint}; - use openvm_pairing_guest::{ - bn254::BN254_MODULUS, halo2curves_shims::bn254::Bn254, pairing::MillerStep, - }; - use openvm_pairing_transpiler::PairingOpcode; - use openvm_rv32_adapters::{rv32_write_heap_default, Rv32VecHeapAdapterChip}; - use openvm_stark_backend::p3_field::FieldAlgebra; - use openvm_stark_sdk::p3_baby_bear::BabyBear; - use rand::{rngs::StdRng, SeedableRng}; - - use super::*; - - type F = BabyBear; - const NUM_LIMBS: usize = 32; - const LIMB_BITS: usize = 8; - const BLOCK_SIZE: usize = 32; - - #[test] - #[allow(non_snake_case)] - fn test_miller_double_and_add() { - let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let adapter = Rv32VecHeapAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - bitwise_chip.clone(), - ); - let mut chip = MillerDoubleAndAddStepChip::new( - adapter, - ExprBuilderConfig { - modulus: BN254_MODULUS.clone(), - limb_bits: LIMB_BITS, - num_limbs: NUM_LIMBS, - }, - PairingOpcode::CLASS_OFFSET, - tester.range_checker(), - tester.offline_memory_mutex_arc(), - ); - - let mut rng0 = StdRng::seed_from_u64(2); - let Q = G2Affine::random(&mut rng0); - let Q2 = G2Affine::random(&mut rng0); - let inputs = [ - Q.x.c0, Q.x.c1, Q.y.c0, Q.y.c1, Q2.x.c0, Q2.x.c1, Q2.y.c0, Q2.y.c1, - ] - .map(bn254_fq_to_biguint); - - let Q_ecpoint = AffinePoint { x: Q.x, y: Q.y }; - let Q_ecpoint2 = AffinePoint { x: Q2.x, y: Q2.y }; - let (Q_daa, l_qa, l_sqs) = Bn254::miller_double_and_add_step(&Q_ecpoint, &Q_ecpoint2); - let result = chip - .0 - .core - .expr() - .execute_with_output(inputs.to_vec(), vec![]); - assert_eq!(result.len(), 12); // AffinePoint and 4 Fp2 coefficients - assert_eq!(result[0], bn254_fq_to_biguint(Q_daa.x.c0)); - assert_eq!(result[1], bn254_fq_to_biguint(Q_daa.x.c1)); - assert_eq!(result[2], bn254_fq_to_biguint(Q_daa.y.c0)); - assert_eq!(result[3], bn254_fq_to_biguint(Q_daa.y.c1)); - assert_eq!(result[4], bn254_fq_to_biguint(l_qa.b.c0)); - assert_eq!(result[5], bn254_fq_to_biguint(l_qa.b.c1)); - assert_eq!(result[6], bn254_fq_to_biguint(l_qa.c.c0)); - assert_eq!(result[7], bn254_fq_to_biguint(l_qa.c.c1)); - assert_eq!(result[8], bn254_fq_to_biguint(l_sqs.b.c0)); - assert_eq!(result[9], bn254_fq_to_biguint(l_sqs.b.c1)); - assert_eq!(result[10], bn254_fq_to_biguint(l_sqs.c.c0)); - assert_eq!(result[11], bn254_fq_to_biguint(l_sqs.c.c1)); - - let input1_limbs = inputs[0..4] - .iter() - .map(|x| { - biguint_to_limbs::(x.clone(), LIMB_BITS) - .map(BabyBear::from_canonical_u32) - }) - .collect::>(); - - let input2_limbs = inputs[4..8] - .iter() - .map(|x| { - biguint_to_limbs::(x.clone(), LIMB_BITS) - .map(BabyBear::from_canonical_u32) - }) - .collect::>(); - - let instruction = rv32_write_heap_default( - &mut tester, - input1_limbs, - input2_limbs, - chip.0.core.air.offset + PairingOpcode::MILLER_DOUBLE_AND_ADD_STEP as usize, - ); - - tester.execute(&mut chip, &instruction); - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); - tester.simple_test().expect("Verification failed"); - } -} diff --git a/extensions/pairing/circuit/src/pairing_chip/miller_double_step.rs b/extensions/pairing/circuit/src/pairing_chip/miller_double_step.rs deleted file mode 100644 index 519eb473a5..0000000000 --- a/extensions/pairing/circuit/src/pairing_chip/miller_double_step.rs +++ /dev/null @@ -1,253 +0,0 @@ -use std::{ - cell::RefCell, - rc::Rc, - sync::{Arc, Mutex}, -}; - -use openvm_algebra_circuit::Fp2; -use openvm_circuit::{arch::VmChipWrapper, system::memory::OfflineMemory}; -use openvm_circuit_derive::InstructionExecutor; -use openvm_circuit_primitives::var_range::{ - SharedVariableRangeCheckerChip, VariableRangeCheckerBus, -}; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; -use openvm_mod_circuit_builder::{ - ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreChip, -}; -use openvm_pairing_transpiler::PairingOpcode; -use openvm_rv32_adapters::Rv32VecHeapAdapterChip; -use openvm_stark_backend::p3_field::PrimeField32; - -// Input: AffinePoint: 4 field elements -// Output: (AffinePoint, Fp2, Fp2) -> 8 field elements -#[derive(Chip, ChipUsageGetter, InstructionExecutor)] -pub struct MillerDoubleStepChip< - F: PrimeField32, - const INPUT_BLOCKS: usize, - const OUTPUT_BLOCKS: usize, - const BLOCK_SIZE: usize, ->( - VmChipWrapper< - F, - Rv32VecHeapAdapterChip, - FieldExpressionCoreChip, - >, -); - -impl< - F: PrimeField32, - const INPUT_BLOCKS: usize, - const OUTPUT_BLOCKS: usize, - const BLOCK_SIZE: usize, - > MillerDoubleStepChip -{ - pub fn new( - adapter: Rv32VecHeapAdapterChip, - config: ExprBuilderConfig, - offset: usize, - range_checker: SharedVariableRangeCheckerChip, - offline_memory: Arc>>, - ) -> Self { - let expr = miller_double_step_expr(config, range_checker.bus()); - let core = FieldExpressionCoreChip::new( - expr, - offset, - vec![PairingOpcode::MILLER_DOUBLE_STEP as usize], - vec![], - range_checker, - "MillerDoubleStep", - false, - ); - Self(VmChipWrapper::new(adapter, core, offline_memory)) - } -} - -// Ref: https://github.com/openvm-org/openvm/blob/f7d6fa7b8ef247e579740eb652fcdf5a04259c28/lib/ecc-execution/src/common/miller_step.rs#L7 -pub fn miller_double_step_expr( - config: ExprBuilderConfig, - range_bus: VariableRangeCheckerBus, -) -> FieldExpr { - config.check_valid(); - let builder = ExprBuilder::new(config, range_bus.range_max_bits); - let builder = Rc::new(RefCell::new(builder)); - - let mut x_s = Fp2::new(builder.clone()); - let mut y_s = Fp2::new(builder.clone()); - - let mut three_x_square = x_s.square().int_mul([3, 0]); - let mut lambda = three_x_square.div(&mut y_s.int_mul([2, 0])); - let mut x_2s = lambda.square().sub(&mut x_s.int_mul([2, 0])); - let mut y_2s = lambda.mul(&mut (x_s.sub(&mut x_2s))).sub(&mut y_s); - x_2s.save_output(); - y_2s.save_output(); - - let mut b = lambda.neg(); - let mut c = lambda.mul(&mut x_s).sub(&mut y_s); - b.save_output(); - c.save_output(); - - let builder = builder.borrow().clone(); - FieldExpr::new(builder, range_bus, false) -} - -#[cfg(test)] -mod tests { - use openvm_circuit::arch::testing::{VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}; - use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, - }; - use openvm_ecc_guest::AffinePoint; - use openvm_instructions::{riscv::RV32_CELL_BITS, LocalOpcode}; - use openvm_mod_circuit_builder::test_utils::{ - biguint_to_limbs, bls12381_fq_to_biguint, bn254_fq_to_biguint, - }; - use openvm_pairing_guest::{ - bls12_381::{BLS12_381_LIMB_BITS, BLS12_381_MODULUS, BLS12_381_NUM_LIMBS}, - bn254::{BN254_LIMB_BITS, BN254_MODULUS, BN254_NUM_LIMBS}, - halo2curves_shims::{bls12_381::Bls12_381, bn254::Bn254}, - pairing::MillerStep, - }; - use openvm_pairing_transpiler::PairingOpcode; - use openvm_rv32_adapters::{rv32_write_heap_default, Rv32VecHeapAdapterChip}; - use openvm_stark_backend::p3_field::FieldAlgebra; - use openvm_stark_sdk::p3_baby_bear::BabyBear; - use rand::{rngs::StdRng, SeedableRng}; - - use super::*; - - type F = BabyBear; - - #[test] - #[allow(non_snake_case)] - fn test_miller_double_bn254() { - use halo2curves_axiom::bn256::G2Affine; - const NUM_LIMBS: usize = 32; - const LIMB_BITS: usize = 8; - const BLOCK_SIZE: usize = 32; - - let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let config = ExprBuilderConfig { - modulus: BN254_MODULUS.clone(), - limb_bits: BN254_LIMB_BITS, - num_limbs: BN254_NUM_LIMBS, - }; - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let adapter = Rv32VecHeapAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - bitwise_chip.clone(), - ); - let mut chip = MillerDoubleStepChip::new( - adapter, - config, - PairingOpcode::CLASS_OFFSET, - tester.range_checker(), - tester.offline_memory_mutex_arc(), - ); - - let mut rng0 = StdRng::seed_from_u64(2); - let Q = G2Affine::random(&mut rng0); - let inputs = [Q.x.c0, Q.x.c1, Q.y.c0, Q.y.c1].map(bn254_fq_to_biguint); - - let Q_ecpoint = AffinePoint { x: Q.x, y: Q.y }; - let (Q_acc_init, l_init) = Bn254::miller_double_step(&Q_ecpoint); - let result = chip - .0 - .core - .expr() - .execute_with_output(inputs.to_vec(), vec![]); - assert_eq!(result.len(), 8); // AffinePoint and two Fp2 coefficients - assert_eq!(result[0], bn254_fq_to_biguint(Q_acc_init.x.c0)); - assert_eq!(result[1], bn254_fq_to_biguint(Q_acc_init.x.c1)); - assert_eq!(result[2], bn254_fq_to_biguint(Q_acc_init.y.c0)); - assert_eq!(result[3], bn254_fq_to_biguint(Q_acc_init.y.c1)); - assert_eq!(result[4], bn254_fq_to_biguint(l_init.b.c0)); - assert_eq!(result[5], bn254_fq_to_biguint(l_init.b.c1)); - assert_eq!(result[6], bn254_fq_to_biguint(l_init.c.c0)); - assert_eq!(result[7], bn254_fq_to_biguint(l_init.c.c1)); - - let input_limbs = inputs - .map(|x| biguint_to_limbs::(x, LIMB_BITS).map(BabyBear::from_canonical_u32)); - - let instruction = rv32_write_heap_default( - &mut tester, - input_limbs.to_vec(), - vec![], - chip.0.core.air.offset + PairingOpcode::MILLER_DOUBLE_STEP as usize, - ); - - tester.execute(&mut chip, &instruction); - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); - tester.simple_test().expect("Verification failed"); - } - - #[test] - #[allow(non_snake_case)] - fn test_miller_double_bls12_381() { - use halo2curves_axiom::bls12_381::G2Affine; - const NUM_LIMBS: usize = 48; - const LIMB_BITS: usize = 8; - const BLOCK_SIZE: usize = 16; - - let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let config = ExprBuilderConfig { - modulus: BLS12_381_MODULUS.clone(), - limb_bits: BLS12_381_LIMB_BITS, - num_limbs: BLS12_381_NUM_LIMBS, - }; - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let adapter = Rv32VecHeapAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - bitwise_chip.clone(), - ); - let mut chip = MillerDoubleStepChip::new( - adapter, - config, - PairingOpcode::CLASS_OFFSET, - tester.range_checker(), - tester.offline_memory_mutex_arc(), - ); - - let mut rng0 = StdRng::seed_from_u64(12); - let Q = G2Affine::random(&mut rng0); - let inputs = [Q.x.c0, Q.x.c1, Q.y.c0, Q.y.c1].map(bls12381_fq_to_biguint); - - let Q_ecpoint = AffinePoint { x: Q.x, y: Q.y }; - let (Q_acc_init, l_init) = Bls12_381::miller_double_step(&Q_ecpoint); - let result = chip - .0 - .core - .expr() - .execute_with_output(inputs.to_vec(), vec![]); - assert_eq!(result.len(), 8); // AffinePoint and two Fp2 coefficients - assert_eq!(result[0], bls12381_fq_to_biguint(Q_acc_init.x.c0)); - assert_eq!(result[1], bls12381_fq_to_biguint(Q_acc_init.x.c1)); - assert_eq!(result[2], bls12381_fq_to_biguint(Q_acc_init.y.c0)); - assert_eq!(result[3], bls12381_fq_to_biguint(Q_acc_init.y.c1)); - assert_eq!(result[4], bls12381_fq_to_biguint(l_init.b.c0)); - assert_eq!(result[5], bls12381_fq_to_biguint(l_init.b.c1)); - assert_eq!(result[6], bls12381_fq_to_biguint(l_init.c.c0)); - assert_eq!(result[7], bls12381_fq_to_biguint(l_init.c.c1)); - - let input_limbs = inputs - .map(|x| biguint_to_limbs::(x, LIMB_BITS).map(BabyBear::from_canonical_u32)); - - let instruction = rv32_write_heap_default( - &mut tester, - input_limbs.to_vec(), - vec![], - chip.0.core.air.offset + PairingOpcode::MILLER_DOUBLE_STEP as usize, - ); - - tester.execute(&mut chip, &instruction); - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); - tester.simple_test().expect("Verification failed"); - } -} diff --git a/extensions/pairing/circuit/src/pairing_chip/mod.rs b/extensions/pairing/circuit/src/pairing_chip/mod.rs deleted file mode 100644 index df00df16ce..0000000000 --- a/extensions/pairing/circuit/src/pairing_chip/mod.rs +++ /dev/null @@ -1,8 +0,0 @@ -mod line; -mod miller_double_step; - -pub use line::*; -pub use miller_double_step::*; - -mod miller_double_and_add_step; -pub use miller_double_and_add_step::*; diff --git a/extensions/pairing/circuit/src/pairing_extension.rs b/extensions/pairing/circuit/src/pairing_extension.rs index c75687f404..f700ca4dc5 100644 --- a/extensions/pairing/circuit/src/pairing_extension.rs +++ b/extensions/pairing/circuit/src/pairing_extension.rs @@ -5,7 +5,7 @@ use openvm_circuit::{ arch::{VmExtension, VmInventory, VmInventoryBuilder, VmInventoryError}, system::phantom::PhantomChip, }; -use openvm_circuit_derive::{AnyEnum, InstructionExecutor}; +use openvm_circuit_derive::{AnyEnum, InsExecutorE1, InsExecutorE2, InstructionExecutor}; use openvm_circuit_primitives::bitwise_op_lookup::SharedBitwiseOperationLookupChip; use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; use openvm_ecc_circuit::CurveConfig; @@ -21,8 +21,6 @@ use openvm_stark_backend::p3_field::PrimeField32; use serde::{Deserialize, Serialize}; use strum::FromRepr; -use super::*; - // All the supported pairing curves. #[derive(Clone, Copy, Debug, FromRepr, Serialize, Deserialize)] #[repr(usize)] @@ -64,14 +62,9 @@ pub struct PairingExtension { pub supported_curves: Vec, } -#[derive(Chip, ChipUsageGetter, InstructionExecutor, AnyEnum)] +#[derive(Chip, ChipUsageGetter, InstructionExecutor, AnyEnum, InsExecutorE1, InsExecutorE2)] pub enum PairingExtensionExecutor { - // bn254 (32 limbs) - MillerDoubleAndAddStepRv32_32(MillerDoubleAndAddStepChip), - EvaluateLineRv32_32(EvaluateLineChip), - // bls12-381 (48 limbs) - MillerDoubleAndAddStepRv32_48(MillerDoubleAndAddStepChip), - EvaluateLineRv32_48(EvaluateLineChip), + Phantom(PhantomChip), } #[derive(ChipUsageGetter, Chip, AnyEnum, From)] @@ -106,7 +99,7 @@ pub(crate) mod phantom { use halo2curves_axiom::ff; use openvm_circuit::{ arch::{PhantomSubExecutor, Streams}, - system::memory::MemoryController, + system::memory::online::GuestMemory, }; use openvm_ecc_guest::{algebra::field::FieldExtension, AffinePoint}; use openvm_instructions::{ @@ -118,8 +111,9 @@ pub(crate) mod phantom { bn254::BN254_NUM_LIMBS, pairing::{FinalExp, MultiMillerLoop}, }; - use openvm_rv32im_circuit::adapters::{compose, unsafe_read_rv32_register}; + use openvm_rv32im_circuit::adapters::{memory_read, read_rv32_register}; use openvm_stark_backend::p3_field::PrimeField32; + use rand::rngs::StdRng; use super::PairingCurve; @@ -127,44 +121,42 @@ pub(crate) mod phantom { impl PhantomSubExecutor for PairingHintSubEx { fn phantom_execute( - &mut self, - memory: &MemoryController, + &self, + memory: &GuestMemory, streams: &mut Streams, + _: &mut StdRng, _: PhantomDiscriminant, - a: F, - b: F, + a: u32, + b: u32, c_upper: u16, ) -> eyre::Result<()> { - let rs1 = unsafe_read_rv32_register(memory, a); - let rs2 = unsafe_read_rv32_register(memory, b); + let rs1 = read_rv32_register(memory, a); + let rs2 = read_rv32_register(memory, b); hint_pairing(memory, &mut streams.hint_stream, rs1, rs2, c_upper) } } fn hint_pairing( - memory: &MemoryController, + memory: &GuestMemory, hint_stream: &mut VecDeque, rs1: u32, rs2: u32, c_upper: u16, ) -> eyre::Result<()> { - let p_ptr = compose(memory.unsafe_read( - F::from_canonical_u32(RV32_MEMORY_AS), - F::from_canonical_u32(rs1), - )); + let p_ptr = u32::from_le_bytes(memory_read(memory, RV32_MEMORY_AS, rs1)); // len in bytes - let p_len = compose(memory.unsafe_read( - F::from_canonical_u32(RV32_MEMORY_AS), - F::from_canonical_u32(rs1 + RV32_REGISTER_NUM_LIMBS as u32), - )); - let q_ptr = compose(memory.unsafe_read( - F::from_canonical_u32(RV32_MEMORY_AS), - F::from_canonical_u32(rs2), + let p_len = u32::from_le_bytes(memory_read( + memory, + RV32_MEMORY_AS, + rs1 + RV32_REGISTER_NUM_LIMBS as u32, )); + + let q_ptr = u32::from_le_bytes(memory_read(memory, RV32_MEMORY_AS, rs2)); // len in bytes - let q_len = compose(memory.unsafe_read( - F::from_canonical_u32(RV32_MEMORY_AS), - F::from_canonical_u32(rs2 + RV32_REGISTER_NUM_LIMBS as u32), + let q_len = u32::from_le_bytes(memory_read( + memory, + RV32_MEMORY_AS, + rs2 + RV32_REGISTER_NUM_LIMBS as u32, )); match PairingCurve::from_repr(c_upper as usize) { @@ -178,8 +170,8 @@ pub(crate) mod phantom { let p = (0..p_len) .map(|i| -> eyre::Result<_> { let ptr = p_ptr + i * 2 * (N as u32); - let x = read_fp::(memory, ptr)?; - let y = read_fp::(memory, ptr + N as u32)?; + let x = read_fp::(memory, ptr)?; + let y = read_fp::(memory, ptr + N as u32)?; Ok(AffinePoint::new(x, y)) }) .collect::>>()?; @@ -187,8 +179,8 @@ pub(crate) mod phantom { .map(|i| -> eyre::Result<_> { let mut ptr = q_ptr + i * 4 * (N as u32); let mut read_fp2 = || -> eyre::Result<_> { - let c0 = read_fp::(memory, ptr)?; - let c1 = read_fp::(memory, ptr + N as u32)?; + let c0 = read_fp::(memory, ptr)?; + let c1 = read_fp::(memory, ptr + N as u32)?; ptr += 2 * N as u32; Ok(Fq2::new(c0, c1)) }; @@ -220,8 +212,8 @@ pub(crate) mod phantom { let p = (0..p_len) .map(|i| -> eyre::Result<_> { let ptr = p_ptr + i * 2 * (N as u32); - let x = read_fp::(memory, ptr)?; - let y = read_fp::(memory, ptr + N as u32)?; + let x = read_fp::(memory, ptr)?; + let y = read_fp::(memory, ptr + N as u32)?; Ok(AffinePoint::new(x, y)) }) .collect::>>()?; @@ -229,8 +221,8 @@ pub(crate) mod phantom { .map(|i| -> eyre::Result<_> { let mut ptr = q_ptr + i * 4 * (N as u32); let mut read_fp2 = || -> eyre::Result<_> { - let c0 = read_fp::(memory, ptr)?; - let c1 = read_fp::(memory, ptr + N as u32)?; + let c0 = read_fp::(memory, ptr)?; + let c1 = read_fp::(memory, ptr + N as u32)?; ptr += 2 * N as u32; Ok(Fq2 { c0, c1 }) }; @@ -259,24 +251,21 @@ pub(crate) mod phantom { Ok(()) } - fn read_fp( - memory: &MemoryController, + fn read_fp( + memory: &GuestMemory, ptr: u32, ) -> eyre::Result where Fp::Repr: From<[u8; N]>, { - let mut repr = [0u8; N]; - for (i, byte) in repr.iter_mut().enumerate() { - *byte = memory - .unsafe_read_cell( - F::from_canonical_u32(RV32_MEMORY_AS), - F::from_canonical_u32(ptr + i as u32), - ) - .as_canonical_u32() - .try_into()?; - } - Fp::from_repr(repr.into()) + let repr: &[u8; N] = unsafe { + memory + .memory + .get_slice::((RV32_MEMORY_AS, ptr), N) + .try_into() + .unwrap() + }; + Fp::from_repr((*repr).into()) .into_option() .ok_or(eyre::eyre!("bad ff::PrimeField repr")) } diff --git a/extensions/pairing/transpiler/Cargo.toml b/extensions/pairing/transpiler/Cargo.toml index a5557b03d1..9ce32bc85c 100644 --- a/extensions/pairing/transpiler/Cargo.toml +++ b/extensions/pairing/transpiler/Cargo.toml @@ -14,4 +14,3 @@ openvm-transpiler = { workspace = true } rrs-lib = { workspace = true } strum = { workspace = true } openvm-pairing-guest = { workspace = true } -openvm-instructions-derive = { workspace = true } diff --git a/extensions/pairing/transpiler/src/lib.rs b/extensions/pairing/transpiler/src/lib.rs index 7777c37c91..e80deaf154 100644 --- a/extensions/pairing/transpiler/src/lib.rs +++ b/extensions/pairing/transpiler/src/lib.rs @@ -1,71 +1,11 @@ use openvm_instructions::{ - instruction::Instruction, riscv::RV32_REGISTER_NUM_LIMBS, LocalOpcode, PhantomDiscriminant, + instruction::Instruction, riscv::RV32_REGISTER_NUM_LIMBS, PhantomDiscriminant, }; -use openvm_instructions_derive::LocalOpcode; use openvm_pairing_guest::{PairingBaseFunct7, OPCODE, PAIRING_FUNCT3}; use openvm_stark_backend::p3_field::PrimeField32; use openvm_transpiler::{TranspilerExtension, TranspilerOutput}; use rrs_lib::instruction_formats::RType; -use strum::{EnumCount, EnumIter, FromRepr}; - -// NOTE: the following opcodes are enabled only in testing and not enabled in the VM Extension -#[derive( - Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, EnumCount, EnumIter, FromRepr, LocalOpcode, -)] -#[opcode_offset = 0x750] -#[repr(usize)] -#[allow(non_camel_case_types)] -pub enum PairingOpcode { - MILLER_DOUBLE_AND_ADD_STEP, - MILLER_DOUBLE_STEP, - EVALUATE_LINE, - MUL_013_BY_013, - MUL_023_BY_023, - MUL_BY_01234, - MUL_BY_02345, -} - -// NOTE: Fp12 opcodes are only enabled in testing and not enabled in the VM Extension -#[derive( - Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, EnumCount, EnumIter, FromRepr, LocalOpcode, -)] -#[opcode_offset = 0x700] -#[repr(usize)] -#[allow(non_camel_case_types)] -pub enum Fp12Opcode { - ADD, - SUB, - MUL, -} -const FP12_OPS: usize = 4; - -pub struct Bn254Fp12Opcode(Fp12Opcode); - -impl LocalOpcode for Bn254Fp12Opcode { - const CLASS_OFFSET: usize = Fp12Opcode::CLASS_OFFSET; - - fn from_usize(value: usize) -> Self { - Self(Fp12Opcode::from_usize(value)) - } - - fn local_usize(&self) -> usize { - self.0.local_usize() - } -} - -pub struct Bls12381Fp12Opcode(Fp12Opcode); - -impl LocalOpcode for Bls12381Fp12Opcode { - const CLASS_OFFSET: usize = Fp12Opcode::CLASS_OFFSET + FP12_OPS; - - fn from_usize(value: usize) -> Self { - Self(Fp12Opcode::from_usize(value - FP12_OPS)) - } - - fn local_usize(&self) -> usize { - self.0.local_usize() + FP12_OPS - } -} +use strum::FromRepr; #[derive(Copy, Clone, Debug, PartialEq, Eq, FromRepr)] #[repr(u16)] diff --git a/extensions/rv32-adapters/Cargo.toml b/extensions/rv32-adapters/Cargo.toml index adf133555b..54ec529e2c 100644 --- a/extensions/rv32-adapters/Cargo.toml +++ b/extensions/rv32-adapters/Cargo.toml @@ -19,9 +19,6 @@ openvm-instructions = { workspace = true } itertools.workspace = true derive-new.workspace = true rand.workspace = true -serde = { workspace = true, features = ["derive"] } -serde-big-array.workspace = true -serde_with.workspace = true [dev-dependencies] openvm-stark-sdk = { workspace = true } diff --git a/extensions/rv32-adapters/src/eq_mod.rs b/extensions/rv32-adapters/src/eq_mod.rs index ab80481f19..6d67b6114e 100644 --- a/extensions/rv32-adapters/src/eq_mod.rs +++ b/extensions/rv32-adapters/src/eq_mod.rs @@ -1,26 +1,26 @@ use std::{ array::from_fn, borrow::{Borrow, BorrowMut}, - marker::PhantomData, }; use itertools::izip; use openvm_circuit::{ arch::{ - AdapterAirContext, AdapterRuntimeContext, BasicAdapterInterface, ExecutionBridge, - ExecutionBus, ExecutionState, MinimalInstruction, Result, VmAdapterAir, VmAdapterChip, - VmAdapterInterface, + get_record_from_slice, AdapterAirContext, AdapterTraceFiller, AdapterTraceStep, + BasicAdapterInterface, ExecutionBridge, ExecutionState, MinimalInstruction, VmAdapterAir, }, - system::{ - memory::{ - offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols}, - MemoryAddress, MemoryController, OfflineMemory, RecordId, + system::memory::{ + offline_checker::{ + MemoryBridge, MemoryReadAuxCols, MemoryReadAuxRecord, MemoryWriteAuxCols, + MemoryWriteBytesAuxRecord, }, - program::ProgramBus, + online::TracingMemory, + MemoryAddress, MemoryAuxColsFactory, }, }; -use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, +use openvm_circuit_primitives::{ + bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, + AlignedBytesBorrow, }; use openvm_circuit_primitives_derive::AlignedBorrow; use openvm_instructions::{ @@ -29,16 +29,13 @@ use openvm_instructions::{ riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS}, }; use openvm_rv32im_circuit::adapters::{ - read_rv32_register, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS, + tracing_read, tracing_write, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS, }; use openvm_stark_backend::{ interaction::InteractionBuilder, p3_air::BaseAir, p3_field::{Field, FieldAlgebra, PrimeField32}, }; -use serde::{Deserialize, Serialize}; -use serde_big_array::BigArray; -use serde_with::serde_as; /// This adapter reads from NUM_READS <= 2 pointers and writes to a register. /// * The data is read from the heap (address space 2), and the pointers are read from registers @@ -47,7 +44,7 @@ use serde_with::serde_as; /// starting from the addresses in `rs[0]` (and `rs[1]` if `R = 2`). /// * Writes are to 32-bit register rd. #[repr(C)] -#[derive(AlignedBorrow)] +#[derive(AlignedBorrow, Debug)] pub struct Rv32IsEqualModAdapterCols< T, const NUM_READS: usize, @@ -227,209 +224,225 @@ impl< } } -pub struct Rv32IsEqualModAdapterChip< - F: Field, +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug)] +pub struct Rv32IsEqualModAdapterRecord< const NUM_READS: usize, const BLOCKS_PER_READ: usize, const BLOCK_SIZE: usize, const TOTAL_READ_SIZE: usize, > { - pub air: Rv32IsEqualModAdapterAir, + pub from_pc: u32, + pub timestamp: u32, + + pub rs_ptr: [u32; NUM_READS], + pub rs_val: [u32; NUM_READS], + pub rs_read_aux: [MemoryReadAuxRecord; NUM_READS], + pub heap_read_aux: [[MemoryReadAuxRecord; BLOCKS_PER_READ]; NUM_READS], + + pub rd_ptr: u32, + pub writes_aux: MemoryWriteBytesAuxRecord, +} + +pub struct Rv32IsEqualModeAdapterStep< + const NUM_READS: usize, + const BLOCKS_PER_READ: usize, + const BLOCK_SIZE: usize, + const TOTAL_READ_SIZE: usize, +> { + pointer_max_bits: usize, pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, - _marker: PhantomData, } impl< - F: PrimeField32, const NUM_READS: usize, const BLOCKS_PER_READ: usize, const BLOCK_SIZE: usize, const TOTAL_READ_SIZE: usize, - > Rv32IsEqualModAdapterChip + > Rv32IsEqualModeAdapterStep { pub fn new( - execution_bus: ExecutionBus, - program_bus: ProgramBus, - memory_bridge: MemoryBridge, - address_bits: usize, + pointer_max_bits: usize, bitwise_lookup_chip: SharedBitwiseOperationLookupChip, ) -> Self { assert!(NUM_READS <= 2); assert_eq!(TOTAL_READ_SIZE, BLOCKS_PER_READ * BLOCK_SIZE); assert!( - RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - address_bits < RV32_CELL_BITS, - "address_bits={address_bits} needs to be large enough for high limb range check" + RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - pointer_max_bits < RV32_CELL_BITS, + "pointer_max_bits={pointer_max_bits} needs to be large enough for high limb range check" ); Self { - air: Rv32IsEqualModAdapterAir { - execution_bridge: ExecutionBridge::new(execution_bus, program_bus), - memory_bridge, - bus: bitwise_lookup_chip.bus(), - address_bits, - }, + pointer_max_bits, bitwise_lookup_chip, - _marker: PhantomData, } } } -#[repr(C)] -#[serde_as] -#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] -pub struct Rv32IsEqualModReadRecord< - const NUM_READS: usize, - const BLOCKS_PER_READ: usize, - const BLOCK_SIZE: usize, -> { - #[serde(with = "BigArray")] - pub rs: [RecordId; NUM_READS], - #[serde_as(as = "[[_; BLOCKS_PER_READ]; NUM_READS]")] - pub reads: [[RecordId; BLOCKS_PER_READ]; NUM_READS], -} - -#[repr(C)] -#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] -pub struct Rv32IsEqualModWriteRecord { - pub from_state: ExecutionState, - pub rd_id: RecordId, -} - impl< F: PrimeField32, + CTX, const NUM_READS: usize, const BLOCKS_PER_READ: usize, const BLOCK_SIZE: usize, const TOTAL_READ_SIZE: usize, - > VmAdapterChip - for Rv32IsEqualModAdapterChip + > AdapterTraceStep + for Rv32IsEqualModeAdapterStep +where + F: PrimeField32, { - type ReadRecord = Rv32IsEqualModReadRecord; - type WriteRecord = Rv32IsEqualModWriteRecord; - type Air = Rv32IsEqualModAdapterAir; - type Interface = BasicAdapterInterface< - F, - MinimalInstruction, + const WIDTH: usize = + Rv32IsEqualModAdapterCols::::width(); + type ReadData = [[u8; TOTAL_READ_SIZE]; NUM_READS]; + type WriteData = [u8; RV32_REGISTER_NUM_LIMBS]; + type RecordMut<'a> = &'a mut Rv32IsEqualModAdapterRecord< NUM_READS, - 1, + BLOCKS_PER_READ, + BLOCK_SIZE, TOTAL_READ_SIZE, - RV32_REGISTER_NUM_LIMBS, >; - fn preprocess( - &mut self, - memory: &mut MemoryController, + fn start(pc: u32, memory: &TracingMemory, record: &mut Self::RecordMut<'_>) { + record.from_pc = pc; + record.timestamp = memory.timestamp; + } + + fn read( + &self, + memory: &mut TracingMemory, instruction: &Instruction, - ) -> Result<( - >::Reads, - Self::ReadRecord, - )> { + record: &mut Self::RecordMut<'_>, + ) -> Self::ReadData { let Instruction { b, c, d, e, .. } = *instruction; debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); debug_assert_eq!(e.as_canonical_u32(), RV32_MEMORY_AS); - let mut rs_vals = [0; NUM_READS]; - let rs_records: [_; NUM_READS] = from_fn(|i| { - let addr = if i == 0 { b } else { c }; - let (record, val) = read_rv32_register(memory, d, addr); - rs_vals[i] = val; - record - }); - - let read_records = rs_vals.map(|address| { - debug_assert!(address < (1 << self.air.address_bits)); - from_fn(|i| { - memory - .read::(e, F::from_canonical_u32(address + (i * BLOCK_SIZE) as u32)) - }) - }); + // Read register values + record.rs_val = from_fn(|i| { + record.rs_ptr[i] = if i == 0 { b } else { c }.as_canonical_u32(); - let read_data = read_records.map(|r| { - let read = r.map(|x| x.1); - let mut read_it = read.iter().flatten(); - from_fn(|_| *(read_it.next().unwrap())) + u32::from_le_bytes(tracing_read( + memory, + RV32_REGISTER_AS, + record.rs_ptr[i], + &mut record.rs_read_aux[i].prev_timestamp, + )) }); - let record = Rv32IsEqualModReadRecord { - rs: rs_records, - reads: read_records.map(|r| r.map(|x| x.0)), - }; - Ok((read_data, record)) + // Read memory values + from_fn(|i| { + assert!(record.rs_val[i] as usize + TOTAL_READ_SIZE - 1 < (1 << self.pointer_max_bits)); + from_fn::<_, BLOCKS_PER_READ, _>(|j| { + tracing_read::<_, BLOCK_SIZE>( + memory, + RV32_MEMORY_AS, + record.rs_val[i] + (j * BLOCK_SIZE) as u32, + &mut record.heap_read_aux[i][j].prev_timestamp, + ) + }) + .concat() + .try_into() + .unwrap() + }) } - fn postprocess( - &mut self, - memory: &mut MemoryController, + fn write( + &self, + memory: &mut TracingMemory, instruction: &Instruction, - from_state: ExecutionState, - output: AdapterRuntimeContext, - _read_record: &Self::ReadRecord, - ) -> Result<(ExecutionState, Self::WriteRecord)> { - let Instruction { a, d, .. } = *instruction; - let (rd_id, _) = memory.write(d, a, output.writes[0]); - - debug_assert!( - memory.timestamp() - from_state.timestamp - == (NUM_READS * (BLOCKS_PER_READ + 1) + 1) as u32, - "timestamp delta is {}, expected {}", - memory.timestamp() - from_state.timestamp, - NUM_READS * (BLOCKS_PER_READ + 1) + 1 + data: Self::WriteData, + record: &mut Self::RecordMut<'_>, + ) { + let Instruction { a, .. } = *instruction; + record.rd_ptr = a.as_canonical_u32(); + tracing_write( + memory, + RV32_REGISTER_AS, + record.rd_ptr, + data, + &mut record.writes_aux.prev_timestamp, + &mut record.writes_aux.prev_data, ); - - Ok(( - ExecutionState { - pc: from_state.pc + DEFAULT_PC_STEP, - timestamp: memory.timestamp(), - }, - Self::WriteRecord { from_state, rd_id }, - )) } +} - fn generate_trace_row( - &self, - row_slice: &mut [F], - read_record: Self::ReadRecord, - write_record: Self::WriteRecord, - memory: &OfflineMemory, - ) { - let aux_cols_factory = memory.aux_cols_factory(); - let row_slice: &mut Rv32IsEqualModAdapterCols = - row_slice.borrow_mut(); - row_slice.from_state = write_record.from_state.map(F::from_canonical_u32); - - let rs = read_record.rs.map(|r| memory.record_by_id(r)); - for (i, r) in rs.iter().enumerate() { - row_slice.rs_ptr[i] = r.pointer; - row_slice.rs_val[i].copy_from_slice(r.data_slice()); - aux_cols_factory.generate_read_aux(r, &mut row_slice.rs_read_aux[i]); - for (j, x) in read_record.reads[i].iter().enumerate() { - let read = memory.record_by_id(*x); - aux_cols_factory.generate_read_aux(read, &mut row_slice.heap_read_aux[i][j]); - } - } - - let rd = memory.record_by_id(write_record.rd_id); - row_slice.rd_ptr = rd.pointer; - aux_cols_factory.generate_write_aux(rd, &mut row_slice.writes_aux); - - // Range checks - let need_range_check: [u32; 2] = from_fn(|i| { - if i < NUM_READS { - rs[i] - .data_at(RV32_REGISTER_NUM_LIMBS - 1) - .as_canonical_u32() +impl< + F: PrimeField32, + CTX, + const NUM_READS: usize, + const BLOCKS_PER_READ: usize, + const BLOCK_SIZE: usize, + const TOTAL_READ_SIZE: usize, + > AdapterTraceFiller + for Rv32IsEqualModeAdapterStep +{ + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, mut adapter_row: &mut [F]) { + let record: &Rv32IsEqualModAdapterRecord< + NUM_READS, + BLOCKS_PER_READ, + BLOCK_SIZE, + TOTAL_READ_SIZE, + > = unsafe { get_record_from_slice(&mut adapter_row, ()) }; + + let cols: &mut Rv32IsEqualModAdapterCols = + adapter_row.borrow_mut(); + + let mut timestamp = record.timestamp + (NUM_READS + NUM_READS * BLOCKS_PER_READ) as u32 + 1; + let mut timestamp_mm = || { + timestamp -= 1; + timestamp + }; + // Do range checks before writing anything: + debug_assert!(self.pointer_max_bits <= RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS); + let limb_shift_bits = RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.pointer_max_bits; + const MSL_SHIFT: usize = RV32_CELL_BITS * (RV32_REGISTER_NUM_LIMBS - 1); + self.bitwise_lookup_chip.request_range( + (record.rs_val[0] >> MSL_SHIFT) << limb_shift_bits, + if NUM_READS > 1 { + (record.rs_val[1] >> MSL_SHIFT) << limb_shift_bits } else { 0 - } - }); - let limb_shift_bits = RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.air.address_bits; - self.bitwise_lookup_chip.request_range( - need_range_check[0] << limb_shift_bits, - need_range_check[1] << limb_shift_bits, + }, ); - } + // Writing in reverse order + cols.writes_aux + .set_prev_data(record.writes_aux.prev_data.map(F::from_canonical_u8)); + mem_helper.fill( + record.writes_aux.prev_timestamp, + timestamp_mm(), + cols.writes_aux.as_mut(), + ); + cols.rd_ptr = F::from_canonical_u32(record.rd_ptr); + + // **NOTE**: Must iterate everything in reverse order to avoid overwriting the records + cols.heap_read_aux + .iter_mut() + .rev() + .zip(record.heap_read_aux.iter().rev()) + .for_each(|(col_reads, record_reads)| { + col_reads + .iter_mut() + .rev() + .zip(record_reads.iter().rev()) + .for_each(|(col, record)| { + mem_helper.fill(record.prev_timestamp, timestamp_mm(), col.as_mut()); + }); + }); + + cols.rs_read_aux + .iter_mut() + .rev() + .zip(record.rs_read_aux.iter().rev()) + .for_each(|(col, record)| { + mem_helper.fill(record.prev_timestamp, timestamp_mm(), col.as_mut()); + }); + + cols.rs_val = record + .rs_val + .map(|val| val.to_le_bytes().map(F::from_canonical_u8)); + cols.rs_ptr = record.rs_ptr.map(|ptr| F::from_canonical_u32(ptr)); - fn air(&self) -> &Self::Air { - &self.air + cols.from_state.timestamp = F::from_canonical_u32(record.timestamp); + cols.from_state.pc = F::from_canonical_u32(record.from_pc); } } diff --git a/extensions/rv32-adapters/src/heap.rs b/extensions/rv32-adapters/src/heap.rs index cd9f93abbc..28f9c812dd 100644 --- a/extensions/rv32-adapters/src/heap.rs +++ b/extensions/rv32-adapters/src/heap.rs @@ -1,38 +1,27 @@ -use std::{ - array::{self, from_fn}, - borrow::Borrow, - marker::PhantomData, -}; +use std::borrow::Borrow; use openvm_circuit::{ arch::{ - AdapterAirContext, AdapterRuntimeContext, BasicAdapterInterface, ExecutionBridge, - ExecutionBus, ExecutionState, MinimalInstruction, Result, VmAdapterAir, VmAdapterChip, - VmAdapterInterface, - }, - system::{ - memory::{offline_checker::MemoryBridge, MemoryController, OfflineMemory}, - program::ProgramBus, + AdapterAirContext, AdapterTraceFiller, AdapterTraceStep, BasicAdapterInterface, + ExecutionBridge, MinimalInstruction, VmAdapterAir, }, + system::memory::{offline_checker::MemoryBridge, online::TracingMemory, MemoryAuxColsFactory}, }; use openvm_circuit_primitives::bitwise_op_lookup::{ BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, }; use openvm_instructions::{ instruction::Instruction, - program::DEFAULT_PC_STEP, - riscv::{RV32_CELL_BITS, RV32_MEMORY_AS, RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS}, + riscv::{RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}, }; -use openvm_rv32im_circuit::adapters::read_rv32_register; use openvm_stark_backend::{ interaction::InteractionBuilder, p3_air::BaseAir, p3_field::{Field, PrimeField32}, }; -use super::{ - vec_heap_generate_trace_row_impl, Rv32VecHeapAdapterAir, Rv32VecHeapAdapterCols, - Rv32VecHeapReadRecord, Rv32VecHeapWriteRecord, +use crate::{ + Rv32VecHeapAdapterAir, Rv32VecHeapAdapterCols, Rv32VecHeapAdapterRecord, Rv32VecHeapAdapterStep, }; /// This adapter reads from NUM_READS <= 2 pointers and writes to 1 pointer. @@ -101,137 +90,82 @@ impl< } } -pub struct Rv32HeapAdapterChip< - F: Field, +pub struct Rv32HeapAdapterStep< const NUM_READS: usize, const READ_SIZE: usize, const WRITE_SIZE: usize, -> { - pub air: Rv32HeapAdapterAir, - pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, - _marker: PhantomData, -} +>(Rv32VecHeapAdapterStep); -impl - Rv32HeapAdapterChip +impl + Rv32HeapAdapterStep { pub fn new( - execution_bus: ExecutionBus, - program_bus: ProgramBus, - memory_bridge: MemoryBridge, - address_bits: usize, + pointer_max_bits: usize, bitwise_lookup_chip: SharedBitwiseOperationLookupChip, ) -> Self { assert!(NUM_READS <= 2); assert!( - RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - address_bits < RV32_CELL_BITS, - "address_bits={address_bits} needs to be large enough for high limb range check" + RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - pointer_max_bits < RV32_CELL_BITS, + "pointer_max_bits={pointer_max_bits} needs to be large enough for high limb range check" ); - Self { - air: Rv32HeapAdapterAir { - execution_bridge: ExecutionBridge::new(execution_bus, program_bus), - memory_bridge, - bus: bitwise_lookup_chip.bus(), - address_bits, - }, + Rv32HeapAdapterStep(Rv32VecHeapAdapterStep::new( + pointer_max_bits, bitwise_lookup_chip, - _marker: PhantomData, - } + )) } } -impl - VmAdapterChip for Rv32HeapAdapterChip +impl< + F: PrimeField32, + CTX, + const NUM_READS: usize, + const READ_SIZE: usize, + const WRITE_SIZE: usize, + > AdapterTraceStep for Rv32HeapAdapterStep +where + F: PrimeField32, { - type ReadRecord = Rv32VecHeapReadRecord; - type WriteRecord = Rv32VecHeapWriteRecord<1, WRITE_SIZE>; - type Air = Rv32HeapAdapterAir; - type Interface = - BasicAdapterInterface, NUM_READS, 1, READ_SIZE, WRITE_SIZE>; - - fn preprocess( - &mut self, - memory: &mut MemoryController, - instruction: &Instruction, - ) -> Result<( - >::Reads, - Self::ReadRecord, - )> { - let Instruction { a, b, c, d, e, .. } = *instruction; - - debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); - debug_assert_eq!(e.as_canonical_u32(), RV32_MEMORY_AS); - - let mut rs_vals = [0; NUM_READS]; - let rs_records: [_; NUM_READS] = from_fn(|i| { - let addr = if i == 0 { b } else { c }; - let (record, val) = read_rv32_register(memory, d, addr); - rs_vals[i] = val; - record - }); - let (rd_record, rd_val) = read_rv32_register(memory, d, a); - - let read_records = rs_vals.map(|address| { - debug_assert!(address as usize + READ_SIZE - 1 < (1 << self.air.address_bits)); - [memory.read::(e, F::from_canonical_u32(address))] - }); - let read_data = read_records.map(|r| r[0].1); - - let record = Rv32VecHeapReadRecord { - rs: rs_records, - rd: rd_record, - rd_val: F::from_canonical_u32(rd_val), - reads: read_records.map(|r| array::from_fn(|i| r[i].0)), - }; - - Ok((read_data, record)) + const WIDTH: usize = + Rv32VecHeapAdapterCols::::width(); + type ReadData = [[u8; READ_SIZE]; NUM_READS]; + type WriteData = [[u8; WRITE_SIZE]; 1]; + type RecordMut<'a> = &'a mut Rv32VecHeapAdapterRecord; + + fn start(pc: u32, memory: &TracingMemory, record: &mut Self::RecordMut<'_>) { + record.from_pc = pc; + record.from_timestamp = memory.timestamp; } - fn postprocess( - &mut self, - memory: &mut MemoryController, + fn read( + &self, + memory: &mut TracingMemory, instruction: &Instruction, - from_state: ExecutionState, - output: AdapterRuntimeContext, - read_record: &Self::ReadRecord, - ) -> Result<(ExecutionState, Self::WriteRecord)> { - let e = instruction.e; - let writes = [memory.write(e, read_record.rd_val, output.writes[0]).0]; - - let timestamp_delta = memory.timestamp() - from_state.timestamp; - debug_assert!( - timestamp_delta == 6, - "timestamp delta is {}, expected 6", - timestamp_delta - ); - - Ok(( - ExecutionState { - pc: from_state.pc + DEFAULT_PC_STEP, - timestamp: memory.timestamp(), - }, - Self::WriteRecord { from_state, writes }, - )) + record: &mut Self::RecordMut<'_>, + ) -> Self::ReadData { + let read_data = AdapterTraceStep::::read(&self.0, memory, instruction, record); + read_data.map(|r| r[0]) } - fn generate_trace_row( + fn write( &self, - row_slice: &mut [F], - read_record: Self::ReadRecord, - write_record: Self::WriteRecord, - memory: &OfflineMemory, + memory: &mut TracingMemory, + instruction: &Instruction, + data: Self::WriteData, + record: &mut Self::RecordMut<'_>, ) { - vec_heap_generate_trace_row_impl( - row_slice, - &read_record, - &write_record, - self.bitwise_lookup_chip.clone(), - self.air.address_bits, - memory, - ); + AdapterTraceStep::::write(&self.0, memory, instruction, data, record); } +} - fn air(&self) -> &Self::Air { - &self.air +impl< + F: PrimeField32, + CTX, + const NUM_READS: usize, + const READ_SIZE: usize, + const WRITE_SIZE: usize, + > AdapterTraceFiller for Rv32HeapAdapterStep +{ + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, adapter_row: &mut [F]) { + AdapterTraceFiller::::fill_trace_row(&self.0, mem_helper, adapter_row); } } diff --git a/extensions/rv32-adapters/src/heap_branch.rs b/extensions/rv32-adapters/src/heap_branch.rs index 29c9a151c9..8f8dab8707 100644 --- a/extensions/rv32-adapters/src/heap_branch.rs +++ b/extensions/rv32-adapters/src/heap_branch.rs @@ -1,27 +1,23 @@ use std::{ array::from_fn, borrow::{Borrow, BorrowMut}, - iter::once, - marker::PhantomData, }; use itertools::izip; use openvm_circuit::{ arch::{ - AdapterAirContext, AdapterRuntimeContext, BasicAdapterInterface, ExecutionBridge, - ExecutionBus, ExecutionState, ImmInstruction, Result, VmAdapterAir, VmAdapterChip, - VmAdapterInterface, + get_record_from_slice, AdapterAirContext, AdapterTraceFiller, AdapterTraceStep, + BasicAdapterInterface, ExecutionBridge, ExecutionState, ImmInstruction, VmAdapterAir, }, - system::{ - memory::{ - offline_checker::{MemoryBridge, MemoryReadAuxCols}, - MemoryAddress, MemoryController, OfflineMemory, RecordId, - }, - program::ProgramBus, + system::memory::{ + offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryReadAuxRecord}, + online::TracingMemory, + MemoryAddress, MemoryAuxColsFactory, }, }; -use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, +use openvm_circuit_primitives::{ + bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, + AlignedBytesBorrow, }; use openvm_circuit_primitives_derive::AlignedBorrow; use openvm_instructions::{ @@ -29,16 +25,12 @@ use openvm_instructions::{ program::DEFAULT_PC_STEP, riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS}, }; -use openvm_rv32im_circuit::adapters::{ - read_rv32_register, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS, -}; +use openvm_rv32im_circuit::adapters::{tracing_read, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}; use openvm_stark_backend::{ interaction::InteractionBuilder, p3_air::BaseAir, p3_field::{Field, FieldAlgebra, PrimeField32}, }; -use serde::{Deserialize, Serialize}; -use serde_big_array::BigArray; /// This adapter reads from NUM_READS <= 2 pointers. /// * The data is read from the heap (address space 2), and the pointers are read from registers @@ -170,158 +162,159 @@ impl VmA } } -pub struct Rv32HeapBranchAdapterChip { - pub air: Rv32HeapBranchAdapterAir, +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug)] +pub struct Rv32HeapBranchAdapterRecord { + pub from_pc: u32, + pub from_timestamp: u32, + + pub rs_ptr: [u32; NUM_READS], + pub rs_vals: [u32; NUM_READS], + + pub rs_read_aux: [MemoryReadAuxRecord; NUM_READS], + pub heap_read_aux: [MemoryReadAuxRecord; NUM_READS], +} + +pub struct Rv32HeapBranchAdapterStep { + pub pointer_max_bits: usize, + // TODO(arayi): use reference to bitwise lookup chip with lifetimes instead pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, - _marker: PhantomData, } -impl - Rv32HeapBranchAdapterChip +impl + Rv32HeapBranchAdapterStep { pub fn new( - execution_bus: ExecutionBus, - program_bus: ProgramBus, - memory_bridge: MemoryBridge, - address_bits: usize, + pointer_max_bits: usize, bitwise_lookup_chip: SharedBitwiseOperationLookupChip, ) -> Self { assert!(NUM_READS <= 2); assert!( - RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - address_bits < RV32_CELL_BITS, - "address_bits={address_bits} needs to be large enough for high limb range check" + RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - pointer_max_bits < RV32_CELL_BITS, + "pointer_max_bits={pointer_max_bits} needs to be large enough for high limb range check" ); Self { - air: Rv32HeapBranchAdapterAir { - execution_bridge: ExecutionBridge::new(execution_bus, program_bus), - memory_bridge, - bus: bitwise_lookup_chip.bus(), - address_bits, - }, + pointer_max_bits, bitwise_lookup_chip, - _marker: PhantomData, } } } -#[repr(C)] -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct Rv32HeapBranchReadRecord { - #[serde(with = "BigArray")] - pub rs_reads: [RecordId; NUM_READS], - #[serde(with = "BigArray")] - pub heap_reads: [RecordId; NUM_READS], -} - -impl VmAdapterChip - for Rv32HeapBranchAdapterChip +impl AdapterTraceStep + for Rv32HeapBranchAdapterStep { - type ReadRecord = Rv32HeapBranchReadRecord; - type WriteRecord = ExecutionState; - type Air = Rv32HeapBranchAdapterAir; - type Interface = BasicAdapterInterface, NUM_READS, 0, READ_SIZE, 0>; - - fn preprocess( - &mut self, - memory: &mut MemoryController, + const WIDTH: usize = Rv32HeapBranchAdapterCols::::width(); + type ReadData = [[u8; READ_SIZE]; NUM_READS]; + type WriteData = (); + type RecordMut<'a> = &'a mut Rv32HeapBranchAdapterRecord; + + fn start(pc: u32, memory: &TracingMemory, adapter_record: &mut Self::RecordMut<'_>) { + adapter_record.from_pc = pc; + adapter_record.from_timestamp = memory.timestamp; + } + + fn read( + &self, + memory: &mut TracingMemory, instruction: &Instruction, - ) -> Result<( - >::Reads, - Self::ReadRecord, - )> { + record: &mut Self::RecordMut<'_>, + ) -> Self::ReadData { let Instruction { a, b, d, e, .. } = *instruction; debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); debug_assert_eq!(e.as_canonical_u32(), RV32_MEMORY_AS); - let mut rs_vals = [0; NUM_READS]; - let rs_records: [_; NUM_READS] = from_fn(|i| { - let addr = if i == 0 { a } else { b }; - let (record, val) = read_rv32_register(memory, d, addr); - rs_vals[i] = val; - record + // Read register values + record.rs_vals = from_fn(|i| { + record.rs_ptr[i] = if i == 0 { a } else { b }.as_canonical_u32(); + u32::from_le_bytes(tracing_read( + memory, + RV32_REGISTER_AS, + record.rs_ptr[i], + &mut record.rs_read_aux[i].prev_timestamp, + )) }); - let heap_records = rs_vals.map(|address| { - assert!(address as usize + READ_SIZE - 1 < (1 << self.air.address_bits)); - memory.read::(e, F::from_canonical_u32(address)) - }); - - let record = Rv32HeapBranchReadRecord { - rs_reads: rs_records, - heap_reads: heap_records.map(|r| r.0), - }; - Ok((heap_records.map(|r| r.1), record)) - } - - fn postprocess( - &mut self, - memory: &mut MemoryController, - _instruction: &Instruction, - from_state: ExecutionState, - output: AdapterRuntimeContext, - _read_record: &Self::ReadRecord, - ) -> Result<(ExecutionState, Self::WriteRecord)> { - let timestamp_delta = memory.timestamp() - from_state.timestamp; - debug_assert!( - timestamp_delta == 4, - "timestamp delta is {}, expected 4", - timestamp_delta - ); - - Ok(( - ExecutionState { - pc: output.to_pc.unwrap_or(from_state.pc + DEFAULT_PC_STEP), - timestamp: memory.timestamp(), - }, - from_state, - )) + // Read memory values + from_fn(|i| { + assert!(record.rs_vals[i] as usize + READ_SIZE - 1 < (1 << self.pointer_max_bits)); + tracing_read( + memory, + RV32_MEMORY_AS, + record.rs_vals[i], + &mut record.heap_read_aux[i].prev_timestamp, + ) + }) } - fn generate_trace_row( + fn write( &self, - row_slice: &mut [F], - read_record: Self::ReadRecord, - write_record: Self::WriteRecord, - memory: &OfflineMemory, + _memory: &mut TracingMemory, + _instruction: &Instruction, + _data: Self::WriteData, + _record: &mut Self::RecordMut<'_>, ) { - let aux_cols_factory = memory.aux_cols_factory(); - let row_slice: &mut Rv32HeapBranchAdapterCols<_, NUM_READS, READ_SIZE> = - row_slice.borrow_mut(); - row_slice.from_state = write_record.map(F::from_canonical_u32); + // This adapter doesn't write anything + } +} - let rs_reads = read_record.rs_reads.map(|r| memory.record_by_id(r)); +impl + AdapterTraceFiller for Rv32HeapBranchAdapterStep +{ + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, mut adapter_row: &mut [F]) { + let record: &Rv32HeapBranchAdapterRecord = + unsafe { get_record_from_slice(&mut adapter_row, ()) }; + let cols: &mut Rv32HeapBranchAdapterCols = + adapter_row.borrow_mut(); - for (i, rs_read) in rs_reads.iter().enumerate() { - row_slice.rs_ptr[i] = rs_read.pointer; - row_slice.rs_val[i].copy_from_slice(rs_read.data_slice()); - aux_cols_factory.generate_read_aux(rs_read, &mut row_slice.rs_read_aux[i]); - } + // Range checks: + // **NOTE**: Must do the range checks before overwriting the records + debug_assert!(self.pointer_max_bits <= RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS); + let limb_shift_bits = RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.pointer_max_bits; + const MSL_SHIFT: usize = RV32_CELL_BITS * (RV32_REGISTER_NUM_LIMBS - 1); + self.bitwise_lookup_chip.request_range( + (record.rs_vals[0] >> MSL_SHIFT) << limb_shift_bits, + if NUM_READS > 1 { + (record.rs_vals[1] >> MSL_SHIFT) << limb_shift_bits + } else { + 0 + }, + ); - for (i, heap_read) in read_record.heap_reads.iter().enumerate() { - let record = memory.record_by_id(*heap_read); - aux_cols_factory.generate_read_aux(record, &mut row_slice.heap_read_aux[i]); + // **NOTE**: Must iterate everything in reverse order to avoid overwriting the records + for i in (0..NUM_READS).rev() { + mem_helper.fill( + record.heap_read_aux[i].prev_timestamp, + record.from_timestamp + (i + NUM_READS) as u32, + cols.heap_read_aux[i].as_mut(), + ); } - // Range checks: - let need_range_check: Vec = rs_reads - .iter() - .map(|record| { - record - .data_at(RV32_REGISTER_NUM_LIMBS - 1) - .as_canonical_u32() - }) - .chain(once(0)) // in case NUM_READS is odd - .collect(); - debug_assert!(self.air.address_bits <= RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS); - let limb_shift_bits = RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.air.address_bits; - for pair in need_range_check.chunks_exact(2) { - self.bitwise_lookup_chip - .request_range(pair[0] << limb_shift_bits, pair[1] << limb_shift_bits); + for i in (0..NUM_READS).rev() { + mem_helper.fill( + record.rs_read_aux[i].prev_timestamp, + record.from_timestamp + i as u32, + cols.rs_read_aux[i].as_mut(), + ); } - } - fn air(&self) -> &Self::Air { - &self.air + cols.rs_val + .iter_mut() + .rev() + .zip(record.rs_vals.iter().rev()) + .for_each(|(col, record)| { + *col = record.to_le_bytes().map(F::from_canonical_u8); + }); + + cols.rs_ptr + .iter_mut() + .rev() + .zip(record.rs_ptr.iter().rev()) + .for_each(|(col, record)| { + *col = F::from_canonical_u32(*record); + }); + + cols.from_state.timestamp = F::from_canonical_u32(record.from_timestamp); + cols.from_state.pc = F::from_canonical_u32(record.from_pc); } } diff --git a/extensions/rv32-adapters/src/vec_heap.rs b/extensions/rv32-adapters/src/vec_heap.rs index fab0df3334..79edf0f2eb 100644 --- a/extensions/rv32-adapters/src/vec_heap.rs +++ b/extensions/rv32-adapters/src/vec_heap.rs @@ -2,25 +2,26 @@ use std::{ array::from_fn, borrow::{Borrow, BorrowMut}, iter::{once, zip}, - marker::PhantomData, }; use itertools::izip; use openvm_circuit::{ arch::{ - AdapterAirContext, AdapterRuntimeContext, ExecutionBridge, ExecutionBus, ExecutionState, - Result, VecHeapAdapterInterface, VmAdapterAir, VmAdapterChip, VmAdapterInterface, + get_record_from_slice, AdapterAirContext, AdapterTraceFiller, AdapterTraceStep, + ExecutionBridge, ExecutionState, VecHeapAdapterInterface, VmAdapterAir, }, - system::{ - memory::{ - offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols}, - MemoryAddress, MemoryController, OfflineMemory, RecordId, + system::memory::{ + offline_checker::{ + MemoryBridge, MemoryReadAuxCols, MemoryReadAuxRecord, MemoryWriteAuxCols, + MemoryWriteBytesAuxRecord, }, - program::ProgramBus, + online::TracingMemory, + MemoryAddress, MemoryAuxColsFactory, }, }; -use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, +use openvm_circuit_primitives::{ + bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, + AlignedBytesBorrow, }; use openvm_circuit_primitives_derive::AlignedBorrow; use openvm_instructions::{ @@ -29,15 +30,13 @@ use openvm_instructions::{ riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS}, }; use openvm_rv32im_circuit::adapters::{ - abstract_compose, read_rv32_register, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS, + abstract_compose, tracing_read, tracing_write, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS, }; use openvm_stark_backend::{ interaction::InteractionBuilder, p3_air::BaseAir, p3_field::{Field, FieldAlgebra, PrimeField32}, }; -use serde::{Deserialize, Serialize}; -use serde_with::serde_as; /// This adapter reads from R (R <= 2) pointers and writes to 1 pointer. /// * The data is read from the heap (address space 2), and the pointers are read from registers @@ -46,89 +45,8 @@ use serde_with::serde_as; /// starting from the addresses in `rs[0]` (and `rs[1]` if `R = 2`). /// * Writes take the form of `BLOCKS_PER_WRITE` consecutive writes of size `WRITE_SIZE` to the /// heap, starting from the address in `rd`. -#[derive(Clone)] -pub struct Rv32VecHeapAdapterChip< - F: Field, - const NUM_READS: usize, - const BLOCKS_PER_READ: usize, - const BLOCKS_PER_WRITE: usize, - const READ_SIZE: usize, - const WRITE_SIZE: usize, -> { - pub air: - Rv32VecHeapAdapterAir, - pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, - _marker: PhantomData, -} - -impl< - F: PrimeField32, - const NUM_READS: usize, - const BLOCKS_PER_READ: usize, - const BLOCKS_PER_WRITE: usize, - const READ_SIZE: usize, - const WRITE_SIZE: usize, - > - Rv32VecHeapAdapterChip -{ - pub fn new( - execution_bus: ExecutionBus, - program_bus: ProgramBus, - memory_bridge: MemoryBridge, - address_bits: usize, - bitwise_lookup_chip: SharedBitwiseOperationLookupChip, - ) -> Self { - assert!(NUM_READS <= 2); - assert!( - RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - address_bits < RV32_CELL_BITS, - "address_bits={address_bits} needs to be large enough for high limb range check" - ); - Self { - air: Rv32VecHeapAdapterAir { - execution_bridge: ExecutionBridge::new(execution_bus, program_bus), - memory_bridge, - bus: bitwise_lookup_chip.bus(), - address_bits, - }, - bitwise_lookup_chip, - _marker: PhantomData, - } - } -} - -#[repr(C)] -#[serde_as] -#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] -#[serde(bound = "F: Field")] -pub struct Rv32VecHeapReadRecord< - F: Field, - const NUM_READS: usize, - const BLOCKS_PER_READ: usize, - const READ_SIZE: usize, -> { - /// Read register value from address space e=1 - #[serde_as(as = "[_; NUM_READS]")] - pub rs: [RecordId; NUM_READS], - /// Read register value from address space d=1 - pub rd: RecordId, - - pub rd_val: F, - - #[serde_as(as = "[[_; BLOCKS_PER_READ]; NUM_READS]")] - pub reads: [[RecordId; BLOCKS_PER_READ]; NUM_READS], -} - -#[repr(C)] -#[serde_as] -#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] -pub struct Rv32VecHeapWriteRecord { - pub from_state: ExecutionState, - #[serde_as(as = "[_; BLOCKS_PER_WRITE]")] - pub writes: [RecordId; BLOCKS_PER_WRITE], -} - #[repr(C)] -#[derive(AlignedBorrow)] +#[derive(AlignedBorrow, Debug)] pub struct Rv32VecHeapAdapterCols< T, const NUM_READS: usize, @@ -346,29 +264,67 @@ impl< } } +// Intermediate type that should not be copied or cloned and should be directly written to +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug)] +pub struct Rv32VecHeapAdapterRecord< + const NUM_READS: usize, + const BLOCKS_PER_READ: usize, + const BLOCKS_PER_WRITE: usize, + const READ_SIZE: usize, + const WRITE_SIZE: usize, +> { + pub from_pc: u32, + pub from_timestamp: u32, + + pub rs_ptrs: [u32; NUM_READS], + pub rd_ptr: u32, + + pub rs_vals: [u32; NUM_READS], + pub rd_val: u32, + + pub rs_read_aux: [MemoryReadAuxRecord; NUM_READS], + pub rd_read_aux: MemoryReadAuxRecord, + + pub reads_aux: [[MemoryReadAuxRecord; BLOCKS_PER_READ]; NUM_READS], + pub writes_aux: [MemoryWriteBytesAuxRecord; BLOCKS_PER_WRITE], +} + +#[derive(derive_new::new)] +pub struct Rv32VecHeapAdapterStep< + const NUM_READS: usize, + const BLOCKS_PER_READ: usize, + const BLOCKS_PER_WRITE: usize, + const READ_SIZE: usize, + const WRITE_SIZE: usize, +> { + pointer_max_bits: usize, + // TODO(arayi): use reference to bitwise lookup chip with lifetimes instead + pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, +} + impl< F: PrimeField32, + CTX, const NUM_READS: usize, const BLOCKS_PER_READ: usize, const BLOCKS_PER_WRITE: usize, const READ_SIZE: usize, const WRITE_SIZE: usize, - > VmAdapterChip - for Rv32VecHeapAdapterChip< + > AdapterTraceStep + for Rv32VecHeapAdapterStep +{ + const WIDTH: usize = Rv32VecHeapAdapterCols::< F, NUM_READS, BLOCKS_PER_READ, BLOCKS_PER_WRITE, READ_SIZE, WRITE_SIZE, - > -{ - type ReadRecord = Rv32VecHeapReadRecord; - type WriteRecord = Rv32VecHeapWriteRecord; - type Air = - Rv32VecHeapAdapterAir; - type Interface = VecHeapAdapterInterface< - F, + >::width(); + type ReadData = [[[u8; READ_SIZE]; BLOCKS_PER_READ]; NUM_READS]; + type WriteData = [[u8; WRITE_SIZE]; BLOCKS_PER_WRITE]; + type RecordMut<'a> = &'a mut Rv32VecHeapAdapterRecord< NUM_READS, BLOCKS_PER_READ, BLOCKS_PER_WRITE, @@ -376,171 +332,214 @@ impl< WRITE_SIZE, >; - fn preprocess( - &mut self, - memory: &mut MemoryController, + #[inline(always)] + fn start(pc: u32, memory: &TracingMemory, record: &mut Self::RecordMut<'_>) { + record.from_pc = pc; + record.from_timestamp = memory.timestamp; + } + + fn read( + &self, + memory: &mut TracingMemory, instruction: &Instruction, - ) -> Result<( - >::Reads, - Self::ReadRecord, - )> { - let Instruction { a, b, c, d, e, .. } = *instruction; + record: &mut &mut Rv32VecHeapAdapterRecord< + NUM_READS, + BLOCKS_PER_READ, + BLOCKS_PER_WRITE, + READ_SIZE, + WRITE_SIZE, + >, + ) -> Self::ReadData { + let &Instruction { a, b, c, d, e, .. } = instruction; debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); debug_assert_eq!(e.as_canonical_u32(), RV32_MEMORY_AS); // Read register values - let mut rs_vals = [0; NUM_READS]; - let rs_records: [_; NUM_READS] = from_fn(|i| { - let addr = if i == 0 { b } else { c }; - let (record, val) = read_rv32_register(memory, d, addr); - rs_vals[i] = val; - record + record.rs_vals = from_fn(|i| { + record.rs_ptrs[i] = if i == 0 { b } else { c }.as_canonical_u32(); + u32::from_le_bytes(tracing_read( + memory, + RV32_REGISTER_AS, + record.rs_ptrs[i], + &mut record.rs_read_aux[i].prev_timestamp, + )) }); - let (rd_record, rd_val) = read_rv32_register(memory, d, a); + + record.rd_ptr = a.as_canonical_u32(); + record.rd_val = u32::from_le_bytes(tracing_read( + memory, + RV32_REGISTER_AS, + a.as_canonical_u32(), + &mut record.rd_read_aux.prev_timestamp, + )); // Read memory values - let read_records = rs_vals.map(|address| { + from_fn(|i| { assert!( - address as usize + READ_SIZE * BLOCKS_PER_READ - 1 < (1 << self.air.address_bits) + (record.rs_vals[i] + (READ_SIZE * BLOCKS_PER_READ - 1) as u32) + < (1 << self.pointer_max_bits) as u32 ); - from_fn(|i| { - memory.read::(e, F::from_canonical_u32(address + (i * READ_SIZE) as u32)) + from_fn(|j| { + tracing_read( + memory, + RV32_MEMORY_AS, + record.rs_vals[i] + (j * READ_SIZE) as u32, + &mut record.reads_aux[i][j].prev_timestamp, + ) }) - }); - let read_data = read_records.map(|r| r.map(|x| x.1)); - assert!(rd_val as usize + WRITE_SIZE * BLOCKS_PER_WRITE - 1 < (1 << self.air.address_bits)); - - let record = Rv32VecHeapReadRecord { - rs: rs_records, - rd: rd_record, - rd_val: F::from_canonical_u32(rd_val), - reads: read_records.map(|r| r.map(|x| x.0)), - }; - - Ok((read_data, record)) - } - - fn postprocess( - &mut self, - memory: &mut MemoryController, - instruction: &Instruction, - from_state: ExecutionState, - output: AdapterRuntimeContext, - read_record: &Self::ReadRecord, - ) -> Result<(ExecutionState, Self::WriteRecord)> { - let e = instruction.e; - let mut i = 0; - let writes = output.writes.map(|write| { - let (record_id, _) = memory.write( - e, - read_record.rd_val + F::from_canonical_u32((i * WRITE_SIZE) as u32), - write, - ); - i += 1; - record_id - }); - - Ok(( - ExecutionState { - pc: from_state.pc + DEFAULT_PC_STEP, - timestamp: memory.timestamp(), - }, - Self::WriteRecord { from_state, writes }, - )) + }) } - fn generate_trace_row( + fn write( &self, - row_slice: &mut [F], - read_record: Self::ReadRecord, - write_record: Self::WriteRecord, - memory: &OfflineMemory, + memory: &mut TracingMemory, + instruction: &Instruction, + data: Self::WriteData, + record: &mut &mut Rv32VecHeapAdapterRecord< + NUM_READS, + BLOCKS_PER_READ, + BLOCKS_PER_WRITE, + READ_SIZE, + WRITE_SIZE, + >, ) { - vec_heap_generate_trace_row_impl( - row_slice, - &read_record, - &write_record, - self.bitwise_lookup_chip.clone(), - self.air.address_bits, - memory, - ) - } + debug_assert_eq!(instruction.e.as_canonical_u32(), RV32_MEMORY_AS); - fn air(&self) -> &Self::Air { - &self.air - } -} + assert!( + record.rd_val as usize + WRITE_SIZE * BLOCKS_PER_WRITE - 1 + < (1 << self.pointer_max_bits) + ); -pub(super) fn vec_heap_generate_trace_row_impl< - F: PrimeField32, - const NUM_READS: usize, - const BLOCKS_PER_READ: usize, - const BLOCKS_PER_WRITE: usize, - const READ_SIZE: usize, - const WRITE_SIZE: usize, ->( - row_slice: &mut [F], - read_record: &Rv32VecHeapReadRecord, - write_record: &Rv32VecHeapWriteRecord, - bitwise_lookup_chip: SharedBitwiseOperationLookupChip, - address_bits: usize, - memory: &OfflineMemory, -) { - let aux_cols_factory = memory.aux_cols_factory(); - let row_slice: &mut Rv32VecHeapAdapterCols< - F, - NUM_READS, - BLOCKS_PER_READ, - BLOCKS_PER_WRITE, - READ_SIZE, - WRITE_SIZE, - > = row_slice.borrow_mut(); - row_slice.from_state = write_record.from_state.map(F::from_canonical_u32); - - let rd = memory.record_by_id(read_record.rd); - let rs = read_record - .rs - .into_iter() - .map(|r| memory.record_by_id(r)) - .collect::>(); - - row_slice.rd_ptr = rd.pointer; - row_slice.rd_val.copy_from_slice(rd.data_slice()); - - for (i, r) in rs.iter().enumerate() { - row_slice.rs_ptr[i] = r.pointer; - row_slice.rs_val[i].copy_from_slice(r.data_slice()); - aux_cols_factory.generate_read_aux(r, &mut row_slice.rs_read_aux[i]); + #[allow(clippy::needless_range_loop)] + for i in 0..BLOCKS_PER_WRITE { + tracing_write( + memory, + RV32_MEMORY_AS, + record.rd_val + (i * WRITE_SIZE) as u32, + data[i], + &mut record.writes_aux[i].prev_timestamp, + &mut record.writes_aux[i].prev_data, + ); + } } +} - aux_cols_factory.generate_read_aux(rd, &mut row_slice.rd_read_aux); +impl< + F: PrimeField32, + CTX, + const NUM_READS: usize, + const BLOCKS_PER_READ: usize, + const BLOCKS_PER_WRITE: usize, + const READ_SIZE: usize, + const WRITE_SIZE: usize, + > AdapterTraceFiller + for Rv32VecHeapAdapterStep +{ + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, mut adapter_row: &mut [F]) { + let record: &Rv32VecHeapAdapterRecord< + NUM_READS, + BLOCKS_PER_READ, + BLOCKS_PER_WRITE, + READ_SIZE, + WRITE_SIZE, + > = unsafe { get_record_from_slice(&mut adapter_row, ()) }; - for (i, reads) in read_record.reads.iter().enumerate() { - for (j, &x) in reads.iter().enumerate() { - let record = memory.record_by_id(x); - aux_cols_factory.generate_read_aux(record, &mut row_slice.reads_aux[i][j]); + let cols: &mut Rv32VecHeapAdapterCols< + F, + NUM_READS, + BLOCKS_PER_READ, + BLOCKS_PER_WRITE, + READ_SIZE, + WRITE_SIZE, + > = adapter_row.borrow_mut(); + + // Range checks: + // **NOTE**: Must do the range checks before overwriting the records + debug_assert!(self.pointer_max_bits <= RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS); + let limb_shift_bits = RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.pointer_max_bits; + const MSL_SHIFT: usize = RV32_CELL_BITS * (RV32_REGISTER_NUM_LIMBS - 1); + if NUM_READS > 1 { + self.bitwise_lookup_chip.request_range( + (record.rs_vals[0] >> MSL_SHIFT) << limb_shift_bits, + (record.rs_vals[1] >> MSL_SHIFT) << limb_shift_bits, + ); + self.bitwise_lookup_chip.request_range( + (record.rd_val >> MSL_SHIFT) << limb_shift_bits, + (record.rd_val >> MSL_SHIFT) << limb_shift_bits, + ); + } else { + self.bitwise_lookup_chip.request_range( + (record.rs_vals[0] >> MSL_SHIFT) << limb_shift_bits, + (record.rd_val >> MSL_SHIFT) << limb_shift_bits, + ); } - } - for (i, &w) in write_record.writes.iter().enumerate() { - let record = memory.record_by_id(w); - aux_cols_factory.generate_write_aux(record, &mut row_slice.writes_aux[i]); - } + let timestamp_delta = NUM_READS + 1 + NUM_READS * BLOCKS_PER_READ + BLOCKS_PER_WRITE; + let mut timestamp = record.from_timestamp + timestamp_delta as u32; + let mut timestamp_mm = || { + timestamp -= 1; + timestamp + }; - // Range checks: - let need_range_check: Vec = rs - .iter() - .chain(std::iter::repeat_n(&rd, 2)) - .map(|record| { - record - .data_at(RV32_REGISTER_NUM_LIMBS - 1) - .as_canonical_u32() - }) - .collect(); - debug_assert!(address_bits <= RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS); - let limb_shift_bits = RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - address_bits; - for pair in need_range_check.chunks_exact(2) { - bitwise_lookup_chip.request_range(pair[0] << limb_shift_bits, pair[1] << limb_shift_bits); + // **NOTE**: Must iterate everything in reverse order to avoid overwriting the records + record + .writes_aux + .iter() + .rev() + .zip(cols.writes_aux.iter_mut().rev()) + .for_each(|(write, cols_write)| { + cols_write.set_prev_data(write.prev_data.map(F::from_canonical_u8)); + mem_helper.fill(write.prev_timestamp, timestamp_mm(), cols_write.as_mut()); + }); + + record + .reads_aux + .iter() + .zip(cols.reads_aux.iter_mut()) + .rev() + .for_each(|(reads, cols_reads)| { + reads + .iter() + .zip(cols_reads.iter_mut()) + .rev() + .for_each(|(read, cols_read)| { + mem_helper.fill(read.prev_timestamp, timestamp_mm(), cols_read.as_mut()); + }); + }); + + mem_helper.fill( + record.rd_read_aux.prev_timestamp, + timestamp_mm(), + cols.rd_read_aux.as_mut(), + ); + + record + .rs_read_aux + .iter() + .zip(cols.rs_read_aux.iter_mut()) + .rev() + .for_each(|(aux, cols_aux)| { + mem_helper.fill(aux.prev_timestamp, timestamp_mm(), cols_aux.as_mut()); + }); + + cols.rd_val = record.rd_val.to_le_bytes().map(F::from_canonical_u8); + cols.rs_val + .iter_mut() + .rev() + .zip(record.rs_vals.iter().rev()) + .for_each(|(cols_val, val)| { + *cols_val = val.to_le_bytes().map(F::from_canonical_u8); + }); + cols.rd_ptr = F::from_canonical_u32(record.rd_ptr); + cols.rs_ptr + .iter_mut() + .rev() + .zip(record.rs_ptrs.iter().rev()) + .for_each(|(cols_ptr, ptr)| { + *cols_ptr = F::from_canonical_u32(*ptr); + }); + cols.from_state.timestamp = F::from_canonical_u32(record.from_timestamp); + cols.from_state.pc = F::from_canonical_u32(record.from_pc); } } diff --git a/extensions/rv32-adapters/src/vec_heap_two_reads.rs b/extensions/rv32-adapters/src/vec_heap_two_reads.rs index f829db8bbc..bee6cf1a13 100644 --- a/extensions/rv32-adapters/src/vec_heap_two_reads.rs +++ b/extensions/rv32-adapters/src/vec_heap_two_reads.rs @@ -2,25 +2,26 @@ use std::{ array::from_fn, borrow::{Borrow, BorrowMut}, iter::zip, - marker::PhantomData, }; use itertools::izip; use openvm_circuit::{ arch::{ - AdapterAirContext, AdapterRuntimeContext, ExecutionBridge, ExecutionBus, ExecutionState, - Result, VecHeapTwoReadsAdapterInterface, VmAdapterAir, VmAdapterChip, VmAdapterInterface, + get_record_from_slice, AdapterAirContext, AdapterTraceFiller, AdapterTraceStep, + ExecutionBridge, ExecutionState, VecHeapTwoReadsAdapterInterface, VmAdapterAir, }, - system::{ - memory::{ - offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols}, - MemoryAddress, MemoryController, OfflineMemory, RecordId, + system::memory::{ + offline_checker::{ + MemoryBridge, MemoryReadAuxCols, MemoryReadAuxRecord, MemoryWriteAuxCols, + MemoryWriteBytesAuxRecord, }, - program::ProgramBus, + online::TracingMemory, + MemoryAddress, MemoryAuxColsFactory, }, }; -use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, +use openvm_circuit_primitives::{ + bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, + AlignedBytesBorrow, }; use openvm_circuit_primitives_derive::AlignedBorrow; use openvm_instructions::{ @@ -29,15 +30,13 @@ use openvm_instructions::{ riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS}, }; use openvm_rv32im_circuit::adapters::{ - abstract_compose, read_rv32_register, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS, + abstract_compose, tracing_read, tracing_write, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS, }; use openvm_stark_backend::{ interaction::InteractionBuilder, p3_air::BaseAir, p3_field::{Field, FieldAlgebra, PrimeField32}, }; -use serde::{Deserialize, Serialize}; -use serde_with::serde_as; /// This adapter reads from 2 pointers and writes to 1 pointer. /// * The data is read from the heap (address space 2), and the pointers are read from registers @@ -47,99 +46,6 @@ use serde_with::serde_as; /// * NOTE that the two reads can read different numbers of blocks. /// * Writes take the form of `BLOCKS_PER_WRITE` consecutive writes of size `WRITE_SIZE` to the /// heap, starting from the address in `rd`. -pub struct Rv32VecHeapTwoReadsAdapterChip< - F: Field, - const BLOCKS_PER_READ1: usize, - const BLOCKS_PER_READ2: usize, - const BLOCKS_PER_WRITE: usize, - const READ_SIZE: usize, - const WRITE_SIZE: usize, -> { - pub air: Rv32VecHeapTwoReadsAdapterAir< - BLOCKS_PER_READ1, - BLOCKS_PER_READ2, - BLOCKS_PER_WRITE, - READ_SIZE, - WRITE_SIZE, - >, - pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, - _marker: PhantomData, -} - -impl< - F: PrimeField32, - const BLOCKS_PER_READ1: usize, - const BLOCKS_PER_READ2: usize, - const BLOCKS_PER_WRITE: usize, - const READ_SIZE: usize, - const WRITE_SIZE: usize, - > - Rv32VecHeapTwoReadsAdapterChip< - F, - BLOCKS_PER_READ1, - BLOCKS_PER_READ2, - BLOCKS_PER_WRITE, - READ_SIZE, - WRITE_SIZE, - > -{ - pub fn new( - execution_bus: ExecutionBus, - program_bus: ProgramBus, - memory_bridge: MemoryBridge, - address_bits: usize, - bitwise_lookup_chip: SharedBitwiseOperationLookupChip, - ) -> Self { - assert!( - RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - address_bits < RV32_CELL_BITS, - "address_bits={address_bits} needs to be large enough for high limb range check" - ); - Self { - air: Rv32VecHeapTwoReadsAdapterAir { - execution_bridge: ExecutionBridge::new(execution_bus, program_bus), - memory_bridge, - bus: bitwise_lookup_chip.bus(), - address_bits, - }, - bitwise_lookup_chip, - _marker: PhantomData, - } - } -} - -#[repr(C)] -#[serde_as] -#[derive(Clone, Debug, Serialize, Deserialize)] -#[serde(bound = "F: Field")] -pub struct Rv32VecHeapTwoReadsReadRecord< - F: Field, - const BLOCKS_PER_READ1: usize, - const BLOCKS_PER_READ2: usize, - const READ_SIZE: usize, -> { - /// Read register value from address space e=1 - pub rs1: RecordId, - pub rs2: RecordId, - /// Read register value from address space d=1 - pub rd: RecordId, - - pub rd_val: F, - - #[serde_as(as = "[_; BLOCKS_PER_READ1]")] - pub reads1: [RecordId; BLOCKS_PER_READ1], - #[serde_as(as = "[_; BLOCKS_PER_READ2]")] - pub reads2: [RecordId; BLOCKS_PER_READ2], -} - -#[repr(C)] -#[serde_as] -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct Rv32VecHeapTwoReadsWriteRecord { - pub from_state: ExecutionState, - #[serde_as(as = "[_; BLOCKS_PER_WRITE]")] - pub writes: [RecordId; BLOCKS_PER_WRITE], -} - #[repr(C)] #[derive(AlignedBorrow)] pub struct Rv32VecHeapTwoReadsAdapterCols< @@ -372,16 +278,53 @@ impl< } } +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug)] +pub struct Rv32VecHeapTwoReadsAdapterRecord< + const BLOCKS_PER_READ1: usize, + const BLOCKS_PER_READ2: usize, + const BLOCKS_PER_WRITE: usize, + const WRITE_SIZE: usize, +> { + pub from_pc: u32, + pub from_timestamp: u32, + + pub rs1_ptr: u32, + pub rs2_ptr: u32, + pub rd_ptr: u32, + + pub rs1_val: u32, + pub rs2_val: u32, + pub rd_val: u32, + + pub rs1_read_aux: MemoryReadAuxRecord, + pub rs2_read_aux: MemoryReadAuxRecord, + pub rd_read_aux: MemoryReadAuxRecord, + + pub reads1_aux: [MemoryReadAuxRecord; BLOCKS_PER_READ1], + pub reads2_aux: [MemoryReadAuxRecord; BLOCKS_PER_READ2], + pub writes_aux: [MemoryWriteBytesAuxRecord; BLOCKS_PER_WRITE], +} + +pub struct Rv32VecHeapTwoReadsAdapterStep< + const BLOCKS_PER_READ1: usize, + const BLOCKS_PER_READ2: usize, + const BLOCKS_PER_WRITE: usize, + const READ_SIZE: usize, + const WRITE_SIZE: usize, +> { + pointer_max_bits: usize, + pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, +} + impl< - F: PrimeField32, const BLOCKS_PER_READ1: usize, const BLOCKS_PER_READ2: usize, const BLOCKS_PER_WRITE: usize, const READ_SIZE: usize, const WRITE_SIZE: usize, - > VmAdapterChip - for Rv32VecHeapTwoReadsAdapterChip< - F, + > + Rv32VecHeapTwoReadsAdapterStep< BLOCKS_PER_READ1, BLOCKS_PER_READ2, BLOCKS_PER_WRITE, @@ -389,189 +332,256 @@ impl< WRITE_SIZE, > { - type ReadRecord = - Rv32VecHeapTwoReadsReadRecord; - type WriteRecord = Rv32VecHeapTwoReadsWriteRecord; - type Air = Rv32VecHeapTwoReadsAdapterAir< + pub fn new( + pointer_max_bits: usize, + bitwise_lookup_chip: SharedBitwiseOperationLookupChip, + ) -> Self { + assert!( + RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - pointer_max_bits < RV32_CELL_BITS, + "pointer_max_bits={pointer_max_bits} needs to be large enough for high limb range check" + ); + Self { + pointer_max_bits, + bitwise_lookup_chip, + } + } +} + +impl< + F: PrimeField32, + CTX, + const BLOCKS_PER_READ1: usize, + const BLOCKS_PER_READ2: usize, + const BLOCKS_PER_WRITE: usize, + const READ_SIZE: usize, + const WRITE_SIZE: usize, + > AdapterTraceStep + for Rv32VecHeapTwoReadsAdapterStep< BLOCKS_PER_READ1, BLOCKS_PER_READ2, BLOCKS_PER_WRITE, READ_SIZE, WRITE_SIZE, - >; - type Interface = VecHeapTwoReadsAdapterInterface< + > +{ + const WIDTH: usize = Rv32VecHeapTwoReadsAdapterCols::< F, BLOCKS_PER_READ1, BLOCKS_PER_READ2, BLOCKS_PER_WRITE, READ_SIZE, WRITE_SIZE, + >::width(); + + type ReadData = ( + [[u8; READ_SIZE]; BLOCKS_PER_READ1], + [[u8; READ_SIZE]; BLOCKS_PER_READ2], + ); + type WriteData = [[u8; WRITE_SIZE]; BLOCKS_PER_WRITE]; + type RecordMut<'a> = &'a mut Rv32VecHeapTwoReadsAdapterRecord< + BLOCKS_PER_READ1, + BLOCKS_PER_READ2, + BLOCKS_PER_WRITE, + WRITE_SIZE, >; - fn preprocess( - &mut self, - memory: &mut MemoryController, + fn start(pc: u32, memory: &TracingMemory, record: &mut Self::RecordMut<'_>) { + record.from_pc = pc; + record.from_timestamp = memory.timestamp; + } + + fn read( + &self, + memory: &mut TracingMemory, instruction: &Instruction, - ) -> Result<( - >::Reads, - Self::ReadRecord, - )> { + record: &mut Self::RecordMut<'_>, + ) -> Self::ReadData { let Instruction { a, b, c, d, e, .. } = *instruction; - debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); debug_assert_eq!(e.as_canonical_u32(), RV32_MEMORY_AS); - let (rs1_record, rs1_val) = read_rv32_register(memory, d, b); - let (rs2_record, rs2_val) = read_rv32_register(memory, d, c); - let (rd_record, rd_val) = read_rv32_register(memory, d, a); - - assert!(rs1_val as usize + READ_SIZE * BLOCKS_PER_READ1 - 1 < (1 << self.air.address_bits)); - let read1_records = from_fn(|i| { - memory.read::(e, F::from_canonical_u32(rs1_val + (i * READ_SIZE) as u32)) - }); - let read1_data = read1_records.map(|r| r.1); - assert!(rs2_val as usize + READ_SIZE * BLOCKS_PER_READ2 - 1 < (1 << self.air.address_bits)); - let read2_records = from_fn(|i| { - memory.read::(e, F::from_canonical_u32(rs2_val + (i * READ_SIZE) as u32)) - }); - let read2_data = read2_records.map(|r| r.1); - assert!(rd_val as usize + WRITE_SIZE * BLOCKS_PER_WRITE - 1 < (1 << self.air.address_bits)); - - let record = Rv32VecHeapTwoReadsReadRecord { - rs1: rs1_record, - rs2: rs2_record, - rd: rd_record, - rd_val: F::from_canonical_u32(rd_val), - reads1: read1_records.map(|r| r.0), - reads2: read2_records.map(|r| r.0), - }; + // Read register values + record.rs1_ptr = b.as_canonical_u32(); + record.rs1_val = u32::from_le_bytes(tracing_read( + memory, + RV32_REGISTER_AS, + record.rs1_ptr, + &mut record.rs1_read_aux.prev_timestamp, + )); + record.rs2_ptr = c.as_canonical_u32(); + record.rs2_val = u32::from_le_bytes(tracing_read( + memory, + RV32_REGISTER_AS, + record.rs2_ptr, + &mut record.rs2_read_aux.prev_timestamp, + )); - Ok(((read1_data, read2_data), record)) - } + record.rd_ptr = a.as_canonical_u32(); + record.rd_val = u32::from_le_bytes(tracing_read( + memory, + RV32_REGISTER_AS, + record.rd_ptr, + &mut record.rd_read_aux.prev_timestamp, + )); + assert!( + record.rs1_val as usize + READ_SIZE * BLOCKS_PER_READ1 - 1 + < (1 << self.pointer_max_bits) + ); + assert!( + record.rs2_val as usize + READ_SIZE * BLOCKS_PER_READ2 - 1 + < (1 << self.pointer_max_bits) + ); - fn postprocess( - &mut self, - memory: &mut MemoryController, - instruction: &Instruction, - from_state: ExecutionState, - output: AdapterRuntimeContext, - read_record: &Self::ReadRecord, - ) -> Result<(ExecutionState, Self::WriteRecord)> { - let e = instruction.e; - let mut i = 0; - let writes = output.writes.map(|write| { - let (record_id, _) = memory.write( - e, - read_record.rd_val + F::from_canonical_u32((i * WRITE_SIZE) as u32), - write, - ); - i += 1; - record_id - }); - - Ok(( - ExecutionState { - pc: from_state.pc + DEFAULT_PC_STEP, - timestamp: memory.timestamp(), - }, - Self::WriteRecord { from_state, writes }, - )) + ( + from_fn(|i| { + tracing_read( + memory, + RV32_MEMORY_AS, + record.rs1_val + (i * READ_SIZE) as u32, + &mut record.reads1_aux[i].prev_timestamp, + ) + }), + from_fn(|i| { + tracing_read( + memory, + RV32_MEMORY_AS, + record.rs2_val + (i * READ_SIZE) as u32, + &mut record.reads2_aux[i].prev_timestamp, + ) + }), + ) } - fn generate_trace_row( + fn write( &self, - row_slice: &mut [F], - read_record: Self::ReadRecord, - write_record: Self::WriteRecord, - memory: &OfflineMemory, + memory: &mut TracingMemory, + _instruction: &Instruction, + data: Self::WriteData, + record: &mut Self::RecordMut<'_>, ) { - vec_heap_two_reads_generate_trace_row_impl( - row_slice, - &read_record, - &write_record, - self.bitwise_lookup_chip.clone(), - self.air.address_bits, - memory, - ) - } + assert!( + record.rd_val as usize + WRITE_SIZE * BLOCKS_PER_WRITE - 1 + < (1 << self.pointer_max_bits) + ); - fn air(&self) -> &Self::Air { - &self.air + for (i, block) in data.into_iter().enumerate() { + tracing_write( + memory, + RV32_MEMORY_AS, + record.rd_val + (i * WRITE_SIZE) as u32, + block, + &mut record.writes_aux[i].prev_timestamp, + &mut record.writes_aux[i].prev_data, + ); + } } } -pub(super) fn vec_heap_two_reads_generate_trace_row_impl< - F: PrimeField32, - const BLOCKS_PER_READ1: usize, - const BLOCKS_PER_READ2: usize, - const BLOCKS_PER_WRITE: usize, - const READ_SIZE: usize, - const WRITE_SIZE: usize, ->( - row_slice: &mut [F], - read_record: &Rv32VecHeapTwoReadsReadRecord, - write_record: &Rv32VecHeapTwoReadsWriteRecord, - bitwise_lookup_chip: SharedBitwiseOperationLookupChip, - address_bits: usize, - memory: &OfflineMemory, -) { - let aux_cols_factory = memory.aux_cols_factory(); - let row_slice: &mut Rv32VecHeapTwoReadsAdapterCols< - F, +impl< + F: PrimeField32, + CTX, + const BLOCKS_PER_READ1: usize, + const BLOCKS_PER_READ2: usize, + const BLOCKS_PER_WRITE: usize, + const READ_SIZE: usize, + const WRITE_SIZE: usize, + > AdapterTraceFiller + for Rv32VecHeapTwoReadsAdapterStep< BLOCKS_PER_READ1, BLOCKS_PER_READ2, BLOCKS_PER_WRITE, READ_SIZE, WRITE_SIZE, - > = row_slice.borrow_mut(); - row_slice.from_state = write_record.from_state.map(F::from_canonical_u32); + > +{ + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, mut adapter_row: &mut [F]) { + let record: &Rv32VecHeapTwoReadsAdapterRecord< + BLOCKS_PER_READ1, + BLOCKS_PER_READ2, + BLOCKS_PER_WRITE, + WRITE_SIZE, + > = unsafe { get_record_from_slice(&mut adapter_row, ()) }; - let rd = memory.record_by_id(read_record.rd); - let rs1 = memory.record_by_id(read_record.rs1); - let rs2 = memory.record_by_id(read_record.rs2); + let cols: &mut Rv32VecHeapTwoReadsAdapterCols< + F, + BLOCKS_PER_READ1, + BLOCKS_PER_READ2, + BLOCKS_PER_WRITE, + READ_SIZE, + WRITE_SIZE, + > = adapter_row.borrow_mut(); - row_slice.rd_ptr = rd.pointer; - row_slice.rs1_ptr = rs1.pointer; - row_slice.rs2_ptr = rs2.pointer; + debug_assert!(self.pointer_max_bits <= RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS); - row_slice.rd_val.copy_from_slice(rd.data_slice()); - row_slice.rs1_val.copy_from_slice(rs1.data_slice()); - row_slice.rs2_val.copy_from_slice(rs2.data_slice()); + const MSL_SHIFT: usize = RV32_CELL_BITS * (RV32_REGISTER_NUM_LIMBS - 1); + let limb_shift_bits = RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.pointer_max_bits; + self.bitwise_lookup_chip.request_range( + (record.rs1_val >> MSL_SHIFT) << limb_shift_bits, + (record.rs2_val >> MSL_SHIFT) << limb_shift_bits, + ); + self.bitwise_lookup_chip.request_range( + (record.rd_val >> MSL_SHIFT) << limb_shift_bits, + (record.rd_val >> MSL_SHIFT) << limb_shift_bits, + ); - aux_cols_factory.generate_read_aux(rs1, &mut row_slice.rs1_read_aux); - aux_cols_factory.generate_read_aux(rs2, &mut row_slice.rs2_read_aux); - aux_cols_factory.generate_read_aux(rd, &mut row_slice.rd_read_aux); + let mut timestamp = record.from_timestamp + + 2 + + (BLOCKS_PER_READ1 + BLOCKS_PER_READ2 + BLOCKS_PER_WRITE) as u32; + let mut timestamp_mm = || { + timestamp -= 1; + timestamp + }; - for (i, r) in read_record.reads1.iter().enumerate() { - let record = memory.record_by_id(*r); - aux_cols_factory.generate_read_aux(record, &mut row_slice.reads1_aux[i]); - } + // Writing everything in reverse order + cols.writes_aux + .iter_mut() + .rev() + .zip(record.writes_aux.iter().rev()) + .for_each(|(col, record)| { + col.set_prev_data(record.prev_data.map(F::from_canonical_u8)); + mem_helper.fill(record.prev_timestamp, timestamp_mm(), col.as_mut()); + }); + + cols.reads2_aux + .iter_mut() + .rev() + .zip(record.reads2_aux.iter().rev()) + .for_each(|(col, record)| { + mem_helper.fill(record.prev_timestamp, timestamp_mm(), col.as_mut()); + }); + + cols.reads1_aux + .iter_mut() + .rev() + .zip(record.reads1_aux.iter().rev()) + .for_each(|(col, record)| { + mem_helper.fill(record.prev_timestamp, timestamp_mm(), col.as_mut()); + }); + + mem_helper.fill( + record.rd_read_aux.prev_timestamp, + timestamp_mm(), + cols.rd_read_aux.as_mut(), + ); + mem_helper.fill( + record.rs2_read_aux.prev_timestamp, + timestamp_mm(), + cols.rs2_read_aux.as_mut(), + ); + mem_helper.fill( + record.rs1_read_aux.prev_timestamp, + timestamp_mm(), + cols.rs1_read_aux.as_mut(), + ); - for (i, r) in read_record.reads2.iter().enumerate() { - let record = memory.record_by_id(*r); - aux_cols_factory.generate_read_aux(record, &mut row_slice.reads2_aux[i]); - } + cols.rd_val = record.rd_val.to_le_bytes().map(F::from_canonical_u8); + cols.rs2_val = record.rs2_val.to_le_bytes().map(F::from_canonical_u8); + cols.rs1_val = record.rs1_val.to_le_bytes().map(F::from_canonical_u8); + cols.rd_ptr = F::from_canonical_u32(record.rd_ptr); + cols.rs2_ptr = F::from_canonical_u32(record.rs2_ptr); + cols.rs1_ptr = F::from_canonical_u32(record.rs1_ptr); - for (i, w) in write_record.writes.iter().enumerate() { - let record = memory.record_by_id(*w); - aux_cols_factory.generate_write_aux(record, &mut row_slice.writes_aux[i]); - } - // Range checks: - let need_range_check = [ - &read_record.rs1, - &read_record.rs2, - &read_record.rd, - &read_record.rd, - ] - .map(|record| { - memory - .record_by_id(*record) - .data_at(RV32_REGISTER_NUM_LIMBS - 1) - .as_canonical_u32() - }); - debug_assert!(address_bits <= RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS); - let limb_shift_bits = RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - address_bits; - for pair in need_range_check.chunks_exact(2) { - bitwise_lookup_chip.request_range(pair[0] << limb_shift_bits, pair[1] << limb_shift_bits); + cols.from_state.timestamp = F::from_canonical_u32(timestamp); + cols.from_state.pc = F::from_canonical_u32(record.from_pc); } } diff --git a/extensions/rv32im/circuit/Cargo.toml b/extensions/rv32im/circuit/Cargo.toml index 8b20385104..9f6bbb6824 100644 --- a/extensions/rv32im/circuit/Cargo.toml +++ b/extensions/rv32im/circuit/Cargo.toml @@ -21,15 +21,16 @@ derive-new.workspace = true derive_more = { workspace = true, features = ["from"] } rand.workspace = true eyre.workspace = true + # for div_rem: num-bigint.workspace = true num-integer.workspace = true serde = { workspace = true, features = ["derive", "std"] } -serde-big-array.workspace = true [dev-dependencies] openvm-stark-sdk = { workspace = true } openvm-circuit = { workspace = true, features = ["test-utils"] } +test-case.workspace = true [features] default = ["parallel", "jemalloc"] diff --git a/extensions/rv32im/circuit/src/adapters/alu.rs b/extensions/rv32im/circuit/src/adapters/alu.rs index b61e2a224a..188e6ec966 100644 --- a/extensions/rv32im/circuit/src/adapters/alu.rs +++ b/extensions/rv32im/circuit/src/adapters/alu.rs @@ -1,25 +1,23 @@ -use std::{ - borrow::{Borrow, BorrowMut}, - marker::PhantomData, -}; +use std::borrow::{Borrow, BorrowMut}; use openvm_circuit::{ arch::{ - AdapterAirContext, AdapterRuntimeContext, BasicAdapterInterface, ExecutionBridge, - ExecutionBus, ExecutionState, MinimalInstruction, Result, VmAdapterAir, VmAdapterChip, - VmAdapterInterface, + get_record_from_slice, AdapterAirContext, AdapterTraceFiller, AdapterTraceStep, + BasicAdapterInterface, ExecutionBridge, ExecutionState, MinimalInstruction, VmAdapterAir, }, - system::{ - memory::{ - offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols}, - MemoryAddress, MemoryController, OfflineMemory, RecordId, + system::memory::{ + offline_checker::{ + MemoryBridge, MemoryReadAuxCols, MemoryReadAuxRecord, MemoryWriteAuxCols, + MemoryWriteBytesAuxRecord, }, - program::ProgramBus, + online::TracingMemory, + MemoryAddress, MemoryAuxColsFactory, }, }; use openvm_circuit_primitives::{ bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, utils::not, + AlignedBytesBorrow, }; use openvm_circuit_primitives_derive::AlignedBorrow; use openvm_instructions::{ @@ -32,60 +30,10 @@ use openvm_stark_backend::{ p3_air::{AirBuilder, BaseAir}, p3_field::{Field, FieldAlgebra, PrimeField32}, }; -use serde::{Deserialize, Serialize}; - -use super::{RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}; - -/// Reads instructions of the form OP a, b, c, d, e where \[a:4\]_d = \[b:4\]_d op \[c:4\]_e. -/// Operand d can only be 1, and e can be either 1 (for register reads) or 0 (when c -/// is an immediate). -pub struct Rv32BaseAluAdapterChip { - pub air: Rv32BaseAluAdapterAir, - bitwise_lookup_chip: SharedBitwiseOperationLookupChip, - _marker: PhantomData, -} -impl Rv32BaseAluAdapterChip { - pub fn new( - execution_bus: ExecutionBus, - program_bus: ProgramBus, - memory_bridge: MemoryBridge, - bitwise_lookup_chip: SharedBitwiseOperationLookupChip, - ) -> Self { - Self { - air: Rv32BaseAluAdapterAir { - execution_bridge: ExecutionBridge::new(execution_bus, program_bus), - memory_bridge, - bitwise_lookup_bus: bitwise_lookup_chip.bus(), - }, - bitwise_lookup_chip, - _marker: PhantomData, - } - } -} - -#[repr(C)] -#[derive(Clone, Debug, Serialize, Deserialize)] -#[serde(bound = "F: Field")] -pub struct Rv32BaseAluReadRecord { - /// Read register value from address space d=1 - pub rs1: RecordId, - /// Either - /// - read rs2 register value or - /// - if `rs2_is_imm` is true, this is None - pub rs2: Option, - /// immediate value of rs2 or 0 - pub rs2_imm: F, -} - -#[repr(C)] -#[derive(Clone, Debug, Serialize, Deserialize)] -#[serde(bound = "F: Field")] -pub struct Rv32BaseAluWriteRecord { - pub from_state: ExecutionState, - /// Write to destination register - pub rd: (RecordId, [F; 4]), -} +use super::{ + tracing_read, tracing_read_imm, tracing_write, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS, +}; #[repr(C)] #[derive(AlignedBorrow)] @@ -101,7 +49,9 @@ pub struct Rv32BaseAluAdapterCols { pub writes_aux: MemoryWriteAuxCols, } -#[allow(dead_code)] +/// Reads instructions of the form OP a, b, c, d, e where \[a:4\]_d = \[b:4\]_d op \[c:4\]_e. +/// Operand d can only be 1, and e can be either 1 (for register reads) or 0 (when c +/// is an immediate). #[derive(Clone, Copy, Debug, derive_new::new)] pub struct Rv32BaseAluAdapterAir { pub(super) execution_bridge: ExecutionBridge, @@ -213,129 +163,166 @@ impl VmAdapterAir for Rv32BaseAluAdapterAir { } } -impl VmAdapterChip for Rv32BaseAluAdapterChip { - type ReadRecord = Rv32BaseAluReadRecord; - type WriteRecord = Rv32BaseAluWriteRecord; - type Air = Rv32BaseAluAdapterAir; - type Interface = BasicAdapterInterface< - F, - MinimalInstruction, - 2, - 1, - RV32_REGISTER_NUM_LIMBS, - RV32_REGISTER_NUM_LIMBS, - >; +#[derive(derive_new::new)] +pub struct Rv32BaseAluAdapterStep { + // TODO(arayi): use reference to bitwise lookup chip with lifetimes instead + pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, +} + +// Intermediate type that should not be copied or cloned and should be directly written to +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug)] +pub struct Rv32BaseAluAdapterRecord { + pub from_pc: u32, + pub from_timestamp: u32, + + pub rd_ptr: u32, + pub rs1_ptr: u32, + /// Pointer if rs2 was a read, immediate value otherwise + pub rs2: u32, + /// 1 if rs2 was a read, 0 if an immediate + pub rs2_as: u8, + + pub reads_aux: [MemoryReadAuxRecord; 2], + pub writes_aux: MemoryWriteBytesAuxRecord, +} - fn preprocess( - &mut self, - memory: &mut MemoryController, +impl AdapterTraceStep + for Rv32BaseAluAdapterStep +{ + const WIDTH: usize = size_of::>(); + type ReadData = [[u8; RV32_REGISTER_NUM_LIMBS]; 2]; + type WriteData = [[u8; RV32_REGISTER_NUM_LIMBS]; 1]; + type RecordMut<'a> = &'a mut Rv32BaseAluAdapterRecord; + + #[inline(always)] + fn start(pc: u32, memory: &TracingMemory, record: &mut &mut Rv32BaseAluAdapterRecord) { + record.from_pc = pc; + record.from_timestamp = memory.timestamp; + } + + // @dev cannot get rid of double &mut due to trait + #[inline(always)] + fn read( + &self, + memory: &mut TracingMemory, instruction: &Instruction, - ) -> Result<( - >::Reads, - Self::ReadRecord, - )> { - let Instruction { b, c, d, e, .. } = *instruction; + record: &mut &mut Rv32BaseAluAdapterRecord, + ) -> Self::ReadData { + let &Instruction { b, c, d, e, .. } = instruction; debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); debug_assert!( - e.as_canonical_u32() == RV32_IMM_AS || e.as_canonical_u32() == RV32_REGISTER_AS + e.as_canonical_u32() == RV32_REGISTER_AS || e.as_canonical_u32() == RV32_IMM_AS ); - let rs1 = memory.read::(d, b); - let (rs2, rs2_data, rs2_imm) = if e.is_zero() { - let c_u32 = c.as_canonical_u32(); - debug_assert_eq!(c_u32 >> 24, 0); - memory.increment_timestamp(); - ( - None, - [ - c_u32 as u8, - (c_u32 >> 8) as u8, - (c_u32 >> 16) as u8, - (c_u32 >> 16) as u8, - ] - .map(F::from_canonical_u8), - c, + record.rs1_ptr = b.as_canonical_u32(); + let rs1 = tracing_read( + memory, + RV32_REGISTER_AS, + record.rs1_ptr, + &mut record.reads_aux[0].prev_timestamp, + ); + + let rs2 = if e.as_canonical_u32() == RV32_REGISTER_AS { + record.rs2_as = RV32_REGISTER_AS as u8; + record.rs2 = c.as_canonical_u32(); + + tracing_read( + memory, + RV32_REGISTER_AS, + record.rs2, + &mut record.reads_aux[1].prev_timestamp, ) } else { - let rs2_read = memory.read::(e, c); - (Some(rs2_read.0), rs2_read.1, F::ZERO) + record.rs2_as = RV32_IMM_AS as u8; + + tracing_read_imm(memory, c.as_canonical_u32(), &mut record.rs2) }; - Ok(( - [rs1.1, rs2_data], - Self::ReadRecord { - rs1: rs1.0, - rs2, - rs2_imm, - }, - )) + [rs1, rs2] } - fn postprocess( - &mut self, - memory: &mut MemoryController, + #[inline(always)] + fn write( + &self, + memory: &mut TracingMemory, instruction: &Instruction, - from_state: ExecutionState, - output: AdapterRuntimeContext, - _read_record: &Self::ReadRecord, - ) -> Result<(ExecutionState, Self::WriteRecord)> { - let Instruction { a, d, .. } = instruction; - let rd = memory.write(*d, *a, output.writes[0]); - - let timestamp_delta = memory.timestamp() - from_state.timestamp; - debug_assert!( - timestamp_delta == 3, - "timestamp delta is {}, expected 3", - timestamp_delta - ); + data: Self::WriteData, + record: &mut &mut Rv32BaseAluAdapterRecord, + ) { + let &Instruction { a, d, .. } = instruction; + + debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); - Ok(( - ExecutionState { - pc: from_state.pc + DEFAULT_PC_STEP, - timestamp: memory.timestamp(), - }, - Self::WriteRecord { from_state, rd }, - )) + record.rd_ptr = a.as_canonical_u32(); + tracing_write( + memory, + RV32_REGISTER_AS, + record.rd_ptr, + data[0], + &mut record.writes_aux.prev_timestamp, + &mut record.writes_aux.prev_data, + ); } +} - fn generate_trace_row( - &self, - row_slice: &mut [F], - read_record: Self::ReadRecord, - write_record: Self::WriteRecord, - memory: &OfflineMemory, - ) { - let row_slice: &mut Rv32BaseAluAdapterCols<_> = row_slice.borrow_mut(); - let aux_cols_factory = memory.aux_cols_factory(); - - let rd = memory.record_by_id(write_record.rd.0); - row_slice.from_state = write_record.from_state.map(F::from_canonical_u32); - row_slice.rd_ptr = rd.pointer; - - let rs1 = memory.record_by_id(read_record.rs1); - let rs2 = read_record.rs2.map(|rs2| memory.record_by_id(rs2)); - row_slice.rs1_ptr = rs1.pointer; - - if let Some(rs2) = rs2 { - row_slice.rs2 = rs2.pointer; - row_slice.rs2_as = rs2.address_space; - aux_cols_factory.generate_read_aux(rs1, &mut row_slice.reads_aux[0]); - aux_cols_factory.generate_read_aux(rs2, &mut row_slice.reads_aux[1]); +impl AdapterTraceFiller + for Rv32BaseAluAdapterStep +{ + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, mut adapter_row: &mut [F]) { + // SAFETY: the following is highly unsafe. We are going to cast `adapter_row` to a record + // buffer, and then do an _overlapping_ write to the `adapter_row` as a row of field + // elements. This requires: + // - Cols struct should be repr(C) and we write in reverse order (to ensure non-overlapping) + // - Do not overwrite any reference in `record` before it has already been used or moved + // - alignment of `F` must be >= alignment of Record (AlignedBytesBorrow will panic + // otherwise) + let record: &Rv32BaseAluAdapterRecord = + unsafe { get_record_from_slice(&mut adapter_row, ()) }; + let adapter_row: &mut Rv32BaseAluAdapterCols = adapter_row.borrow_mut(); + + // We must assign in reverse + // TODO[jpw]: is there a way to not hardcode? + const TIMESTAMP_DELTA: u32 = 2; + let mut timestamp = record.from_timestamp + TIMESTAMP_DELTA; + + adapter_row + .writes_aux + .set_prev_data(record.writes_aux.prev_data.map(F::from_canonical_u8)); + mem_helper.fill( + record.writes_aux.prev_timestamp, + timestamp, + adapter_row.writes_aux.as_mut(), + ); + timestamp -= 1; + + if record.rs2_as != 0 { + mem_helper.fill( + record.reads_aux[1].prev_timestamp, + timestamp, + adapter_row.reads_aux[1].as_mut(), + ); } else { - row_slice.rs2 = read_record.rs2_imm; - row_slice.rs2_as = F::ZERO; - let rs2_imm = row_slice.rs2.as_canonical_u32(); + mem_helper.fill_zero(adapter_row.reads_aux[1].as_mut()); + let rs2_imm = record.rs2; let mask = (1 << RV32_CELL_BITS) - 1; self.bitwise_lookup_chip .request_range(rs2_imm & mask, (rs2_imm >> 8) & mask); - aux_cols_factory.generate_read_aux(rs1, &mut row_slice.reads_aux[0]); - // row_slice.reads_aux[1] is disabled } - aux_cols_factory.generate_write_aux(rd, &mut row_slice.writes_aux); - } + timestamp -= 1; + + mem_helper.fill( + record.reads_aux[0].prev_timestamp, + timestamp, + adapter_row.reads_aux[0].as_mut(), + ); - fn air(&self) -> &Self::Air { - &self.air + adapter_row.rs2_as = F::from_canonical_u8(record.rs2_as); + adapter_row.rs2 = F::from_canonical_u32(record.rs2); + adapter_row.rs1_ptr = F::from_canonical_u32(record.rs1_ptr); + adapter_row.rd_ptr = F::from_canonical_u32(record.rd_ptr); + adapter_row.from_state.timestamp = F::from_canonical_u32(timestamp); + adapter_row.from_state.pc = F::from_canonical_u32(record.from_pc); } } diff --git a/extensions/rv32im/circuit/src/adapters/branch.rs b/extensions/rv32im/circuit/src/adapters/branch.rs index 3e26f37f4c..21d55fcff2 100644 --- a/extensions/rv32im/circuit/src/adapters/branch.rs +++ b/extensions/rv32im/circuit/src/adapters/branch.rs @@ -1,22 +1,17 @@ -use std::{ - borrow::{Borrow, BorrowMut}, - marker::PhantomData, -}; +use std::borrow::{Borrow, BorrowMut}; use openvm_circuit::{ arch::{ - AdapterAirContext, AdapterRuntimeContext, BasicAdapterInterface, ExecutionBridge, - ExecutionBus, ExecutionState, ImmInstruction, Result, VmAdapterAir, VmAdapterChip, - VmAdapterInterface, + get_record_from_slice, AdapterAirContext, AdapterTraceFiller, AdapterTraceStep, + BasicAdapterInterface, ExecutionBridge, ExecutionState, ImmInstruction, VmAdapterAir, }, - system::{ - memory::{ - offline_checker::{MemoryBridge, MemoryReadAuxCols}, - MemoryAddress, MemoryController, OfflineMemory, RecordId, - }, - program::ProgramBus, + system::memory::{ + offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryReadAuxRecord}, + online::TracingMemory, + MemoryAddress, MemoryAuxColsFactory, }, }; +use openvm_circuit_primitives::AlignedBytesBorrow; use openvm_circuit_primitives_derive::AlignedBorrow; use openvm_instructions::{ instruction::Instruction, program::DEFAULT_PC_STEP, riscv::RV32_REGISTER_AS, @@ -26,48 +21,9 @@ use openvm_stark_backend::{ p3_air::BaseAir, p3_field::{Field, FieldAlgebra, PrimeField32}, }; -use serde::{Deserialize, Serialize}; use super::RV32_REGISTER_NUM_LIMBS; - -/// Reads instructions of the form OP a, b, c, d, e where if(\[a:4\]_d op \[b:4\]_e) pc += c. -/// Operands d and e can only be 1. -#[derive(Debug)] -pub struct Rv32BranchAdapterChip { - pub air: Rv32BranchAdapterAir, - _marker: PhantomData, -} - -impl Rv32BranchAdapterChip { - pub fn new( - execution_bus: ExecutionBus, - program_bus: ProgramBus, - memory_bridge: MemoryBridge, - ) -> Self { - Self { - air: Rv32BranchAdapterAir { - execution_bridge: ExecutionBridge::new(execution_bus, program_bus), - memory_bridge, - }, - _marker: PhantomData, - } - } -} - -#[repr(C)] -#[derive(Debug, Serialize, Deserialize)] -pub struct Rv32BranchReadRecord { - /// Read register value from address space d = 1 - pub rs1: RecordId, - /// Read register value from address space e = 1 - pub rs2: RecordId, -} - -#[repr(C)] -#[derive(Debug, Serialize, Deserialize)] -pub struct Rv32BranchWriteRecord { - pub from_state: ExecutionState, -} +use crate::adapters::tracing_read; #[repr(C)] #[derive(AlignedBorrow)] @@ -149,80 +105,102 @@ impl VmAdapterAir for Rv32BranchAdapterAir { } } -impl VmAdapterChip for Rv32BranchAdapterChip { - type ReadRecord = Rv32BranchReadRecord; - type WriteRecord = Rv32BranchWriteRecord; - type Air = Rv32BranchAdapterAir; - type Interface = BasicAdapterInterface, 2, 0, RV32_REGISTER_NUM_LIMBS, 0>; +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug)] +pub struct Rv32BranchAdapterRecord { + pub from_pc: u32, + pub from_timestamp: u32, + pub rs1_ptr: u32, + pub rs2_ptr: u32, + pub reads_aux: [MemoryReadAuxRecord; 2], +} - fn preprocess( - &mut self, - memory: &mut MemoryController, +/// Reads instructions of the form OP a, b, c, d, e where if(\[a:4\]_d op \[b:4\]_e) pc += c. +/// Operands d and e can only be 1. +#[derive(derive_new::new)] +pub struct Rv32BranchAdapterStep; + +impl AdapterTraceStep for Rv32BranchAdapterStep +where + F: PrimeField32, +{ + const WIDTH: usize = size_of::>(); + type ReadData = [[u8; RV32_REGISTER_NUM_LIMBS]; 2]; + type WriteData = (); + type RecordMut<'a> = &'a mut Rv32BranchAdapterRecord; + + #[inline(always)] + fn start(pc: u32, memory: &TracingMemory, record: &mut &mut Rv32BranchAdapterRecord) { + record.from_pc = pc; + record.from_timestamp = memory.timestamp; + } + + #[inline(always)] + fn read( + &self, + memory: &mut TracingMemory, instruction: &Instruction, - ) -> Result<( - >::Reads, - Self::ReadRecord, - )> { - let Instruction { a, b, d, e, .. } = *instruction; + record: &mut &mut Rv32BranchAdapterRecord, + ) -> Self::ReadData { + let &Instruction { a, b, d, e, .. } = instruction; debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); debug_assert_eq!(e.as_canonical_u32(), RV32_REGISTER_AS); - let rs1 = memory.read::(d, a); - let rs2 = memory.read::(e, b); - - Ok(( - [rs1.1, rs2.1], - Self::ReadRecord { - rs1: rs1.0, - rs2: rs2.0, - }, - )) - } - - fn postprocess( - &mut self, - memory: &mut MemoryController, - _instruction: &Instruction, - from_state: ExecutionState, - output: AdapterRuntimeContext, - _read_record: &Self::ReadRecord, - ) -> Result<(ExecutionState, Self::WriteRecord)> { - let timestamp_delta = memory.timestamp() - from_state.timestamp; - debug_assert!( - timestamp_delta == 2, - "timestamp delta is {}, expected 2", - timestamp_delta + record.rs1_ptr = a.as_canonical_u32(); + let rs1 = tracing_read( + memory, + RV32_REGISTER_AS, + a.as_canonical_u32(), + &mut record.reads_aux[0].prev_timestamp, + ); + record.rs2_ptr = b.as_canonical_u32(); + let rs2 = tracing_read( + memory, + RV32_REGISTER_AS, + b.as_canonical_u32(), + &mut record.reads_aux[1].prev_timestamp, ); - Ok(( - ExecutionState { - pc: output.to_pc.unwrap_or(from_state.pc + DEFAULT_PC_STEP), - timestamp: memory.timestamp(), - }, - Self::WriteRecord { from_state }, - )) + [rs1, rs2] } - fn generate_trace_row( + #[inline(always)] + fn write( &self, - row_slice: &mut [F], - read_record: Self::ReadRecord, - write_record: Self::WriteRecord, - memory: &OfflineMemory, + _memory: &mut TracingMemory, + _instruction: &Instruction, + _data: Self::WriteData, + _record: &mut Self::RecordMut<'_>, ) { - let aux_cols_factory = memory.aux_cols_factory(); - let row_slice: &mut Rv32BranchAdapterCols<_> = row_slice.borrow_mut(); - row_slice.from_state = write_record.from_state.map(F::from_canonical_u32); - let rs1 = memory.record_by_id(read_record.rs1); - let rs2 = memory.record_by_id(read_record.rs2); - row_slice.rs1_ptr = rs1.pointer; - row_slice.rs2_ptr = rs2.pointer; - aux_cols_factory.generate_read_aux(rs1, &mut row_slice.reads_aux[0]); - aux_cols_factory.generate_read_aux(rs2, &mut row_slice.reads_aux[1]); + // This function is intentionally left empty } +} +impl AdapterTraceFiller for Rv32BranchAdapterStep { + #[inline(always)] + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, mut adapter_row: &mut [F]) { + let record: &Rv32BranchAdapterRecord = + unsafe { get_record_from_slice(&mut adapter_row, ()) }; + let adapter_row: &mut Rv32BranchAdapterCols = adapter_row.borrow_mut(); + + // We must assign in reverse + let timestamp = record.from_timestamp; + + mem_helper.fill( + record.reads_aux[1].prev_timestamp, + timestamp + 1, + adapter_row.reads_aux[1].as_mut(), + ); + + mem_helper.fill( + record.reads_aux[0].prev_timestamp, + timestamp, + adapter_row.reads_aux[0].as_mut(), + ); - fn air(&self) -> &Self::Air { - &self.air + adapter_row.from_state.pc = F::from_canonical_u32(record.from_pc); + adapter_row.from_state.timestamp = F::from_canonical_u32(record.from_timestamp); + adapter_row.rs1_ptr = F::from_canonical_u32(record.rs1_ptr); + adapter_row.rs2_ptr = F::from_canonical_u32(record.rs2_ptr); } } diff --git a/extensions/rv32im/circuit/src/adapters/jalr.rs b/extensions/rv32im/circuit/src/adapters/jalr.rs index f7dbf623b8..58a83c3794 100644 --- a/extensions/rv32im/circuit/src/adapters/jalr.rs +++ b/extensions/rv32im/circuit/src/adapters/jalr.rs @@ -1,23 +1,20 @@ -use std::{ - borrow::{Borrow, BorrowMut}, - marker::PhantomData, -}; +use std::borrow::{Borrow, BorrowMut}; use openvm_circuit::{ arch::{ - AdapterAirContext, AdapterRuntimeContext, BasicAdapterInterface, ExecutionBridge, - ExecutionBus, ExecutionState, Result, SignedImmInstruction, VmAdapterAir, VmAdapterChip, - VmAdapterInterface, + get_record_from_slice, AdapterAirContext, AdapterTraceFiller, AdapterTraceStep, + BasicAdapterInterface, ExecutionBridge, ExecutionState, SignedImmInstruction, VmAdapterAir, }, - system::{ - memory::{ - offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols}, - MemoryAddress, MemoryController, OfflineMemory, RecordId, + system::memory::{ + offline_checker::{ + MemoryBridge, MemoryReadAuxCols, MemoryReadAuxRecord, MemoryWriteAuxCols, + MemoryWriteBytesAuxRecord, }, - program::ProgramBus, + online::TracingMemory, + MemoryAddress, MemoryAuxColsFactory, }, }; -use openvm_circuit_primitives::utils::not; +use openvm_circuit_primitives::{utils::not, AlignedBytesBorrow}; use openvm_circuit_primitives_derive::AlignedBorrow; use openvm_instructions::{ instruction::Instruction, program::DEFAULT_PC_STEP, riscv::RV32_REGISTER_AS, @@ -27,44 +24,9 @@ use openvm_stark_backend::{ p3_air::{AirBuilder, BaseAir}, p3_field::{Field, FieldAlgebra, PrimeField32}, }; -use serde::{Deserialize, Serialize}; use super::RV32_REGISTER_NUM_LIMBS; - -// This adapter reads from [b:4]_d (rs1) and writes to [a:4]_d (rd) -#[derive(Debug)] -pub struct Rv32JalrAdapterChip { - pub air: Rv32JalrAdapterAir, - _marker: PhantomData, -} - -impl Rv32JalrAdapterChip { - pub fn new( - execution_bus: ExecutionBus, - program_bus: ProgramBus, - memory_bridge: MemoryBridge, - ) -> Self { - Self { - air: Rv32JalrAdapterAir { - execution_bridge: ExecutionBridge::new(execution_bus, program_bus), - memory_bridge, - }, - _marker: PhantomData, - } - } -} -#[repr(C)] -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Rv32JalrReadRecord { - pub rs1: RecordId, -} - -#[repr(C)] -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Rv32JalrWriteRecord { - pub from_state: ExecutionState, - pub rd_id: Option, -} +use crate::adapters::{tracing_read, tracing_write}; #[repr(C)] #[derive(Debug, Clone, AlignedBorrow)] @@ -179,84 +141,120 @@ impl VmAdapterAir for Rv32JalrAdapterAir { } } -impl VmAdapterChip for Rv32JalrAdapterChip { - type ReadRecord = Rv32JalrReadRecord; - type WriteRecord = Rv32JalrWriteRecord; - type Air = Rv32JalrAdapterAir; - type Interface = BasicAdapterInterface< - F, - SignedImmInstruction, - 1, - 1, - RV32_REGISTER_NUM_LIMBS, - RV32_REGISTER_NUM_LIMBS, - >; - fn preprocess( - &mut self, - memory: &mut MemoryController, - instruction: &Instruction, - ) -> Result<( - >::Reads, - Self::ReadRecord, - )> { - let Instruction { b, d, .. } = *instruction; - debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug)] +pub struct Rv32JalrAdapterRecord { + pub from_pc: u32, + pub from_timestamp: u32, + + pub rs1_ptr: u32, + // Will use u32::MAX to indicate no write + pub rd_ptr: u32, - let rs1 = memory.read::(d, b); + pub reads_aux: MemoryReadAuxRecord, + pub writes_aux: MemoryWriteBytesAuxRecord, +} - Ok(([rs1.1], Rv32JalrReadRecord { rs1: rs1.0 })) +// This adapter reads from [b:4]_d (rs1) and writes to [a:4]_d (rd) +#[derive(derive_new::new)] +pub struct Rv32JalrAdapterStep; + +impl AdapterTraceStep for Rv32JalrAdapterStep +where + F: PrimeField32, +{ + const WIDTH: usize = size_of::>(); + type ReadData = [u8; RV32_REGISTER_NUM_LIMBS]; + type WriteData = [u8; RV32_REGISTER_NUM_LIMBS]; + type RecordMut<'a> = &'a mut Rv32JalrAdapterRecord; + + #[inline(always)] + fn start(pc: u32, memory: &TracingMemory, record: &mut Self::RecordMut<'_>) { + record.from_pc = pc; + record.from_timestamp = memory.timestamp; } - fn postprocess( - &mut self, - memory: &mut MemoryController, + #[inline(always)] + fn read( + &self, + memory: &mut TracingMemory, instruction: &Instruction, - from_state: ExecutionState, - output: AdapterRuntimeContext, - _read_record: &Self::ReadRecord, - ) -> Result<(ExecutionState, Self::WriteRecord)> { - let Instruction { - a, d, f: enabled, .. - } = *instruction; - let rd_id = if enabled != F::ZERO { - let (record_id, _) = memory.write(d, a, output.writes[0]); - Some(record_id) - } else { - memory.increment_timestamp(); - None - }; + record: &mut Self::RecordMut<'_>, + ) -> Self::ReadData { + let &Instruction { b, d, .. } = instruction; + + debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); - Ok(( - ExecutionState { - pc: output.to_pc.unwrap_or(from_state.pc + DEFAULT_PC_STEP), - timestamp: memory.timestamp(), - }, - Self::WriteRecord { from_state, rd_id }, - )) + record.rs1_ptr = b.as_canonical_u32(); + tracing_read( + memory, + RV32_REGISTER_AS, + b.as_canonical_u32(), + &mut record.reads_aux.prev_timestamp, + ) } - fn generate_trace_row( + #[inline(always)] + fn write( &self, - row_slice: &mut [F], - read_record: Self::ReadRecord, - write_record: Self::WriteRecord, - memory: &OfflineMemory, + memory: &mut TracingMemory, + instruction: &Instruction, + data: Self::WriteData, + record: &mut Self::RecordMut<'_>, ) { - let aux_cols_factory = memory.aux_cols_factory(); - let adapter_cols: &mut Rv32JalrAdapterCols<_> = row_slice.borrow_mut(); - adapter_cols.from_state = write_record.from_state.map(F::from_canonical_u32); - let rs1 = memory.record_by_id(read_record.rs1); - adapter_cols.rs1_ptr = rs1.pointer; - aux_cols_factory.generate_read_aux(rs1, &mut adapter_cols.rs1_aux_cols); - if let Some(id) = write_record.rd_id { - let rd = memory.record_by_id(id); - adapter_cols.rd_ptr = rd.pointer; - adapter_cols.needs_write = F::ONE; - aux_cols_factory.generate_write_aux(rd, &mut adapter_cols.rd_aux_cols); + let &Instruction { + a, d, f: enabled, .. + } = instruction; + + debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); + + if enabled.is_one() { + record.rd_ptr = a.as_canonical_u32(); + + tracing_write( + memory, + RV32_REGISTER_AS, + a.as_canonical_u32(), + data, + &mut record.writes_aux.prev_timestamp, + &mut record.writes_aux.prev_data, + ); + } else { + record.rd_ptr = u32::MAX; + memory.increment_timestamp(); } } +} +impl AdapterTraceFiller for Rv32JalrAdapterStep { + #[inline(always)] + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, mut adapter_row: &mut [F]) { + let record: &Rv32JalrAdapterRecord = unsafe { get_record_from_slice(&mut adapter_row, ()) }; + let adapter_row: &mut Rv32JalrAdapterCols = adapter_row.borrow_mut(); + + // We must assign in reverse + adapter_row.needs_write = F::from_bool(record.rd_ptr != u32::MAX); + + if record.rd_ptr != u32::MAX { + adapter_row + .rd_aux_cols + .set_prev_data(record.writes_aux.prev_data.map(F::from_canonical_u8)); + mem_helper.fill( + record.writes_aux.prev_timestamp, + record.from_timestamp + 1, + adapter_row.rd_aux_cols.as_mut(), + ); + adapter_row.rd_ptr = F::from_canonical_u32(record.rd_ptr); + } else { + adapter_row.rd_ptr = F::ZERO; + } - fn air(&self) -> &Self::Air { - &self.air + mem_helper.fill( + record.reads_aux.prev_timestamp, + record.from_timestamp, + adapter_row.rs1_aux_cols.as_mut(), + ); + adapter_row.rs1_ptr = F::from_canonical_u32(record.rs1_ptr); + adapter_row.from_state.timestamp = F::from_canonical_u32(record.from_timestamp); + adapter_row.from_state.pc = F::from_canonical_u32(record.from_pc); } } diff --git a/extensions/rv32im/circuit/src/adapters/loadstore.rs b/extensions/rv32im/circuit/src/adapters/loadstore.rs index b92680a0c7..23b1ba6307 100644 --- a/extensions/rv32im/circuit/src/adapters/loadstore.rs +++ b/extensions/rv32im/circuit/src/adapters/loadstore.rs @@ -1,27 +1,29 @@ use std::{ - array, borrow::{Borrow, BorrowMut}, marker::PhantomData, }; use openvm_circuit::{ arch::{ - AdapterAirContext, AdapterRuntimeContext, ExecutionBridge, ExecutionBus, ExecutionState, - Result, VmAdapterAir, VmAdapterChip, VmAdapterInterface, + get_record_from_slice, AdapterAirContext, AdapterTraceFiller, AdapterTraceStep, + ExecutionBridge, ExecutionState, VmAdapterAir, VmAdapterInterface, }, system::{ memory::{ offline_checker::{ - MemoryBaseAuxCols, MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols, + MemoryBaseAuxCols, MemoryBridge, MemoryReadAuxCols, MemoryReadAuxRecord, + MemoryWriteAuxCols, }, - MemoryAddress, MemoryController, OfflineMemory, RecordId, + online::TracingMemory, + MemoryAddress, MemoryAuxColsFactory, }, - program::ProgramBus, + native_adapter::util::{memory_read_native, timed_write_native}, }, }; use openvm_circuit_primitives::{ utils::{not, select}, var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus}, + AlignedBytesBorrow, }; use openvm_circuit_primitives_derive::AlignedBorrow; use openvm_instructions::{ @@ -36,10 +38,9 @@ use openvm_stark_backend::{ p3_air::{AirBuilder, BaseAir}, p3_field::{Field, FieldAlgebra, PrimeField32}, }; -use serde::{Deserialize, Serialize}; -use super::{compose, RV32_REGISTER_NUM_LIMBS}; -use crate::adapters::RV32_CELL_BITS; +use super::RV32_REGISTER_NUM_LIMBS; +use crate::adapters::{memory_read, timed_write, tracing_read, RV32_CELL_BITS}; /// LoadStore Adapter handles all memory and register operations, so it must be aware /// of the instruction type, specifically whether it is a load or store @@ -64,22 +65,6 @@ pub struct LoadStoreInstruction { pub store_shift_amount: T, } -/// The LoadStoreAdapter separates Runtime and Air AdapterInterfaces. -/// This is necessary because `prev_data` should be owned by the core chip and sent to the adapter, -/// and it must have an AB::Var type in AIR as to satisfy the memory_bridge interface. -/// This is achieved by having different types for reads and writes in Air AdapterInterface. -/// This method ensures that there are no modifications to the global interfaces. -/// -/// Here 2 reads represent read_data and prev_data, -/// The second element of the tuple in Reads is the shift amount needed to be passed to the core -/// chip Getting the intermediate pointer is completely internal to the adapter and shouldn't be a -/// part of the AdapterInterface -pub struct Rv32LoadStoreAdapterRuntimeInterface(PhantomData); -impl VmAdapterInterface for Rv32LoadStoreAdapterRuntimeInterface { - type Reads = ([[T; RV32_REGISTER_NUM_LIMBS]; 2], T); - type Writes = [[T; RV32_REGISTER_NUM_LIMBS]; 1]; - type ProcessedInstruction = (); -} pub struct Rv32LoadStoreAdapterAirInterface(PhantomData); /// Using AB::Var for prev_data and AB::Expr for read_data @@ -92,65 +77,6 @@ impl VmAdapterInterface for Rv32LoadStoreAdapt type ProcessedInstruction = LoadStoreInstruction; } -/// This chip reads rs1 and gets a intermediate memory pointer address with rs1 + imm. -/// In case of Loads, reads from the shifted intermediate pointer and writes to rd. -/// In case of Stores, reads from rs2 and writes to the shifted intermediate pointer. -pub struct Rv32LoadStoreAdapterChip { - pub air: Rv32LoadStoreAdapterAir, - pub range_checker_chip: SharedVariableRangeCheckerChip, - _marker: PhantomData, -} - -impl Rv32LoadStoreAdapterChip { - pub fn new( - execution_bus: ExecutionBus, - program_bus: ProgramBus, - memory_bridge: MemoryBridge, - pointer_max_bits: usize, - range_checker_chip: SharedVariableRangeCheckerChip, - ) -> Self { - assert!(range_checker_chip.range_max_bits() >= 15); - Self { - air: Rv32LoadStoreAdapterAir { - execution_bridge: ExecutionBridge::new(execution_bus, program_bus), - memory_bridge, - range_bus: range_checker_chip.bus(), - pointer_max_bits, - }, - range_checker_chip, - _marker: PhantomData, - } - } -} - -#[repr(C)] -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(bound = "F: Field")] -pub struct Rv32LoadStoreReadRecord { - pub rs1_record: RecordId, - /// This will be a read from a register in case of Stores and a read from RISC-V memory in case - /// of Loads. - pub read: RecordId, - pub rs1_ptr: F, - pub imm: F, - pub imm_sign: F, - pub mem_as: F, - pub mem_ptr_limbs: [u32; 2], - pub shift_amount: u32, -} - -#[repr(C)] -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(bound = "F: Field")] -pub struct Rv32LoadStoreWriteRecord { - /// This will be a write to a register in case of Load and a write to RISC-V memory in case of - /// Stores. For better struct packing, `RecordId(usize::MAX)` is used to indicate that - /// there is no write. - pub write_id: RecordId, - pub from_state: ExecutionState, - pub rd_rs2_ptr: F, -} - #[repr(C)] #[derive(Debug, Clone, AlignedBorrow)] pub struct Rv32LoadStoreAdapterCols { @@ -366,22 +292,64 @@ impl VmAdapterAir for Rv32LoadStoreAdapterAir { } } -impl VmAdapterChip for Rv32LoadStoreAdapterChip { - type ReadRecord = Rv32LoadStoreReadRecord; - type WriteRecord = Rv32LoadStoreWriteRecord; - type Air = Rv32LoadStoreAdapterAir; - type Interface = Rv32LoadStoreAdapterRuntimeInterface; +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug)] +pub struct Rv32LoadStoreAdapterRecord { + pub from_pc: u32, + pub from_timestamp: u32, + + pub rs1_ptr: u32, + pub rs1_val: u32, + pub rs1_aux_record: MemoryReadAuxRecord, + + pub rd_rs2_ptr: u32, + pub read_data_aux: MemoryReadAuxRecord, + pub imm: u16, + pub imm_sign: bool, + + pub mem_as: u8, + + pub write_prev_timestamp: u32, +} + +/// This chip reads rs1 and gets a intermediate memory pointer address with rs1 + imm. +/// In case of Loads, reads from the shifted intermediate pointer and writes to rd. +/// In case of Stores, reads from rs2 and writes to the shifted intermediate pointer. +#[derive(derive_new::new)] +pub struct Rv32LoadStoreAdapterStep { + pointer_max_bits: usize, + pub range_checker_chip: SharedVariableRangeCheckerChip, +} + +impl AdapterTraceStep for Rv32LoadStoreAdapterStep +where + F: PrimeField32, +{ + const WIDTH: usize = size_of::>(); + type ReadData = ( + ( + [u32; RV32_REGISTER_NUM_LIMBS], + [u8; RV32_REGISTER_NUM_LIMBS], + ), + u8, + ); + type WriteData = [u32; RV32_REGISTER_NUM_LIMBS]; + type RecordMut<'a> = &'a mut Rv32LoadStoreAdapterRecord; + + #[inline(always)] + fn start(pc: u32, memory: &TracingMemory, record: &mut Self::RecordMut<'_>) { + record.from_pc = pc; + record.from_timestamp = memory.timestamp; + } - #[allow(clippy::type_complexity)] - fn preprocess( - &mut self, - memory: &mut MemoryController, + #[inline(always)] + fn read( + &self, + memory: &mut TracingMemory, instruction: &Instruction, - ) -> Result<( - >::Reads, - Self::ReadRecord, - )> { - let Instruction { + record: &mut Self::RecordMut<'_>, + ) -> Self::ReadData { + let &Instruction { opcode, a, b, @@ -390,154 +358,188 @@ impl VmAdapterChip for Rv32LoadStoreAdapterChip { e, g, .. - } = *instruction; + } = instruction; + debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); debug_assert!(e.as_canonical_u32() != RV32_IMM_AS); + debug_assert!(e.as_canonical_u32() != RV32_REGISTER_AS); let local_opcode = Rv32LoadStoreOpcode::from_usize( opcode.local_opcode_idx(Rv32LoadStoreOpcode::CLASS_OFFSET), ); - let rs1_record = memory.read::(d, b); - let rs1_val = compose(rs1_record.1); - let imm = c.as_canonical_u32(); - let imm_sign = g.as_canonical_u32(); - let imm_extended = imm + imm_sign * 0xffff0000; + record.rs1_ptr = b.as_canonical_u32(); + record.rs1_val = u32::from_le_bytes(tracing_read( + memory, + RV32_REGISTER_AS, + record.rs1_ptr, + &mut record.rs1_aux_record.prev_timestamp, + )); + + record.mem_as = e.as_canonical_u32() as u8; + record.imm = c.as_canonical_u32() as u16; + record.imm_sign = g.is_one(); + let imm_extended = record.imm as u32 + record.imm_sign as u32 * 0xffff0000; + + let ptr_val = record.rs1_val.wrapping_add(imm_extended); + let shift_amount = ptr_val & 3; + let ptr_val = ptr_val - shift_amount; - let ptr_val = rs1_val.wrapping_add(imm_extended); - let shift_amount = ptr_val % 4; assert!( - ptr_val < (1 << self.air.pointer_max_bits), - "ptr_val: {ptr_val} = rs1_val: {rs1_val} + imm_extended: {imm_extended} >= 2 ** {}", - self.air.pointer_max_bits + ptr_val < (1 << self.pointer_max_bits), + "ptr_val: {ptr_val} = rs1_val: {} + imm_extended: {imm_extended} >= 2 ** {}", + record.rs1_val, + self.pointer_max_bits ); - let mem_ptr_limbs = array::from_fn(|i| ((ptr_val >> (i * (RV32_CELL_BITS * 2))) & 0xffff)); - - let ptr_val = ptr_val - shift_amount; - let read_record = match local_opcode { - LOADW | LOADB | LOADH | LOADBU | LOADHU => { - memory.read::(e, F::from_canonical_u32(ptr_val)) - } - STOREW | STOREH | STOREB => memory.read::(d, a), + let read_data = match local_opcode { + LOADW | LOADB | LOADH | LOADBU | LOADHU => tracing_read( + memory, + e.as_canonical_u32(), + ptr_val, + &mut record.read_data_aux.prev_timestamp, + ), + STOREW | STOREH | STOREB => tracing_read( + memory, + RV32_REGISTER_AS, + a.as_canonical_u32(), + &mut record.read_data_aux.prev_timestamp, + ), }; // We need to keep values of some cells to keep them unchanged when writing to those cells let prev_data = match local_opcode { - STOREW | STOREH | STOREB => array::from_fn(|i| { - memory.unsafe_read_cell(e, F::from_canonical_usize(ptr_val as usize + i)) - }), + STOREW | STOREH | STOREB => { + if e.as_canonical_u32() == 4 { + memory_read_native(memory.data(), ptr_val).map(|x: F| x.as_canonical_u32()) + } else { + memory_read(memory.data(), e.as_canonical_u32(), ptr_val).map(u32::from) + } + } LOADW | LOADB | LOADH | LOADBU | LOADHU => { - array::from_fn(|i| memory.unsafe_read_cell(d, a + F::from_canonical_usize(i))) + memory_read(memory.data(), d.as_canonical_u32(), a.as_canonical_u32()) + .map(u32::from) } }; - Ok(( - ( - [prev_data, read_record.1], - F::from_canonical_u32(shift_amount), - ), - Self::ReadRecord { - rs1_record: rs1_record.0, - rs1_ptr: b, - read: read_record.0, - imm: c, - imm_sign: g, - shift_amount, - mem_ptr_limbs, - mem_as: e, - }, - )) + ((prev_data, read_data), shift_amount as u8) } - fn postprocess( - &mut self, - memory: &mut MemoryController, + #[inline(always)] + fn write( + &self, + memory: &mut TracingMemory, instruction: &Instruction, - from_state: ExecutionState, - output: AdapterRuntimeContext, - read_record: &Self::ReadRecord, - ) -> Result<(ExecutionState, Self::WriteRecord)> { - let Instruction { + data: Self::WriteData, + record: &mut Self::RecordMut<'_>, + ) { + let &Instruction { opcode, a, d, e, f: enabled, .. - } = *instruction; + } = instruction; + + debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); + debug_assert!(e.as_canonical_u32() != RV32_IMM_AS); + debug_assert!(e.as_canonical_u32() != RV32_REGISTER_AS); let local_opcode = Rv32LoadStoreOpcode::from_usize( opcode.local_opcode_idx(Rv32LoadStoreOpcode::CLASS_OFFSET), ); - let write_id = if enabled != F::ZERO { - let (record_id, _) = match local_opcode { + if enabled != F::ZERO { + record.rd_rs2_ptr = a.as_canonical_u32(); + + record.write_prev_timestamp = match local_opcode { STOREW | STOREH | STOREB => { - let ptr = read_record.mem_ptr_limbs[0] - + read_record.mem_ptr_limbs[1] * (1 << (RV32_CELL_BITS * 2)); - memory.write(e, F::from_canonical_u32(ptr & 0xfffffffc), output.writes[0]) + let imm_extended = record.imm as u32 + record.imm_sign as u32 * 0xffff0000; + let ptr = record.rs1_val.wrapping_add(imm_extended) & !3; + + if record.mem_as == 4 { + timed_write_native(memory, ptr, data.map(F::from_canonical_u32)).0 + } else { + timed_write(memory, record.mem_as as u32, ptr, data.map(|x| x as u8)).0 + } + } + LOADW | LOADB | LOADH | LOADBU | LOADHU => { + timed_write( + memory, + RV32_REGISTER_AS, + record.rd_rs2_ptr, + data.map(|x| x as u8), + ) + .0 } - LOADW | LOADB | LOADH | LOADBU | LOADHU => memory.write(d, a, output.writes[0]), }; - record_id } else { + record.rd_rs2_ptr = u32::MAX; memory.increment_timestamp(); - // RecordId will never get to usize::MAX, so it can be used as a flag for no write - RecordId(usize::MAX) }; - - Ok(( - ExecutionState { - pc: output.to_pc.unwrap_or(from_state.pc + DEFAULT_PC_STEP), - timestamp: memory.timestamp(), - }, - Self::WriteRecord { - from_state, - write_id, - rd_rs2_ptr: a, - }, - )) } +} - fn generate_trace_row( - &self, - row_slice: &mut [F], - read_record: Self::ReadRecord, - write_record: Self::WriteRecord, - memory: &OfflineMemory, - ) { - self.range_checker_chip.add_count( - (read_record.mem_ptr_limbs[0] - read_record.shift_amount) / 4, - RV32_CELL_BITS * 2 - 2, +impl AdapterTraceFiller for Rv32LoadStoreAdapterStep { + #[inline(always)] + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, mut adapter_row: &mut [F]) { + debug_assert!(self.range_checker_chip.range_max_bits() >= 15); + + let record: &Rv32LoadStoreAdapterRecord = + unsafe { get_record_from_slice(&mut adapter_row, ()) }; + let adapter_row: &mut Rv32LoadStoreAdapterCols = adapter_row.borrow_mut(); + + let needs_write = record.rd_rs2_ptr != u32::MAX; + // Writing in reverse order + adapter_row.needs_write = F::from_bool(needs_write); + + if needs_write { + mem_helper.fill( + record.write_prev_timestamp, + record.from_timestamp + 2, + &mut adapter_row.write_base_aux, + ); + } else { + mem_helper.fill_zero(&mut adapter_row.write_base_aux); + } + + adapter_row.mem_as = F::from_canonical_u8(record.mem_as); + let ptr = record + .rs1_val + .wrapping_add(record.imm as u32 + record.imm_sign as u32 * 0xffff0000); + + let ptr_limbs = [ptr & 0xffff, ptr >> 16]; + self.range_checker_chip + .add_count(ptr_limbs[0] >> 2, RV32_CELL_BITS * 2 - 2); + self.range_checker_chip + .add_count(ptr_limbs[1], self.pointer_max_bits - 16); + adapter_row.mem_ptr_limbs = ptr_limbs.map(F::from_canonical_u32); + + adapter_row.imm_sign = F::from_bool(record.imm_sign); + adapter_row.imm = F::from_canonical_u16(record.imm); + + mem_helper.fill( + record.read_data_aux.prev_timestamp, + record.from_timestamp + 1, + adapter_row.read_data_aux.as_mut(), ); - self.range_checker_chip.add_count( - read_record.mem_ptr_limbs[1], - self.air.pointer_max_bits - RV32_CELL_BITS * 2, + adapter_row.rd_rs2_ptr = if record.rd_rs2_ptr != u32::MAX { + F::from_canonical_u32(record.rd_rs2_ptr) + } else { + F::ZERO + }; + + mem_helper.fill( + record.rs1_aux_record.prev_timestamp, + record.from_timestamp, + adapter_row.rs1_aux_cols.as_mut(), ); - let aux_cols_factory = memory.aux_cols_factory(); - let adapter_cols: &mut Rv32LoadStoreAdapterCols<_> = row_slice.borrow_mut(); - adapter_cols.from_state = write_record.from_state.map(F::from_canonical_u32); - let rs1 = memory.record_by_id(read_record.rs1_record); - adapter_cols.rs1_data.copy_from_slice(rs1.data_slice()); - aux_cols_factory.generate_read_aux(rs1, &mut adapter_cols.rs1_aux_cols); - adapter_cols.rs1_ptr = read_record.rs1_ptr; - adapter_cols.rd_rs2_ptr = write_record.rd_rs2_ptr; - let read = memory.record_by_id(read_record.read); - aux_cols_factory.generate_read_aux(read, &mut adapter_cols.read_data_aux); - adapter_cols.imm = read_record.imm; - adapter_cols.imm_sign = read_record.imm_sign; - adapter_cols.mem_ptr_limbs = read_record.mem_ptr_limbs.map(F::from_canonical_u32); - adapter_cols.mem_as = read_record.mem_as; - if write_record.write_id.0 != usize::MAX { - let write = memory.record_by_id(write_record.write_id); - aux_cols_factory.generate_base_aux(write, &mut adapter_cols.write_base_aux); - adapter_cols.needs_write = F::ONE; - } - } + adapter_row.rs1_data = record.rs1_val.to_le_bytes().map(F::from_canonical_u8); + adapter_row.rs1_ptr = F::from_canonical_u32(record.rs1_ptr); - fn air(&self) -> &Self::Air { - &self.air + adapter_row.from_state.timestamp = F::from_canonical_u32(record.from_timestamp); + adapter_row.from_state.pc = F::from_canonical_u32(record.from_pc); } } diff --git a/extensions/rv32im/circuit/src/adapters/mod.rs b/extensions/rv32im/circuit/src/adapters/mod.rs index ab15671b74..ffc49bac06 100644 --- a/extensions/rv32im/circuit/src/adapters/mod.rs +++ b/extensions/rv32im/circuit/src/adapters/mod.rs @@ -1,6 +1,13 @@ use std::ops::Mul; -use openvm_circuit::system::memory::{MemoryController, RecordId}; +use openvm_circuit::{ + arch::{execution_mode::E1ExecutionCtx, VmStateMut}, + system::memory::{ + merkle::public_values::PUBLIC_VALUES_AS, + online::{GuestMemory, TracingMemory}, + }, +}; +use openvm_instructions::riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS}; use openvm_stark_backend::p3_field::{FieldAlgebra, PrimeField32}; mod alu; @@ -46,25 +53,185 @@ pub fn decompose(value: u32) -> [F; RV32_REGISTER_NUM_LIMBS] { }) } -/// Read register value as [RV32_REGISTER_NUM_LIMBS] limbs from memory. -/// Returns the read record and the register value as u32. -/// Does not make any range check calls. -pub fn read_rv32_register( - memory: &mut MemoryController, - address_space: F, - pointer: F, -) -> (RecordId, u32) { - debug_assert_eq!(address_space, F::ONE); - let record = memory.read::(address_space, pointer); - let val = compose(record.1); - (record.0, val) +#[inline(always)] +pub fn imm_to_bytes(imm: u32) -> [u8; RV32_REGISTER_NUM_LIMBS] { + debug_assert_eq!(imm >> 24, 0); + let mut imm_le = imm.to_le_bytes(); + imm_le[3] = imm_le[2]; + imm_le } -/// Peeks at the value of a register without updating the memory state or incrementing the -/// timestamp. -pub fn unsafe_read_rv32_register(memory: &MemoryController, pointer: F) -> u32 { - let data = memory.unsafe_read::(F::ONE, pointer); - compose(data) +#[inline(always)] +pub fn memory_read(memory: &GuestMemory, address_space: u32, ptr: u32) -> [u8; N] { + debug_assert!( + address_space == RV32_REGISTER_AS + || address_space == RV32_MEMORY_AS + || address_space == PUBLIC_VALUES_AS, + ); + + // SAFETY: + // - address space `RV32_REGISTER_AS` and `RV32_MEMORY_AS` will always have cell type `u8` and + // minimum alignment of `RV32_REGISTER_NUM_LIMBS` + unsafe { memory.read::(address_space, ptr) } +} + +#[inline(always)] +pub fn memory_write( + memory: &mut GuestMemory, + address_space: u32, + ptr: u32, + data: [u8; N], +) { + debug_assert!( + address_space == RV32_REGISTER_AS + || address_space == RV32_MEMORY_AS + || address_space == PUBLIC_VALUES_AS + ); + + // SAFETY: + // - address space `RV32_REGISTER_AS` and `RV32_MEMORY_AS` will always have cell type `u8` and + // minimum alignment of `RV32_REGISTER_NUM_LIMBS` + unsafe { memory.write::(address_space, ptr, data) } +} + +/// Atomic read operation which increments the timestamp by 1. +/// Returns `(t_prev, [ptr:4]_{address_space})` where `t_prev` is the timestamp of the last memory +/// access. +#[inline(always)] +pub fn timed_read( + memory: &mut TracingMemory, + address_space: u32, + ptr: u32, +) -> (u32, [u8; N]) { + debug_assert!( + address_space == RV32_REGISTER_AS + || address_space == RV32_MEMORY_AS + || address_space == PUBLIC_VALUES_AS + ); + + // SAFETY: + // - address space `RV32_REGISTER_AS` and `RV32_MEMORY_AS` will always have cell type `u8` and + // minimum alignment of `RV32_REGISTER_NUM_LIMBS` + unsafe { memory.read::(address_space, ptr) } +} + +#[inline(always)] +pub fn timed_write( + memory: &mut TracingMemory, + address_space: u32, + ptr: u32, + data: [u8; N], +) -> (u32, [u8; N]) { + debug_assert!( + address_space == RV32_REGISTER_AS + || address_space == RV32_MEMORY_AS + || address_space == PUBLIC_VALUES_AS + ); + + // SAFETY: + // - address space `RV32_REGISTER_AS` and `RV32_MEMORY_AS` will always have cell type `u8` and + // minimum alignment of `RV32_REGISTER_NUM_LIMBS` + unsafe { memory.write::(address_space, ptr, data) } +} + +/// Reads register value at `reg_ptr` from memory and records the memory access in mutable buffer. +/// Trace generation relevant to this memory access can be done fully from the recorded buffer. +#[inline(always)] +pub fn tracing_read( + memory: &mut TracingMemory, + address_space: u32, + ptr: u32, + prev_timestamp: &mut u32, +) -> [u8; N] +where + F: PrimeField32, +{ + let (t_prev, data) = timed_read(memory, address_space, ptr); + *prev_timestamp = t_prev; + data +} + +#[inline(always)] +pub fn tracing_read_imm( + memory: &mut TracingMemory, + imm: u32, + imm_mut: &mut u32, +) -> [u8; RV32_REGISTER_NUM_LIMBS] +where + F: PrimeField32, +{ + *imm_mut = imm; + debug_assert_eq!(imm >> 24, 0); // highest byte should be zero to prevent overflow + + memory.increment_timestamp(); + + let mut imm_le = imm.to_le_bytes(); + // Important: we set the highest byte equal to the second highest byte, using the assumption + // that imm is at most 24 bits + imm_le[3] = imm_le[2]; + imm_le +} + +/// Writes `reg_ptr, reg_val` into memory and records the memory access in mutable buffer. +/// Trace generation relevant to this memory access can be done fully from the recorded buffer. +#[inline(always)] +pub fn tracing_write( + memory: &mut TracingMemory, + address_space: u32, + ptr: u32, + data: [u8; N], + prev_timestamp: &mut u32, + prev_data: &mut [u8; N], +) where + F: PrimeField32, +{ + let (t_prev, data_prev) = timed_write(memory, address_space, ptr, data); + *prev_timestamp = t_prev; + *prev_data = data_prev; +} + +#[inline(always)] +pub fn memory_read_from_state( + state: &mut VmStateMut, + address_space: u32, + ptr: u32, +) -> [u8; N] +where + Ctx: E1ExecutionCtx, +{ + state.ctx.on_memory_operation(address_space, ptr, N as u32); + + memory_read(state.memory, address_space, ptr) +} + +#[inline(always)] +pub fn memory_write_from_state( + state: &mut VmStateMut, + address_space: u32, + ptr: u32, + data: [u8; N], +) where + Ctx: E1ExecutionCtx, +{ + state.ctx.on_memory_operation(address_space, ptr, N as u32); + + memory_write(state.memory, address_space, ptr, data) +} + +#[inline(always)] +pub fn read_rv32_register_from_state( + state: &mut VmStateMut, + ptr: u32, +) -> u32 +where + Ctx: E1ExecutionCtx, +{ + u32::from_le_bytes(memory_read_from_state(state, RV32_REGISTER_AS, ptr)) +} + +#[inline(always)] +pub fn read_rv32_register(memory: &GuestMemory, ptr: u32) -> u32 { + u32::from_le_bytes(memory_read(memory, RV32_REGISTER_AS, ptr)) } pub fn abstract_compose>( @@ -76,3 +243,8 @@ pub fn abstract_compose>( acc + limb * T::from_canonical_u32(1 << (i * RV32_CELL_BITS)) }) } + +// TEMP[jpw] +pub fn tmp_convert_to_u8s(data: [F; N]) -> [u8; N] { + data.map(|x| x.as_canonical_u32() as u8) +} diff --git a/extensions/rv32im/circuit/src/adapters/mul.rs b/extensions/rv32im/circuit/src/adapters/mul.rs index a82e83acaa..4914e019aa 100644 --- a/extensions/rv32im/circuit/src/adapters/mul.rs +++ b/extensions/rv32im/circuit/src/adapters/mul.rs @@ -1,22 +1,20 @@ -use std::{ - borrow::{Borrow, BorrowMut}, - marker::PhantomData, -}; +use std::borrow::{Borrow, BorrowMut}; use openvm_circuit::{ arch::{ - AdapterAirContext, AdapterRuntimeContext, BasicAdapterInterface, ExecutionBridge, - ExecutionBus, ExecutionState, MinimalInstruction, Result, VmAdapterAir, VmAdapterChip, - VmAdapterInterface, + get_record_from_slice, AdapterAirContext, AdapterTraceFiller, AdapterTraceStep, + BasicAdapterInterface, ExecutionBridge, ExecutionState, MinimalInstruction, VmAdapterAir, }, - system::{ - memory::{ - offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols}, - MemoryAddress, MemoryController, OfflineMemory, RecordId, + system::memory::{ + offline_checker::{ + MemoryBridge, MemoryReadAuxCols, MemoryReadAuxRecord, MemoryWriteAuxCols, + MemoryWriteBytesAuxRecord, }, - program::ProgramBus, + online::TracingMemory, + MemoryAddress, MemoryAuxColsFactory, }, }; +use openvm_circuit_primitives::AlignedBytesBorrow; use openvm_circuit_primitives_derive::AlignedBorrow; use openvm_instructions::{ instruction::Instruction, program::DEFAULT_PC_STEP, riscv::RV32_REGISTER_AS, @@ -26,49 +24,9 @@ use openvm_stark_backend::{ p3_air::BaseAir, p3_field::{Field, FieldAlgebra, PrimeField32}, }; -use serde::{Deserialize, Serialize}; - -use super::RV32_REGISTER_NUM_LIMBS; - -/// Reads instructions of the form OP a, b, c, d where \[a:4\]_d = \[b:4\]_d op \[c:4\]_d. -/// Operand d can only be 1, and there is no immediate support. -#[derive(Debug)] -pub struct Rv32MultAdapterChip { - pub air: Rv32MultAdapterAir, - _marker: PhantomData, -} - -impl Rv32MultAdapterChip { - pub fn new( - execution_bus: ExecutionBus, - program_bus: ProgramBus, - memory_bridge: MemoryBridge, - ) -> Self { - Self { - air: Rv32MultAdapterAir { - execution_bridge: ExecutionBridge::new(execution_bus, program_bus), - memory_bridge, - }, - _marker: PhantomData, - } - } -} - -#[repr(C)] -#[derive(Debug, Serialize, Deserialize)] -pub struct Rv32MultReadRecord { - /// Reads from operand registers - pub rs1: RecordId, - pub rs2: RecordId, -} -#[repr(C)] -#[derive(Debug, Serialize, Deserialize)] -pub struct Rv32MultWriteRecord { - pub from_state: ExecutionState, - /// Write to destination register - pub rd_id: RecordId, -} +use super::{tracing_write, RV32_REGISTER_NUM_LIMBS}; +use crate::adapters::tracing_read; #[repr(C)] #[derive(AlignedBorrow)] @@ -81,6 +39,8 @@ pub struct Rv32MultAdapterCols { pub writes_aux: MemoryWriteAuxCols, } +/// Reads instructions of the form OP a, b, c, d where \[a:4\]_d = \[b:4\]_d op \[c:4\]_d. +/// Operand d can only be 1, and there is no immediate support. #[derive(Clone, Copy, Debug, derive_new::new)] pub struct Rv32MultAdapterAir { pub(super) execution_bridge: ExecutionBridge, @@ -167,92 +127,125 @@ impl VmAdapterAir for Rv32MultAdapterAir { } } -impl VmAdapterChip for Rv32MultAdapterChip { - type ReadRecord = Rv32MultReadRecord; - type WriteRecord = Rv32MultWriteRecord; - type Air = Rv32MultAdapterAir; - type Interface = BasicAdapterInterface< - F, - MinimalInstruction, - 2, - 1, - RV32_REGISTER_NUM_LIMBS, - RV32_REGISTER_NUM_LIMBS, - >; - - fn preprocess( - &mut self, - memory: &mut MemoryController, - instruction: &Instruction, - ) -> Result<( - >::Reads, - Self::ReadRecord, - )> { - let Instruction { b, c, d, .. } = *instruction; +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug)] +pub struct Rv32MultAdapterRecord { + pub from_pc: u32, + pub from_timestamp: u32, - debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); + pub rd_ptr: u32, + pub rs1_ptr: u32, + pub rs2_ptr: u32, - let rs1 = memory.read::(d, b); - let rs2 = memory.read::(d, c); + pub reads_aux: [MemoryReadAuxRecord; 2], + pub writes_aux: MemoryWriteBytesAuxRecord, +} - Ok(( - [rs1.1, rs2.1], - Self::ReadRecord { - rs1: rs1.0, - rs2: rs2.0, - }, - )) +#[derive(derive_new::new)] +pub struct Rv32MultAdapterStep; + +impl AdapterTraceStep for Rv32MultAdapterStep +where + F: PrimeField32, +{ + const WIDTH: usize = size_of::>(); + type ReadData = [[u8; RV32_REGISTER_NUM_LIMBS]; 2]; + type WriteData = [[u8; RV32_REGISTER_NUM_LIMBS]; 1]; + type RecordMut<'a> = &'a mut Rv32MultAdapterRecord; + + #[inline(always)] + fn start(pc: u32, memory: &TracingMemory, record: &mut Self::RecordMut<'_>) { + record.from_pc = pc; + record.from_timestamp = memory.timestamp; } - fn postprocess( - &mut self, - memory: &mut MemoryController, + #[inline(always)] + fn read( + &self, + memory: &mut TracingMemory, instruction: &Instruction, - from_state: ExecutionState, - output: AdapterRuntimeContext, - _read_record: &Self::ReadRecord, - ) -> Result<(ExecutionState, Self::WriteRecord)> { - let Instruction { a, d, .. } = *instruction; - let (rd_id, _) = memory.write(d, a, output.writes[0]); - - let timestamp_delta = memory.timestamp() - from_state.timestamp; - debug_assert!( - timestamp_delta == 3, - "timestamp delta is {}, expected 3", - timestamp_delta + record: &mut Self::RecordMut<'_>, + ) -> Self::ReadData { + let &Instruction { b, c, d, .. } = instruction; + + debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); + + record.rs1_ptr = b.as_canonical_u32(); + let rs1 = tracing_read( + memory, + RV32_REGISTER_AS, + b.as_canonical_u32(), + &mut record.reads_aux[0].prev_timestamp, + ); + record.rs2_ptr = c.as_canonical_u32(); + let rs2 = tracing_read( + memory, + RV32_REGISTER_AS, + c.as_canonical_u32(), + &mut record.reads_aux[1].prev_timestamp, ); - Ok(( - ExecutionState { - pc: from_state.pc + DEFAULT_PC_STEP, - timestamp: memory.timestamp(), - }, - Self::WriteRecord { from_state, rd_id }, - )) + [rs1, rs2] } - fn generate_trace_row( + #[inline(always)] + fn write( &self, - row_slice: &mut [F], - read_record: Self::ReadRecord, - write_record: Self::WriteRecord, - memory: &OfflineMemory, + memory: &mut TracingMemory, + instruction: &Instruction, + data: Self::WriteData, + record: &mut Self::RecordMut<'_>, ) { - let aux_cols_factory = memory.aux_cols_factory(); - let row_slice: &mut Rv32MultAdapterCols<_> = row_slice.borrow_mut(); - row_slice.from_state = write_record.from_state.map(F::from_canonical_u32); - let rd = memory.record_by_id(write_record.rd_id); - row_slice.rd_ptr = rd.pointer; - let rs1 = memory.record_by_id(read_record.rs1); - let rs2 = memory.record_by_id(read_record.rs2); - row_slice.rs1_ptr = rs1.pointer; - row_slice.rs2_ptr = rs2.pointer; - aux_cols_factory.generate_read_aux(rs1, &mut row_slice.reads_aux[0]); - aux_cols_factory.generate_read_aux(rs2, &mut row_slice.reads_aux[1]); - aux_cols_factory.generate_write_aux(rd, &mut row_slice.writes_aux); + let &Instruction { a, d, .. } = instruction; + + debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); + + record.rd_ptr = a.as_canonical_u32(); + tracing_write( + memory, + RV32_REGISTER_AS, + a.as_canonical_u32(), + data[0], + &mut record.writes_aux.prev_timestamp, + &mut record.writes_aux.prev_data, + ) } +} + +impl AdapterTraceFiller for Rv32MultAdapterStep { + #[inline(always)] + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, mut adapter_row: &mut [F]) { + let record: &Rv32MultAdapterRecord = unsafe { get_record_from_slice(&mut adapter_row, ()) }; + let adapter_row: &mut Rv32MultAdapterCols = adapter_row.borrow_mut(); + + let timestamp = record.from_timestamp; + + adapter_row + .writes_aux + .set_prev_data(record.writes_aux.prev_data.map(F::from_canonical_u8)); + mem_helper.fill( + record.writes_aux.prev_timestamp, + timestamp + 2, + adapter_row.writes_aux.as_mut(), + ); + + mem_helper.fill( + record.reads_aux[1].prev_timestamp, + timestamp + 1, + adapter_row.reads_aux[1].as_mut(), + ); + + mem_helper.fill( + record.reads_aux[0].prev_timestamp, + timestamp, + adapter_row.reads_aux[0].as_mut(), + ); + + adapter_row.rs2_ptr = F::from_canonical_u32(record.rs2_ptr); + adapter_row.rs1_ptr = F::from_canonical_u32(record.rs1_ptr); + adapter_row.rd_ptr = F::from_canonical_u32(record.rd_ptr); - fn air(&self) -> &Self::Air { - &self.air + adapter_row.from_state.timestamp = F::from_canonical_u32(record.from_timestamp); + adapter_row.from_state.pc = F::from_canonical_u32(record.from_pc); } } diff --git a/extensions/rv32im/circuit/src/adapters/rdwrite.rs b/extensions/rv32im/circuit/src/adapters/rdwrite.rs index abd4d8eb17..e736e39d95 100644 --- a/extensions/rv32im/circuit/src/adapters/rdwrite.rs +++ b/extensions/rv32im/circuit/src/adapters/rdwrite.rs @@ -1,23 +1,17 @@ -use std::{ - borrow::{Borrow, BorrowMut}, - marker::PhantomData, -}; +use std::borrow::{Borrow, BorrowMut}; use openvm_circuit::{ arch::{ - AdapterAirContext, AdapterRuntimeContext, BasicAdapterInterface, ExecutionBridge, - ExecutionBus, ExecutionState, ImmInstruction, Result, VmAdapterAir, VmAdapterChip, - VmAdapterInterface, + get_record_from_slice, AdapterAirContext, AdapterTraceFiller, AdapterTraceStep, + BasicAdapterInterface, ExecutionBridge, ExecutionState, ImmInstruction, VmAdapterAir, }, - system::{ - memory::{ - offline_checker::{MemoryBridge, MemoryWriteAuxCols}, - MemoryAddress, MemoryController, OfflineMemory, RecordId, - }, - program::ProgramBus, + system::memory::{ + offline_checker::{MemoryBridge, MemoryWriteAuxCols, MemoryWriteBytesAuxRecord}, + online::TracingMemory, + MemoryAddress, MemoryAuxColsFactory, }, }; -use openvm_circuit_primitives::utils::not; +use openvm_circuit_primitives::{utils::not, AlignedBytesBorrow}; use openvm_circuit_primitives_derive::AlignedBorrow; use openvm_instructions::{ instruction::Instruction, program::DEFAULT_PC_STEP, riscv::RV32_REGISTER_AS, @@ -27,59 +21,9 @@ use openvm_stark_backend::{ p3_air::{AirBuilder, BaseAir}, p3_field::{Field, FieldAlgebra, PrimeField32}, }; -use serde::{Deserialize, Serialize}; use super::RV32_REGISTER_NUM_LIMBS; - -/// This adapter doesn't read anything, and writes to \[a:4\]_d, where d == 1 -#[derive(Debug)] -pub struct Rv32RdWriteAdapterChip { - pub air: Rv32RdWriteAdapterAir, - _marker: PhantomData, -} - -/// This adapter doesn't read anything, and **maybe** writes to \[a:4\]_d, where d == 1 -#[derive(Debug)] -pub struct Rv32CondRdWriteAdapterChip { - /// Do not use the inner air directly, use `air` instead. - inner: Rv32RdWriteAdapterChip, - pub air: Rv32CondRdWriteAdapterAir, -} - -impl Rv32RdWriteAdapterChip { - pub fn new( - execution_bus: ExecutionBus, - program_bus: ProgramBus, - memory_bridge: MemoryBridge, - ) -> Self { - Self { - air: Rv32RdWriteAdapterAir { - execution_bridge: ExecutionBridge::new(execution_bus, program_bus), - memory_bridge, - }, - _marker: PhantomData, - } - } -} - -impl Rv32CondRdWriteAdapterChip { - pub fn new( - execution_bus: ExecutionBus, - program_bus: ProgramBus, - memory_bridge: MemoryBridge, - ) -> Self { - let inner = Rv32RdWriteAdapterChip::new(execution_bus, program_bus, memory_bridge); - let air = Rv32CondRdWriteAdapterAir { inner: inner.air }; - Self { inner, air } - } -} - -#[repr(C)] -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Rv32RdWriteWriteRecord { - pub from_state: ExecutionState, - pub rd_id: Option, -} +use crate::adapters::tracing_write; #[repr(C)] #[derive(Debug, Clone, AlignedBorrow)] @@ -92,16 +36,18 @@ pub struct Rv32RdWriteAdapterCols { #[repr(C)] #[derive(Debug, Clone, AlignedBorrow)] pub struct Rv32CondRdWriteAdapterCols { - inner: Rv32RdWriteAdapterCols, + pub inner: Rv32RdWriteAdapterCols, pub needs_write: T, } +/// This adapter doesn't read anything, and writes to \[a:4\]_d, where d == 1 #[derive(Clone, Copy, Debug, derive_new::new)] pub struct Rv32RdWriteAdapterAir { pub(super) memory_bridge: MemoryBridge, pub(super) execution_bridge: ExecutionBridge, } +/// This adapter doesn't read anything, and **maybe** writes to \[a:4\]_d, where d == 1 #[derive(Clone, Copy, Debug, derive_new::new)] pub struct Rv32CondRdWriteAdapterAir { inner: Rv32RdWriteAdapterAir, @@ -241,131 +187,176 @@ impl VmAdapterAir for Rv32CondRdWriteAdapterAir { } } -impl VmAdapterChip for Rv32RdWriteAdapterChip { - type ReadRecord = (); - type WriteRecord = Rv32RdWriteWriteRecord; - type Air = Rv32RdWriteAdapterAir; - type Interface = BasicAdapterInterface, 0, 1, 0, RV32_REGISTER_NUM_LIMBS>; - - fn preprocess( - &mut self, - _memory: &mut MemoryController, - instruction: &Instruction, - ) -> Result<( - >::Reads, - Self::ReadRecord, - )> { - let d = instruction.d; - debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); +/// This adapter doesn't read anything, and writes to \[a:4\]_d, where d == 1 +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug, Clone)] +pub struct Rv32RdWriteAdapterRecord { + pub from_pc: u32, + pub from_timestamp: u32, + + // Will use u32::MAX to indicate no write + pub rd_ptr: u32, + pub rd_aux_record: MemoryWriteBytesAuxRecord, +} - Ok(([], ())) +#[derive(derive_new::new)] +pub struct Rv32RdWriteAdapterStep; + +impl AdapterTraceStep for Rv32RdWriteAdapterStep +where + F: PrimeField32, +{ + const WIDTH: usize = size_of::>(); + type ReadData = (); + type WriteData = [u8; RV32_REGISTER_NUM_LIMBS]; + type RecordMut<'a> = &'a mut Rv32RdWriteAdapterRecord; + + #[inline(always)] + fn start(pc: u32, memory: &TracingMemory, record: &mut Self::RecordMut<'_>) { + record.from_pc = pc; + record.from_timestamp = memory.timestamp; } - fn postprocess( - &mut self, - memory: &mut MemoryController, - instruction: &Instruction, - from_state: ExecutionState, - output: AdapterRuntimeContext, - _read_record: &Self::ReadRecord, - ) -> Result<(ExecutionState, Self::WriteRecord)> { - let Instruction { a, d, .. } = *instruction; - let (rd_id, _) = memory.write(d, a, output.writes[0]); - - Ok(( - ExecutionState { - pc: output.to_pc.unwrap_or(from_state.pc + DEFAULT_PC_STEP), - timestamp: memory.timestamp(), - }, - Self::WriteRecord { - from_state, - rd_id: Some(rd_id), - }, - )) + #[inline(always)] + fn read( + &self, + _memory: &mut TracingMemory, + _instruction: &Instruction, + _record: &mut Self::RecordMut<'_>, + ) -> Self::ReadData { + // Rv32RdWriteAdapter doesn't read anything } - fn generate_trace_row( + #[inline(always)] + fn write( &self, - row_slice: &mut [F], - _read_record: Self::ReadRecord, - write_record: Self::WriteRecord, - memory: &OfflineMemory, + memory: &mut TracingMemory, + instruction: &Instruction, + data: Self::WriteData, + record: &mut Self::RecordMut<'_>, ) { - let aux_cols_factory = memory.aux_cols_factory(); - let adapter_cols: &mut Rv32RdWriteAdapterCols = row_slice.borrow_mut(); - adapter_cols.from_state = write_record.from_state.map(F::from_canonical_u32); - let rd = memory.record_by_id(write_record.rd_id.unwrap()); - adapter_cols.rd_ptr = rd.pointer; - aux_cols_factory.generate_write_aux(rd, &mut adapter_cols.rd_aux_cols); + let &Instruction { a, d, .. } = instruction; + + debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); + + record.rd_ptr = a.as_canonical_u32(); + tracing_write( + memory, + RV32_REGISTER_AS, + record.rd_ptr, + data, + &mut record.rd_aux_record.prev_timestamp, + &mut record.rd_aux_record.prev_data, + ); } +} - fn air(&self) -> &Self::Air { - &self.air +impl AdapterTraceFiller for Rv32RdWriteAdapterStep { + #[inline(always)] + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, mut adapter_row: &mut [F]) { + let record: &Rv32RdWriteAdapterRecord = + unsafe { get_record_from_slice(&mut adapter_row, ()) }; + let adapter_row: &mut Rv32RdWriteAdapterCols = adapter_row.borrow_mut(); + + adapter_row + .rd_aux_cols + .set_prev_data(record.rd_aux_record.prev_data.map(F::from_canonical_u8)); + mem_helper.fill( + record.rd_aux_record.prev_timestamp, + record.from_timestamp, + adapter_row.rd_aux_cols.as_mut(), + ); + adapter_row.rd_ptr = F::from_canonical_u32(record.rd_ptr); + adapter_row.from_state.timestamp = F::from_canonical_u32(record.from_timestamp); + adapter_row.from_state.pc = F::from_canonical_u32(record.from_pc); } } -impl VmAdapterChip for Rv32CondRdWriteAdapterChip { - type ReadRecord = (); - type WriteRecord = Rv32RdWriteWriteRecord; - type Air = Rv32CondRdWriteAdapterAir; - type Interface = BasicAdapterInterface, 0, 1, 0, RV32_REGISTER_NUM_LIMBS>; +/// This adapter doesn't read anything, and **maybe** writes to \[a:4\]_d, where d == 1 +#[derive(derive_new::new)] +pub struct Rv32CondRdWriteAdapterStep { + inner: Rv32RdWriteAdapterStep, +} - fn preprocess( - &mut self, - memory: &mut MemoryController, - instruction: &Instruction, - ) -> Result<( - >::Reads, - Self::ReadRecord, - )> { - self.inner.preprocess(memory, instruction) +impl AdapterTraceStep for Rv32CondRdWriteAdapterStep +where + F: PrimeField32, +{ + const WIDTH: usize = size_of::>(); + type ReadData = (); + type WriteData = [u8; RV32_REGISTER_NUM_LIMBS]; + type RecordMut<'a> = &'a mut Rv32RdWriteAdapterRecord; + + #[inline(always)] + fn start(pc: u32, memory: &TracingMemory, record: &mut Self::RecordMut<'_>) { + record.from_pc = pc; + record.from_timestamp = memory.timestamp; } - fn postprocess( - &mut self, - memory: &mut MemoryController, + #[inline(always)] + fn read( + &self, + memory: &mut TracingMemory, instruction: &Instruction, - from_state: ExecutionState, - output: AdapterRuntimeContext, - _read_record: &Self::ReadRecord, - ) -> Result<(ExecutionState, Self::WriteRecord)> { - let Instruction { a, d, .. } = *instruction; - let rd_id = if instruction.f != F::ZERO { - let (rd_id, _) = memory.write(d, a, output.writes[0]); - Some(rd_id) - } else { - memory.increment_timestamp(); - None - }; - - Ok(( - ExecutionState { - pc: output.to_pc.unwrap_or(from_state.pc + DEFAULT_PC_STEP), - timestamp: memory.timestamp(), - }, - Self::WriteRecord { from_state, rd_id }, - )) + record: &mut Self::RecordMut<'_>, + ) -> Self::ReadData { + >::read( + &self.inner, + memory, + instruction, + record, + ) } - fn generate_trace_row( + #[inline(always)] + fn write( &self, - row_slice: &mut [F], - _read_record: Self::ReadRecord, - write_record: Self::WriteRecord, - memory: &OfflineMemory, + memory: &mut TracingMemory, + instruction: &Instruction, + data: Self::WriteData, + record: &mut Self::RecordMut<'_>, ) { - let aux_cols_factory = memory.aux_cols_factory(); - let adapter_cols: &mut Rv32CondRdWriteAdapterCols = row_slice.borrow_mut(); - adapter_cols.inner.from_state = write_record.from_state.map(F::from_canonical_u32); - if let Some(rd_id) = write_record.rd_id { - let rd = memory.record_by_id(rd_id); - adapter_cols.inner.rd_ptr = rd.pointer; - aux_cols_factory.generate_write_aux(rd, &mut adapter_cols.inner.rd_aux_cols); - adapter_cols.needs_write = F::ONE; + let Instruction { f: enabled, .. } = instruction; + + if enabled.is_one() { + >::write( + &self.inner, + memory, + instruction, + data, + record, + ); + } else { + memory.increment_timestamp(); + record.rd_ptr = u32::MAX; } } +} - fn air(&self) -> &Self::Air { - &self.air +impl AdapterTraceFiller for Rv32CondRdWriteAdapterStep { + #[inline(always)] + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, mut adapter_row: &mut [F]) { + let record: &Rv32RdWriteAdapterRecord = + unsafe { get_record_from_slice(&mut adapter_row, ()) }; + let adapter_cols: &mut Rv32CondRdWriteAdapterCols = adapter_row.borrow_mut(); + + adapter_cols.needs_write = F::from_bool(record.rd_ptr != u32::MAX); + + if record.rd_ptr != u32::MAX { + unsafe { + >::fill_trace_row( + &self.inner, + mem_helper, + adapter_row + .split_at_mut_unchecked(size_of::>()) + .0, + ) + }; + } else { + adapter_cols.inner.rd_ptr = F::ZERO; + mem_helper.fill_zero(adapter_cols.inner.rd_aux_cols.as_mut()); + adapter_cols.inner.from_state.timestamp = F::from_canonical_u32(record.from_timestamp); + adapter_cols.inner.from_state.pc = F::from_canonical_u32(record.from_pc); + } } } diff --git a/extensions/rv32im/circuit/src/auipc/core.rs b/extensions/rv32im/circuit/src/auipc/core.rs index 8ec9e274f6..ec97371374 100644 --- a/extensions/rv32im/circuit/src/auipc/core.rs +++ b/extensions/rv32im/circuit/src/auipc/core.rs @@ -1,17 +1,30 @@ use std::{ - array, + array::{self, from_fn}, borrow::{Borrow, BorrowMut}, }; -use openvm_circuit::arch::{ - AdapterAirContext, AdapterRuntimeContext, ImmInstruction, Result, VmAdapterInterface, - VmCoreAir, VmCoreChip, +use openvm_circuit::{ + arch::{ + execution_mode::{E1ExecutionCtx, E2ExecutionCtx}, + get_record_from_slice, AdapterAirContext, AdapterTraceFiller, AdapterTraceStep, + E2PreCompute, EmptyAdapterCoreLayout, ExecuteFunc, + ExecutionError::InvalidInstruction, + ImmInstruction, RecordArena, Result, StepExecutorE1, StepExecutorE2, TraceFiller, + TraceStep, VmAdapterInterface, VmCoreAir, VmSegmentState, VmStateMut, + }, + system::memory::{online::TracingMemory, MemoryAuxColsFactory}, }; -use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, +use openvm_circuit_primitives::{ + bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, + AlignedBytesBorrow, }; use openvm_circuit_primitives_derive::AlignedBorrow; -use openvm_instructions::{instruction::Instruction, program::PC_BITS, LocalOpcode}; +use openvm_instructions::{ + instruction::Instruction, + program::{DEFAULT_PC_STEP, PC_BITS}, + riscv::RV32_REGISTER_AS, + LocalOpcode, +}; use openvm_rv32im_transpiler::Rv32AuipcOpcode::{self, *}; use openvm_stark_backend::{ interaction::InteractionBuilder, @@ -19,12 +32,9 @@ use openvm_stark_backend::{ p3_field::{Field, FieldAlgebra, PrimeField32}, rap::BaseAirWithPublicValues, }; -use serde::{Deserialize, Serialize}; use crate::adapters::{RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}; -const RV32_LIMB_MAX: u32 = (1 << RV32_CELL_BITS) - 1; - #[repr(C)] #[derive(Debug, Clone, AlignedBorrow)] pub struct Rv32AuipcCoreCols { @@ -36,7 +46,7 @@ pub struct Rv32AuipcCoreCols { pub rd_data: [T; RV32_REGISTER_NUM_LIMBS], } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Copy, derive_new::new)] pub struct Rv32AuipcCoreAir { pub bus: BitwiseOperationLookupBus, } @@ -186,116 +196,202 @@ where } #[repr(C)] -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Rv32AuipcCoreRecord { - pub imm_limbs: [F; RV32_REGISTER_NUM_LIMBS - 1], - pub pc_limbs: [F; RV32_REGISTER_NUM_LIMBS - 2], - pub rd_data: [F; RV32_REGISTER_NUM_LIMBS], +#[derive(AlignedBytesBorrow, Debug, Clone)] +pub struct Rv32AuipcCoreRecord { + pub from_pc: u32, + pub imm: u32, } -pub struct Rv32AuipcCoreChip { - pub air: Rv32AuipcCoreAir, +#[derive(derive_new::new)] +pub struct Rv32AuipcStep { + adapter: A, pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, } -impl Rv32AuipcCoreChip { - pub fn new(bitwise_lookup_chip: SharedBitwiseOperationLookupChip) -> Self { - Self { - air: Rv32AuipcCoreAir { - bus: bitwise_lookup_chip.bus(), - }, - bitwise_lookup_chip, - } - } -} - -impl> VmCoreChip for Rv32AuipcCoreChip +impl TraceStep for Rv32AuipcStep where - I::Writes: From<[[F; RV32_REGISTER_NUM_LIMBS]; 1]>, + F: PrimeField32, + A: 'static + AdapterTraceStep, { - type Record = Rv32AuipcCoreRecord; - type Air = Rv32AuipcCoreAir; + type RecordLayout = EmptyAdapterCoreLayout; + type RecordMut<'a> = (A::RecordMut<'a>, &'a mut Rv32AuipcCoreRecord); - #[allow(clippy::type_complexity)] - fn execute_instruction( - &self, + fn get_opcode_name(&self, _: usize) -> String { + format!("{:?}", AUIPC) + } + + fn execute<'buf, RA>( + &mut self, + state: VmStateMut, CTX>, instruction: &Instruction, - from_pc: u32, - _reads: I::Reads, - ) -> Result<(AdapterRuntimeContext, Self::Record)> { - let local_opcode = Rv32AuipcOpcode::from_usize( - instruction - .opcode - .local_opcode_idx(Rv32AuipcOpcode::CLASS_OFFSET), - ); - let imm = instruction.c.as_canonical_u32(); - let rd_data = run_auipc(local_opcode, from_pc, imm); - let rd_data_field = rd_data.map(F::from_canonical_u32); - - let output = AdapterRuntimeContext::without_pc([rd_data_field]); - - let imm_limbs = array::from_fn(|i| (imm >> (i * RV32_CELL_BITS)) & RV32_LIMB_MAX); - let pc_limbs: [u32; RV32_REGISTER_NUM_LIMBS] = - array::from_fn(|i| (from_pc >> (i * RV32_CELL_BITS)) & RV32_LIMB_MAX); + arena: &'buf mut RA, + ) -> Result<()> + where + RA: RecordArena<'buf, Self::RecordLayout, Self::RecordMut<'buf>>, + { + let (mut adapter_record, core_record) = arena.alloc(EmptyAdapterCoreLayout::new()); - for i in 0..(RV32_REGISTER_NUM_LIMBS / 2) { + A::start(*state.pc, state.memory, &mut adapter_record); + + core_record.from_pc = *state.pc; + core_record.imm = instruction.c.as_canonical_u32(); + + let rd = run_auipc(*state.pc, core_record.imm); + + self.adapter + .write(state.memory, instruction, rd, &mut adapter_record); + + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + + Ok(()) + } +} + +impl TraceFiller for Rv32AuipcStep +where + F: PrimeField32, + A: 'static + AdapterTraceFiller, +{ + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + self.adapter.fill_trace_row(mem_helper, adapter_row); + + let record: &Rv32AuipcCoreRecord = unsafe { get_record_from_slice(&mut core_row, ()) }; + + let core_row: &mut Rv32AuipcCoreCols = core_row.borrow_mut(); + + let imm_limbs = record.imm.to_le_bytes(); + let pc_limbs = record.from_pc.to_le_bytes(); + let rd_data = run_auipc(record.from_pc, record.imm); + debug_assert_eq!(imm_limbs[3], 0); + + // range checks: + // hardcoding for performance: first 3 limbs of imm_limbs, last 3 limbs of pc_limbs where + // most significant limb of pc_limbs is shifted up + self.bitwise_lookup_chip + .request_range(imm_limbs[0] as u32, imm_limbs[1] as u32); + self.bitwise_lookup_chip + .request_range(imm_limbs[2] as u32, pc_limbs[1] as u32); + let msl_shift = RV32_REGISTER_NUM_LIMBS * RV32_CELL_BITS - PC_BITS; + self.bitwise_lookup_chip + .request_range(pc_limbs[2] as u32, (pc_limbs[3] as u32) << msl_shift); + for pair in rd_data.chunks_exact(2) { self.bitwise_lookup_chip - .request_range(rd_data[i * 2], rd_data[i * 2 + 1]); + .request_range(pair[0] as u32, pair[1] as u32); } + // Writing in reverse order + core_row.rd_data = rd_data.map(F::from_canonical_u8); + // only the middle 2 limbs: + core_row.pc_limbs = from_fn(|i| F::from_canonical_u8(pc_limbs[i + 1])); + core_row.imm_limbs = from_fn(|i| F::from_canonical_u8(imm_limbs[i])); - let mut need_range_check: Vec = Vec::new(); - for limb in imm_limbs { - need_range_check.push(limb); - } + core_row.is_valid = F::ONE; + } +} - for (i, limb) in pc_limbs.iter().enumerate().skip(1) { - if i == pc_limbs.len() - 1 { - need_range_check.push((*limb) << (pc_limbs.len() * RV32_CELL_BITS - PC_BITS)); - } else { - need_range_check.push(*limb); - } - } +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct AuiPcPreCompute { + imm: u32, + a: u8, +} - for pair in need_range_check.chunks(2) { - self.bitwise_lookup_chip.request_range(pair[0], pair[1]); - } +impl StepExecutorE1 for Rv32AuipcStep +where + F: PrimeField32, +{ + #[inline(always)] + fn pre_compute_size(&self) -> usize { + size_of::() + } - Ok(( - output, - Self::Record { - imm_limbs: imm_limbs.map(F::from_canonical_u32), - pc_limbs: array::from_fn(|i| F::from_canonical_u32(pc_limbs[i + 1])), - rd_data: rd_data.map(F::from_canonical_u32), - }, - )) + #[inline(always)] + fn pre_compute_e1( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result> { + let data: &mut AuiPcPreCompute = data.borrow_mut(); + self.pre_compute_impl(pc, inst, data)?; + Ok(|pre_compute, vm_state| { + let pre_compute: &AuiPcPreCompute = pre_compute.borrow(); + unsafe { + execute_e1_impl(pre_compute, vm_state); + } + }) } +} + +#[inline(always)] +unsafe fn execute_e1_impl( + pre_compute: &AuiPcPreCompute, + vm_state: &mut VmSegmentState, +) { + let rd = run_auipc(vm_state.pc, pre_compute.imm); + vm_state.vm_write(RV32_REGISTER_AS, pre_compute.a as u32, &rd); + + vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP); + vm_state.instret += 1; +} - fn get_opcode_name(&self, opcode: usize) -> String { - format!( - "{:?}", - Rv32AuipcOpcode::from_usize(opcode - Rv32AuipcOpcode::CLASS_OFFSET) - ) +impl StepExecutorE2 for Rv32AuipcStep +where + F: PrimeField32, +{ + fn e2_pre_compute_size(&self) -> usize { + size_of::>() } - fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) { - let core_cols: &mut Rv32AuipcCoreCols = row_slice.borrow_mut(); - core_cols.imm_limbs = record.imm_limbs; - core_cols.pc_limbs = record.pc_limbs; - core_cols.rd_data = record.rd_data; - core_cols.is_valid = F::ONE; + fn pre_compute_e2( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result> + where + Ctx: E2ExecutionCtx, + { + let data: &mut E2PreCompute = data.borrow_mut(); + data.chip_idx = chip_idx as u32; + self.pre_compute_impl(pc, inst, &mut data.data)?; + Ok(|pre_compute, vm_state| { + let pre_compute: &E2PreCompute = pre_compute.borrow(); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, 1); + unsafe { + execute_e1_impl(&pre_compute.data, vm_state); + } + }) } +} - fn air(&self) -> &Self::Air { - &self.air +impl Rv32AuipcStep { + fn pre_compute_impl( + &self, + pc: u32, + inst: &Instruction, + data: &mut AuiPcPreCompute, + ) -> Result<()> { + let Instruction { a, c: imm, d, .. } = inst; + if d.as_canonical_u32() != RV32_REGISTER_AS { + return Err(InvalidInstruction(pc)); + } + let imm = imm.as_canonical_u32(); + let data: &mut AuiPcPreCompute = data.borrow_mut(); + *data = AuiPcPreCompute { + imm, + a: a.as_canonical_u32() as u8, + }; + Ok(()) } } // returns rd_data -pub(super) fn run_auipc( - _opcode: Rv32AuipcOpcode, - pc: u32, - imm: u32, -) -> [u32; RV32_REGISTER_NUM_LIMBS] { +#[inline(always)] +pub(super) fn run_auipc(pc: u32, imm: u32) -> [u8; RV32_REGISTER_NUM_LIMBS] { let rd = pc.wrapping_add(imm << RV32_CELL_BITS); - array::from_fn(|i| (rd >> (RV32_CELL_BITS * i)) & RV32_LIMB_MAX) + rd.to_le_bytes() } diff --git a/extensions/rv32im/circuit/src/auipc/mod.rs b/extensions/rv32im/circuit/src/auipc/mod.rs index 6e2234bfbd..f18fa3ad43 100644 --- a/extensions/rv32im/circuit/src/auipc/mod.rs +++ b/extensions/rv32im/circuit/src/auipc/mod.rs @@ -1,6 +1,6 @@ -use openvm_circuit::arch::VmChipWrapper; +use openvm_circuit::arch::{MatrixRecordArena, NewVmChipWrapper, VmAirWrapper}; -use crate::adapters::Rv32RdWriteAdapterChip; +use crate::adapters::{Rv32RdWriteAdapterAir, Rv32RdWriteAdapterStep}; mod core; pub use core::*; @@ -8,4 +8,7 @@ pub use core::*; #[cfg(test)] mod tests; -pub type Rv32AuipcChip = VmChipWrapper, Rv32AuipcCoreChip>; +pub type Rv32AuipcAir = VmAirWrapper; +pub type Rv32AuipcStepWithAdapter = Rv32AuipcStep; +pub type Rv32AuipcChip = + NewVmChipWrapper>; diff --git a/extensions/rv32im/circuit/src/auipc/tests.rs b/extensions/rv32im/circuit/src/auipc/tests.rs index 2c8a399198..af1a47711c 100644 --- a/extensions/rv32im/circuit/src/auipc/tests.rs +++ b/extensions/rv32im/circuit/src/auipc/tests.rs @@ -1,34 +1,65 @@ use std::borrow::BorrowMut; -use openvm_circuit::arch::{testing::VmChipTestBuilder, VmAdapterChip}; +use openvm_circuit::arch::{ + testing::{VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}, + DenseRecordArena, EmptyAdapterCoreLayout, InstructionExecutor, NewVmChipWrapper, VmAirWrapper, +}; use openvm_circuit_primitives::bitwise_op_lookup::{ BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, }; use openvm_instructions::{instruction::Instruction, program::PC_BITS, LocalOpcode}; use openvm_rv32im_transpiler::Rv32AuipcOpcode::{self, *}; use openvm_stark_backend::{ - interaction::BusIndex, p3_air::BaseAir, p3_field::{FieldAlgebra, PrimeField32}, - p3_matrix::{dense::RowMajorMatrix, Matrix}, + p3_matrix::{ + dense::{DenseMatrix, RowMajorMatrix}, + Matrix, + }, utils::disable_debug_builder, - verifier::VerificationError, - Chip, ChipUsageGetter, }; use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; use rand::{rngs::StdRng, Rng}; -use super::{run_auipc, Rv32AuipcChip, Rv32AuipcCoreChip, Rv32AuipcCoreCols}; -use crate::adapters::{Rv32RdWriteAdapterChip, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}; +use super::{run_auipc, Rv32AuipcChip, Rv32AuipcCoreAir, Rv32AuipcCoreCols, Rv32AuipcStep}; +use crate::{ + adapters::{ + Rv32RdWriteAdapterAir, Rv32RdWriteAdapterRecord, Rv32RdWriteAdapterStep, RV32_CELL_BITS, + RV32_REGISTER_NUM_LIMBS, + }, + test_utils::get_verification_error, + Rv32AuipcAir, Rv32AuipcCoreRecord, Rv32AuipcStepWithAdapter, +}; const IMM_BITS: usize = 24; -const BITWISE_OP_LOOKUP_BUS: BusIndex = 9; - +const MAX_INS_CAPACITY: usize = 128; type F = BabyBear; -fn set_and_execute( +fn create_test_chip( + tester: &VmChipTestBuilder, +) -> ( + Rv32AuipcChip, + SharedBitwiseOperationLookupChip, +) { + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); + + let mut chip = Rv32AuipcChip::::new( + VmAirWrapper::new( + Rv32RdWriteAdapterAir::new(tester.memory_bridge(), tester.execution_bridge()), + Rv32AuipcCoreAir::new(bitwise_bus), + ), + Rv32AuipcStep::new(Rv32RdWriteAdapterStep::new(), bitwise_chip.clone()), + tester.memory_helper(), + ); + chip.set_trace_buffer_height(MAX_INS_CAPACITY); + + (chip, bitwise_chip) +} + +fn set_and_execute>( tester: &mut VmChipTestBuilder, - chip: &mut Rv32AuipcChip, + chip: &mut E, rng: &mut StdRng, opcode: Rv32AuipcOpcode, imm: Option, @@ -43,10 +74,8 @@ fn set_and_execute( initial_pc.unwrap_or(rng.gen_range(0..(1 << PC_BITS))), ); let initial_pc = tester.execution.last_from_pc().as_canonical_u32(); - - let rd_data = run_auipc(opcode, initial_pc, imm as u32); - - assert_eq!(rd_data.map(F::from_canonical_u32), tester.read::<4>(1, a)); + let rd_data = run_auipc(initial_pc, imm as u32); + assert_eq!(rd_data.map(F::from_canonical_u8), tester.read::<4>(1, a)); } /////////////////////////////////////////////////////////////////////////////////////// @@ -59,17 +88,8 @@ fn set_and_execute( #[test] fn rand_auipc_test() { let mut rng = create_seeded_rng(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let mut tester = VmChipTestBuilder::default(); - let adapter = Rv32RdWriteAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - ); - let core = Rv32AuipcCoreChip::new(bitwise_chip.clone()); - let mut chip = Rv32AuipcChip::::new(adapter, core, tester.offline_memory_mutex_arc()); + let (mut chip, bitwise_chip) = create_test_chip(&tester); let num_tests: usize = 100; for _ in 0..num_tests { @@ -84,32 +104,26 @@ fn rand_auipc_test() { // NEGATIVE TESTS // // Given a fake trace of a single operation, setup a chip and run the test. We replace -// the write part of the trace and check that the core chip throws the expected error. -// A dummy adaptor is used so memory interactions don't indirectly cause false passes. +// part of the trace and check that the chip throws the expected error. ////////////////////////////////////////////////////////////////////////////////////// +#[derive(Clone, Copy, Default, PartialEq)] +struct AuipcPrankValues { + pub rd_data: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, + pub imm_limbs: Option<[u32; RV32_REGISTER_NUM_LIMBS - 1]>, + pub pc_limbs: Option<[u32; RV32_REGISTER_NUM_LIMBS - 2]>, +} + fn run_negative_auipc_test( opcode: Rv32AuipcOpcode, initial_imm: Option, initial_pc: Option, - rd_data: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, - imm_limbs: Option<[u32; RV32_REGISTER_NUM_LIMBS - 1]>, - pc_limbs: Option<[u32; RV32_REGISTER_NUM_LIMBS - 2]>, - expected_error: VerificationError, + prank_vals: AuipcPrankValues, + interaction_error: bool, ) { let mut rng = create_seeded_rng(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let mut tester = VmChipTestBuilder::default(); - let adapter = Rv32RdWriteAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - ); - let adapter_width = BaseAir::::width(adapter.air()); - let core = Rv32AuipcCoreChip::new(bitwise_chip.clone()); - let mut chip = Rv32AuipcChip::::new(adapter, core, tester.offline_memory_mutex_arc()); + let (mut chip, bitwise_chip) = create_test_chip(&tester); set_and_execute( &mut tester, @@ -120,39 +134,32 @@ fn run_negative_auipc_test( initial_pc, ); - let tester = tester.build(); - - let auipc_trace_width = chip.trace_width(); - let air = chip.air(); - let mut chip_input = chip.generate_air_proof_input(); - let auipc_trace = chip_input.raw.common_main.as_mut().unwrap(); - { - let mut trace_row = auipc_trace.row_slice(0).to_vec(); - + let adapter_width = BaseAir::::width(&chip.air.adapter); + let modify_trace = |trace: &mut DenseMatrix| { + let mut trace_row = trace.row_slice(0).to_vec(); let (_, core_row) = trace_row.split_at_mut(adapter_width); - let core_cols: &mut Rv32AuipcCoreCols = core_row.borrow_mut(); - if let Some(data) = rd_data { + if let Some(data) = prank_vals.rd_data { core_cols.rd_data = data.map(F::from_canonical_u32); } - - if let Some(data) = imm_limbs { + if let Some(data) = prank_vals.imm_limbs { core_cols.imm_limbs = data.map(F::from_canonical_u32); } - - if let Some(data) = pc_limbs { + if let Some(data) = prank_vals.pc_limbs { core_cols.pc_limbs = data.map(F::from_canonical_u32); } - *auipc_trace = RowMajorMatrix::new(trace_row, auipc_trace_width); - } + *trace = RowMajorMatrix::new(trace_row, trace.width()); + }; + disable_debug_builder(); let tester = tester - .load_air_proof_input((air, chip_input)) + .build() + .load_and_prank_trace(chip, modify_trace) .load(bitwise_chip) .finalize(); - tester.simple_test_with_expected_error(expected_error); + tester.simple_test_with_expected_error(get_verification_error(interaction_error)); } #[test] @@ -161,47 +168,53 @@ fn invalid_limb_negative_tests() { AUIPC, Some(9722891), None, - None, - Some([107, 46, 81]), - None, - VerificationError::OodEvaluationMismatch, + AuipcPrankValues { + imm_limbs: Some([107, 46, 81]), + ..Default::default() + }, + false, ); run_negative_auipc_test( AUIPC, Some(0), Some(2110400), - Some([194, 51, 32, 240]), - None, - Some([51, 32]), - VerificationError::ChallengePhaseError, + AuipcPrankValues { + rd_data: Some([194, 51, 32, 240]), + pc_limbs: Some([51, 32]), + ..Default::default() + }, + true, ); run_negative_auipc_test( AUIPC, None, None, - None, - None, - Some([206, 166]), - VerificationError::OodEvaluationMismatch, + AuipcPrankValues { + pc_limbs: Some([206, 166]), + ..Default::default() + }, + false, ); run_negative_auipc_test( AUIPC, None, None, - Some([30, 92, 82, 132]), - None, - None, - VerificationError::OodEvaluationMismatch, + AuipcPrankValues { + rd_data: Some([30, 92, 82, 132]), + ..Default::default() + }, + false, ); - run_negative_auipc_test( AUIPC, None, Some(876487877), - Some([197, 202, 49, 70]), - Some([166, 243, 17]), - Some([36, 62]), - VerificationError::ChallengePhaseError, + AuipcPrankValues { + rd_data: Some([197, 202, 49, 70]), + imm_limbs: Some([166, 243, 17]), + pc_limbs: Some([36, 62]), + }, + true, ); } @@ -211,37 +224,42 @@ fn overflow_negative_tests() { AUIPC, Some(256264), None, - None, - Some([3592, 219, 3]), - None, - VerificationError::OodEvaluationMismatch, + AuipcPrankValues { + imm_limbs: Some([3592, 219, 3]), + ..Default::default() + }, + false, ); run_negative_auipc_test( AUIPC, None, None, - None, - None, - Some([0, 0]), - VerificationError::OodEvaluationMismatch, + AuipcPrankValues { + pc_limbs: Some([0, 0]), + ..Default::default() + }, + false, ); run_negative_auipc_test( AUIPC, Some(255), None, - None, - Some([F::NEG_ONE.as_canonical_u32(), 1, 0]), - None, - VerificationError::ChallengePhaseError, + AuipcPrankValues { + imm_limbs: Some([F::NEG_ONE.as_canonical_u32(), 1, 0]), + ..Default::default() + }, + true, ); run_negative_auipc_test( AUIPC, Some(0), Some(255), - Some([F::NEG_ONE.as_canonical_u32(), 1, 0, 0]), - Some([0, 0, 0]), - Some([1, 0]), - VerificationError::ChallengePhaseError, + AuipcPrankValues { + rd_data: Some([F::NEG_ONE.as_canonical_u32(), 1, 0, 0]), + imm_limbs: Some([0, 0, 0]), + pc_limbs: Some([1, 0]), + }, + true, ); } @@ -252,32 +270,72 @@ fn overflow_negative_tests() { /////////////////////////////////////////////////////////////////////////////////////// #[test] -fn execute_roundtrip_sanity_test() { - let mut rng = create_seeded_rng(); +fn run_auipc_sanity_test() { + let initial_pc = 234567890; + let imm = 11302451; + let rd_data = run_auipc(initial_pc, imm); + + assert_eq!(rd_data, [210, 107, 113, 186]); +} + +/////////////////////////////////////////////////////////////////////////////////////// +/// DENSE TESTS +/// +/// Ensure that the chip works as expected with dense records. +/// We first execute some instructions with a [DenseRecordArena] and transfer the records +/// to a [MatrixRecordArena]. After transferring we generate the trace and make sure that +/// all the constraints pass. +/////////////////////////////////////////////////////////////////////////////////////// +type Rv32AuipcChipDense = + NewVmChipWrapper; + +fn create_test_chip_dense(tester: &mut VmChipTestBuilder) -> Rv32AuipcChipDense { let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let mut tester = VmChipTestBuilder::default(); - let adapter = Rv32RdWriteAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), + let mut chip = Rv32AuipcChipDense::new( + Rv32AuipcAir::new( + Rv32RdWriteAdapterAir::new(tester.memory_bridge(), tester.execution_bridge()), + Rv32AuipcCoreAir::new(bitwise_bus), + ), + Rv32AuipcStep::new(Rv32RdWriteAdapterStep::new(), bitwise_chip.clone()), + tester.memory_helper(), ); - let inner = Rv32AuipcCoreChip::new(bitwise_chip); - let mut chip = Rv32AuipcChip::::new(adapter, inner, tester.offline_memory_mutex_arc()); - let num_tests: usize = 100; - for _ in 0..num_tests { - set_and_execute(&mut tester, &mut chip, &mut rng, AUIPC, None, None); - } + chip.set_trace_buffer_height(MAX_INS_CAPACITY); + chip } #[test] -fn run_auipc_sanity_test() { - let opcode = AUIPC; - let initial_pc = 234567890; - let imm = 11302451; - let rd_data = run_auipc(opcode, initial_pc, imm); +fn dense_record_arena_test() { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default(); + let (mut sparse_chip, bitwise_chip) = create_test_chip(&tester); - assert_eq!(rd_data, [210, 107, 113, 186]); + { + let mut dense_chip = create_test_chip_dense(&mut tester); + + let num_ops: usize = 100; + for _ in 0..num_ops { + set_and_execute(&mut tester, &mut dense_chip, &mut rng, AUIPC, None, None); + } + + type Record<'a> = ( + &'a mut Rv32RdWriteAdapterRecord, + &'a mut Rv32AuipcCoreRecord, + ); + + let mut record_interpreter = dense_chip.arena.get_record_seeker::(); + record_interpreter.transfer_to_matrix_arena( + &mut sparse_chip.arena, + EmptyAdapterCoreLayout::::new(), + ); + } + + let tester = tester + .build() + .load(sparse_chip) + .load(bitwise_chip) + .finalize(); + tester.simple_test().expect("Verification failed"); } diff --git a/extensions/rv32im/circuit/src/base_alu/core.rs b/extensions/rv32im/circuit/src/base_alu/core.rs index a87418cc91..9be0af9261 100644 --- a/extensions/rv32im/circuit/src/base_alu/core.rs +++ b/extensions/rv32im/circuit/src/base_alu/core.rs @@ -1,18 +1,32 @@ use std::{ array, borrow::{Borrow, BorrowMut}, + iter::zip, }; -use openvm_circuit::arch::{ - AdapterAirContext, AdapterRuntimeContext, MinimalInstruction, Result, VmAdapterInterface, - VmCoreAir, VmCoreChip, +use openvm_circuit::{ + arch::{ + execution_mode::{E1ExecutionCtx, E2ExecutionCtx}, + get_record_from_slice, AdapterAirContext, AdapterTraceFiller, AdapterTraceStep, + E2PreCompute, EmptyAdapterCoreLayout, ExecuteFunc, + ExecutionError::InvalidInstruction, + MinimalInstruction, RecordArena, Result, StepExecutorE1, StepExecutorE2, TraceFiller, + TraceStep, VmAdapterInterface, VmCoreAir, VmSegmentState, VmStateMut, + }, + system::memory::{online::TracingMemory, MemoryAuxColsFactory}, }; use openvm_circuit_primitives::{ bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, utils::not, + AlignedBytesBorrow, }; use openvm_circuit_primitives_derive::AlignedBorrow; -use openvm_instructions::{instruction::Instruction, LocalOpcode}; +use openvm_instructions::{ + instruction::Instruction, + program::DEFAULT_PC_STEP, + riscv::{RV32_IMM_AS, RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS}, + LocalOpcode, +}; use openvm_rv32im_transpiler::BaseAluOpcode; use openvm_stark_backend::{ interaction::InteractionBuilder, @@ -20,12 +34,12 @@ use openvm_stark_backend::{ p3_field::{Field, FieldAlgebra, PrimeField32}, rap::BaseAirWithPublicValues, }; -use serde::{de::DeserializeOwned, Deserialize, Serialize}; -use serde_big_array::BigArray; use strum::IntoEnumIterator; +use crate::adapters::imm_to_bytes; + #[repr(C)] -#[derive(AlignedBorrow)] +#[derive(AlignedBorrow, Debug)] pub struct BaseAluCoreCols { pub a: [T; NUM_LIMBS], pub b: [T; NUM_LIMBS], @@ -38,10 +52,10 @@ pub struct BaseAluCoreCols { pub opcode_and_flag: T, } -#[derive(Copy, Clone, Debug)] +#[derive(Copy, Clone, Debug, derive_new::new)] pub struct BaseAluCoreAir { pub bus: BitwiseOperationLookupBus, - offset: usize, + pub offset: usize, } impl BaseAir @@ -165,175 +179,395 @@ where } } -#[repr(C)] -#[derive(Clone, Debug, Serialize, Deserialize)] -#[serde(bound = "T: Serialize + DeserializeOwned")] -pub struct BaseAluCoreRecord { - pub opcode: BaseAluOpcode, - #[serde(with = "BigArray")] - pub a: [T; NUM_LIMBS], - #[serde(with = "BigArray")] - pub b: [T; NUM_LIMBS], - #[serde(with = "BigArray")] - pub c: [T; NUM_LIMBS], +#[repr(C, align(4))] +#[derive(AlignedBytesBorrow, Debug)] +pub struct BaseAluCoreRecord { + pub b: [u8; NUM_LIMBS], + pub c: [u8; NUM_LIMBS], + // Use u8 instead of usize for better packing + pub local_opcode: u8, } -pub struct BaseAluCoreChip { - pub air: BaseAluCoreAir, +#[derive(derive_new::new)] +pub struct BaseAluStep { + adapter: A, pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, + pub offset: usize, } -impl BaseAluCoreChip { - pub fn new( - bitwise_lookup_chip: SharedBitwiseOperationLookupChip, - offset: usize, - ) -> Self { - Self { - air: BaseAluCoreAir { - bus: bitwise_lookup_chip.bus(), - offset, - }, - bitwise_lookup_chip, - } - } -} - -impl VmCoreChip - for BaseAluCoreChip +impl TraceStep + for BaseAluStep where F: PrimeField32, - I: VmAdapterInterface, - I::Reads: Into<[[F; NUM_LIMBS]; 2]>, - I::Writes: From<[[F; NUM_LIMBS]; 1]>, + A: 'static + + for<'a> AdapterTraceStep< + F, + CTX, + ReadData: Into<[[u8; NUM_LIMBS]; 2]>, + WriteData: From<[[u8; NUM_LIMBS]; 1]>, + >, { - type Record = BaseAluCoreRecord; - type Air = BaseAluCoreAir; + /// Instructions that use one trace row per instruction have implicit layout + type RecordLayout = EmptyAdapterCoreLayout; + type RecordMut<'a> = (A::RecordMut<'a>, &'a mut BaseAluCoreRecord); - #[allow(clippy::type_complexity)] - fn execute_instruction( - &self, + fn get_opcode_name(&self, opcode: usize) -> String { + format!("{:?}", BaseAluOpcode::from_usize(opcode - self.offset)) + } + + fn execute<'buf, RA>( + &mut self, + state: VmStateMut, CTX>, instruction: &Instruction, - _from_pc: u32, - reads: I::Reads, - ) -> Result<(AdapterRuntimeContext, Self::Record)> { + arena: &'buf mut RA, + ) -> Result<()> + where + RA: RecordArena<'buf, Self::RecordLayout, Self::RecordMut<'buf>>, + { let Instruction { opcode, .. } = instruction; - let local_opcode = BaseAluOpcode::from_usize(opcode.local_opcode_idx(self.air.offset)); - let data: [[F; NUM_LIMBS]; 2] = reads.into(); - let b = data[0].map(|x| x.as_canonical_u32()); - let c = data[1].map(|y| y.as_canonical_u32()); - let a = run_alu::(local_opcode, &b, &c); + let local_opcode = BaseAluOpcode::from_usize(opcode.local_opcode_idx(self.offset)); + let (mut adapter_record, core_record) = arena.alloc(EmptyAdapterCoreLayout::new()); - let output = AdapterRuntimeContext { - to_pc: None, - writes: [a.map(F::from_canonical_u32)].into(), - }; + A::start(*state.pc, state.memory, &mut adapter_record); + + [core_record.b, core_record.c] = self + .adapter + .read(state.memory, instruction, &mut adapter_record) + .into(); + + let rd = run_alu::(local_opcode, &core_record.b, &core_record.c); + + core_record.local_opcode = local_opcode as u8; + + self.adapter + .write(state.memory, instruction, [rd].into(), &mut adapter_record); + + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + + Ok(()) + } +} + +impl TraceFiller + for BaseAluStep +where + F: PrimeField32, + A: 'static + AdapterTraceFiller, +{ + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + self.adapter.fill_trace_row(mem_helper, adapter_row); + + let record: &BaseAluCoreRecord = + unsafe { get_record_from_slice(&mut core_row, ()) }; + let core_row: &mut BaseAluCoreCols = core_row.borrow_mut(); + // SAFETY: the following is highly unsafe. We are going to cast `core_row` to a record + // buffer, and then do an _overlapping_ write to the `core_row` as a row of field elements. + // This requires: + // - Cols and Record structs should be repr(C) and we write in reverse order (to ensure + // non-overlapping) + // - Do not overwrite any reference in `record` before it has already been used or moved + // - alignment of `F` must be >= alignment of Record (AlignedBytesBorrow will panic + // otherwise) + + let local_opcode = BaseAluOpcode::from_usize(record.local_opcode as usize); + let a = run_alu::(local_opcode, &record.b, &record.c); + // PERF: needless conversion + core_row.opcode_and_flag = F::from_bool(local_opcode == BaseAluOpcode::AND); + core_row.opcode_or_flag = F::from_bool(local_opcode == BaseAluOpcode::OR); + core_row.opcode_xor_flag = F::from_bool(local_opcode == BaseAluOpcode::XOR); + core_row.opcode_sub_flag = F::from_bool(local_opcode == BaseAluOpcode::SUB); + core_row.opcode_add_flag = F::from_bool(local_opcode == BaseAluOpcode::ADD); if local_opcode == BaseAluOpcode::ADD || local_opcode == BaseAluOpcode::SUB { for a_val in a { - self.bitwise_lookup_chip.request_xor(a_val, a_val); + self.bitwise_lookup_chip + .request_xor(a_val as u32, a_val as u32); } } else { - for (b_val, c_val) in b.iter().zip(c.iter()) { - self.bitwise_lookup_chip.request_xor(*b_val, *c_val); + for (b_val, c_val) in zip(record.b, record.c) { + self.bitwise_lookup_chip + .request_xor(b_val as u32, c_val as u32); } } + core_row.c = record.c.map(F::from_canonical_u8); + core_row.b = record.b.map(F::from_canonical_u8); + core_row.a = a.map(F::from_canonical_u8); + } +} - let record = Self::Record { - opcode: local_opcode, - a: a.map(F::from_canonical_u32), - b: data[0], - c: data[1], +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct BaseAluPreCompute { + c: u32, + a: u8, + b: u8, +} + +impl StepExecutorE1 + for BaseAluStep +where + F: PrimeField32, +{ + #[inline(always)] + fn pre_compute_size(&self) -> usize { + size_of::() + } + + #[inline(always)] + fn pre_compute_e1( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result> + where + Ctx: E1ExecutionCtx, + { + let data: &mut BaseAluPreCompute = data.borrow_mut(); + let is_imm = self.pre_compute_impl(pc, inst, data)?; + let opcode = inst.opcode; + + let fn_ptr = match ( + is_imm, + BaseAluOpcode::from_usize(opcode.local_opcode_idx(self.offset)), + ) { + (true, BaseAluOpcode::ADD) => execute_e1_impl::<_, _, true, AddOp>, + (false, BaseAluOpcode::ADD) => execute_e1_impl::<_, _, false, AddOp>, + (true, BaseAluOpcode::SUB) => execute_e1_impl::<_, _, true, SubOp>, + (false, BaseAluOpcode::SUB) => execute_e1_impl::<_, _, false, SubOp>, + (true, BaseAluOpcode::XOR) => execute_e1_impl::<_, _, true, XorOp>, + (false, BaseAluOpcode::XOR) => execute_e1_impl::<_, _, false, XorOp>, + (true, BaseAluOpcode::OR) => execute_e1_impl::<_, _, true, OrOp>, + (false, BaseAluOpcode::OR) => execute_e1_impl::<_, _, false, OrOp>, + (true, BaseAluOpcode::AND) => execute_e1_impl::<_, _, true, AndOp>, + (false, BaseAluOpcode::AND) => execute_e1_impl::<_, _, false, AndOp>, }; + Ok(fn_ptr) + } +} + +#[inline(always)] +unsafe fn execute_e12_impl( + pre_compute: &BaseAluPreCompute, + vm_state: &mut VmSegmentState, +) { + let rs1 = vm_state.vm_read::(RV32_REGISTER_AS, pre_compute.b as u32); + let rs2 = if IS_IMM { + pre_compute.c.to_le_bytes() + } else { + vm_state.vm_read::(RV32_REGISTER_AS, pre_compute.c) + }; + let rs1 = u32::from_le_bytes(rs1); + let rs2 = u32::from_le_bytes(rs2); + let rd = ::compute(rs1, rs2); + let rd = rd.to_le_bytes(); + vm_state.vm_write::(RV32_REGISTER_AS, pre_compute.a as u32, &rd); + vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP); + vm_state.instret += 1; +} - Ok((output, record)) +#[inline(always)] +unsafe fn execute_e1_impl( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &BaseAluPreCompute = pre_compute.borrow(); + execute_e12_impl::(pre_compute, vm_state); +} + +#[inline(always)] +unsafe fn execute_e2_impl( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &E2PreCompute = pre_compute.borrow(); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, 1); + execute_e12_impl::(&pre_compute.data, vm_state); +} + +impl StepExecutorE2 + for BaseAluStep +where + F: PrimeField32, +{ + #[inline(always)] + fn e2_pre_compute_size(&self) -> usize { + size_of::>() } - fn get_opcode_name(&self, opcode: usize) -> String { - format!("{:?}", BaseAluOpcode::from_usize(opcode - self.air.offset)) + #[inline(always)] + fn pre_compute_e2( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result> + where + Ctx: E2ExecutionCtx, + { + let data: &mut E2PreCompute = data.borrow_mut(); + data.chip_idx = chip_idx as u32; + let is_imm = self.pre_compute_impl(pc, inst, &mut data.data)?; + let opcode = inst.opcode; + + let fn_ptr = match ( + is_imm, + BaseAluOpcode::from_usize(opcode.local_opcode_idx(self.offset)), + ) { + (true, BaseAluOpcode::ADD) => execute_e2_impl::<_, _, true, AddOp>, + (false, BaseAluOpcode::ADD) => execute_e2_impl::<_, _, false, AddOp>, + (true, BaseAluOpcode::SUB) => execute_e2_impl::<_, _, true, SubOp>, + (false, BaseAluOpcode::SUB) => execute_e2_impl::<_, _, false, SubOp>, + (true, BaseAluOpcode::XOR) => execute_e2_impl::<_, _, true, XorOp>, + (false, BaseAluOpcode::XOR) => execute_e2_impl::<_, _, false, XorOp>, + (true, BaseAluOpcode::OR) => execute_e2_impl::<_, _, true, OrOp>, + (false, BaseAluOpcode::OR) => execute_e2_impl::<_, _, false, OrOp>, + (true, BaseAluOpcode::AND) => execute_e2_impl::<_, _, true, AndOp>, + (false, BaseAluOpcode::AND) => execute_e2_impl::<_, _, false, AndOp>, + }; + Ok(fn_ptr) } +} - fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) { - let row_slice: &mut BaseAluCoreCols<_, NUM_LIMBS, LIMB_BITS> = row_slice.borrow_mut(); - row_slice.a = record.a; - row_slice.b = record.b; - row_slice.c = record.c; - row_slice.opcode_add_flag = F::from_bool(record.opcode == BaseAluOpcode::ADD); - row_slice.opcode_sub_flag = F::from_bool(record.opcode == BaseAluOpcode::SUB); - row_slice.opcode_xor_flag = F::from_bool(record.opcode == BaseAluOpcode::XOR); - row_slice.opcode_or_flag = F::from_bool(record.opcode == BaseAluOpcode::OR); - row_slice.opcode_and_flag = F::from_bool(record.opcode == BaseAluOpcode::AND); +impl BaseAluStep { + /// Return `is_imm`, true if `e` is RV32_IMM_AS. + #[inline(always)] + fn pre_compute_impl( + &self, + pc: u32, + inst: &Instruction, + data: &mut BaseAluPreCompute, + ) -> Result { + let Instruction { a, b, c, d, e, .. } = inst; + let e_u32 = e.as_canonical_u32(); + if (d.as_canonical_u32() != RV32_REGISTER_AS) + || !(e_u32 == RV32_IMM_AS || e_u32 == RV32_REGISTER_AS) + { + return Err(InvalidInstruction(pc)); + } + let is_imm = e_u32 == RV32_IMM_AS; + let c_u32 = c.as_canonical_u32(); + *data = BaseAluPreCompute { + c: if is_imm { + u32::from_le_bytes(imm_to_bytes(c_u32)) + } else { + c_u32 + }, + a: a.as_canonical_u32() as u8, + b: b.as_canonical_u32() as u8, + }; + Ok(is_imm) } +} - fn air(&self) -> &Self::Air { - &self.air +trait AluOp { + fn compute(rs1: u32, rs2: u32) -> u32; +} +struct AddOp; +struct SubOp; +struct XorOp; +struct OrOp; +struct AndOp; +impl AluOp for AddOp { + #[inline(always)] + fn compute(rs1: u32, rs2: u32) -> u32 { + rs1.wrapping_add(rs2) + } +} +impl AluOp for SubOp { + #[inline(always)] + fn compute(rs1: u32, rs2: u32) -> u32 { + rs1.wrapping_sub(rs2) + } +} +impl AluOp for XorOp { + #[inline(always)] + fn compute(rs1: u32, rs2: u32) -> u32 { + rs1 ^ rs2 + } +} +impl AluOp for OrOp { + #[inline(always)] + fn compute(rs1: u32, rs2: u32) -> u32 { + rs1 | rs2 + } +} +impl AluOp for AndOp { + #[inline(always)] + fn compute(rs1: u32, rs2: u32) -> u32 { + rs1 & rs2 } } +#[inline(always)] pub(super) fn run_alu( opcode: BaseAluOpcode, - x: &[u32; NUM_LIMBS], - y: &[u32; NUM_LIMBS], -) -> [u32; NUM_LIMBS] { + x: &[u8; NUM_LIMBS], + y: &[u8; NUM_LIMBS], +) -> [u8; NUM_LIMBS] { + debug_assert!(LIMB_BITS <= 8, "specialize for bytes"); match opcode { BaseAluOpcode::ADD => run_add::(x, y), BaseAluOpcode::SUB => run_subtract::(x, y), - BaseAluOpcode::XOR => run_xor::(x, y), - BaseAluOpcode::OR => run_or::(x, y), - BaseAluOpcode::AND => run_and::(x, y), + BaseAluOpcode::XOR => run_xor::(x, y), + BaseAluOpcode::OR => run_or::(x, y), + BaseAluOpcode::AND => run_and::(x, y), } } +#[inline(always)] fn run_add( - x: &[u32; NUM_LIMBS], - y: &[u32; NUM_LIMBS], -) -> [u32; NUM_LIMBS] { - let mut z = [0u32; NUM_LIMBS]; - let mut carry = [0u32; NUM_LIMBS]; + x: &[u8; NUM_LIMBS], + y: &[u8; NUM_LIMBS], +) -> [u8; NUM_LIMBS] { + let mut z = [0u8; NUM_LIMBS]; + let mut carry = [0u8; NUM_LIMBS]; for i in 0..NUM_LIMBS { - z[i] = x[i] + y[i] + if i > 0 { carry[i - 1] } else { 0 }; - carry[i] = z[i] >> LIMB_BITS; - z[i] &= (1 << LIMB_BITS) - 1; + let mut overflow = + (x[i] as u16) + (y[i] as u16) + if i > 0 { carry[i - 1] as u16 } else { 0 }; + carry[i] = (overflow >> LIMB_BITS) as u8; + overflow &= (1u16 << LIMB_BITS) - 1; + z[i] = overflow as u8; } z } +#[inline(always)] fn run_subtract( - x: &[u32; NUM_LIMBS], - y: &[u32; NUM_LIMBS], -) -> [u32; NUM_LIMBS] { - let mut z = [0u32; NUM_LIMBS]; - let mut carry = [0u32; NUM_LIMBS]; + x: &[u8; NUM_LIMBS], + y: &[u8; NUM_LIMBS], +) -> [u8; NUM_LIMBS] { + let mut z = [0u8; NUM_LIMBS]; + let mut carry = [0u8; NUM_LIMBS]; for i in 0..NUM_LIMBS { - let rhs = y[i] + if i > 0 { carry[i - 1] } else { 0 }; - if x[i] >= rhs { - z[i] = x[i] - rhs; + let rhs = y[i] as u16 + if i > 0 { carry[i - 1] as u16 } else { 0 }; + if x[i] as u16 >= rhs { + z[i] = x[i] - rhs as u8; carry[i] = 0; } else { - z[i] = x[i] + (1 << LIMB_BITS) - rhs; + z[i] = (x[i] as u16 + (1u16 << LIMB_BITS) - rhs) as u8; carry[i] = 1; } } z } -fn run_xor( - x: &[u32; NUM_LIMBS], - y: &[u32; NUM_LIMBS], -) -> [u32; NUM_LIMBS] { +#[inline(always)] +fn run_xor(x: &[u8; NUM_LIMBS], y: &[u8; NUM_LIMBS]) -> [u8; NUM_LIMBS] { array::from_fn(|i| x[i] ^ y[i]) } -fn run_or( - x: &[u32; NUM_LIMBS], - y: &[u32; NUM_LIMBS], -) -> [u32; NUM_LIMBS] { +#[inline(always)] +fn run_or(x: &[u8; NUM_LIMBS], y: &[u8; NUM_LIMBS]) -> [u8; NUM_LIMBS] { array::from_fn(|i| x[i] | y[i]) } -fn run_and( - x: &[u32; NUM_LIMBS], - y: &[u32; NUM_LIMBS], -) -> [u32; NUM_LIMBS] { +#[inline(always)] +fn run_and(x: &[u8; NUM_LIMBS], y: &[u8; NUM_LIMBS]) -> [u8; NUM_LIMBS] { array::from_fn(|i| x[i] & y[i]) } diff --git a/extensions/rv32im/circuit/src/base_alu/mod.rs b/extensions/rv32im/circuit/src/base_alu/mod.rs index cbda8ce555..0ab855a96b 100644 --- a/extensions/rv32im/circuit/src/base_alu/mod.rs +++ b/extensions/rv32im/circuit/src/base_alu/mod.rs @@ -1,7 +1,8 @@ -use openvm_circuit::arch::VmChipWrapper; +use openvm_circuit::arch::{MatrixRecordArena, NewVmChipWrapper, VmAirWrapper}; -use super::adapters::{RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}; -use crate::adapters::Rv32BaseAluAdapterChip; +use super::adapters::{ + Rv32BaseAluAdapterAir, Rv32BaseAluAdapterStep, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS, +}; mod core; pub use core::*; @@ -9,8 +10,9 @@ pub use core::*; #[cfg(test)] mod tests; -pub type Rv32BaseAluChip = VmChipWrapper< - F, - Rv32BaseAluAdapterChip, - BaseAluCoreChip, ->; +pub type Rv32BaseAluAir = + VmAirWrapper>; +pub type Rv32BaseAluStep = + BaseAluStep, RV32_REGISTER_NUM_LIMBS, RV32_CELL_BITS>; +pub type Rv32BaseAluChip = + NewVmChipWrapper>; diff --git a/extensions/rv32im/circuit/src/base_alu/tests.rs b/extensions/rv32im/circuit/src/base_alu/tests.rs index 165cd12526..f8884b6648 100644 --- a/extensions/rv32im/circuit/src/base_alu/tests.rs +++ b/extensions/rv32im/circuit/src/base_alu/tests.rs @@ -1,45 +1,110 @@ -use std::borrow::BorrowMut; +use std::{array, borrow::BorrowMut}; -use openvm_circuit::{ - arch::{ - testing::{TestAdapterChip, VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}, - AdapterRuntimeContext, BasicAdapterInterface, ExecutionBridge, ExecutionState, - MinimalInstruction, Result, VmAdapterChip, VmAdapterInterface, VmChipWrapper, - }, - system::memory::{MemoryController, OfflineMemory}, - utils::generate_long_number, +use openvm_circuit::arch::{ + testing::{VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}, + VmAirWrapper, }; use openvm_circuit_primitives::bitwise_op_lookup::{ BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, }; -use openvm_instructions::{instruction::Instruction, LocalOpcode}; -use openvm_rv32im_transpiler::BaseAluOpcode; +use openvm_instructions::LocalOpcode; +use openvm_rv32im_transpiler::BaseAluOpcode::{self, *}; use openvm_stark_backend::{ p3_air::BaseAir, - p3_field::{Field, FieldAlgebra, PrimeField32}, + p3_field::{FieldAlgebra, PrimeField32}, p3_matrix::{ dense::{DenseMatrix, RowMajorMatrix}, Matrix, }, utils::disable_debug_builder, - verifier::VerificationError, - ChipUsageGetter, }; use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; -use rand::Rng; +use rand::{rngs::StdRng, Rng}; +use test_case::test_case; -use super::{core::run_alu, BaseAluCoreChip, Rv32BaseAluChip}; +use super::{core::run_alu, BaseAluCoreAir, Rv32BaseAluChip, Rv32BaseAluStep}; use crate::{ adapters::{ - Rv32BaseAluAdapterAir, Rv32BaseAluAdapterChip, Rv32BaseAluReadRecord, - Rv32BaseAluWriteRecord, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS, + Rv32BaseAluAdapterAir, Rv32BaseAluAdapterStep, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS, }, base_alu::BaseAluCoreCols, - test_utils::{generate_rv32_is_type_immediate, rv32_rand_write_register_or_imm}, + test_utils::{ + generate_rv32_is_type_immediate, get_verification_error, rv32_rand_write_register_or_imm, + }, }; +const MAX_INS_CAPACITY: usize = 128; type F = BabyBear; +fn create_test_chip( + tester: &VmChipTestBuilder, +) -> ( + Rv32BaseAluChip, + SharedBitwiseOperationLookupChip, +) { + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); + + let mut chip = Rv32BaseAluChip::new( + VmAirWrapper::new( + Rv32BaseAluAdapterAir::new( + tester.execution_bridge(), + tester.memory_bridge(), + bitwise_bus, + ), + BaseAluCoreAir::new(bitwise_bus, BaseAluOpcode::CLASS_OFFSET), + ), + Rv32BaseAluStep::new( + Rv32BaseAluAdapterStep::new(bitwise_chip.clone()), + bitwise_chip.clone(), + BaseAluOpcode::CLASS_OFFSET, + ), + tester.memory_helper(), + ); + chip.set_trace_buffer_height(MAX_INS_CAPACITY); + + (chip, bitwise_chip) +} + +fn set_and_execute( + tester: &mut VmChipTestBuilder, + chip: &mut Rv32BaseAluChip, + rng: &mut StdRng, + opcode: BaseAluOpcode, + b: Option<[u8; RV32_REGISTER_NUM_LIMBS]>, + is_imm: Option, + c: Option<[u8; RV32_REGISTER_NUM_LIMBS]>, +) { + let b = b.unwrap_or(array::from_fn(|_| rng.gen_range(0..=u8::MAX))); + let (c_imm, c) = if is_imm.unwrap_or(rng.gen_bool(0.5)) { + let (imm, c) = if let Some(c) = c { + ((u32::from_le_bytes(c) & 0xFFFFFF) as usize, c) + } else { + generate_rv32_is_type_immediate(rng) + }; + (Some(imm), c) + } else { + ( + None, + c.unwrap_or(array::from_fn(|_| rng.gen_range(0..=u8::MAX))), + ) + }; + + let (instruction, rd) = rv32_rand_write_register_or_imm( + tester, + b, + c, + c_imm, + opcode.global_opcode().as_usize(), + rng, + ); + tester.execute(chip, &instruction); + + let a = run_alu::(opcode, &b, &c) + .map(F::from_canonical_u8); + assert_eq!(a, tester.read::(1, rd)) +} + ////////////////////////////////////////////////////////////////////////////////////// // POSITIVE TESTS // @@ -47,135 +112,105 @@ type F = BabyBear; // passes all constraints. ////////////////////////////////////////////////////////////////////////////////////// -fn run_rv32_alu_rand_test(opcode: BaseAluOpcode, num_ops: usize) { +#[test_case(ADD, 100)] +#[test_case(SUB, 100)] +#[test_case(XOR, 100)] +#[test_case(OR, 100)] +#[test_case(AND, 100)] +fn rand_rv32_alu_test(opcode: BaseAluOpcode, num_ops: usize) { let mut rng = create_seeded_rng(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); let mut tester = VmChipTestBuilder::default(); - let mut chip = Rv32BaseAluChip::::new( - Rv32BaseAluAdapterChip::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - bitwise_chip.clone(), - ), - BaseAluCoreChip::new(bitwise_chip.clone(), BaseAluOpcode::CLASS_OFFSET), - tester.offline_memory_mutex_arc(), - ); + let (mut chip, bitwise_chip) = create_test_chip(&tester); - for _ in 0..num_ops { - let b = generate_long_number::(&mut rng); - let (c_imm, c) = if rng.gen_bool(0.5) { - ( - None, - generate_long_number::(&mut rng), - ) - } else { - let (imm, c) = generate_rv32_is_type_immediate(&mut rng); - (Some(imm), c) - }; + // TODO(AG): make a more meaningful test for memory accesses + tester.write(2, 1024, [F::ONE; 4]); + tester.write(2, 1028, [F::ONE; 4]); + let sm = tester.read(2, 1024); + assert_eq!(sm, [F::ONE; 8]); - let (instruction, rd) = rv32_rand_write_register_or_imm( - &mut tester, - b, - c, - c_imm, - opcode.global_opcode().as_usize(), - &mut rng, - ); - tester.execute(&mut chip, &instruction); - - let a = run_alu::(opcode, &b, &c) - .map(F::from_canonical_u32); - assert_eq!(a, tester.read::(1, rd)) + for _ in 0..num_ops { + set_and_execute(&mut tester, &mut chip, &mut rng, opcode, None, None, None); } let tester = tester.build().load(chip).load(bitwise_chip).finalize(); tester.simple_test().expect("Verification failed"); } -#[test] -fn rv32_alu_add_rand_test() { - run_rv32_alu_rand_test(BaseAluOpcode::ADD, 100); -} +#[test_case(ADD, 100)] +#[test_case(SUB, 100)] +#[test_case(XOR, 100)] +#[test_case(OR, 100)] +#[test_case(AND, 100)] +fn rand_rv32_alu_test_persistent(opcode: BaseAluOpcode, num_ops: usize) { + let mut rng = create_seeded_rng(); -#[test] -fn rv32_alu_sub_rand_test() { - run_rv32_alu_rand_test(BaseAluOpcode::SUB, 100); -} + let mut tester = VmChipTestBuilder::default_persistent(); + let (mut chip, bitwise_chip) = create_test_chip(&tester); -#[test] -fn rv32_alu_xor_rand_test() { - run_rv32_alu_rand_test(BaseAluOpcode::XOR, 100); -} + // TODO(AG): make a more meaningful test for memory accesses + tester.write(2, 1024, [F::ONE; 4]); + tester.write(2, 1028, [F::ONE; 4]); + let sm = tester.read(2, 1024); + assert_eq!(sm, [F::ONE; 8]); -#[test] -fn rv32_alu_or_rand_test() { - run_rv32_alu_rand_test(BaseAluOpcode::OR, 100); -} + for _ in 0..num_ops { + set_and_execute(&mut tester, &mut chip, &mut rng, opcode, None, None, None); + } -#[test] -fn rv32_alu_and_rand_test() { - run_rv32_alu_rand_test(BaseAluOpcode::AND, 100); + let tester = tester.build().load(chip).load(bitwise_chip).finalize(); + tester.simple_test().expect("Verification failed"); } ////////////////////////////////////////////////////////////////////////////////////// // NEGATIVE TESTS // // Given a fake trace of a single operation, setup a chip and run the test. We replace -// the write part of the trace and check that the core chip throws the expected error. -// A dummy adapter is used so memory interactions don't indirectly cause false passes. +// part of the trace and check that the chip throws the expected error. ////////////////////////////////////////////////////////////////////////////////////// -type Rv32BaseAluTestChip = - VmChipWrapper, BaseAluCoreChip>; - #[allow(clippy::too_many_arguments)] -fn run_rv32_alu_negative_test( +fn run_negative_alu_test( opcode: BaseAluOpcode, - a: [u32; RV32_REGISTER_NUM_LIMBS], - b: [u32; RV32_REGISTER_NUM_LIMBS], - c: [u32; RV32_REGISTER_NUM_LIMBS], + prank_a: [u32; RV32_REGISTER_NUM_LIMBS], + b: [u8; RV32_REGISTER_NUM_LIMBS], + c: [u8; RV32_REGISTER_NUM_LIMBS], + prank_c: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, + prank_opcode_flags: Option<[bool; 5]>, + is_imm: Option, interaction_error: bool, ) { - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - + let mut rng = create_seeded_rng(); let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let mut chip = Rv32BaseAluTestChip::::new( - TestAdapterChip::new( - vec![[b.map(F::from_canonical_u32), c.map(F::from_canonical_u32)].concat()], - vec![None], - ExecutionBridge::new(tester.execution_bus(), tester.program_bus()), - ), - BaseAluCoreChip::new(bitwise_chip.clone(), BaseAluOpcode::CLASS_OFFSET), - tester.offline_memory_mutex_arc(), - ); + let (mut chip, bitwise_chip) = create_test_chip(&tester); - tester.execute( + set_and_execute( + &mut tester, &mut chip, - &Instruction::from_usize(opcode.global_opcode(), [0, 0, 0, 1, 1]), + &mut rng, + opcode, + Some(b), + is_imm, + Some(c), ); - let trace_width = chip.trace_width(); - let adapter_width = BaseAir::::width(chip.adapter.air()); - - if (opcode == BaseAluOpcode::ADD || opcode == BaseAluOpcode::SUB) - && a.iter().all(|&a_val| a_val < (1 << RV32_CELL_BITS)) - { - bitwise_chip.clear(); - for a_val in a { - bitwise_chip.request_xor(a_val, a_val); - } - } - + let adapter_width = BaseAir::::width(&chip.air.adapter); let modify_trace = |trace: &mut DenseMatrix| { let mut values = trace.row_slice(0).to_vec(); let cols: &mut BaseAluCoreCols = values.split_at_mut(adapter_width).1.borrow_mut(); - cols.a = a.map(F::from_canonical_u32); - *trace = RowMajorMatrix::new(values, trace_width); + cols.a = prank_a.map(F::from_canonical_u32); + if let Some(prank_c) = prank_c { + cols.c = prank_c.map(F::from_canonical_u32); + } + if let Some(prank_opcode_flags) = prank_opcode_flags { + cols.opcode_add_flag = F::from_bool(prank_opcode_flags[0]); + cols.opcode_and_flag = F::from_bool(prank_opcode_flags[1]); + cols.opcode_or_flag = F::from_bool(prank_opcode_flags[2]); + cols.opcode_sub_flag = F::from_bool(prank_opcode_flags[3]); + cols.opcode_xor_flag = F::from_bool(prank_opcode_flags[4]); + } + *trace = RowMajorMatrix::new(values, trace.width()); }; disable_debug_builder(); @@ -184,90 +219,135 @@ fn run_rv32_alu_negative_test( .load_and_prank_trace(chip, modify_trace) .load(bitwise_chip) .finalize(); - tester.simple_test_with_expected_error(if interaction_error { - VerificationError::ChallengePhaseError - } else { - VerificationError::OodEvaluationMismatch - }); + tester.simple_test_with_expected_error(get_verification_error(interaction_error)); } #[test] fn rv32_alu_add_wrong_negative_test() { - run_rv32_alu_negative_test( - BaseAluOpcode::ADD, + run_negative_alu_test( + ADD, [246, 0, 0, 0], [250, 0, 0, 0], [250, 0, 0, 0], + None, + None, + None, false, ); } #[test] fn rv32_alu_add_out_of_range_negative_test() { - run_rv32_alu_negative_test( - BaseAluOpcode::ADD, + run_negative_alu_test( + ADD, [500, 0, 0, 0], [250, 0, 0, 0], [250, 0, 0, 0], + None, + None, + None, true, ); } #[test] fn rv32_alu_sub_wrong_negative_test() { - run_rv32_alu_negative_test( - BaseAluOpcode::SUB, + run_negative_alu_test( + SUB, [255, 0, 0, 0], [1, 0, 0, 0], [2, 0, 0, 0], + None, + None, + None, false, ); } #[test] fn rv32_alu_sub_out_of_range_negative_test() { - run_rv32_alu_negative_test( - BaseAluOpcode::SUB, + run_negative_alu_test( + SUB, [F::NEG_ONE.as_canonical_u32(), 0, 0, 0], [1, 0, 0, 0], [2, 0, 0, 0], + None, + None, + None, true, ); } #[test] fn rv32_alu_xor_wrong_negative_test() { - run_rv32_alu_negative_test( - BaseAluOpcode::XOR, + run_negative_alu_test( + XOR, [255, 255, 255, 255], [0, 0, 1, 0], [255, 255, 255, 255], + None, + None, + None, true, ); } #[test] fn rv32_alu_or_wrong_negative_test() { - run_rv32_alu_negative_test( - BaseAluOpcode::OR, + run_negative_alu_test( + OR, [255, 255, 255, 255], [255, 255, 255, 254], [0, 0, 0, 0], + None, + None, + None, true, ); } #[test] fn rv32_alu_and_wrong_negative_test() { - run_rv32_alu_negative_test( - BaseAluOpcode::AND, + run_negative_alu_test( + AND, [255, 255, 255, 255], [0, 0, 1, 0], [0, 0, 0, 0], + None, + None, + None, true, ); } +#[test] +fn rv32_alu_adapter_unconstrained_imm_limb_test() { + run_negative_alu_test( + ADD, + [255, 7, 0, 0], + [0, 0, 0, 0], + [255, 7, 0, 0], + Some([511, 6, 0, 0]), + None, + Some(true), + true, + ); +} + +#[test] +fn rv32_alu_adapter_unconstrained_rs2_read_test() { + run_negative_alu_test( + ADD, + [2, 2, 2, 2], + [1, 1, 1, 1], + [1, 1, 1, 1], + None, + Some([false, false, false, false, false]), + Some(false), + false, + ); +} + /////////////////////////////////////////////////////////////////////////////////////// /// SANITY TESTS /// @@ -276,10 +356,10 @@ fn rv32_alu_and_wrong_negative_test() { #[test] fn run_add_sanity_test() { - let x: [u32; RV32_REGISTER_NUM_LIMBS] = [229, 33, 29, 111]; - let y: [u32; RV32_REGISTER_NUM_LIMBS] = [50, 171, 44, 194]; - let z: [u32; RV32_REGISTER_NUM_LIMBS] = [23, 205, 73, 49]; - let result = run_alu::(BaseAluOpcode::ADD, &x, &y); + let x: [u8; RV32_REGISTER_NUM_LIMBS] = [229, 33, 29, 111]; + let y: [u8; RV32_REGISTER_NUM_LIMBS] = [50, 171, 44, 194]; + let z: [u8; RV32_REGISTER_NUM_LIMBS] = [23, 205, 73, 49]; + let result = run_alu::(ADD, &x, &y); for i in 0..RV32_REGISTER_NUM_LIMBS { assert_eq!(z[i], result[i]) } @@ -287,10 +367,10 @@ fn run_add_sanity_test() { #[test] fn run_sub_sanity_test() { - let x: [u32; RV32_REGISTER_NUM_LIMBS] = [229, 33, 29, 111]; - let y: [u32; RV32_REGISTER_NUM_LIMBS] = [50, 171, 44, 194]; - let z: [u32; RV32_REGISTER_NUM_LIMBS] = [179, 118, 240, 172]; - let result = run_alu::(BaseAluOpcode::SUB, &x, &y); + let x: [u8; RV32_REGISTER_NUM_LIMBS] = [229, 33, 29, 111]; + let y: [u8; RV32_REGISTER_NUM_LIMBS] = [50, 171, 44, 194]; + let z: [u8; RV32_REGISTER_NUM_LIMBS] = [179, 118, 240, 172]; + let result = run_alu::(SUB, &x, &y); for i in 0..RV32_REGISTER_NUM_LIMBS { assert_eq!(z[i], result[i]) } @@ -298,10 +378,10 @@ fn run_sub_sanity_test() { #[test] fn run_xor_sanity_test() { - let x: [u32; RV32_REGISTER_NUM_LIMBS] = [229, 33, 29, 111]; - let y: [u32; RV32_REGISTER_NUM_LIMBS] = [50, 171, 44, 194]; - let z: [u32; RV32_REGISTER_NUM_LIMBS] = [215, 138, 49, 173]; - let result = run_alu::(BaseAluOpcode::XOR, &x, &y); + let x: [u8; RV32_REGISTER_NUM_LIMBS] = [229, 33, 29, 111]; + let y: [u8; RV32_REGISTER_NUM_LIMBS] = [50, 171, 44, 194]; + let z: [u8; RV32_REGISTER_NUM_LIMBS] = [215, 138, 49, 173]; + let result = run_alu::(XOR, &x, &y); for i in 0..RV32_REGISTER_NUM_LIMBS { assert_eq!(z[i], result[i]) } @@ -309,10 +389,10 @@ fn run_xor_sanity_test() { #[test] fn run_or_sanity_test() { - let x: [u32; RV32_REGISTER_NUM_LIMBS] = [229, 33, 29, 111]; - let y: [u32; RV32_REGISTER_NUM_LIMBS] = [50, 171, 44, 194]; - let z: [u32; RV32_REGISTER_NUM_LIMBS] = [247, 171, 61, 239]; - let result = run_alu::(BaseAluOpcode::OR, &x, &y); + let x: [u8; RV32_REGISTER_NUM_LIMBS] = [229, 33, 29, 111]; + let y: [u8; RV32_REGISTER_NUM_LIMBS] = [50, 171, 44, 194]; + let z: [u8; RV32_REGISTER_NUM_LIMBS] = [247, 171, 61, 239]; + let result = run_alu::(OR, &x, &y); for i in 0..RV32_REGISTER_NUM_LIMBS { assert_eq!(z[i], result[i]) } @@ -320,195 +400,11 @@ fn run_or_sanity_test() { #[test] fn run_and_sanity_test() { - let x: [u32; RV32_REGISTER_NUM_LIMBS] = [229, 33, 29, 111]; - let y: [u32; RV32_REGISTER_NUM_LIMBS] = [50, 171, 44, 194]; - let z: [u32; RV32_REGISTER_NUM_LIMBS] = [32, 33, 12, 66]; - let result = run_alu::(BaseAluOpcode::AND, &x, &y); + let x: [u8; RV32_REGISTER_NUM_LIMBS] = [229, 33, 29, 111]; + let y: [u8; RV32_REGISTER_NUM_LIMBS] = [50, 171, 44, 194]; + let z: [u8; RV32_REGISTER_NUM_LIMBS] = [32, 33, 12, 66]; + let result = run_alu::(AND, &x, &y); for i in 0..RV32_REGISTER_NUM_LIMBS { assert_eq!(z[i], result[i]) } } - -////////////////////////////////////////////////////////////////////////////////////// -// ADAPTER TESTS -// -// Ensure that the adapter is correct. -////////////////////////////////////////////////////////////////////////////////////// - -// A pranking chip where `preprocess` can have `rs2` limbs that overflow. -struct Rv32BaseAluAdapterTestChip(Rv32BaseAluAdapterChip); - -impl VmAdapterChip for Rv32BaseAluAdapterTestChip { - type ReadRecord = Rv32BaseAluReadRecord; - type WriteRecord = Rv32BaseAluWriteRecord; - type Air = Rv32BaseAluAdapterAir; - type Interface = BasicAdapterInterface< - F, - MinimalInstruction, - 2, - 1, - RV32_REGISTER_NUM_LIMBS, - RV32_REGISTER_NUM_LIMBS, - >; - - fn preprocess( - &mut self, - memory: &mut MemoryController, - instruction: &Instruction, - ) -> Result<( - >::Reads, - Self::ReadRecord, - )> { - let Instruction { b, c, d, e, .. } = *instruction; - - let rs1 = memory.read::(d, b); - let (rs2, rs2_data, rs2_imm) = if e.is_zero() { - let c_u32 = c.as_canonical_u32(); - memory.increment_timestamp(); - let mask1 = (1 << 9) - 1; - let mask2 = (1 << 3) - 2; - ( - None, - [ - (c_u32 & mask1) as u16, - ((c_u32 >> 8) & mask2) as u16, - (c_u32 >> 16) as u16, - (c_u32 >> 16) as u16, - ] - .map(F::from_canonical_u16), - c, - ) - } else { - let rs2_read = memory.read::(e, c); - (Some(rs2_read.0), rs2_read.1, F::ZERO) - }; - - Ok(( - [rs1.1, rs2_data], - Self::ReadRecord { - rs1: rs1.0, - rs2, - rs2_imm, - }, - )) - } - - fn postprocess( - &mut self, - memory: &mut MemoryController, - instruction: &Instruction, - from_state: ExecutionState, - output: AdapterRuntimeContext, - _read_record: &Self::ReadRecord, - ) -> Result<(ExecutionState, Self::WriteRecord)> { - self.0 - .postprocess(memory, instruction, from_state, output, _read_record) - } - - fn generate_trace_row( - &self, - row_slice: &mut [F], - read_record: Self::ReadRecord, - write_record: Self::WriteRecord, - memory: &OfflineMemory, - ) { - self.0 - .generate_trace_row(row_slice, read_record, write_record, memory) - } - - fn air(&self) -> &Self::Air { - self.0.air() - } -} - -#[test] -fn rv32_alu_adapter_unconstrained_imm_limb_test() { - let mut rng = create_seeded_rng(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - - let mut tester = VmChipTestBuilder::default(); - let mut chip = VmChipWrapper::new( - Rv32BaseAluAdapterTestChip(Rv32BaseAluAdapterChip::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - bitwise_chip.clone(), - )), - BaseAluCoreChip::new(bitwise_chip.clone(), BaseAluOpcode::CLASS_OFFSET), - tester.offline_memory_mutex_arc(), - ); - - let b = [0, 0, 0, 0]; - let (c_imm, c) = { - let imm = (1 << 11) - 1; - let fake_c = [(1 << 9) - 1, (1 << 3) - 2, 0, 0]; - (Some(imm), fake_c) - }; - - let (instruction, _rd) = rv32_rand_write_register_or_imm( - &mut tester, - b, - c, - c_imm, - BaseAluOpcode::ADD.global_opcode().as_usize(), - &mut rng, - ); - tester.execute(&mut chip, &instruction); - - disable_debug_builder(); - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); - tester.simple_test_with_expected_error(VerificationError::ChallengePhaseError); -} - -#[test] -fn rv32_alu_adapter_unconstrained_rs2_read_test() { - let mut rng = create_seeded_rng(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - - let mut tester = VmChipTestBuilder::default(); - let mut chip = Rv32BaseAluChip::::new( - Rv32BaseAluAdapterChip::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - bitwise_chip.clone(), - ), - BaseAluCoreChip::new(bitwise_chip.clone(), BaseAluOpcode::CLASS_OFFSET), - tester.offline_memory_mutex_arc(), - ); - - let b = [1, 1, 1, 1]; - let c = [1, 1, 1, 1]; - let (instruction, _rd) = rv32_rand_write_register_or_imm( - &mut tester, - b, - c, - None, - BaseAluOpcode::ADD.global_opcode().as_usize(), - &mut rng, - ); - tester.execute(&mut chip, &instruction); - - let trace_width = chip.trace_width(); - let adapter_width = BaseAir::::width(chip.adapter.air()); - - let modify_trace = |trace: &mut DenseMatrix| { - let mut values = trace.row_slice(0).to_vec(); - let mut dummy_values = values.clone(); - let cols: &mut BaseAluCoreCols = - dummy_values.split_at_mut(adapter_width).1.borrow_mut(); - cols.opcode_add_flag = F::ZERO; - values.extend(dummy_values); - *trace = RowMajorMatrix::new(values, trace_width); - }; - - disable_debug_builder(); - let tester = tester - .build() - .load_and_prank_trace(chip, modify_trace) - .load(bitwise_chip) - .finalize(); - tester.simple_test_with_expected_error(VerificationError::OodEvaluationMismatch); -} diff --git a/extensions/rv32im/circuit/src/branch_eq/core.rs b/extensions/rv32im/circuit/src/branch_eq/core.rs index bb04d86ee5..18336fa5ed 100644 --- a/extensions/rv32im/circuit/src/branch_eq/core.rs +++ b/extensions/rv32im/circuit/src/branch_eq/core.rs @@ -1,15 +1,21 @@ -use std::{ - array, - borrow::{Borrow, BorrowMut}, -}; +use std::borrow::{Borrow, BorrowMut}; -use openvm_circuit::arch::{ - AdapterAirContext, AdapterRuntimeContext, ImmInstruction, Result, VmAdapterInterface, - VmCoreAir, VmCoreChip, +use openvm_circuit::{ + arch::{ + execution_mode::{E1ExecutionCtx, E2ExecutionCtx}, + get_record_from_slice, AdapterAirContext, AdapterTraceFiller, AdapterTraceStep, + E2PreCompute, EmptyAdapterCoreLayout, ExecuteFunc, + ExecutionError::InvalidInstruction, + ImmInstruction, RecordArena, Result, StepExecutorE1, StepExecutorE2, TraceFiller, + TraceStep, VmAdapterInterface, VmCoreAir, VmSegmentState, VmStateMut, + }, + system::memory::{online::TracingMemory, MemoryAuxColsFactory}, }; use openvm_circuit_primitives::utils::not; -use openvm_circuit_primitives_derive::AlignedBorrow; -use openvm_instructions::{instruction::Instruction, LocalOpcode}; +use openvm_circuit_primitives_derive::{AlignedBorrow, AlignedBytesBorrow}; +use openvm_instructions::{ + instruction::Instruction, program::DEFAULT_PC_STEP, riscv::RV32_REGISTER_AS, LocalOpcode, +}; use openvm_rv32im_transpiler::BranchEqualOpcode; use openvm_stark_backend::{ interaction::InteractionBuilder, @@ -17,8 +23,6 @@ use openvm_stark_backend::{ p3_field::{Field, FieldAlgebra, PrimeField32}, rap::BaseAirWithPublicValues, }; -use serde::{Deserialize, Serialize}; -use serde_big_array::BigArray; use strum::IntoEnumIterator; #[repr(C)] @@ -37,7 +41,7 @@ pub struct BranchEqualCoreCols { pub diff_inv_marker: [T; NUM_LIMBS], } -#[derive(Copy, Clone, Debug)] +#[derive(Copy, Clone, Debug, derive_new::new)] pub struct BranchEqualCoreAir { offset: usize, pc_step: u32, @@ -135,117 +139,265 @@ where } #[repr(C)] -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct BranchEqualCoreRecord { - #[serde(with = "BigArray")] - pub a: [T; NUM_LIMBS], - #[serde(with = "BigArray")] - pub b: [T; NUM_LIMBS], - pub cmp_result: T, - pub imm: T, - pub diff_inv_val: T, - pub diff_idx: usize, - pub opcode: BranchEqualOpcode, +#[derive(AlignedBytesBorrow, Debug)] +pub struct BranchEqualCoreRecord { + pub a: [u8; NUM_LIMBS], + pub b: [u8; NUM_LIMBS], + pub imm: u32, + pub local_opcode: u8, } -#[derive(Debug)] -pub struct BranchEqualCoreChip { - pub air: BranchEqualCoreAir, +#[derive(derive_new::new)] +pub struct BranchEqualStep { + adapter: A, + pub offset: usize, + pub pc_step: u32, } -impl BranchEqualCoreChip { - pub fn new(offset: usize, pc_step: u32) -> Self { - Self { - air: BranchEqualCoreAir { offset, pc_step }, +impl TraceStep for BranchEqualStep +where + F: PrimeField32, + A: 'static + + for<'a> AdapterTraceStep, WriteData = ()>, +{ + type RecordLayout = EmptyAdapterCoreLayout; + type RecordMut<'a> = (A::RecordMut<'a>, &'a mut BranchEqualCoreRecord); + + fn get_opcode_name(&self, opcode: usize) -> String { + format!("{:?}", BranchEqualOpcode::from_usize(opcode - self.offset)) + } + + fn execute<'buf, RA>( + &mut self, + state: VmStateMut, CTX>, + instruction: &Instruction, + arena: &'buf mut RA, + ) -> Result<()> + where + RA: RecordArena<'buf, Self::RecordLayout, Self::RecordMut<'buf>>, + { + let &Instruction { opcode, c: imm, .. } = instruction; + + let branch_eq_opcode = BranchEqualOpcode::from_usize(opcode.local_opcode_idx(self.offset)); + + let (mut adapter_record, core_record) = arena.alloc(EmptyAdapterCoreLayout::new()); + + A::start(*state.pc, state.memory, &mut adapter_record); + + let [rs1, rs2] = self + .adapter + .read(state.memory, instruction, &mut adapter_record) + .into(); + + core_record.a = rs1; + core_record.b = rs2; + core_record.imm = imm.as_canonical_u32(); + core_record.local_opcode = branch_eq_opcode as u8; + + if fast_run_eq(branch_eq_opcode, &rs1, &rs2) { + *state.pc = (F::from_canonical_u32(*state.pc) + imm).as_canonical_u32(); + } else { + *state.pc = state.pc.wrapping_add(self.pc_step); } + + Ok(()) } } -impl, const NUM_LIMBS: usize> VmCoreChip - for BranchEqualCoreChip +impl TraceFiller for BranchEqualStep where - I::Reads: Into<[[F; NUM_LIMBS]; 2]>, - I::Writes: Default, + F: PrimeField32, + A: 'static + AdapterTraceFiller, { - type Record = BranchEqualCoreRecord; - type Air = BranchEqualCoreAir; + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + self.adapter.fill_trace_row(mem_helper, adapter_row); + let record: &BranchEqualCoreRecord = + unsafe { get_record_from_slice(&mut core_row, ()) }; + let core_row: &mut BranchEqualCoreCols = core_row.borrow_mut(); + + let (cmp_result, diff_idx, diff_inv_val) = run_eq::( + record.local_opcode == BranchEqualOpcode::BEQ as u8, + &record.a, + &record.b, + ); + core_row.diff_inv_marker = [F::ZERO; NUM_LIMBS]; + core_row.diff_inv_marker[diff_idx] = diff_inv_val; + + core_row.opcode_bne_flag = + F::from_bool(record.local_opcode == BranchEqualOpcode::BNE as u8); + core_row.opcode_beq_flag = + F::from_bool(record.local_opcode == BranchEqualOpcode::BEQ as u8); + + core_row.imm = F::from_canonical_u32(record.imm); + core_row.cmp_result = F::from_bool(cmp_result); - #[allow(clippy::type_complexity)] - fn execute_instruction( + core_row.b = record.b.map(F::from_canonical_u8); + core_row.a = record.a.map(F::from_canonical_u8); + } +} + +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct BranchEqualPreCompute { + imm: isize, + a: u8, + b: u8, +} + +impl StepExecutorE1 for BranchEqualStep +where + F: PrimeField32, +{ + #[inline(always)] + fn pre_compute_size(&self) -> usize { + size_of::() + } + + #[inline(always)] + fn pre_compute_e1( &self, - instruction: &Instruction, - from_pc: u32, - reads: I::Reads, - ) -> Result<(AdapterRuntimeContext, Self::Record)> { - let Instruction { opcode, c: imm, .. } = *instruction; - let branch_eq_opcode = - BranchEqualOpcode::from_usize(opcode.local_opcode_idx(self.air.offset)); - - let data: [[F; NUM_LIMBS]; 2] = reads.into(); - let x = data[0].map(|x| x.as_canonical_u32()); - let y = data[1].map(|y| y.as_canonical_u32()); - let (cmp_result, diff_idx, diff_inv_val) = run_eq::(branch_eq_opcode, &x, &y); - - let output = AdapterRuntimeContext { - to_pc: cmp_result.then_some((F::from_canonical_u32(from_pc) + imm).as_canonical_u32()), - writes: Default::default(), - }; - let record = BranchEqualCoreRecord { - opcode: branch_eq_opcode, - a: data[0], - b: data[1], - cmp_result: F::from_bool(cmp_result), - imm, - diff_idx, - diff_inv_val, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result> { + let data: &mut BranchEqualPreCompute = data.borrow_mut(); + let is_bne = self.pre_compute_impl(pc, inst, data)?; + let fn_ptr = if is_bne { + execute_e1_impl::<_, _, true> + } else { + execute_e1_impl::<_, _, false> }; + Ok(fn_ptr) + } +} - Ok((output, record)) +impl StepExecutorE2 for BranchEqualStep +where + F: PrimeField32, +{ + fn e2_pre_compute_size(&self) -> usize { + size_of::>() } - fn get_opcode_name(&self, opcode: usize) -> String { - format!( - "{:?}", - BranchEqualOpcode::from_usize(opcode - self.air.offset) - ) + fn pre_compute_e2( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result> + where + Ctx: E2ExecutionCtx, + { + let data: &mut E2PreCompute = data.borrow_mut(); + data.chip_idx = chip_idx as u32; + let is_bne = self.pre_compute_impl(pc, inst, &mut data.data)?; + let fn_ptr = if is_bne { + execute_e2_impl::<_, _, true> + } else { + execute_e2_impl::<_, _, false> + }; + Ok(fn_ptr) } +} - fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) { - let row_slice: &mut BranchEqualCoreCols<_, NUM_LIMBS> = row_slice.borrow_mut(); - row_slice.a = record.a; - row_slice.b = record.b; - row_slice.cmp_result = record.cmp_result; - row_slice.imm = record.imm; - row_slice.opcode_beq_flag = F::from_bool(record.opcode == BranchEqualOpcode::BEQ); - row_slice.opcode_bne_flag = F::from_bool(record.opcode == BranchEqualOpcode::BNE); - row_slice.diff_inv_marker = array::from_fn(|i| { - if i == record.diff_idx { - record.diff_inv_val - } else { - F::ZERO - } - }); +#[inline(always)] +unsafe fn execute_e12_impl( + pre_compute: &BranchEqualPreCompute, + vm_state: &mut VmSegmentState, +) { + let rs1 = vm_state.vm_read::(RV32_REGISTER_AS, pre_compute.a as u32); + let rs2 = vm_state.vm_read::(RV32_REGISTER_AS, pre_compute.b as u32); + if (rs1 == rs2) ^ IS_NE { + vm_state.pc = (vm_state.pc as isize + pre_compute.imm) as u32; + } else { + vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP); } + vm_state.instret += 1; +} + +unsafe fn execute_e1_impl( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &BranchEqualPreCompute = pre_compute.borrow(); + execute_e12_impl::(pre_compute, vm_state); +} +unsafe fn execute_e2_impl( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &E2PreCompute = pre_compute.borrow(); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, 1); + execute_e12_impl::(&pre_compute.data, vm_state); +} - fn air(&self) -> &Self::Air { - &self.air +impl BranchEqualStep { + /// Return `is_bne`, true if the local opcode is BNE. + #[inline(always)] + fn pre_compute_impl( + &self, + pc: u32, + inst: &Instruction, + data: &mut BranchEqualPreCompute, + ) -> Result { + let data: &mut BranchEqualPreCompute = data.borrow_mut(); + let &Instruction { + opcode, a, b, c, d, .. + } = inst; + let local_opcode = BranchEqualOpcode::from_usize(opcode.local_opcode_idx(self.offset)); + let c = c.as_canonical_u32(); + let imm = if F::ORDER_U32 - c < c { + -((F::ORDER_U32 - c) as isize) + } else { + c as isize + }; + if d.as_canonical_u32() != RV32_REGISTER_AS { + return Err(InvalidInstruction(pc)); + } + *data = BranchEqualPreCompute { + imm, + a: a.as_canonical_u32() as u8, + b: b.as_canonical_u32() as u8, + }; + Ok(local_opcode == BranchEqualOpcode::BNE) } } // Returns (cmp_result, diff_idx, x[diff_idx] - y[diff_idx]) -pub(super) fn run_eq( +#[inline(always)] +pub(super) fn fast_run_eq( local_opcode: BranchEqualOpcode, - x: &[u32; NUM_LIMBS], - y: &[u32; NUM_LIMBS], -) -> (bool, usize, F) { + x: &[u8; NUM_LIMBS], + y: &[u8; NUM_LIMBS], +) -> bool { + match local_opcode { + BranchEqualOpcode::BEQ => x == y, + BranchEqualOpcode::BNE => x != y, + } +} + +// Returns (cmp_result, diff_idx, x[diff_idx] - y[diff_idx]) +#[inline(always)] +pub(super) fn run_eq( + is_beq: bool, + x: &[u8; NUM_LIMBS], + y: &[u8; NUM_LIMBS], +) -> (bool, usize, F) +where + F: PrimeField32, +{ for i in 0..NUM_LIMBS { if x[i] != y[i] { return ( - local_opcode == BranchEqualOpcode::BNE, + !is_beq, i, - (F::from_canonical_u32(x[i]) - F::from_canonical_u32(y[i])).inverse(), + (F::from_canonical_u8(x[i]) - F::from_canonical_u8(y[i])).inverse(), ); } } - (local_opcode == BranchEqualOpcode::BEQ, 0, F::ZERO) + (is_beq, 0, F::ZERO) } diff --git a/extensions/rv32im/circuit/src/branch_eq/mod.rs b/extensions/rv32im/circuit/src/branch_eq/mod.rs index 7d53946a73..35ece9a2c3 100644 --- a/extensions/rv32im/circuit/src/branch_eq/mod.rs +++ b/extensions/rv32im/circuit/src/branch_eq/mod.rs @@ -1,7 +1,7 @@ -use openvm_circuit::arch::VmChipWrapper; +use openvm_circuit::arch::{MatrixRecordArena, NewVmChipWrapper, VmAirWrapper}; use super::adapters::RV32_REGISTER_NUM_LIMBS; -use crate::adapters::Rv32BranchAdapterChip; +use crate::adapters::{Rv32BranchAdapterAir, Rv32BranchAdapterStep}; mod core; pub use core::*; @@ -9,5 +9,8 @@ pub use core::*; #[cfg(test)] mod tests; +pub type Rv32BranchEqualAir = + VmAirWrapper>; +pub type Rv32BranchEqualStep = BranchEqualStep; pub type Rv32BranchEqualChip = - VmChipWrapper, BranchEqualCoreChip>; + NewVmChipWrapper>; diff --git a/extensions/rv32im/circuit/src/branch_eq/tests.rs b/extensions/rv32im/circuit/src/branch_eq/tests.rs index c16858b071..0087e29fa7 100644 --- a/extensions/rv32im/circuit/src/branch_eq/tests.rs +++ b/extensions/rv32im/circuit/src/branch_eq/tests.rs @@ -1,11 +1,14 @@ use std::{array, borrow::BorrowMut}; use openvm_circuit::arch::{ - testing::{memory::gen_pointer, TestAdapterChip, VmChipTestBuilder}, - BasicAdapterInterface, ExecutionBridge, ImmInstruction, InstructionExecutor, VmAdapterChip, - VmChipWrapper, VmCoreChip, + testing::{memory::gen_pointer, VmChipTestBuilder}, + InstructionExecutor, VmAirWrapper, +}; +use openvm_instructions::{ + instruction::Instruction, + program::{DEFAULT_PC_STEP, PC_BITS}, + LocalOpcode, }; -use openvm_instructions::{instruction::Instruction, program::PC_BITS, LocalOpcode}; use openvm_rv32im_transpiler::BranchEqualOpcode; use openvm_stark_backend::{ p3_air::BaseAir, @@ -15,42 +18,70 @@ use openvm_stark_backend::{ Matrix, }, utils::disable_debug_builder, - verifier::VerificationError, - ChipUsageGetter, }; use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; use rand::{rngs::StdRng, Rng}; +use test_case::test_case; use super::{ - core::{run_eq, BranchEqualCoreChip}, + core::{run_eq, BranchEqualStep}, BranchEqualCoreCols, Rv32BranchEqualChip, }; -use crate::adapters::{Rv32BranchAdapterChip, RV32_REGISTER_NUM_LIMBS, RV_B_TYPE_IMM_BITS}; +use crate::{ + adapters::{ + Rv32BranchAdapterAir, Rv32BranchAdapterStep, RV32_REGISTER_NUM_LIMBS, RV_B_TYPE_IMM_BITS, + }, + branch_eq::fast_run_eq, + test_utils::get_verification_error, + BranchEqualCoreAir, +}; type F = BabyBear; +const MAX_INS_CAPACITY: usize = 128; +const ABS_MAX_IMM: i32 = 1 << (RV_B_TYPE_IMM_BITS - 1); -////////////////////////////////////////////////////////////////////////////////////// -// POSITIVE TESTS -// -// Randomly generate computations and execute, ensuring that the generated trace -// passes all constraints. -////////////////////////////////////////////////////////////////////////////////////// +fn create_test_chip(tester: &mut VmChipTestBuilder) -> Rv32BranchEqualChip { + let mut chip = Rv32BranchEqualChip::::new( + VmAirWrapper::new( + Rv32BranchAdapterAir::new(tester.execution_bridge(), tester.memory_bridge()), + BranchEqualCoreAir::new(BranchEqualOpcode::CLASS_OFFSET, DEFAULT_PC_STEP), + ), + BranchEqualStep::new( + Rv32BranchAdapterStep::new(), + BranchEqualOpcode::CLASS_OFFSET, + DEFAULT_PC_STEP, + ), + tester.memory_helper(), + ); + chip.set_trace_buffer_height(MAX_INS_CAPACITY); + + chip +} #[allow(clippy::too_many_arguments)] -fn run_rv32_branch_eq_rand_execute>( +fn set_and_execute>( tester: &mut VmChipTestBuilder, chip: &mut E, - opcode: BranchEqualOpcode, - a: [u32; RV32_REGISTER_NUM_LIMBS], - b: [u32; RV32_REGISTER_NUM_LIMBS], - imm: i32, rng: &mut StdRng, + opcode: BranchEqualOpcode, + a: Option<[u8; RV32_REGISTER_NUM_LIMBS]>, + b: Option<[u8; RV32_REGISTER_NUM_LIMBS]>, + imm: Option, ) { + let a = a.unwrap_or(array::from_fn(|_| rng.gen_range(0..=u8::MAX))); + let b = b.unwrap_or(if rng.gen_bool(0.5) { + a + } else { + array::from_fn(|_| rng.gen_range(0..=u8::MAX)) + }); + + let imm = imm.unwrap_or(rng.gen_range((-ABS_MAX_IMM)..ABS_MAX_IMM)); let rs1 = gen_pointer(rng, 4); let rs2 = gen_pointer(rng, 4); - tester.write::(1, rs1, a.map(F::from_canonical_u32)); - tester.write::(1, rs2, b.map(F::from_canonical_u32)); + tester.write::(1, rs1, a.map(F::from_canonical_u8)); + tester.write::(1, rs2, b.map(F::from_canonical_u8)); + let initial_pc = rng.gen_range(imm.unsigned_abs()..(1 << (PC_BITS - 1))); tester.execute_with_pc( chip, &Instruction::from_isize( @@ -61,10 +92,10 @@ fn run_rv32_branch_eq_rand_execute>( 1, 1, ), - rng.gen_range(imm.unsigned_abs()..(1 << (PC_BITS - 1))), + initial_pc, ); - let (cmp_result, _, _) = run_eq::(opcode, &a, &b); + let cmp_result = fast_run_eq(opcode, &a, &b); let from_pc = tester.execution.last_from_pc().as_canonical_u32() as i32; let to_pc = tester.execution.last_to_pc().as_canonical_u32() as i32; let pc_inc = if cmp_result { imm } else { 4 }; @@ -72,94 +103,71 @@ fn run_rv32_branch_eq_rand_execute>( assert_eq!(to_pc, from_pc + pc_inc); } -fn run_rv32_branch_eq_rand_test(opcode: BranchEqualOpcode, num_ops: usize) { - let mut rng = create_seeded_rng(); - const ABS_MAX_BRANCH: i32 = 1 << (RV_B_TYPE_IMM_BITS - 1); +////////////////////////////////////////////////////////////////////////////////////// +// POSITIVE TESTS +// +// Randomly generate computations and execute, ensuring that the generated trace +// passes all constraints. +////////////////////////////////////////////////////////////////////////////////////// +#[test_case(BranchEqualOpcode::BEQ, 100)] +#[test_case(BranchEqualOpcode::BNE, 100)] +fn rand_rv32_branch_eq_test(opcode: BranchEqualOpcode, num_ops: usize) { + let mut rng = create_seeded_rng(); let mut tester = VmChipTestBuilder::default(); - let mut chip = Rv32BranchEqualChip::::new( - Rv32BranchAdapterChip::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - ), - BranchEqualCoreChip::new(BranchEqualOpcode::CLASS_OFFSET, 4), - tester.offline_memory_mutex_arc(), - ); + let mut chip = create_test_chip(&mut tester); for _ in 0..num_ops { - let a = array::from_fn(|_| rng.gen_range(0..F::ORDER_U32)); - let b = if rng.gen_bool(0.5) { - a - } else { - array::from_fn(|_| rng.gen_range(0..F::ORDER_U32)) - }; - let imm = rng.gen_range((-ABS_MAX_BRANCH)..ABS_MAX_BRANCH); - run_rv32_branch_eq_rand_execute(&mut tester, &mut chip, opcode, a, b, imm, &mut rng); + set_and_execute(&mut tester, &mut chip, &mut rng, opcode, None, None, None); } let tester = tester.build().load(chip).finalize(); tester.simple_test().expect("Verification failed"); } -#[test] -fn rv32_beq_rand_test() { - run_rv32_branch_eq_rand_test(BranchEqualOpcode::BEQ, 100); -} - -#[test] -fn rv32_bne_rand_test() { - run_rv32_branch_eq_rand_test(BranchEqualOpcode::BNE, 100); -} - ////////////////////////////////////////////////////////////////////////////////////// // NEGATIVE TESTS // // Given a fake trace of a single operation, setup a chip and run the test. We replace -// the write part of the trace and check that the core chip throws the expected error. -// A dummy adapter is used so memory interactions don't indirectly cause false passes. +// part of the trace and check that the chip throws the expected error. ////////////////////////////////////////////////////////////////////////////////////// -type Rv32BranchEqualTestChip = - VmChipWrapper, BranchEqualCoreChip>; - #[allow(clippy::too_many_arguments)] -fn run_rv32_beq_negative_test( +fn run_negative_branch_eq_test( opcode: BranchEqualOpcode, - a: [u32; RV32_REGISTER_NUM_LIMBS], - b: [u32; RV32_REGISTER_NUM_LIMBS], - cmp_result: bool, - diff_inv_marker: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, + a: [u8; RV32_REGISTER_NUM_LIMBS], + b: [u8; RV32_REGISTER_NUM_LIMBS], + prank_cmp_result: Option, + prank_diff_inv_marker: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, + interaction_error: bool, ) { - let imm = 16u32; + let imm = 16i32; + let mut rng = create_seeded_rng(); let mut tester = VmChipTestBuilder::default(); - let mut chip = Rv32BranchEqualTestChip::::new( - TestAdapterChip::new( - vec![[a.map(F::from_canonical_u32), b.map(F::from_canonical_u32)].concat()], - vec![if cmp_result { Some(imm) } else { None }], - ExecutionBridge::new(tester.execution_bus(), tester.program_bus()), - ), - BranchEqualCoreChip::new(BranchEqualOpcode::CLASS_OFFSET, 4), - tester.offline_memory_mutex_arc(), - ); + let mut chip = create_test_chip(&mut tester); - tester.execute( + set_and_execute( + &mut tester, &mut chip, - &Instruction::from_usize(opcode.global_opcode(), [0, 0, imm as usize, 1, 1]), + &mut rng, + opcode, + Some(a), + Some(b), + Some(imm), ); - let trace_width = chip.trace_width(); - let adapter_width = BaseAir::::width(chip.adapter.air()); - + let adapter_width = BaseAir::::width(&chip.air.adapter); let modify_trace = |trace: &mut DenseMatrix| { let mut values = trace.row_slice(0).to_vec(); let cols: &mut BranchEqualCoreCols = values.split_at_mut(adapter_width).1.borrow_mut(); - cols.cmp_result = F::from_bool(cmp_result); - if let Some(diff_inv_marker) = diff_inv_marker { + if let Some(cmp_result) = prank_cmp_result { + cols.cmp_result = F::from_bool(cmp_result); + } + if let Some(diff_inv_marker) = prank_diff_inv_marker { cols.diff_inv_marker = diff_inv_marker.map(F::from_canonical_u32); } - *trace = RowMajorMatrix::new(values, trace_width); + *trace = RowMajorMatrix::new(values, trace.width()); }; disable_debug_builder(); @@ -167,88 +175,96 @@ fn run_rv32_beq_negative_test( .build() .load_and_prank_trace(chip, modify_trace) .finalize(); - tester.simple_test_with_expected_error(VerificationError::OodEvaluationMismatch); + tester.simple_test_with_expected_error(get_verification_error(interaction_error)); } #[test] fn rv32_beq_wrong_cmp_negative_test() { - run_rv32_beq_negative_test( + run_negative_branch_eq_test( BranchEqualOpcode::BEQ, [0, 0, 7, 0], [0, 0, 0, 7], - true, + Some(true), None, + false, ); - run_rv32_beq_negative_test( + run_negative_branch_eq_test( BranchEqualOpcode::BEQ, [0, 0, 7, 0], [0, 0, 7, 0], - false, + Some(false), None, + false, ); } #[test] fn rv32_beq_zero_inv_marker_negative_test() { - run_rv32_beq_negative_test( + run_negative_branch_eq_test( BranchEqualOpcode::BEQ, [0, 0, 7, 0], [0, 0, 0, 7], - true, + Some(true), Some([0, 0, 0, 0]), + false, ); } #[test] fn rv32_beq_invalid_inv_marker_negative_test() { - run_rv32_beq_negative_test( + run_negative_branch_eq_test( BranchEqualOpcode::BEQ, [0, 0, 7, 0], [0, 0, 7, 0], - false, + Some(false), Some([0, 0, 1, 0]), + false, ); } #[test] fn rv32_bne_wrong_cmp_negative_test() { - run_rv32_beq_negative_test( + run_negative_branch_eq_test( BranchEqualOpcode::BNE, [0, 0, 7, 0], [0, 0, 0, 7], - false, + Some(false), None, + false, ); - run_rv32_beq_negative_test( + run_negative_branch_eq_test( BranchEqualOpcode::BNE, [0, 0, 7, 0], [0, 0, 7, 0], - true, + Some(true), None, + false, ); } #[test] fn rv32_bne_zero_inv_marker_negative_test() { - run_rv32_beq_negative_test( + run_negative_branch_eq_test( BranchEqualOpcode::BNE, [0, 0, 7, 0], [0, 0, 0, 7], - false, + Some(false), Some([0, 0, 0, 0]), + false, ); } #[test] fn rv32_bne_invalid_inv_marker_negative_test() { - run_rv32_beq_negative_test( + run_negative_branch_eq_test( BranchEqualOpcode::BNE, [0, 0, 7, 0], [0, 0, 7, 0], - true, + Some(true), Some([0, 0, 1, 0]), + false, ); } @@ -259,66 +275,61 @@ fn rv32_bne_invalid_inv_marker_negative_test() { /////////////////////////////////////////////////////////////////////////////////////// #[test] -fn execute_pc_increment_sanity_test() { - let core = - BranchEqualCoreChip::::new(BranchEqualOpcode::CLASS_OFFSET, 4); - - let mut instruction = Instruction:: { - opcode: BranchEqualOpcode::BEQ.global_opcode(), - c: F::from_canonical_u8(8), - ..Default::default() - }; - let x: [F; RV32_REGISTER_NUM_LIMBS] = [19, 4, 1790, 60].map(F::from_canonical_u32); - let y: [F; RV32_REGISTER_NUM_LIMBS] = [19, 32, 1804, 60].map(F::from_canonical_u32); - - let result = as VmCoreChip< - F, - BasicAdapterInterface, 2, 0, RV32_REGISTER_NUM_LIMBS, 0>, - >>::execute_instruction(&core, &instruction, 0, [x, y]); - let (output, _) = result.expect("execute_instruction failed"); - assert!(output.to_pc.is_none()); - - instruction.opcode = BranchEqualOpcode::BNE.global_opcode(); - let result = as VmCoreChip< - F, - BasicAdapterInterface, 2, 0, RV32_REGISTER_NUM_LIMBS, 0>, - >>::execute_instruction(&core, &instruction, 0, [x, y]); - let (output, _) = result.expect("execute_instruction failed"); - assert!(output.to_pc.is_some()); - assert_eq!(output.to_pc.unwrap(), 8); +fn execute_roundtrip_sanity_test() { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default(); + let mut chip = create_test_chip(&mut tester); + + let x = [19, 4, 179, 60]; + let y = [19, 32, 180, 60]; + set_and_execute( + &mut tester, + &mut chip, + &mut rng, + BranchEqualOpcode::BEQ, + Some(x), + Some(y), + Some(8), + ); + + set_and_execute( + &mut tester, + &mut chip, + &mut rng, + BranchEqualOpcode::BNE, + Some(x), + Some(y), + Some(8), + ); } #[test] fn run_eq_sanity_test() { - let x: [u32; RV32_REGISTER_NUM_LIMBS] = [19, 4, 1790, 60]; - let (cmp_result, _, diff_val) = - run_eq::(BranchEqualOpcode::BEQ, &x, &x); + let x: [u8; RV32_REGISTER_NUM_LIMBS] = [19, 4, 17, 60]; + let (cmp_result, _, diff_val) = run_eq::(true, &x, &x); assert!(cmp_result); assert_eq!(diff_val, F::ZERO); - let (cmp_result, _, diff_val) = - run_eq::(BranchEqualOpcode::BNE, &x, &x); + let (cmp_result, _, diff_val) = run_eq::(false, &x, &x); assert!(!cmp_result); assert_eq!(diff_val, F::ZERO); } #[test] fn run_ne_sanity_test() { - let x: [u32; RV32_REGISTER_NUM_LIMBS] = [19, 4, 1790, 60]; - let y: [u32; RV32_REGISTER_NUM_LIMBS] = [19, 32, 1804, 60]; - let (cmp_result, diff_idx, diff_val) = - run_eq::(BranchEqualOpcode::BEQ, &x, &y); + let x: [u8; RV32_REGISTER_NUM_LIMBS] = [19, 4, 17, 60]; + let y: [u8; RV32_REGISTER_NUM_LIMBS] = [19, 32, 18, 60]; + let (cmp_result, diff_idx, diff_val) = run_eq::(true, &x, &y); assert!(!cmp_result); assert_eq!( - diff_val * (F::from_canonical_u32(x[diff_idx]) - F::from_canonical_u32(y[diff_idx])), + diff_val * (F::from_canonical_u8(x[diff_idx]) - F::from_canonical_u8(y[diff_idx])), F::ONE ); - let (cmp_result, diff_idx, diff_val) = - run_eq::(BranchEqualOpcode::BNE, &x, &y); + let (cmp_result, diff_idx, diff_val) = run_eq::(false, &x, &y); assert!(cmp_result); assert_eq!( - diff_val * (F::from_canonical_u32(x[diff_idx]) - F::from_canonical_u32(y[diff_idx])), + diff_val * (F::from_canonical_u8(x[diff_idx]) - F::from_canonical_u8(y[diff_idx])), F::ONE ); } diff --git a/extensions/rv32im/circuit/src/branch_lt/core.rs b/extensions/rv32im/circuit/src/branch_lt/core.rs index 3eebb02146..9ce03a932f 100644 --- a/extensions/rv32im/circuit/src/branch_lt/core.rs +++ b/extensions/rv32im/circuit/src/branch_lt/core.rs @@ -1,18 +1,25 @@ -use std::{ - array, - borrow::{Borrow, BorrowMut}, -}; - -use openvm_circuit::arch::{ - AdapterAirContext, AdapterRuntimeContext, ImmInstruction, Result, VmAdapterInterface, - VmCoreAir, VmCoreChip, +use std::borrow::{Borrow, BorrowMut}; + +use openvm_circuit::{ + arch::{ + execution_mode::{E1ExecutionCtx, E2ExecutionCtx}, + get_record_from_slice, AdapterAirContext, AdapterTraceFiller, AdapterTraceStep, + E2PreCompute, EmptyAdapterCoreLayout, ExecuteFunc, + ExecutionError::InvalidInstruction, + ImmInstruction, RecordArena, Result, StepExecutorE1, StepExecutorE2, TraceFiller, + TraceStep, VmAdapterInterface, VmCoreAir, VmSegmentState, VmStateMut, + }, + system::memory::{online::TracingMemory, MemoryAuxColsFactory}, }; use openvm_circuit_primitives::{ bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, utils::not, + AlignedBytesBorrow, }; use openvm_circuit_primitives_derive::AlignedBorrow; -use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode}; +use openvm_instructions::{ + instruction::Instruction, program::DEFAULT_PC_STEP, riscv::RV32_REGISTER_AS, LocalOpcode, +}; use openvm_rv32im_transpiler::BranchLessThanOpcode; use openvm_stark_backend::{ interaction::InteractionBuilder, @@ -20,8 +27,6 @@ use openvm_stark_backend::{ p3_field::{Field, FieldAlgebra, PrimeField32}, rap::BaseAirWithPublicValues, }; -use serde::{Deserialize, Serialize}; -use serde_big_array::BigArray; use strum::IntoEnumIterator; #[repr(C)] @@ -53,7 +58,7 @@ pub struct BranchLessThanCoreCols { pub bus: BitwiseOperationLookupBus, offset: usize, @@ -188,183 +193,357 @@ where } #[repr(C)] -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct BranchLessThanCoreRecord { - #[serde(with = "BigArray")] - pub a: [T; NUM_LIMBS], - #[serde(with = "BigArray")] - pub b: [T; NUM_LIMBS], - pub cmp_result: T, - pub cmp_lt: T, - pub imm: T, - pub a_msb_f: T, - pub b_msb_f: T, - pub diff_val: T, - pub diff_idx: usize, - pub opcode: BranchLessThanOpcode, +#[derive(AlignedBytesBorrow, Debug)] +pub struct BranchLessThanCoreRecord { + pub a: [u8; NUM_LIMBS], + pub b: [u8; NUM_LIMBS], + pub imm: u32, + pub local_opcode: u8, } -pub struct BranchLessThanCoreChip { - pub air: BranchLessThanCoreAir, +#[derive(derive_new::new)] +pub struct BranchLessThanStep { + adapter: A, pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, + pub offset: usize, } -impl BranchLessThanCoreChip { - pub fn new( - bitwise_lookup_chip: SharedBitwiseOperationLookupChip, - offset: usize, - ) -> Self { - Self { - air: BranchLessThanCoreAir { - bus: bitwise_lookup_chip.bus(), - offset, - }, - bitwise_lookup_chip, +impl TraceStep + for BranchLessThanStep +where + F: PrimeField32, + A: 'static + + for<'a> AdapterTraceStep, WriteData = ()>, +{ + type RecordLayout = EmptyAdapterCoreLayout; + type RecordMut<'a> = ( + A::RecordMut<'a>, + &'a mut BranchLessThanCoreRecord, + ); + + fn get_opcode_name(&self, opcode: usize) -> String { + format!( + "{:?}", + BranchLessThanOpcode::from_usize(opcode - self.offset) + ) + } + + fn execute<'buf, RA>( + &mut self, + state: VmStateMut, CTX>, + instruction: &Instruction, + arena: &'buf mut RA, + ) -> Result<()> + where + RA: RecordArena<'buf, Self::RecordLayout, Self::RecordMut<'buf>>, + { + let &Instruction { opcode, c: imm, .. } = instruction; + + let (mut adapter_record, core_record) = arena.alloc(EmptyAdapterCoreLayout::new()); + + A::start(*state.pc, state.memory, &mut adapter_record); + + let [rs1, rs2] = self + .adapter + .read(state.memory, instruction, &mut adapter_record) + .into(); + + core_record.a = rs1; + core_record.b = rs2; + core_record.imm = imm.as_canonical_u32(); + core_record.local_opcode = opcode.local_opcode_idx(self.offset) as u8; + + if run_cmp::(core_record.local_opcode, &rs1, &rs2).0 { + *state.pc = (F::from_canonical_u32(*state.pc) + imm).as_canonical_u32(); + } else { + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); } + + Ok(()) } } -impl, const NUM_LIMBS: usize, const LIMB_BITS: usize> - VmCoreChip for BranchLessThanCoreChip +impl TraceFiller + for BranchLessThanStep where - I::Reads: Into<[[F; NUM_LIMBS]; 2]>, - I::Writes: Default, + F: PrimeField32, + A: 'static + AdapterTraceFiller, { - type Record = BranchLessThanCoreRecord; - type Air = BranchLessThanCoreAir; + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + + let record: &BranchLessThanCoreRecord = + unsafe { get_record_from_slice(&mut core_row, ()) }; + + self.adapter.fill_trace_row(mem_helper, adapter_row); + let core_row: &mut BranchLessThanCoreCols = core_row.borrow_mut(); + + let signed = record.local_opcode == BranchLessThanOpcode::BLT as u8 + || record.local_opcode == BranchLessThanOpcode::BGE as u8; + let ge_op = record.local_opcode == BranchLessThanOpcode::BGE as u8 + || record.local_opcode == BranchLessThanOpcode::BGEU as u8; - #[allow(clippy::type_complexity)] - fn execute_instruction( - &self, - instruction: &Instruction, - from_pc: u32, - reads: I::Reads, - ) -> Result<(AdapterRuntimeContext, Self::Record)> { - let Instruction { opcode, c: imm, .. } = *instruction; - let blt_opcode = BranchLessThanOpcode::from_usize(opcode.local_opcode_idx(self.air.offset)); - - let data: [[F; NUM_LIMBS]; 2] = reads.into(); - let a = data[0].map(|x| x.as_canonical_u32()); - let b = data[1].map(|y| y.as_canonical_u32()); let (cmp_result, diff_idx, a_sign, b_sign) = - run_cmp::(blt_opcode, &a, &b); + run_cmp::(record.local_opcode, &record.a, &record.b); - let signed = matches!( - blt_opcode, - BranchLessThanOpcode::BLT | BranchLessThanOpcode::BGE - ); - let ge_opcode = matches!( - blt_opcode, - BranchLessThanOpcode::BGE | BranchLessThanOpcode::BGEU - ); - let cmp_lt = cmp_result ^ ge_opcode; + let cmp_lt = cmp_result ^ ge_op; // We range check (a_msb_f + 128) and (b_msb_f + 128) if signed, // a_msb_f and b_msb_f if not let (a_msb_f, a_msb_range) = if a_sign { ( - -F::from_canonical_u32((1 << LIMB_BITS) - a[NUM_LIMBS - 1]), - a[NUM_LIMBS - 1] - (1 << (LIMB_BITS - 1)), + -F::from_canonical_u32((1 << LIMB_BITS) - record.a[NUM_LIMBS - 1] as u32), + record.a[NUM_LIMBS - 1] as u32 - (1 << (LIMB_BITS - 1)), ) } else { ( - F::from_canonical_u32(a[NUM_LIMBS - 1]), - a[NUM_LIMBS - 1] + ((signed as u32) << (LIMB_BITS - 1)), + F::from_canonical_u32(record.a[NUM_LIMBS - 1] as u32), + record.a[NUM_LIMBS - 1] as u32 + ((signed as u32) << (LIMB_BITS - 1)), ) }; let (b_msb_f, b_msb_range) = if b_sign { ( - -F::from_canonical_u32((1 << LIMB_BITS) - b[NUM_LIMBS - 1]), - b[NUM_LIMBS - 1] - (1 << (LIMB_BITS - 1)), + -F::from_canonical_u32((1 << LIMB_BITS) - record.b[NUM_LIMBS - 1] as u32), + record.b[NUM_LIMBS - 1] as u32 - (1 << (LIMB_BITS - 1)), ) } else { ( - F::from_canonical_u32(b[NUM_LIMBS - 1]), - b[NUM_LIMBS - 1] + ((signed as u32) << (LIMB_BITS - 1)), + F::from_canonical_u32(record.b[NUM_LIMBS - 1] as u32), + record.b[NUM_LIMBS - 1] as u32 + ((signed as u32) << (LIMB_BITS - 1)), ) }; - self.bitwise_lookup_chip - .request_range(a_msb_range, b_msb_range); - let diff_val = if diff_idx == NUM_LIMBS { - 0 + core_row.diff_val = if diff_idx == NUM_LIMBS { + F::ZERO } else if diff_idx == (NUM_LIMBS - 1) { if cmp_lt { b_msb_f - a_msb_f } else { a_msb_f - b_msb_f } - .as_canonical_u32() } else if cmp_lt { - b[diff_idx] - a[diff_idx] + F::from_canonical_u8(record.b[diff_idx] - record.a[diff_idx]) } else { - a[diff_idx] - b[diff_idx] + F::from_canonical_u8(record.a[diff_idx] - record.b[diff_idx]) }; + self.bitwise_lookup_chip + .request_range(a_msb_range, b_msb_range); + + core_row.diff_marker = [F::ZERO; NUM_LIMBS]; + if diff_idx != NUM_LIMBS { - self.bitwise_lookup_chip.request_range(diff_val - 1, 0); + self.bitwise_lookup_chip + .request_range(core_row.diff_val.as_canonical_u32() - 1, 0); + core_row.diff_marker[diff_idx] = F::ONE; } - let output = AdapterRuntimeContext { - to_pc: cmp_result.then_some((F::from_canonical_u32(from_pc) + imm).as_canonical_u32()), - writes: Default::default(), - }; - let record = BranchLessThanCoreRecord { - opcode: blt_opcode, - a: data[0], - b: data[1], - cmp_result: F::from_bool(cmp_result), - cmp_lt: F::from_bool(cmp_lt), - imm, - a_msb_f, - b_msb_f, - diff_val: F::from_canonical_u32(diff_val), - diff_idx, - }; + core_row.cmp_lt = F::from_bool(cmp_lt); + core_row.b_msb_f = b_msb_f; + core_row.a_msb_f = a_msb_f; + core_row.opcode_bgeu_flag = + F::from_bool(record.local_opcode == BranchLessThanOpcode::BGEU as u8); + core_row.opcode_bge_flag = + F::from_bool(record.local_opcode == BranchLessThanOpcode::BGE as u8); + core_row.opcode_bltu_flag = + F::from_bool(record.local_opcode == BranchLessThanOpcode::BLTU as u8); + core_row.opcode_blt_flag = + F::from_bool(record.local_opcode == BranchLessThanOpcode::BLT as u8); + + core_row.imm = F::from_canonical_u32(record.imm); + core_row.cmp_result = F::from_bool(cmp_result); + core_row.b = record.b.map(F::from_canonical_u8); + core_row.a = record.a.map(F::from_canonical_u8); + } +} + +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct BranchLePreCompute { + imm: isize, + a: u8, + b: u8, +} - Ok((output, record)) +impl StepExecutorE1 + for BranchLessThanStep +where + F: PrimeField32, +{ + #[inline(always)] + fn pre_compute_size(&self) -> usize { + size_of::() } - fn get_opcode_name(&self, opcode: usize) -> String { - format!( - "{:?}", - BranchLessThanOpcode::from_usize(opcode - self.air.offset) - ) + #[inline(always)] + fn pre_compute_e1( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result> { + let data: &mut BranchLePreCompute = data.borrow_mut(); + let local_opcode = self.pre_compute_impl(pc, inst, data)?; + let fn_ptr = match local_opcode { + BranchLessThanOpcode::BLT => execute_e1_impl::<_, _, BltOp>, + BranchLessThanOpcode::BLTU => execute_e1_impl::<_, _, BltuOp>, + BranchLessThanOpcode::BGE => execute_e1_impl::<_, _, BgeOp>, + BranchLessThanOpcode::BGEU => execute_e1_impl::<_, _, BgeuOp>, + }; + Ok(fn_ptr) + } +} +impl StepExecutorE2 + for BranchLessThanStep +where + F: PrimeField32, +{ + fn e2_pre_compute_size(&self) -> usize { + size_of::>() + } + + fn pre_compute_e2( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result> + where + Ctx: E2ExecutionCtx, + { + let data: &mut E2PreCompute = data.borrow_mut(); + data.chip_idx = chip_idx as u32; + let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?; + let fn_ptr = match local_opcode { + BranchLessThanOpcode::BLT => execute_e2_impl::<_, _, BltOp>, + BranchLessThanOpcode::BLTU => execute_e2_impl::<_, _, BltuOp>, + BranchLessThanOpcode::BGE => execute_e2_impl::<_, _, BgeOp>, + BranchLessThanOpcode::BGEU => execute_e2_impl::<_, _, BgeuOp>, + }; + Ok(fn_ptr) } +} + +#[inline(always)] +unsafe fn execute_e12_impl( + pre_compute: &BranchLePreCompute, + vm_state: &mut VmSegmentState, +) { + let rs1 = vm_state.vm_read::(RV32_REGISTER_AS, pre_compute.a as u32); + let rs2 = vm_state.vm_read::(RV32_REGISTER_AS, pre_compute.b as u32); + let jmp = ::compute(rs1, rs2); + if jmp { + vm_state.pc = (vm_state.pc as isize + pre_compute.imm) as u32; + } else { + vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP); + }; + vm_state.instret += 1; +} - fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) { - let row_slice: &mut BranchLessThanCoreCols<_, NUM_LIMBS, LIMB_BITS> = - row_slice.borrow_mut(); - row_slice.a = record.a; - row_slice.b = record.b; - row_slice.cmp_result = record.cmp_result; - row_slice.cmp_lt = record.cmp_lt; - row_slice.imm = record.imm; - row_slice.a_msb_f = record.a_msb_f; - row_slice.b_msb_f = record.b_msb_f; - row_slice.diff_marker = array::from_fn(|i| F::from_bool(i == record.diff_idx)); - row_slice.diff_val = record.diff_val; - row_slice.opcode_blt_flag = F::from_bool(record.opcode == BranchLessThanOpcode::BLT); - row_slice.opcode_bltu_flag = F::from_bool(record.opcode == BranchLessThanOpcode::BLTU); - row_slice.opcode_bge_flag = F::from_bool(record.opcode == BranchLessThanOpcode::BGE); - row_slice.opcode_bgeu_flag = F::from_bool(record.opcode == BranchLessThanOpcode::BGEU); +unsafe fn execute_e1_impl( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &BranchLePreCompute = pre_compute.borrow(); + execute_e12_impl::(pre_compute, vm_state); +} + +unsafe fn execute_e2_impl( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &E2PreCompute = pre_compute.borrow(); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, 1); + execute_e12_impl::(&pre_compute.data, vm_state); +} + +impl + BranchLessThanStep +{ + #[inline(always)] + fn pre_compute_impl( + &self, + pc: u32, + inst: &Instruction, + data: &mut BranchLePreCompute, + ) -> Result { + let &Instruction { + opcode, a, b, c, d, .. + } = inst; + let local_opcode = BranchLessThanOpcode::from_usize(opcode.local_opcode_idx(self.offset)); + let c = c.as_canonical_u32(); + let imm = if F::ORDER_U32 - c < c { + -((F::ORDER_U32 - c) as isize) + } else { + c as isize + }; + if d.as_canonical_u32() != RV32_REGISTER_AS { + return Err(InvalidInstruction(pc)); + } + *data = BranchLePreCompute { + imm, + a: a.as_canonical_u32() as u8, + b: b.as_canonical_u32() as u8, + }; + Ok(local_opcode) } +} - fn air(&self) -> &Self::Air { - &self.air +trait BranchLessThanOp { + fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> bool; +} +struct BltOp; +struct BltuOp; +struct BgeOp; +struct BgeuOp; + +impl BranchLessThanOp for BltOp { + #[inline(always)] + fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> bool { + let rs1 = i32::from_le_bytes(rs1); + let rs2 = i32::from_le_bytes(rs2); + rs1 < rs2 + } +} +impl BranchLessThanOp for BltuOp { + #[inline(always)] + fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> bool { + let rs1 = u32::from_le_bytes(rs1); + let rs2 = u32::from_le_bytes(rs2); + rs1 < rs2 + } +} +impl BranchLessThanOp for BgeOp { + #[inline(always)] + fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> bool { + let rs1 = i32::from_le_bytes(rs1); + let rs2 = i32::from_le_bytes(rs2); + rs1 >= rs2 + } +} +impl BranchLessThanOp for BgeuOp { + #[inline(always)] + fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> bool { + let rs1 = u32::from_le_bytes(rs1); + let rs2 = u32::from_le_bytes(rs2); + rs1 >= rs2 } } // Returns (cmp_result, diff_idx, x_sign, y_sign) +#[inline(always)] pub(super) fn run_cmp( - local_opcode: BranchLessThanOpcode, - x: &[u32; NUM_LIMBS], - y: &[u32; NUM_LIMBS], + local_opcode: u8, + x: &[u8; NUM_LIMBS], + y: &[u8; NUM_LIMBS], ) -> (bool, usize, bool, bool) { - let signed = - local_opcode == BranchLessThanOpcode::BLT || local_opcode == BranchLessThanOpcode::BGE; - let ge_op = - local_opcode == BranchLessThanOpcode::BGE || local_opcode == BranchLessThanOpcode::BGEU; + let signed = local_opcode == BranchLessThanOpcode::BLT as u8 + || local_opcode == BranchLessThanOpcode::BGE as u8; + let ge_op = local_opcode == BranchLessThanOpcode::BGE as u8 + || local_opcode == BranchLessThanOpcode::BGEU as u8; let x_sign = (x[NUM_LIMBS - 1] >> (LIMB_BITS - 1) == 1) && signed; let y_sign = (y[NUM_LIMBS - 1] >> (LIMB_BITS - 1) == 1) && signed; for i in (0..NUM_LIMBS).rev() { diff --git a/extensions/rv32im/circuit/src/branch_lt/mod.rs b/extensions/rv32im/circuit/src/branch_lt/mod.rs index b0bf8fc417..1c68300fb5 100644 --- a/extensions/rv32im/circuit/src/branch_lt/mod.rs +++ b/extensions/rv32im/circuit/src/branch_lt/mod.rs @@ -1,7 +1,7 @@ -use openvm_circuit::arch::VmChipWrapper; +use openvm_circuit::arch::{MatrixRecordArena, NewVmChipWrapper, VmAirWrapper}; use super::adapters::{RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}; -use crate::adapters::Rv32BranchAdapterChip; +use crate::adapters::{Rv32BranchAdapterAir, Rv32BranchAdapterStep}; mod core; pub use core::*; @@ -9,8 +9,11 @@ pub use core::*; #[cfg(test)] mod tests; -pub type Rv32BranchLessThanChip = VmChipWrapper< - F, - Rv32BranchAdapterChip, - BranchLessThanCoreChip, +pub type Rv32BranchLessThanAir = VmAirWrapper< + Rv32BranchAdapterAir, + BranchLessThanCoreAir, >; +pub type Rv32BranchLessThanStep = + BranchLessThanStep; +pub type Rv32BranchLessThanChip = + NewVmChipWrapper>; diff --git a/extensions/rv32im/circuit/src/branch_lt/tests.rs b/extensions/rv32im/circuit/src/branch_lt/tests.rs index 8c1d7f697a..016a829fe3 100644 --- a/extensions/rv32im/circuit/src/branch_lt/tests.rs +++ b/extensions/rv32im/circuit/src/branch_lt/tests.rs @@ -1,12 +1,11 @@ -use std::borrow::BorrowMut; +use std::{array, borrow::BorrowMut}; use openvm_circuit::{ arch::{ - testing::{memory::gen_pointer, TestAdapterChip, VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}, - BasicAdapterInterface, ExecutionBridge, ImmInstruction, InstructionExecutor, VmAdapterChip, - VmChipWrapper, VmCoreChip, + testing::{memory::gen_pointer, VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}, + InstructionExecutor, VmAirWrapper, }, - utils::{generate_long_number, i32_to_f}, + utils::i32_to_f, }; use openvm_circuit_primitives::bitwise_op_lookup::{ BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, @@ -21,46 +20,76 @@ use openvm_stark_backend::{ Matrix, }, utils::disable_debug_builder, - verifier::VerificationError, - ChipUsageGetter, }; use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; use rand::{rngs::StdRng, Rng}; +use test_case::test_case; use super::{ - core::{run_cmp, BranchLessThanCoreChip}, + core::{run_cmp, BranchLessThanStep}, Rv32BranchLessThanChip, }; use crate::{ adapters::{ - Rv32BranchAdapterChip, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS, RV_B_TYPE_IMM_BITS, + Rv32BranchAdapterAir, Rv32BranchAdapterStep, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS, + RV_B_TYPE_IMM_BITS, }, branch_lt::BranchLessThanCoreCols, + test_utils::get_verification_error, + BranchLessThanCoreAir, }; type F = BabyBear; +const MAX_INS_CAPACITY: usize = 128; +const ABS_MAX_IMM: i32 = 1 << (RV_B_TYPE_IMM_BITS - 1); -////////////////////////////////////////////////////////////////////////////////////// -// POSITIVE TESTS -// -// Randomly generate computations and execute, ensuring that the generated trace -// passes all constraints. -////////////////////////////////////////////////////////////////////////////////////// +fn create_test_chip( + tester: &mut VmChipTestBuilder, +) -> ( + Rv32BranchLessThanChip, + SharedBitwiseOperationLookupChip, +) { + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); + let mut chip = Rv32BranchLessThanChip::::new( + VmAirWrapper::new( + Rv32BranchAdapterAir::new(tester.execution_bridge(), tester.memory_bridge()), + BranchLessThanCoreAir::new(bitwise_bus, BranchLessThanOpcode::CLASS_OFFSET), + ), + BranchLessThanStep::new( + Rv32BranchAdapterStep::new(), + bitwise_chip.clone(), + BranchLessThanOpcode::CLASS_OFFSET, + ), + tester.memory_helper(), + ); + chip.set_trace_buffer_height(MAX_INS_CAPACITY); + + (chip, bitwise_chip) +} #[allow(clippy::too_many_arguments)] -fn run_rv32_branch_lt_rand_execute>( +fn set_and_execute>( tester: &mut VmChipTestBuilder, chip: &mut E, - opcode: BranchLessThanOpcode, - a: [u32; RV32_REGISTER_NUM_LIMBS], - b: [u32; RV32_REGISTER_NUM_LIMBS], - imm: i32, rng: &mut StdRng, + opcode: BranchLessThanOpcode, + a: Option<[u8; RV32_REGISTER_NUM_LIMBS]>, + b: Option<[u8; RV32_REGISTER_NUM_LIMBS]>, + imm: Option, ) { + let a = a.unwrap_or(array::from_fn(|_| rng.gen_range(0..=u8::MAX))); + let b = b.unwrap_or(if rng.gen_bool(0.5) { + a + } else { + array::from_fn(|_| rng.gen_range(0..=u8::MAX)) + }); + + let imm = imm.unwrap_or(rng.gen_range((-ABS_MAX_IMM)..ABS_MAX_IMM)); let rs1 = gen_pointer(rng, 4); let rs2 = gen_pointer(rng, 4); - tester.write::(1, rs1, a.map(F::from_canonical_u32)); - tester.write::(1, rs2, b.map(F::from_canonical_u32)); + tester.write::(1, rs1, a.map(F::from_canonical_u8)); + tester.write::(1, rs2, b.map(F::from_canonical_u8)); tester.execute_with_pc( chip, @@ -75,7 +104,8 @@ fn run_rv32_branch_lt_rand_execute>( rng.gen_range(imm.unsigned_abs()..(1 << (PC_BITS - 1))), ); - let (cmp_result, _, _, _) = run_cmp::(opcode, &a, &b); + let (cmp_result, _, _, _) = + run_cmp::(opcode.local_usize() as u8, &a, &b); let from_pc = tester.execution.last_from_pc().as_canonical_u32() as i32; let to_pc = tester.execution.last_to_pc().as_canonical_u32() as i32; let pc_inc = if cmp_result { imm } else { 4 }; @@ -83,93 +113,57 @@ fn run_rv32_branch_lt_rand_execute>( assert_eq!(to_pc, from_pc + pc_inc); } -fn run_rv32_branch_lt_rand_test(opcode: BranchLessThanOpcode, num_ops: usize) { - let mut rng = create_seeded_rng(); - const ABS_MAX_BRANCH: i32 = 1 << (RV_B_TYPE_IMM_BITS - 1); - - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); +////////////////////////////////////////////////////////////////////////////////////// +// POSITIVE TESTS +// +// Randomly generate computations and execute, ensuring that the generated trace +// passes all constraints. +////////////////////////////////////////////////////////////////////////////////////// +#[test_case(BranchLessThanOpcode::BLT, 100)] +#[test_case(BranchLessThanOpcode::BLTU, 100)] +#[test_case(BranchLessThanOpcode::BGE, 100)] +#[test_case(BranchLessThanOpcode::BGEU, 100)] +fn rand_branch_lt_test(opcode: BranchLessThanOpcode, num_ops: usize) { + let mut rng = create_seeded_rng(); let mut tester = VmChipTestBuilder::default(); - let mut chip = Rv32BranchLessThanChip::::new( - Rv32BranchAdapterChip::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - ), - BranchLessThanCoreChip::new(bitwise_chip.clone(), BranchLessThanOpcode::CLASS_OFFSET), - tester.offline_memory_mutex_arc(), - ); + let (mut chip, bitwise_chip) = create_test_chip(&mut tester); for _ in 0..num_ops { - let a = generate_long_number::(&mut rng); - let b = if rng.gen_bool(0.5) { - a - } else { - generate_long_number::(&mut rng) - }; - let imm = rng.gen_range((-ABS_MAX_BRANCH)..ABS_MAX_BRANCH); - run_rv32_branch_lt_rand_execute(&mut tester, &mut chip, opcode, a, b, imm, &mut rng); + set_and_execute(&mut tester, &mut chip, &mut rng, opcode, None, None, None); } // Test special case where b = c - run_rv32_branch_lt_rand_execute( + set_and_execute( &mut tester, &mut chip, - opcode, - [101, 128, 202, 255], - [101, 128, 202, 255], - 24, &mut rng, + opcode, + Some([101, 128, 202, 255]), + Some([101, 128, 202, 255]), + Some(24), ); - run_rv32_branch_lt_rand_execute( + set_and_execute( &mut tester, &mut chip, - opcode, - [36, 0, 0, 0], - [36, 0, 0, 0], - 24, &mut rng, + opcode, + Some([36, 0, 0, 0]), + Some([36, 0, 0, 0]), + Some(24), ); let tester = tester.build().load(chip).load(bitwise_chip).finalize(); tester.simple_test().expect("Verification failed"); } -#[test] -fn rv32_blt_rand_test() { - run_rv32_branch_lt_rand_test(BranchLessThanOpcode::BLT, 10); -} - -#[test] -fn rv32_bltu_rand_test() { - run_rv32_branch_lt_rand_test(BranchLessThanOpcode::BLTU, 12); -} - -#[test] -fn rv32_bge_rand_test() { - run_rv32_branch_lt_rand_test(BranchLessThanOpcode::BGE, 12); -} - -#[test] -fn rv32_bgeu_rand_test() { - run_rv32_branch_lt_rand_test(BranchLessThanOpcode::BGEU, 12); -} - ////////////////////////////////////////////////////////////////////////////////////// // NEGATIVE TESTS // // Given a fake trace of a single operation, setup a chip and run the test. We replace -// the write part of the trace and check that the core chip throws the expected error. -// A dummy adapter is used so memory interactions don't indirectly cause false passes. +// part of the trace and check that the chip throws the expected error. ////////////////////////////////////////////////////////////////////////////////////// -type Rv32BranchLessThanTestChip = VmChipWrapper< - F, - TestAdapterChip, - BranchLessThanCoreChip, ->; - #[derive(Clone, Copy, Default, PartialEq)] struct BranchLessThanPrankValues { pub a_msb: Option, @@ -179,66 +173,31 @@ struct BranchLessThanPrankValues { } #[allow(clippy::too_many_arguments)] -fn run_rv32_blt_negative_test( +fn run_negative_branch_lt_test( opcode: BranchLessThanOpcode, - a: [u32; RV32_REGISTER_NUM_LIMBS], - b: [u32; RV32_REGISTER_NUM_LIMBS], - cmp_result: bool, + a: [u8; RV32_REGISTER_NUM_LIMBS], + b: [u8; RV32_REGISTER_NUM_LIMBS], + prank_cmp_result: bool, prank_vals: BranchLessThanPrankValues, interaction_error: bool, ) { - let imm = 16u32; - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - - let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let mut chip = Rv32BranchLessThanTestChip::::new( - TestAdapterChip::new( - vec![[a.map(F::from_canonical_u32), b.map(F::from_canonical_u32)].concat()], - vec![if cmp_result { Some(imm) } else { None }], - ExecutionBridge::new(tester.execution_bus(), tester.program_bus()), - ), - BranchLessThanCoreChip::new(bitwise_chip.clone(), BranchLessThanOpcode::CLASS_OFFSET), - tester.offline_memory_mutex_arc(), - ); + let imm = 16i32; + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default(); + let (mut chip, bitwise_chip) = create_test_chip(&mut tester); - tester.execute( + set_and_execute( + &mut tester, &mut chip, - &Instruction::from_usize(opcode.global_opcode(), [0, 0, imm as usize, 1, 1]), + &mut rng, + opcode, + Some(a), + Some(b), + Some(imm), ); - let trace_width = chip.trace_width(); - let adapter_width = BaseAir::::width(chip.adapter.air()); + let adapter_width = BaseAir::::width(&chip.air.adapter); let ge_opcode = opcode == BranchLessThanOpcode::BGE || opcode == BranchLessThanOpcode::BGEU; - let (_, _, a_sign, b_sign) = run_cmp::(opcode, &a, &b); - - if prank_vals != BranchLessThanPrankValues::default() { - debug_assert!(prank_vals.diff_val.is_some()); - let a_msb = prank_vals.a_msb.unwrap_or( - a[RV32_REGISTER_NUM_LIMBS - 1] as i32 - if a_sign { 1 << RV32_CELL_BITS } else { 0 }, - ); - let b_msb = prank_vals.b_msb.unwrap_or( - b[RV32_REGISTER_NUM_LIMBS - 1] as i32 - if b_sign { 1 << RV32_CELL_BITS } else { 0 }, - ); - let signed_offset = match opcode { - BranchLessThanOpcode::BLT | BranchLessThanOpcode::BGE => 1 << (RV32_CELL_BITS - 1), - _ => 0, - }; - - bitwise_chip.clear(); - bitwise_chip.request_range( - (a_msb + signed_offset) as u8 as u32, - (b_msb + signed_offset) as u8 as u32, - ); - - let diff_val = prank_vals - .diff_val - .unwrap() - .clamp(0, (1 << RV32_CELL_BITS) - 1); - if diff_val > 0 { - bitwise_chip.request_range(diff_val - 1, 0); - } - } let modify_trace = |trace: &mut DenseMatrix| { let mut values = trace.row_slice(0).to_vec(); @@ -257,10 +216,10 @@ fn run_rv32_blt_negative_test( if let Some(diff_val) = prank_vals.diff_val { cols.diff_val = F::from_canonical_u32(diff_val); } - cols.cmp_result = F::from_bool(cmp_result); - cols.cmp_lt = F::from_bool(ge_opcode ^ cmp_result); + cols.cmp_result = F::from_bool(prank_cmp_result); + cols.cmp_lt = F::from_bool(ge_opcode ^ prank_cmp_result); - *trace = RowMajorMatrix::new(values, trace_width); + *trace = RowMajorMatrix::new(values, trace.width()); }; disable_debug_builder(); @@ -269,11 +228,7 @@ fn run_rv32_blt_negative_test( .load_and_prank_trace(chip, modify_trace) .load(bitwise_chip) .finalize(); - tester.simple_test_with_expected_error(if interaction_error { - VerificationError::ChallengePhaseError - } else { - VerificationError::OodEvaluationMismatch - }); + tester.simple_test_with_expected_error(get_verification_error(interaction_error)); } #[test] @@ -281,10 +236,10 @@ fn rv32_blt_wrong_lt_cmp_negative_test() { let a = [145, 34, 25, 205]; let b = [73, 35, 25, 205]; let prank_vals = Default::default(); - run_rv32_blt_negative_test(BranchLessThanOpcode::BLT, a, b, false, prank_vals, false); - run_rv32_blt_negative_test(BranchLessThanOpcode::BLTU, a, b, false, prank_vals, false); - run_rv32_blt_negative_test(BranchLessThanOpcode::BGE, a, b, true, prank_vals, false); - run_rv32_blt_negative_test(BranchLessThanOpcode::BGEU, a, b, true, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BLT, a, b, false, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BLTU, a, b, false, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BGE, a, b, true, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BGEU, a, b, true, prank_vals, false); } #[test] @@ -292,10 +247,10 @@ fn rv32_blt_wrong_ge_cmp_negative_test() { let a = [73, 35, 25, 205]; let b = [145, 34, 25, 205]; let prank_vals = Default::default(); - run_rv32_blt_negative_test(BranchLessThanOpcode::BLT, a, b, true, prank_vals, false); - run_rv32_blt_negative_test(BranchLessThanOpcode::BLTU, a, b, true, prank_vals, false); - run_rv32_blt_negative_test(BranchLessThanOpcode::BGE, a, b, false, prank_vals, false); - run_rv32_blt_negative_test(BranchLessThanOpcode::BGEU, a, b, false, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BLT, a, b, true, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BLTU, a, b, true, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BGE, a, b, false, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BGEU, a, b, false, prank_vals, false); } #[test] @@ -303,10 +258,10 @@ fn rv32_blt_wrong_eq_cmp_negative_test() { let a = [73, 35, 25, 205]; let b = [73, 35, 25, 205]; let prank_vals = Default::default(); - run_rv32_blt_negative_test(BranchLessThanOpcode::BLT, a, b, true, prank_vals, false); - run_rv32_blt_negative_test(BranchLessThanOpcode::BLTU, a, b, true, prank_vals, false); - run_rv32_blt_negative_test(BranchLessThanOpcode::BGE, a, b, false, prank_vals, false); - run_rv32_blt_negative_test(BranchLessThanOpcode::BGEU, a, b, false, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BLT, a, b, true, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BLTU, a, b, true, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BGE, a, b, false, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BGEU, a, b, false, prank_vals, false); } #[test] @@ -317,10 +272,10 @@ fn rv32_blt_fake_diff_val_negative_test() { diff_val: Some(F::NEG_ONE.as_canonical_u32()), ..Default::default() }; - run_rv32_blt_negative_test(BranchLessThanOpcode::BLT, a, b, false, prank_vals, true); - run_rv32_blt_negative_test(BranchLessThanOpcode::BLTU, a, b, false, prank_vals, true); - run_rv32_blt_negative_test(BranchLessThanOpcode::BGE, a, b, true, prank_vals, true); - run_rv32_blt_negative_test(BranchLessThanOpcode::BGEU, a, b, true, prank_vals, true); + run_negative_branch_lt_test(BranchLessThanOpcode::BLT, a, b, false, prank_vals, true); + run_negative_branch_lt_test(BranchLessThanOpcode::BLTU, a, b, false, prank_vals, true); + run_negative_branch_lt_test(BranchLessThanOpcode::BGE, a, b, true, prank_vals, true); + run_negative_branch_lt_test(BranchLessThanOpcode::BGEU, a, b, true, prank_vals, true); } #[test] @@ -332,10 +287,10 @@ fn rv32_blt_zero_diff_val_negative_test() { diff_val: Some(0), ..Default::default() }; - run_rv32_blt_negative_test(BranchLessThanOpcode::BLT, a, b, false, prank_vals, true); - run_rv32_blt_negative_test(BranchLessThanOpcode::BLTU, a, b, false, prank_vals, true); - run_rv32_blt_negative_test(BranchLessThanOpcode::BGE, a, b, true, prank_vals, true); - run_rv32_blt_negative_test(BranchLessThanOpcode::BGEU, a, b, true, prank_vals, true); + run_negative_branch_lt_test(BranchLessThanOpcode::BLT, a, b, false, prank_vals, true); + run_negative_branch_lt_test(BranchLessThanOpcode::BLTU, a, b, false, prank_vals, true); + run_negative_branch_lt_test(BranchLessThanOpcode::BGE, a, b, true, prank_vals, true); + run_negative_branch_lt_test(BranchLessThanOpcode::BGEU, a, b, true, prank_vals, true); } #[test] @@ -347,10 +302,10 @@ fn rv32_blt_fake_diff_marker_negative_test() { diff_val: Some(72), ..Default::default() }; - run_rv32_blt_negative_test(BranchLessThanOpcode::BLT, a, b, false, prank_vals, false); - run_rv32_blt_negative_test(BranchLessThanOpcode::BLTU, a, b, false, prank_vals, false); - run_rv32_blt_negative_test(BranchLessThanOpcode::BGE, a, b, true, prank_vals, false); - run_rv32_blt_negative_test(BranchLessThanOpcode::BGEU, a, b, true, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BLT, a, b, false, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BLTU, a, b, false, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BGE, a, b, true, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BGEU, a, b, true, prank_vals, false); } #[test] @@ -362,10 +317,10 @@ fn rv32_blt_zero_diff_marker_negative_test() { diff_val: Some(0), ..Default::default() }; - run_rv32_blt_negative_test(BranchLessThanOpcode::BLT, a, b, false, prank_vals, false); - run_rv32_blt_negative_test(BranchLessThanOpcode::BLTU, a, b, false, prank_vals, false); - run_rv32_blt_negative_test(BranchLessThanOpcode::BGE, a, b, true, prank_vals, false); - run_rv32_blt_negative_test(BranchLessThanOpcode::BGEU, a, b, true, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BLT, a, b, false, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BLTU, a, b, false, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BGE, a, b, true, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BGEU, a, b, true, prank_vals, false); } #[test] @@ -378,8 +333,8 @@ fn rv32_blt_signed_wrong_a_msb_negative_test() { diff_val: Some(1), ..Default::default() }; - run_rv32_blt_negative_test(BranchLessThanOpcode::BLT, a, b, false, prank_vals, false); - run_rv32_blt_negative_test(BranchLessThanOpcode::BGE, a, b, true, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BLT, a, b, false, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BGE, a, b, true, prank_vals, false); } #[test] @@ -392,8 +347,8 @@ fn rv32_blt_signed_wrong_a_msb_sign_negative_test() { diff_val: Some(256), ..Default::default() }; - run_rv32_blt_negative_test(BranchLessThanOpcode::BLT, a, b, false, prank_vals, true); - run_rv32_blt_negative_test(BranchLessThanOpcode::BGE, a, b, true, prank_vals, true); + run_negative_branch_lt_test(BranchLessThanOpcode::BLT, a, b, false, prank_vals, true); + run_negative_branch_lt_test(BranchLessThanOpcode::BGE, a, b, true, prank_vals, true); } #[test] @@ -406,8 +361,8 @@ fn rv32_blt_signed_wrong_b_msb_negative_test() { diff_val: Some(1), ..Default::default() }; - run_rv32_blt_negative_test(BranchLessThanOpcode::BLT, a, b, true, prank_vals, false); - run_rv32_blt_negative_test(BranchLessThanOpcode::BGE, a, b, false, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BLT, a, b, true, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BGE, a, b, false, prank_vals, false); } #[test] @@ -420,8 +375,8 @@ fn rv32_blt_signed_wrong_b_msb_sign_negative_test() { diff_val: Some(256), ..Default::default() }; - run_rv32_blt_negative_test(BranchLessThanOpcode::BLT, a, b, true, prank_vals, true); - run_rv32_blt_negative_test(BranchLessThanOpcode::BGE, a, b, false, prank_vals, true); + run_negative_branch_lt_test(BranchLessThanOpcode::BLT, a, b, true, prank_vals, true); + run_negative_branch_lt_test(BranchLessThanOpcode::BGE, a, b, false, prank_vals, true); } #[test] @@ -434,8 +389,8 @@ fn rv32_blt_unsigned_wrong_a_msb_negative_test() { diff_val: Some(1), ..Default::default() }; - run_rv32_blt_negative_test(BranchLessThanOpcode::BLTU, a, b, true, prank_vals, false); - run_rv32_blt_negative_test(BranchLessThanOpcode::BGEU, a, b, false, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BLTU, a, b, true, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BGEU, a, b, false, prank_vals, false); } #[test] @@ -448,8 +403,8 @@ fn rv32_blt_unsigned_wrong_a_msb_sign_negative_test() { diff_val: Some(256), ..Default::default() }; - run_rv32_blt_negative_test(BranchLessThanOpcode::BLTU, a, b, true, prank_vals, true); - run_rv32_blt_negative_test(BranchLessThanOpcode::BGEU, a, b, false, prank_vals, true); + run_negative_branch_lt_test(BranchLessThanOpcode::BLTU, a, b, true, prank_vals, true); + run_negative_branch_lt_test(BranchLessThanOpcode::BGEU, a, b, false, prank_vals, true); } #[test] @@ -462,8 +417,8 @@ fn rv32_blt_unsigned_wrong_b_msb_negative_test() { diff_val: Some(1), ..Default::default() }; - run_rv32_blt_negative_test(BranchLessThanOpcode::BLTU, a, b, false, prank_vals, false); - run_rv32_blt_negative_test(BranchLessThanOpcode::BGEU, a, b, true, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BLTU, a, b, false, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BGEU, a, b, true, prank_vals, false); } #[test] @@ -476,8 +431,8 @@ fn rv32_blt_unsigned_wrong_b_msb_sign_negative_test() { diff_val: Some(256), ..Default::default() }; - run_rv32_blt_negative_test(BranchLessThanOpcode::BLTU, a, b, false, prank_vals, true); - run_rv32_blt_negative_test(BranchLessThanOpcode::BGEU, a, b, true, prank_vals, true); + run_negative_branch_lt_test(BranchLessThanOpcode::BLTU, a, b, false, prank_vals, true); + run_negative_branch_lt_test(BranchLessThanOpcode::BGEU, a, b, true, prank_vals, true); } /////////////////////////////////////////////////////////////////////////////////////// @@ -487,51 +442,52 @@ fn rv32_blt_unsigned_wrong_b_msb_sign_negative_test() { /////////////////////////////////////////////////////////////////////////////////////// #[test] -fn execute_pc_increment_sanity_test() { - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let core = BranchLessThanCoreChip::::new( - bitwise_chip, - BranchLessThanOpcode::CLASS_OFFSET, +fn execute_roundtrip_sanity_test() { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default(); + let (mut chip, _) = create_test_chip(&mut tester); + + let x = [145, 34, 25, 205]; + set_and_execute( + &mut tester, + &mut chip, + &mut rng, + BranchLessThanOpcode::BLT, + Some(x), + Some(x), + Some(8), ); - let mut instruction = Instruction:: { - opcode: BranchLessThanOpcode::BLT.global_opcode(), - c: F::from_canonical_u8(8), - ..Default::default() - }; - let x: [F; RV32_REGISTER_NUM_LIMBS] = [145, 34, 25, 205].map(F::from_canonical_u32); - - let result = as VmCoreChip< - F, - BasicAdapterInterface, 2, 0, RV32_REGISTER_NUM_LIMBS, 0>, - >>::execute_instruction(&core, &instruction, 0, [x, x]); - let (output, _) = result.expect("execute_instruction failed"); - assert!(output.to_pc.is_none()); - - instruction.opcode = BranchLessThanOpcode::BGE.global_opcode(); - let result = as VmCoreChip< - F, - BasicAdapterInterface, 2, 0, RV32_REGISTER_NUM_LIMBS, 0>, - >>::execute_instruction(&core, &instruction, 0, [x, x]); - let (output, _) = result.expect("execute_instruction failed"); - assert!(output.to_pc.is_some()); - assert_eq!(output.to_pc.unwrap(), 8); + set_and_execute( + &mut tester, + &mut chip, + &mut rng, + BranchLessThanOpcode::BGE, + Some(x), + Some(x), + Some(8), + ); } #[test] fn run_cmp_unsigned_sanity_test() { - let x: [u32; RV32_REGISTER_NUM_LIMBS] = [145, 34, 25, 205]; - let y: [u32; RV32_REGISTER_NUM_LIMBS] = [73, 35, 25, 205]; - let (cmp_result, diff_idx, x_sign, y_sign) = - run_cmp::(BranchLessThanOpcode::BLTU, &x, &y); + let x: [u8; RV32_REGISTER_NUM_LIMBS] = [145, 34, 25, 205]; + let y: [u8; RV32_REGISTER_NUM_LIMBS] = [73, 35, 25, 205]; + let (cmp_result, diff_idx, x_sign, y_sign) = run_cmp::( + BranchLessThanOpcode::BLTU as u8, + &x, + &y, + ); assert!(cmp_result); assert_eq!(diff_idx, 1); assert!(!x_sign); // unsigned assert!(!y_sign); // unsigned - let (cmp_result, diff_idx, x_sign, y_sign) = - run_cmp::(BranchLessThanOpcode::BGEU, &x, &y); + let (cmp_result, diff_idx, x_sign, y_sign) = run_cmp::( + BranchLessThanOpcode::BGEU as u8, + &x, + &y, + ); assert!(!cmp_result); assert_eq!(diff_idx, 1); assert!(!x_sign); // unsigned @@ -540,17 +496,17 @@ fn run_cmp_unsigned_sanity_test() { #[test] fn run_cmp_same_sign_sanity_test() { - let x: [u32; RV32_REGISTER_NUM_LIMBS] = [145, 34, 25, 205]; - let y: [u32; RV32_REGISTER_NUM_LIMBS] = [73, 35, 25, 205]; + let x: [u8; RV32_REGISTER_NUM_LIMBS] = [145, 34, 25, 205]; + let y: [u8; RV32_REGISTER_NUM_LIMBS] = [73, 35, 25, 205]; let (cmp_result, diff_idx, x_sign, y_sign) = - run_cmp::(BranchLessThanOpcode::BLT, &x, &y); + run_cmp::(BranchLessThanOpcode::BLT as u8, &x, &y); assert!(cmp_result); assert_eq!(diff_idx, 1); assert!(x_sign); // negative assert!(y_sign); // negative let (cmp_result, diff_idx, x_sign, y_sign) = - run_cmp::(BranchLessThanOpcode::BGE, &x, &y); + run_cmp::(BranchLessThanOpcode::BGE as u8, &x, &y); assert!(!cmp_result); assert_eq!(diff_idx, 1); assert!(x_sign); // negative @@ -559,17 +515,17 @@ fn run_cmp_same_sign_sanity_test() { #[test] fn run_cmp_diff_sign_sanity_test() { - let x: [u32; RV32_REGISTER_NUM_LIMBS] = [45, 35, 25, 55]; - let y: [u32; RV32_REGISTER_NUM_LIMBS] = [173, 34, 25, 205]; + let x: [u8; RV32_REGISTER_NUM_LIMBS] = [45, 35, 25, 55]; + let y: [u8; RV32_REGISTER_NUM_LIMBS] = [173, 34, 25, 205]; let (cmp_result, diff_idx, x_sign, y_sign) = - run_cmp::(BranchLessThanOpcode::BLT, &x, &y); + run_cmp::(BranchLessThanOpcode::BLT as u8, &x, &y); assert!(!cmp_result); assert_eq!(diff_idx, 3); assert!(!x_sign); // positive assert!(y_sign); // negative let (cmp_result, diff_idx, x_sign, y_sign) = - run_cmp::(BranchLessThanOpcode::BGE, &x, &y); + run_cmp::(BranchLessThanOpcode::BGE as u8, &x, &y); assert!(cmp_result); assert_eq!(diff_idx, 3); assert!(!x_sign); // positive @@ -578,27 +534,33 @@ fn run_cmp_diff_sign_sanity_test() { #[test] fn run_cmp_eq_sanity_test() { - let x: [u32; RV32_REGISTER_NUM_LIMBS] = [45, 35, 25, 55]; + let x: [u8; RV32_REGISTER_NUM_LIMBS] = [45, 35, 25, 55]; let (cmp_result, diff_idx, x_sign, y_sign) = - run_cmp::(BranchLessThanOpcode::BLT, &x, &x); + run_cmp::(BranchLessThanOpcode::BLT as u8, &x, &x); assert!(!cmp_result); assert_eq!(diff_idx, RV32_REGISTER_NUM_LIMBS); assert_eq!(x_sign, y_sign); - let (cmp_result, diff_idx, x_sign, y_sign) = - run_cmp::(BranchLessThanOpcode::BLTU, &x, &x); + let (cmp_result, diff_idx, x_sign, y_sign) = run_cmp::( + BranchLessThanOpcode::BLTU as u8, + &x, + &x, + ); assert!(!cmp_result); assert_eq!(diff_idx, RV32_REGISTER_NUM_LIMBS); assert_eq!(x_sign, y_sign); let (cmp_result, diff_idx, x_sign, y_sign) = - run_cmp::(BranchLessThanOpcode::BGE, &x, &x); + run_cmp::(BranchLessThanOpcode::BGE as u8, &x, &x); assert!(cmp_result); assert_eq!(diff_idx, RV32_REGISTER_NUM_LIMBS); assert_eq!(x_sign, y_sign); - let (cmp_result, diff_idx, x_sign, y_sign) = - run_cmp::(BranchLessThanOpcode::BGEU, &x, &x); + let (cmp_result, diff_idx, x_sign, y_sign) = run_cmp::( + BranchLessThanOpcode::BGEU as u8, + &x, + &x, + ); assert!(cmp_result); assert_eq!(diff_idx, RV32_REGISTER_NUM_LIMBS); assert_eq!(x_sign, y_sign); diff --git a/extensions/rv32im/circuit/src/divrem/core.rs b/extensions/rv32im/circuit/src/divrem/core.rs index b21c32345e..1f36d5e811 100644 --- a/extensions/rv32im/circuit/src/divrem/core.rs +++ b/extensions/rv32im/circuit/src/divrem/core.rs @@ -5,17 +5,30 @@ use std::{ use num_bigint::BigUint; use num_integer::Integer; -use openvm_circuit::arch::{ - AdapterAirContext, AdapterRuntimeContext, MinimalInstruction, Result, VmAdapterInterface, - VmCoreAir, VmCoreChip, +use openvm_circuit::{ + arch::{ + execution_mode::{E1ExecutionCtx, E2ExecutionCtx}, + get_record_from_slice, AdapterAirContext, AdapterTraceFiller, AdapterTraceStep, + E2PreCompute, EmptyAdapterCoreLayout, ExecuteFunc, + ExecutionError::InvalidInstruction, + MinimalInstruction, RecordArena, Result, StepExecutorE1, StepExecutorE2, TraceFiller, + TraceStep, VmAdapterInterface, VmCoreAir, VmSegmentState, VmStateMut, + }, + system::memory::{online::TracingMemory, MemoryAuxColsFactory}, }; use openvm_circuit_primitives::{ bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, range_tuple::{RangeTupleCheckerBus, SharedRangeTupleCheckerChip}, utils::{not, select}, + AlignedBytesBorrow, }; use openvm_circuit_primitives_derive::AlignedBorrow; -use openvm_instructions::{instruction::Instruction, LocalOpcode}; +use openvm_instructions::{ + instruction::Instruction, + program::DEFAULT_PC_STEP, + riscv::{RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS}, + LocalOpcode, +}; use openvm_rv32im_transpiler::DivRemOpcode; use openvm_stark_backend::{ interaction::InteractionBuilder, @@ -23,8 +36,6 @@ use openvm_stark_backend::{ p3_field::{Field, FieldAlgebra, PrimeField32}, rap::BaseAirWithPublicValues, }; -use serde::{de::DeserializeOwned, Deserialize, Serialize}; -use serde_big_array::BigArray; use strum::IntoEnumIterator; #[repr(C)] @@ -67,7 +78,7 @@ pub struct DivRemCoreCols { pub opcode_remu_flag: T, } -#[derive(Copy, Clone, Debug)] +#[derive(Copy, Clone, Debug, derive_new::new)] pub struct DivRemCoreAir { pub bitwise_lookup_bus: BitwiseOperationLookupBus, pub range_tuple_bus: RangeTupleCheckerBus<2>, @@ -342,14 +353,24 @@ where } } -pub struct DivRemCoreChip { - pub air: DivRemCoreAir, +#[derive(Debug, Eq, PartialEq)] +#[repr(u8)] +pub(super) enum DivRemCoreSpecialCase { + None, + ZeroDivisor, + SignedOverflow, +} + +pub struct DivRemStep { + adapter: A, + pub offset: usize, pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, pub range_tuple_chip: SharedRangeTupleCheckerChip<2>, } -impl DivRemCoreChip { +impl DivRemStep { pub fn new( + adapter: A, bitwise_lookup_chip: SharedBitwiseOperationLookupChip, range_tuple_chip: SharedRangeTupleCheckerChip<2>, offset: usize, @@ -369,11 +390,8 @@ impl DivRemCoreChip DivRemCoreChip { - #[serde(with = "BigArray")] - pub b: [T; NUM_LIMBS], - #[serde(with = "BigArray")] - pub c: [T; NUM_LIMBS], - #[serde(with = "BigArray")] - pub q: [T; NUM_LIMBS], - #[serde(with = "BigArray")] - pub r: [T; NUM_LIMBS], - pub zero_divisor: T, - pub r_zero: T, - pub b_sign: T, - pub c_sign: T, - pub q_sign: T, - pub sign_xor: T, - pub c_sum_inv: T, - pub r_sum_inv: T, - #[serde(with = "BigArray")] - pub r_prime: [T; NUM_LIMBS], - #[serde(with = "BigArray")] - pub r_inv: [T; NUM_LIMBS], - pub lt_diff_val: T, - pub lt_diff_idx: usize, - pub opcode: DivRemOpcode, -} - -#[derive(Debug, Eq, PartialEq)] -#[repr(u8)] -pub(super) enum DivRemCoreSpecialCase { - None, - ZeroDivisor, - SignedOverflow, +#[derive(AlignedBytesBorrow, Debug)] +pub struct DivRemCoreRecords { + pub b: [u8; NUM_LIMBS], + pub c: [u8; NUM_LIMBS], + pub local_opcode: u8, } -impl, const NUM_LIMBS: usize, const LIMB_BITS: usize> - VmCoreChip for DivRemCoreChip +impl TraceStep + for DivRemStep where - I::Reads: Into<[[F; NUM_LIMBS]; 2]>, - I::Writes: From<[[F; NUM_LIMBS]; 1]>, + F: PrimeField32, + A: 'static + + for<'a> AdapterTraceStep< + F, + CTX, + ReadData: Into<[[u8; NUM_LIMBS]; 2]>, + WriteData: From<[[u8; NUM_LIMBS]; 1]>, + >, { - type Record = DivRemCoreRecord; - type Air = DivRemCoreAir; + type RecordLayout = EmptyAdapterCoreLayout; + type RecordMut<'a> = (A::RecordMut<'a>, &'a mut DivRemCoreRecords); - #[allow(clippy::type_complexity)] - fn execute_instruction( - &self, + fn get_opcode_name(&self, opcode: usize) -> String { + format!("{:?}", DivRemOpcode::from_usize(opcode - self.offset)) + } + + fn execute<'buf, RA>( + &mut self, + state: VmStateMut, CTX>, instruction: &Instruction, - _from_pc: u32, - reads: I::Reads, - ) -> Result<(AdapterRuntimeContext, Self::Record)> { + arena: &'buf mut RA, + ) -> Result<()> + where + RA: RecordArena<'buf, Self::RecordLayout, Self::RecordMut<'buf>>, + { let Instruction { opcode, .. } = instruction; - let divrem_opcode = DivRemOpcode::from_usize(opcode.local_opcode_idx(self.air.offset)); - let is_div = divrem_opcode == DivRemOpcode::DIV || divrem_opcode == DivRemOpcode::DIVU; - let is_signed = divrem_opcode == DivRemOpcode::DIV || divrem_opcode == DivRemOpcode::REM; + let (mut adapter_record, core_record) = arena.alloc(EmptyAdapterCoreLayout::new()); + + A::start(*state.pc, state.memory, &mut adapter_record); + + core_record.local_opcode = opcode.local_opcode_idx(self.offset) as u8; + + let is_signed = core_record.local_opcode == DivRemOpcode::DIV as u8 + || core_record.local_opcode == DivRemOpcode::REM as u8; + let is_div = core_record.local_opcode == DivRemOpcode::DIV as u8 + || core_record.local_opcode == DivRemOpcode::DIVU as u8; + + [core_record.b, core_record.c] = self + .adapter + .read(state.memory, instruction, &mut adapter_record) + .into(); + + let b = core_record.b.map(u32::from); + let c = core_record.c.map(u32::from); + let (q, r, _, _, _, _) = run_divrem::(is_signed, &b, &c); + + let rd = if is_div { + q.map(|x| x as u8) + } else { + r.map(|x| x as u8) + }; - let data: [[F; NUM_LIMBS]; 2] = reads.into(); - let b = data[0].map(|x| x.as_canonical_u32()); - let c = data[1].map(|y| y.as_canonical_u32()); - let (q, r, b_sign, c_sign, q_sign, case) = - run_divrem::(is_signed, &b, &c); + self.adapter + .write(state.memory, instruction, [rd].into(), &mut adapter_record); - let carries = run_mul_carries::(is_signed, &c, &q, &r, q_sign); + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + + Ok(()) + } +} + +impl TraceFiller + for DivRemStep +where + F: PrimeField32, + A: 'static + AdapterTraceFiller, +{ + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + self.adapter.fill_trace_row(mem_helper, adapter_row); + let record: &DivRemCoreRecords = + unsafe { get_record_from_slice(&mut core_row, ()) }; + let core_row: &mut DivRemCoreCols = core_row.borrow_mut(); + + let opcode = DivRemOpcode::from_usize(record.local_opcode as usize); + let is_signed = opcode == DivRemOpcode::DIV || opcode == DivRemOpcode::REM; + + let (q, r, b_sign, c_sign, q_sign, case) = run_divrem::( + is_signed, + &record.b.map(u32::from), + &record.c.map(u32::from), + ); + + let carries = run_mul_carries::( + is_signed, + &record.c.map(u32::from), + &q, + &r, + q_sign, + ); for i in 0..NUM_LIMBS { self.range_tuple_chip.add_count(&[q[i], carries[i]]); self.range_tuple_chip @@ -464,94 +518,244 @@ where let b_sign_mask = if b_sign { 1 << (LIMB_BITS - 1) } else { 0 }; let c_sign_mask = if c_sign { 1 << (LIMB_BITS - 1) } else { 0 }; self.bitwise_lookup_chip.request_range( - (b[NUM_LIMBS - 1] - b_sign_mask) << 1, - (c[NUM_LIMBS - 1] - c_sign_mask) << 1, + (record.b[NUM_LIMBS - 1] as u32 - b_sign_mask) << 1, + (record.c[NUM_LIMBS - 1] as u32 - c_sign_mask) << 1, ); } - let c_sum_f = data[1].iter().fold(F::ZERO, |acc, c| acc + *c); - let c_sum_inv_f = c_sum_f.try_inverse().unwrap_or(F::ZERO); + // Write in a reverse order + core_row.opcode_remu_flag = F::from_bool(opcode == DivRemOpcode::REMU); + core_row.opcode_rem_flag = F::from_bool(opcode == DivRemOpcode::REM); + core_row.opcode_divu_flag = F::from_bool(opcode == DivRemOpcode::DIVU); + core_row.opcode_div_flag = F::from_bool(opcode == DivRemOpcode::DIV); - let r_sum_f = r - .iter() - .fold(F::ZERO, |acc, r| acc + F::from_canonical_u32(*r)); - let r_sum_inv_f = r_sum_f.try_inverse().unwrap_or(F::ZERO); - - let (lt_diff_idx, lt_diff_val) = if case == DivRemCoreSpecialCase::None && !r_zero { - let idx = run_sltu_diff_idx(&c, &r_prime, c_sign); + core_row.lt_diff = F::ZERO; + core_row.lt_marker = [F::ZERO; NUM_LIMBS]; + if case == DivRemCoreSpecialCase::None && !r_zero { + let idx = run_sltu_diff_idx(&record.c.map(u32::from), &r_prime, c_sign); let val = if c_sign { - r_prime[idx] - c[idx] + r_prime[idx] - record.c[idx] as u32 } else { - c[idx] - r_prime[idx] + record.c[idx] as u32 - r_prime[idx] }; self.bitwise_lookup_chip.request_range(val - 1, 0); - (idx, val) - } else { - (NUM_LIMBS, 0) - }; + core_row.lt_diff = F::from_canonical_u32(val); + core_row.lt_marker[idx] = F::ONE; + } let r_prime_f = r_prime.map(F::from_canonical_u32); - let output = AdapterRuntimeContext::without_pc([ - (if is_div { &q } else { &r }).map(F::from_canonical_u32) - ]); - let record = DivRemCoreRecord { - opcode: divrem_opcode, - b: data[0], - c: data[1], - q: q.map(F::from_canonical_u32), - r: r.map(F::from_canonical_u32), - zero_divisor: F::from_bool(case == DivRemCoreSpecialCase::ZeroDivisor), - r_zero: F::from_bool(r_zero), - b_sign: F::from_bool(b_sign), - c_sign: F::from_bool(c_sign), - q_sign: F::from_bool(q_sign), - sign_xor: F::from_bool(sign_xor), - c_sum_inv: c_sum_inv_f, - r_sum_inv: r_sum_inv_f, - r_prime: r_prime_f, - r_inv: r_prime_f.map(|r| (r - F::from_canonical_u32(256)).inverse()), - lt_diff_val: F::from_canonical_u32(lt_diff_val), - lt_diff_idx, + core_row.r_inv = r_prime_f.map(|r| (r - F::from_canonical_u32(256)).inverse()); + core_row.r_prime = r_prime_f; + + let r_sum_f = r + .iter() + .fold(F::ZERO, |acc, r| acc + F::from_canonical_u32(*r)); + core_row.r_sum_inv = r_sum_f.try_inverse().unwrap_or(F::ZERO); + + let c_sum_f = F::from_canonical_u32(record.c.iter().fold(0, |acc, c| acc + *c as u32)); + core_row.c_sum_inv = c_sum_f.try_inverse().unwrap_or(F::ZERO); + + core_row.sign_xor = F::from_bool(sign_xor); + core_row.q_sign = F::from_bool(q_sign); + core_row.c_sign = F::from_bool(c_sign); + core_row.b_sign = F::from_bool(b_sign); + + core_row.r_zero = F::from_bool(r_zero); + core_row.zero_divisor = F::from_bool(case == DivRemCoreSpecialCase::ZeroDivisor); + + core_row.r = r.map(F::from_canonical_u32); + core_row.q = q.map(F::from_canonical_u32); + core_row.c = record.c.map(F::from_canonical_u8); + core_row.b = record.b.map(F::from_canonical_u8); + } +} + +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct DivRemPreCompute { + a: u8, + b: u8, + c: u8, +} + +impl StepExecutorE1 + for DivRemStep +where + F: PrimeField32, +{ + #[inline(always)] + fn pre_compute_size(&self) -> usize { + size_of::() + } + + #[inline(always)] + fn pre_compute_e1( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result> { + let data: &mut DivRemPreCompute = data.borrow_mut(); + let local_opcode = self.pre_compute_impl(pc, inst, data)?; + let fn_ptr = match local_opcode { + DivRemOpcode::DIV => execute_e1_impl::<_, _, DivOp>, + DivRemOpcode::DIVU => execute_e1_impl::<_, _, DivuOp>, + DivRemOpcode::REM => execute_e1_impl::<_, _, RemOp>, + DivRemOpcode::REMU => execute_e1_impl::<_, _, RemuOp>, }; + Ok(fn_ptr) + } +} - Ok((output, record)) +impl StepExecutorE2 + for DivRemStep +where + F: PrimeField32, +{ + fn e2_pre_compute_size(&self) -> usize { + size_of::>() } - fn get_opcode_name(&self, opcode: usize) -> String { - format!("{:?}", DivRemOpcode::from_usize(opcode - self.air.offset)) + fn pre_compute_e2( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result> + where + Ctx: E2ExecutionCtx, + { + let data: &mut E2PreCompute = data.borrow_mut(); + data.chip_idx = chip_idx as u32; + let local_opcode = self.pre_compute_impl(pc, inst, &mut data.data)?; + let fn_ptr = match local_opcode { + DivRemOpcode::DIV => execute_e2_impl::<_, _, DivOp>, + DivRemOpcode::DIVU => execute_e2_impl::<_, _, DivuOp>, + DivRemOpcode::REM => execute_e2_impl::<_, _, RemOp>, + DivRemOpcode::REMU => execute_e2_impl::<_, _, RemuOp>, + }; + Ok(fn_ptr) } +} - fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) { - let row_slice: &mut DivRemCoreCols<_, NUM_LIMBS, LIMB_BITS> = row_slice.borrow_mut(); - row_slice.b = record.b; - row_slice.c = record.c; - row_slice.q = record.q; - row_slice.r = record.r; - row_slice.zero_divisor = record.zero_divisor; - row_slice.r_zero = record.r_zero; - row_slice.b_sign = record.b_sign; - row_slice.c_sign = record.c_sign; - row_slice.q_sign = record.q_sign; - row_slice.sign_xor = record.sign_xor; - row_slice.c_sum_inv = record.c_sum_inv; - row_slice.r_sum_inv = record.r_sum_inv; - row_slice.r_prime = record.r_prime; - row_slice.r_inv = record.r_inv; - row_slice.lt_marker = array::from_fn(|i| F::from_bool(i == record.lt_diff_idx)); - row_slice.lt_diff = record.lt_diff_val; - row_slice.opcode_div_flag = F::from_bool(record.opcode == DivRemOpcode::DIV); - row_slice.opcode_divu_flag = F::from_bool(record.opcode == DivRemOpcode::DIVU); - row_slice.opcode_rem_flag = F::from_bool(record.opcode == DivRemOpcode::REM); - row_slice.opcode_remu_flag = F::from_bool(record.opcode == DivRemOpcode::REMU); +unsafe fn execute_e12_impl( + pre_compute: &DivRemPreCompute, + vm_state: &mut VmSegmentState, +) { + let rs1 = vm_state.vm_read::(RV32_REGISTER_AS, pre_compute.b as u32); + let rs2 = vm_state.vm_read::(RV32_REGISTER_AS, pre_compute.c as u32); + let result = ::compute(rs1, rs2); + vm_state.vm_write::(RV32_REGISTER_AS, pre_compute.a as u32, &result); + vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP); + vm_state.instret += 1; +} + +unsafe fn execute_e1_impl( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &DivRemPreCompute = pre_compute.borrow(); + execute_e12_impl::(pre_compute, vm_state); +} + +unsafe fn execute_e2_impl( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &E2PreCompute = pre_compute.borrow(); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, 1); + execute_e12_impl::(&pre_compute.data, vm_state); +} + +impl DivRemStep { + #[inline(always)] + fn pre_compute_impl( + &self, + pc: u32, + inst: &Instruction, + data: &mut DivRemPreCompute, + ) -> Result { + let &Instruction { + opcode, a, b, c, d, .. + } = inst; + let local_opcode = DivRemOpcode::from_usize(opcode.local_opcode_idx(self.offset)); + if d.as_canonical_u32() != RV32_REGISTER_AS { + return Err(InvalidInstruction(pc)); + } + let pre_compute: &mut DivRemPreCompute = data.borrow_mut(); + *pre_compute = DivRemPreCompute { + a: a.as_canonical_u32() as u8, + b: b.as_canonical_u32() as u8, + c: c.as_canonical_u32() as u8, + }; + Ok(local_opcode) + } +} + +trait DivRemOp { + fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> [u8; 4]; +} +struct DivOp; +struct DivuOp; +struct RemOp; +struct RemuOp; +impl DivRemOp for DivOp { + #[inline(always)] + fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> [u8; 4] { + let rs1_i32 = i32::from_le_bytes(rs1); + let rs2_i32 = i32::from_le_bytes(rs2); + match (rs1_i32, rs2_i32) { + (_, 0) => [u8::MAX; 4], + (i32::MIN, -1) => rs1, + _ => (rs1_i32 / rs2_i32).to_le_bytes(), + } + } +} +impl DivRemOp for DivuOp { + #[inline(always)] + fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> [u8; 4] { + if rs2 == [0; 4] { + [u8::MAX; 4] + } else { + let rs1 = u32::from_le_bytes(rs1); + let rs2 = u32::from_le_bytes(rs2); + (rs1 / rs2).to_le_bytes() + } } +} - fn air(&self) -> &Self::Air { - &self.air +impl DivRemOp for RemOp { + #[inline(always)] + fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> [u8; 4] { + let rs1_i32 = i32::from_le_bytes(rs1); + let rs2_i32 = i32::from_le_bytes(rs2); + match (rs1_i32, rs2_i32) { + (_, 0) => rs1, + (i32::MIN, -1) => [0; 4], + _ => (rs1_i32 % rs2_i32).to_le_bytes(), + } + } +} + +impl DivRemOp for RemuOp { + #[inline(always)] + fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> [u8; 4] { + if rs2 == [0; 4] { + rs1 + } else { + let rs1 = u32::from_le_bytes(rs1); + let rs2 = u32::from_le_bytes(rs2); + (rs1 % rs2).to_le_bytes() + } } } // Returns (quotient, remainder, x_sign, y_sign, q_sign, case) where case = 0 for normal, 1 // for zero divisor, and 2 for signed overflow +#[inline(always)] pub(super) fn run_divrem( signed: bool, x: &[u32; NUM_LIMBS], @@ -628,6 +832,7 @@ pub(super) fn run_divrem( (q, r, x_sign, y_sign, q_sign, DivRemCoreSpecialCase::None) } +#[inline(always)] pub(super) fn run_sltu_diff_idx( x: &[u32; NUM_LIMBS], y: &[u32; NUM_LIMBS], @@ -644,6 +849,7 @@ pub(super) fn run_sltu_diff_idx( } // returns carries of d * q + r +#[inline(always)] pub(super) fn run_mul_carries( signed: bool, d: &[u32; NUM_LIMBS], @@ -684,6 +890,7 @@ pub(super) fn run_mul_carries( carry } +#[inline(always)] fn limbs_to_biguint( x: &[u32; NUM_LIMBS], ) -> BigUint { @@ -696,6 +903,7 @@ fn limbs_to_biguint( res } +#[inline(always)] fn biguint_to_limbs( x: &BigUint, ) -> [u32; NUM_LIMBS] { @@ -711,6 +919,7 @@ fn biguint_to_limbs( res } +#[inline(always)] fn negate( x: &[u32; NUM_LIMBS], ) -> [u32; NUM_LIMBS] { diff --git a/extensions/rv32im/circuit/src/divrem/mod.rs b/extensions/rv32im/circuit/src/divrem/mod.rs index 979ab38dc3..a5a2b7607c 100644 --- a/extensions/rv32im/circuit/src/divrem/mod.rs +++ b/extensions/rv32im/circuit/src/divrem/mod.rs @@ -1,6 +1,7 @@ -use openvm_circuit::arch::VmChipWrapper; +use openvm_circuit::arch::{MatrixRecordArena, NewVmChipWrapper, VmAirWrapper}; -use super::adapters::{Rv32MultAdapterChip, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}; +use super::adapters::{RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}; +use crate::adapters::{Rv32MultAdapterAir, Rv32MultAdapterStep}; mod core; pub use core::*; @@ -8,8 +9,8 @@ pub use core::*; #[cfg(test)] mod tests; -pub type Rv32DivRemChip = VmChipWrapper< - F, - Rv32MultAdapterChip, - DivRemCoreChip, ->; +pub type Rv32DivRemAir = + VmAirWrapper>; +pub type Rv32DivRemStep = DivRemStep; +pub type Rv32DivRemChip = + NewVmChipWrapper>; diff --git a/extensions/rv32im/circuit/src/divrem/tests.rs b/extensions/rv32im/circuit/src/divrem/tests.rs index 41d8a9cc46..83b824c3ea 100644 --- a/extensions/rv32im/circuit/src/divrem/tests.rs +++ b/extensions/rv32im/circuit/src/divrem/tests.rs @@ -3,10 +3,9 @@ use std::{array, borrow::BorrowMut}; use openvm_circuit::{ arch::{ testing::{ - memory::gen_pointer, TestAdapterChip, VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS, - RANGE_TUPLE_CHECKER_BUS, + memory::gen_pointer, VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS, RANGE_TUPLE_CHECKER_BUS, }, - ExecutionBridge, InstructionExecutor, VmAdapterChip, VmChipWrapper, + InstructionExecutor, VmAirWrapper, }, utils::generate_long_number, }; @@ -15,7 +14,7 @@ use openvm_circuit_primitives::{ range_tuple::{RangeTupleCheckerBus, SharedRangeTupleCheckerChip}, }; use openvm_instructions::{instruction::Instruction, LocalOpcode}; -use openvm_rv32im_transpiler::DivRemOpcode; +use openvm_rv32im_transpiler::DivRemOpcode::{self, *}; use openvm_stark_backend::{ p3_air::BaseAir, p3_field::{Field, FieldAlgebra}, @@ -24,29 +23,26 @@ use openvm_stark_backend::{ Matrix, }, utils::disable_debug_builder, - verifier::VerificationError, - ChipUsageGetter, }; use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; use rand::{rngs::StdRng, Rng}; +use test_case::test_case; use super::core::run_divrem; use crate::{ - adapters::{Rv32MultAdapterChip, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}, + adapters::{Rv32MultAdapterAir, Rv32MultAdapterStep, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}, divrem::{ - run_mul_carries, run_sltu_diff_idx, DivRemCoreChip, DivRemCoreCols, DivRemCoreSpecialCase, + run_mul_carries, run_sltu_diff_idx, DivRemCoreCols, DivRemCoreSpecialCase, DivRemStep, Rv32DivRemChip, }, + test_utils::get_verification_error, + DivRemCoreAir, }; type F = BabyBear; - -////////////////////////////////////////////////////////////////////////////////////// -// POSITIVE TESTS -// -// Randomly generate computations and execute, ensuring that the generated trace -// passes all constraints. -////////////////////////////////////////////////////////////////////////////////////// +const MAX_INS_CAPACITY: usize = 128; +// the max number of limbs we currently support MUL for is 32 (i.e. for U256s) +const MAX_NUM_LIMBS: u32 = 32; fn limb_sra( x: [u32; NUM_LIMBS], @@ -57,15 +53,58 @@ fn limb_sra( array::from_fn(|i| if i + shift < NUM_LIMBS { x[i] } else { ext }) } +fn create_test_chip( + tester: &mut VmChipTestBuilder, +) -> ( + Rv32DivRemChip, + SharedBitwiseOperationLookupChip, + SharedRangeTupleCheckerChip<2>, +) { + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); + let range_tuple_bus = RangeTupleCheckerBus::new( + RANGE_TUPLE_CHECKER_BUS, + [1 << RV32_CELL_BITS, MAX_NUM_LIMBS * (1 << RV32_CELL_BITS)], + ); + + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); + let range_tuple_chip = SharedRangeTupleCheckerChip::new(range_tuple_bus); + + let mut chip = Rv32DivRemChip::::new( + VmAirWrapper::new( + Rv32MultAdapterAir::new(tester.execution_bridge(), tester.memory_bridge()), + DivRemCoreAir::new(bitwise_bus, range_tuple_bus, DivRemOpcode::CLASS_OFFSET), + ), + DivRemStep::new( + Rv32MultAdapterStep::new(), + bitwise_chip.clone(), + range_tuple_chip.clone(), + DivRemOpcode::CLASS_OFFSET, + ), + tester.memory_helper(), + ); + chip.set_trace_buffer_height(MAX_INS_CAPACITY); + + (chip, bitwise_chip, range_tuple_chip) +} + #[allow(clippy::too_many_arguments)] -fn run_rv32_divrem_rand_write_execute>( - opcode: DivRemOpcode, +fn set_and_execute>( tester: &mut VmChipTestBuilder, chip: &mut E, - b: [u32; RV32_REGISTER_NUM_LIMBS], - c: [u32; RV32_REGISTER_NUM_LIMBS], rng: &mut StdRng, + opcode: DivRemOpcode, + b: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, + c: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, ) { + let b = b.unwrap_or(generate_long_number::< + RV32_REGISTER_NUM_LIMBS, + RV32_CELL_BITS, + >(rng)); + let c = c.unwrap_or(limb_sra::( + generate_long_number::(rng), + rng.gen_range(0..(RV32_REGISTER_NUM_LIMBS - 1)), + )); + let rs1 = gen_pointer(rng, 4); let rs2 = gen_pointer(rng, 4); let rd = gen_pointer(rng, 4); @@ -73,8 +112,8 @@ fn run_rv32_divrem_rand_write_execute>( tester.write::(1, rs1, b.map(F::from_canonical_u32)); tester.write::(1, rs2, c.map(F::from_canonical_u32)); - let is_div = opcode == DivRemOpcode::DIV || opcode == DivRemOpcode::DIVU; - let is_signed = opcode == DivRemOpcode::DIV || opcode == DivRemOpcode::REM; + let is_div = opcode == DIV || opcode == DIVU; + let is_signed = opcode == DIV || opcode == REM; let (q, r, _, _, _, _) = run_divrem::(is_signed, &b, &c); @@ -89,136 +128,101 @@ fn run_rv32_divrem_rand_write_execute>( ); } -fn run_rv32_divrem_rand_test(opcode: DivRemOpcode, num_ops: usize) { - // the max number of limbs we currently support MUL for is 32 (i.e. for U256s) - const MAX_NUM_LIMBS: u32 = 32; - let mut rng = create_seeded_rng(); - - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let range_tuple_bus = RangeTupleCheckerBus::new( - RANGE_TUPLE_CHECKER_BUS, - [1 << RV32_CELL_BITS, MAX_NUM_LIMBS * (1 << RV32_CELL_BITS)], - ); - - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let range_tuple_checker = SharedRangeTupleCheckerChip::new(range_tuple_bus); +////////////////////////////////////////////////////////////////////////////////////// +// POSITIVE TESTS +// +// Randomly generate computations and execute, ensuring that the generated trace +// passes all constraints. +////////////////////////////////////////////////////////////////////////////////////// +#[test_case(DIV, 100)] +#[test_case(DIVU, 100)] +#[test_case(REM, 100)] +#[test_case(REMU, 100)] +fn rand_divrem_test(opcode: DivRemOpcode, num_ops: usize) { + let mut rng = create_seeded_rng(); let mut tester = VmChipTestBuilder::default(); - let mut chip = Rv32DivRemChip::::new( - Rv32MultAdapterChip::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - ), - DivRemCoreChip::new( - bitwise_chip.clone(), - range_tuple_checker.clone(), - DivRemOpcode::CLASS_OFFSET, - ), - tester.offline_memory_mutex_arc(), - ); + let (mut chip, bitwise_chip, range_tuple_chip) = create_test_chip(&mut tester); for _ in 0..num_ops { - let b = generate_long_number::(&mut rng); - let leading_zeros = rng.gen_range(0..(RV32_REGISTER_NUM_LIMBS - 1)); - let c = limb_sra::( - generate_long_number::(&mut rng), - leading_zeros, - ); - run_rv32_divrem_rand_write_execute(opcode, &mut tester, &mut chip, b, c, &mut rng); + set_and_execute(&mut tester, &mut chip, &mut rng, opcode, None, None); } // Test special cases in addition to random cases (i.e. zero divisor with b > 0, // zero divisor with b < 0, r = 0 (3 cases), and signed overflow). - run_rv32_divrem_rand_write_execute( - opcode, + set_and_execute( &mut tester, &mut chip, - [98, 188, 163, 127], - [0, 0, 0, 0], &mut rng, - ); - run_rv32_divrem_rand_write_execute( opcode, + Some([98, 188, 163, 127]), + Some([0, 0, 0, 0]), + ); + set_and_execute( &mut tester, &mut chip, - [98, 188, 163, 229], - [0, 0, 0, 0], &mut rng, - ); - run_rv32_divrem_rand_write_execute( opcode, + Some([98, 188, 163, 229]), + Some([0, 0, 0, 0]), + ); + set_and_execute( &mut tester, &mut chip, - [0, 0, 0, 128], - [0, 1, 0, 0], &mut rng, - ); - run_rv32_divrem_rand_write_execute( opcode, + Some([0, 0, 0, 128]), + Some([0, 1, 0, 0]), + ); + set_and_execute( &mut tester, &mut chip, - [0, 0, 0, 127], - [0, 1, 0, 0], &mut rng, - ); - run_rv32_divrem_rand_write_execute( opcode, + Some([0, 0, 0, 127]), + Some([0, 1, 0, 0]), + ); + set_and_execute( &mut tester, &mut chip, - [0, 0, 0, 0], - [0, 0, 0, 0], &mut rng, + opcode, + Some([0, 0, 0, 0]), + Some([0, 0, 0, 0]), ); - run_rv32_divrem_rand_write_execute( + set_and_execute( + &mut tester, + &mut chip, + &mut rng, opcode, + Some([0, 0, 0, 0]), + Some([0, 0, 0, 0]), + ); + set_and_execute( &mut tester, &mut chip, - [0, 0, 0, 128], - [255, 255, 255, 255], &mut rng, + opcode, + Some([0, 0, 0, 128]), + Some([255, 255, 255, 255]), ); let tester = tester .build() .load(chip) .load(bitwise_chip) - .load(range_tuple_checker) + .load(range_tuple_chip) .finalize(); tester.simple_test().expect("Verification failed"); } -#[test] -fn rv32_div_rand_test() { - run_rv32_divrem_rand_test(DivRemOpcode::DIV, 100); -} - -#[test] -fn rv32_divu_rand_test() { - run_rv32_divrem_rand_test(DivRemOpcode::DIVU, 100); -} - -#[test] -fn rv32_rem_rand_test() { - run_rv32_divrem_rand_test(DivRemOpcode::REM, 100); -} - -#[test] -fn rv32_remu_rand_test() { - run_rv32_divrem_rand_test(DivRemOpcode::REMU, 100); -} - ////////////////////////////////////////////////////////////////////////////////////// // NEGATIVE TESTS // // Given a fake trace of a single operation, setup a chip and run the test. We replace -// the write part of the trace and check that the core chip throws the expected error. -// A dummy adapter is used so memory interactions don't indirectly cause false passes. +// part of the trace and check that the chip throws the expected error. ////////////////////////////////////////////////////////////////////////////////////// -type Rv32DivRemTestChip = - VmChipWrapper, DivRemCoreChip>; - #[derive(Default, Clone, Copy)] struct DivRemPrankValues { pub q: Option<[u32; NUM_LIMBS]>, @@ -229,84 +233,20 @@ struct DivRemPrankValues { pub r_zero: Option, } -fn run_rv32_divrem_negative_test( - signed: bool, +fn run_negative_divrem_test( + opcode: DivRemOpcode, b: [u32; RV32_REGISTER_NUM_LIMBS], c: [u32; RV32_REGISTER_NUM_LIMBS], - prank_vals: &DivRemPrankValues, + prank_vals: DivRemPrankValues, interaction_error: bool, ) { - // the max number of limbs we currently support MUL for is 32 (i.e. for U256s) - const MAX_NUM_LIMBS: u32 = 32; - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let range_tuple_bus = RangeTupleCheckerBus::new( - RANGE_TUPLE_CHECKER_BUS, - [1 << RV32_CELL_BITS, MAX_NUM_LIMBS * (1 << RV32_CELL_BITS)], - ); - - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let range_tuple_chip = SharedRangeTupleCheckerChip::new(range_tuple_bus); - + let mut rng = create_seeded_rng(); let mut tester = VmChipTestBuilder::default(); - let mut chip = Rv32DivRemTestChip::::new( - TestAdapterChip::new( - vec![[b.map(F::from_canonical_u32), c.map(F::from_canonical_u32)].concat(); 2], - vec![None], - ExecutionBridge::new(tester.execution_bus(), tester.program_bus()), - ), - DivRemCoreChip::new( - bitwise_chip.clone(), - range_tuple_chip.clone(), - DivRemOpcode::CLASS_OFFSET, - ), - tester.offline_memory_mutex_arc(), - ); - - let (div_opcode, rem_opcode) = if signed { - (DivRemOpcode::DIV, DivRemOpcode::REM) - } else { - (DivRemOpcode::DIVU, DivRemOpcode::REMU) - }; - tester.execute( - &mut chip, - &Instruction::from_usize(div_opcode.global_opcode(), [0, 0, 0, 1, 1]), - ); - tester.execute( - &mut chip, - &Instruction::from_usize(rem_opcode.global_opcode(), [0, 0, 0, 1, 1]), - ); - - let (q, r, b_sign, c_sign, q_sign, case) = - run_divrem::(signed, &b, &c); - let q = prank_vals.q.unwrap_or(q); - let r = prank_vals.r.unwrap_or(r); - let carries = - run_mul_carries::(signed, &c, &q, &r, q_sign); - - range_tuple_chip.clear(); - for i in 0..RV32_REGISTER_NUM_LIMBS { - range_tuple_chip.add_count(&[q[i], carries[i]]); - range_tuple_chip.add_count(&[r[i], carries[i + RV32_REGISTER_NUM_LIMBS]]); - } - - if let Some(diff_val) = prank_vals.diff_val { - bitwise_chip.clear(); - if signed { - let b_sign_mask = if b_sign { 1 << (RV32_CELL_BITS - 1) } else { 0 }; - let c_sign_mask = if c_sign { 1 << (RV32_CELL_BITS - 1) } else { 0 }; - bitwise_chip.request_range( - (b[RV32_REGISTER_NUM_LIMBS - 1] - b_sign_mask) << 1, - (c[RV32_REGISTER_NUM_LIMBS - 1] - c_sign_mask) << 1, - ); - } - if case == DivRemCoreSpecialCase::None { - bitwise_chip.request_range(diff_val - 1, 0); - } - } + let (mut chip, bitwise_chip, range_tuple_chip) = create_test_chip(&mut tester); - let trace_width = chip.trace_width(); - let adapter_width = BaseAir::::width(chip.adapter.air()); + set_and_execute(&mut tester, &mut chip, &mut rng, opcode, Some(b), Some(c)); + let adapter_width = BaseAir::::width(&chip.air.adapter); let modify_trace = |trace: &mut DenseMatrix| { let mut values = trace.row_slice(0).to_vec(); let cols: &mut DivRemCoreCols = @@ -338,7 +278,7 @@ fn run_rv32_divrem_negative_test( cols.r_zero = F::from_bool(r_zero); } - *trace = RowMajorMatrix::new(values, trace_width); + *trace = RowMajorMatrix::new(values, trace.width()); }; disable_debug_builder(); @@ -348,11 +288,7 @@ fn run_rv32_divrem_negative_test( .load(bitwise_chip) .load(range_tuple_chip) .finalize(); - tester.simple_test_with_expected_error(if interaction_error { - VerificationError::ChallengePhaseError - } else { - VerificationError::OodEvaluationMismatch - }); + tester.simple_test_with_expected_error(get_verification_error(interaction_error)); } #[test] @@ -363,7 +299,8 @@ fn rv32_divrem_unsigned_wrong_q_negative_test() { q: Some([245, 168, 7, 0]), ..Default::default() }; - run_rv32_divrem_negative_test(false, b, c, &prank_vals, true); + run_negative_divrem_test(DIVU, b, c, prank_vals, true); + run_negative_divrem_test(REMU, b, c, prank_vals, true); } #[test] @@ -376,7 +313,8 @@ fn rv32_divrem_unsigned_wrong_r_negative_test() { diff_val: Some(31), ..Default::default() }; - run_rv32_divrem_negative_test(false, b, c, &prank_vals, true); + run_negative_divrem_test(DIVU, b, c, prank_vals, true); + run_negative_divrem_test(REMU, b, c, prank_vals, true); } #[test] @@ -387,7 +325,8 @@ fn rv32_divrem_unsigned_high_mult_negative_test() { q: Some([128, 0, 0, 1]), ..Default::default() }; - run_rv32_divrem_negative_test(false, b, c, &prank_vals, true); + run_negative_divrem_test(DIVU, b, c, prank_vals, true); + run_negative_divrem_test(REMU, b, c, prank_vals, true); } #[test] @@ -400,7 +339,8 @@ fn rv32_divrem_unsigned_zero_divisor_wrong_r_negative_test() { diff_val: Some(255), ..Default::default() }; - run_rv32_divrem_negative_test(false, b, c, &prank_vals, true); + run_negative_divrem_test(DIVU, b, c, prank_vals, true); + run_negative_divrem_test(REMU, b, c, prank_vals, true); } #[test] @@ -411,7 +351,8 @@ fn rv32_divrem_signed_wrong_q_negative_test() { q: Some([74, 61, 255, 255]), ..Default::default() }; - run_rv32_divrem_negative_test(true, b, c, &prank_vals, true); + run_negative_divrem_test(DIV, b, c, prank_vals, true); + run_negative_divrem_test(REM, b, c, prank_vals, true); } #[test] @@ -424,7 +365,8 @@ fn rv32_divrem_signed_wrong_r_negative_test() { diff_val: Some(20), ..Default::default() }; - run_rv32_divrem_negative_test(true, b, c, &prank_vals, true); + run_negative_divrem_test(DIV, b, c, prank_vals, true); + run_negative_divrem_test(REM, b, c, prank_vals, true); } #[test] @@ -435,7 +377,8 @@ fn rv32_divrem_signed_high_mult_negative_test() { q: Some([1, 0, 0, 1]), ..Default::default() }; - run_rv32_divrem_negative_test(true, b, c, &prank_vals, true); + run_negative_divrem_test(DIV, b, c, prank_vals, true); + run_negative_divrem_test(REM, b, c, prank_vals, true); } #[test] @@ -449,7 +392,8 @@ fn rv32_divrem_signed_r_wrong_sign_negative_test() { diff_val: Some(192), ..Default::default() }; - run_rv32_divrem_negative_test(true, b, c, &prank_vals, false); + run_negative_divrem_test(DIV, b, c, prank_vals, false); + run_negative_divrem_test(REM, b, c, prank_vals, false); } #[test] @@ -463,7 +407,8 @@ fn rv32_divrem_signed_r_wrong_prime_negative_test() { diff_val: Some(36), ..Default::default() }; - run_rv32_divrem_negative_test(true, b, c, &prank_vals, false); + run_negative_divrem_test(DIV, b, c, prank_vals, false); + run_negative_divrem_test(REM, b, c, prank_vals, false); } #[test] @@ -476,7 +421,8 @@ fn rv32_divrem_signed_zero_divisor_wrong_r_negative_test() { diff_val: Some(1), ..Default::default() }; - run_rv32_divrem_negative_test(true, b, c, &prank_vals, true); + run_negative_divrem_test(DIV, b, c, prank_vals, true); + run_negative_divrem_test(REM, b, c, prank_vals, true); } #[test] @@ -491,8 +437,10 @@ fn rv32_divrem_false_zero_divisor_flag_negative_test() { zero_divisor: Some(true), ..Default::default() }; - run_rv32_divrem_negative_test(true, b, c, &prank_vals, false); - run_rv32_divrem_negative_test(false, b, c, &prank_vals, false); + run_negative_divrem_test(DIVU, b, c, prank_vals, false); + run_negative_divrem_test(REMU, b, c, prank_vals, false); + run_negative_divrem_test(DIV, b, c, prank_vals, false); + run_negative_divrem_test(REM, b, c, prank_vals, false); } #[test] @@ -507,8 +455,10 @@ fn rv32_divrem_false_r_zero_flag_negative_test() { r_zero: Some(true), ..Default::default() }; - run_rv32_divrem_negative_test(true, b, c, &prank_vals, false); - run_rv32_divrem_negative_test(false, b, c, &prank_vals, false); + run_negative_divrem_test(DIVU, b, c, prank_vals, false); + run_negative_divrem_test(REMU, b, c, prank_vals, false); + run_negative_divrem_test(DIV, b, c, prank_vals, false); + run_negative_divrem_test(REM, b, c, prank_vals, false); } #[test] @@ -519,8 +469,10 @@ fn rv32_divrem_unset_zero_divisor_flag_negative_test() { zero_divisor: Some(false), ..Default::default() }; - run_rv32_divrem_negative_test(true, b, c, &prank_vals, false); - run_rv32_divrem_negative_test(false, b, c, &prank_vals, false); + run_negative_divrem_test(DIVU, b, c, prank_vals, false); + run_negative_divrem_test(REMU, b, c, prank_vals, false); + run_negative_divrem_test(DIV, b, c, prank_vals, false); + run_negative_divrem_test(REM, b, c, prank_vals, false); } #[test] @@ -532,8 +484,10 @@ fn rv32_divrem_wrong_r_zero_flag_negative_test() { r_zero: Some(true), ..Default::default() }; - run_rv32_divrem_negative_test(true, b, c, &prank_vals, false); - run_rv32_divrem_negative_test(false, b, c, &prank_vals, false); + run_negative_divrem_test(DIVU, b, c, prank_vals, false); + run_negative_divrem_test(REMU, b, c, prank_vals, false); + run_negative_divrem_test(DIV, b, c, prank_vals, false); + run_negative_divrem_test(REM, b, c, prank_vals, false); } #[test] @@ -544,8 +498,10 @@ fn rv32_divrem_unset_r_zero_flag_negative_test() { r_zero: Some(false), ..Default::default() }; - run_rv32_divrem_negative_test(true, b, c, &prank_vals, false); - run_rv32_divrem_negative_test(false, b, c, &prank_vals, false); + run_negative_divrem_test(DIVU, b, c, prank_vals, false); + run_negative_divrem_test(REMU, b, c, prank_vals, false); + run_negative_divrem_test(DIV, b, c, prank_vals, false); + run_negative_divrem_test(REM, b, c, prank_vals, false); } /////////////////////////////////////////////////////////////////////////////////////// diff --git a/extensions/rv32im/circuit/src/extension.rs b/extensions/rv32im/circuit/src/extension.rs index f8dd2fbf54..4993a87762 100644 --- a/extensions/rv32im/circuit/src/extension.rs +++ b/extensions/rv32im/circuit/src/extension.rs @@ -1,12 +1,12 @@ use derive_more::derive::From; use openvm_circuit::{ arch::{ - InitFileGenerator, SystemConfig, SystemPort, VmExtension, VmInventory, VmInventoryBuilder, - VmInventoryError, + ExecutionBridge, InitFileGenerator, SystemConfig, SystemPort, VmAirWrapper, VmExtension, + VmInventory, VmInventoryBuilder, VmInventoryError, }, system::phantom::PhantomChip, }; -use openvm_circuit_derive::{AnyEnum, InstructionExecutor, VmConfig}; +use openvm_circuit_derive::{AnyEnum, InsExecutorE1, InsExecutorE2, InstructionExecutor, VmConfig}; use openvm_circuit_primitives::{ bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, range_tuple::{RangeTupleCheckerBus, SharedRangeTupleCheckerChip}, @@ -134,7 +134,9 @@ fn default_range_tuple_checker_sizes() -> [u32; 2] { // ============ Executor and Periphery Enums for Extension ============ /// RISC-V 32-bit Base (RV32I) Instruction Executors -#[derive(ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum)] +#[derive( + ChipUsageGetter, Chip, InstructionExecutor, InsExecutorE1, InsExecutorE2, From, AnyEnum, +)] pub enum Rv32IExecutor { // Rv32 (for standard 32-bit integers): BaseAlu(Rv32BaseAluChip), @@ -150,7 +152,9 @@ pub enum Rv32IExecutor { } /// RISC-V 32-bit Multiplication Extension (RV32M) Instruction Executors -#[derive(ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum)] +#[derive( + ChipUsageGetter, Chip, InstructionExecutor, InsExecutorE1, InsExecutorE2, From, AnyEnum, +)] pub enum Rv32MExecutor { Multiplication(Rv32MultiplicationChip), MultiplicationHigh(Rv32MulHChip), @@ -158,7 +162,9 @@ pub enum Rv32MExecutor { } /// RISC-V 32-bit Io Instruction Executors -#[derive(ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum)] +#[derive( + ChipUsageGetter, Chip, InstructionExecutor, InsExecutorE1, InsExecutorE2, From, AnyEnum, +)] pub enum Rv32IoExecutor { HintStore(Rv32HintStoreChip), } @@ -204,7 +210,6 @@ impl VmExtension for Rv32I { } = builder.system_port(); let range_checker = builder.system_base().range_checker_chip.clone(); - let offline_memory = builder.system_base().offline_memory(); let pointer_max_bits = builder.system_config().memory_config.pointer_max_bits; let bitwise_lu_chip = if let Some(&chip) = builder @@ -220,14 +225,20 @@ impl VmExtension for Rv32I { }; let base_alu_chip = Rv32BaseAluChip::new( - Rv32BaseAluAdapterChip::new( - execution_bus, - program_bus, - memory_bridge, + VmAirWrapper::new( + Rv32BaseAluAdapterAir::new( + ExecutionBridge::new(execution_bus, program_bus), + memory_bridge, + bitwise_lu_chip.bus(), + ), + BaseAluCoreAir::new(bitwise_lu_chip.bus(), BaseAluOpcode::CLASS_OFFSET), + ), + Rv32BaseAluStep::new( + Rv32BaseAluAdapterStep::new(bitwise_lu_chip.clone()), bitwise_lu_chip.clone(), + BaseAluOpcode::CLASS_OFFSET, ), - BaseAluCoreChip::new(bitwise_lu_chip.clone(), BaseAluOpcode::CLASS_OFFSET), - offline_memory.clone(), + builder.system_base().memory_controller.helper(), ); inventory.add_executor( base_alu_chip, @@ -235,43 +246,61 @@ impl VmExtension for Rv32I { )?; let lt_chip = Rv32LessThanChip::new( - Rv32BaseAluAdapterChip::new( - execution_bus, - program_bus, - memory_bridge, + VmAirWrapper::new( + Rv32BaseAluAdapterAir::new( + ExecutionBridge::new(execution_bus, program_bus), + memory_bridge, + bitwise_lu_chip.bus(), + ), + LessThanCoreAir::new(bitwise_lu_chip.bus(), LessThanOpcode::CLASS_OFFSET), + ), + LessThanStep::new( + Rv32BaseAluAdapterStep::new(bitwise_lu_chip.clone()), bitwise_lu_chip.clone(), + LessThanOpcode::CLASS_OFFSET, ), - LessThanCoreChip::new(bitwise_lu_chip.clone(), LessThanOpcode::CLASS_OFFSET), - offline_memory.clone(), + builder.system_base().memory_controller.helper(), ); inventory.add_executor(lt_chip, LessThanOpcode::iter().map(|x| x.global_opcode()))?; let shift_chip = Rv32ShiftChip::new( - Rv32BaseAluAdapterChip::new( - execution_bus, - program_bus, - memory_bridge, - bitwise_lu_chip.clone(), + VmAirWrapper::new( + Rv32BaseAluAdapterAir::new( + ExecutionBridge::new(execution_bus, program_bus), + memory_bridge, + bitwise_lu_chip.bus(), + ), + ShiftCoreAir::new( + bitwise_lu_chip.bus(), + range_checker.bus(), + ShiftOpcode::CLASS_OFFSET, + ), ), - ShiftCoreChip::new( + ShiftStep::new( + Rv32BaseAluAdapterStep::new(bitwise_lu_chip.clone()), bitwise_lu_chip.clone(), range_checker.clone(), ShiftOpcode::CLASS_OFFSET, ), - offline_memory.clone(), + builder.system_base().memory_controller.helper(), ); inventory.add_executor(shift_chip, ShiftOpcode::iter().map(|x| x.global_opcode()))?; let load_store_chip = Rv32LoadStoreChip::new( - Rv32LoadStoreAdapterChip::new( - execution_bus, - program_bus, - memory_bridge, - pointer_max_bits, - range_checker.clone(), + VmAirWrapper::new( + Rv32LoadStoreAdapterAir::new( + memory_bridge, + ExecutionBridge::new(execution_bus, program_bus), + range_checker.bus(), + pointer_max_bits, + ), + LoadStoreCoreAir::new(Rv32LoadStoreOpcode::CLASS_OFFSET), + ), + LoadStoreStep::new( + Rv32LoadStoreAdapterStep::new(pointer_max_bits, range_checker.clone()), + Rv32LoadStoreOpcode::CLASS_OFFSET, ), - LoadStoreCoreChip::new(Rv32LoadStoreOpcode::CLASS_OFFSET), - offline_memory.clone(), + builder.system_base().memory_controller.helper(), ); inventory.add_executor( load_store_chip, @@ -281,15 +310,20 @@ impl VmExtension for Rv32I { )?; let load_sign_extend_chip = Rv32LoadSignExtendChip::new( - Rv32LoadStoreAdapterChip::new( - execution_bus, - program_bus, - memory_bridge, - pointer_max_bits, + VmAirWrapper::new( + Rv32LoadStoreAdapterAir::new( + memory_bridge, + ExecutionBridge::new(execution_bus, program_bus), + range_checker.bus(), + pointer_max_bits, + ), + LoadSignExtendCoreAir::new(range_checker.bus()), + ), + LoadSignExtendStep::new( + Rv32LoadStoreAdapterStep::new(pointer_max_bits, range_checker.clone()), range_checker.clone(), ), - LoadSignExtendCoreChip::new(range_checker.clone()), - offline_memory.clone(), + builder.system_base().memory_controller.helper(), ); inventory.add_executor( load_sign_extend_chip, @@ -297,49 +331,94 @@ impl VmExtension for Rv32I { )?; let beq_chip = Rv32BranchEqualChip::new( - Rv32BranchAdapterChip::new(execution_bus, program_bus, memory_bridge), - BranchEqualCoreChip::new(BranchEqualOpcode::CLASS_OFFSET, DEFAULT_PC_STEP), - offline_memory.clone(), + VmAirWrapper::new( + Rv32BranchAdapterAir::new( + ExecutionBridge::new(execution_bus, program_bus), + memory_bridge, + ), + BranchEqualCoreAir::new(BranchEqualOpcode::CLASS_OFFSET, DEFAULT_PC_STEP), + ), + BranchEqualStep::new( + Rv32BranchAdapterStep::new(), + BranchEqualOpcode::CLASS_OFFSET, + DEFAULT_PC_STEP, + ), + builder.system_base().memory_controller.helper(), ); inventory.add_executor( beq_chip, BranchEqualOpcode::iter().map(|x| x.global_opcode()), )?; - let blt_chip = Rv32BranchLessThanChip::new( - Rv32BranchAdapterChip::new(execution_bus, program_bus, memory_bridge), - BranchLessThanCoreChip::new( + let blt_chip = Rv32BranchLessThanChip::::new( + VmAirWrapper::new( + Rv32BranchAdapterAir::new( + ExecutionBridge::new(execution_bus, program_bus), + memory_bridge, + ), + BranchLessThanCoreAir::new( + bitwise_lu_chip.bus(), + BranchLessThanOpcode::CLASS_OFFSET, + ), + ), + BranchLessThanStep::new( + Rv32BranchAdapterStep::new(), bitwise_lu_chip.clone(), BranchLessThanOpcode::CLASS_OFFSET, ), - offline_memory.clone(), + builder.system_base().memory_controller.helper(), ); inventory.add_executor( blt_chip, BranchLessThanOpcode::iter().map(|x| x.global_opcode()), )?; - let jal_lui_chip = Rv32JalLuiChip::new( - Rv32CondRdWriteAdapterChip::new(execution_bus, program_bus, memory_bridge), - Rv32JalLuiCoreChip::new(bitwise_lu_chip.clone()), - offline_memory.clone(), + let jal_lui_chip = Rv32JalLuiChip::::new( + VmAirWrapper::new( + Rv32CondRdWriteAdapterAir::new(Rv32RdWriteAdapterAir::new( + memory_bridge, + ExecutionBridge::new(execution_bus, program_bus), + )), + Rv32JalLuiCoreAir::new(bitwise_lu_chip.bus()), + ), + Rv32JalLuiStep::new( + Rv32CondRdWriteAdapterStep::new(Rv32RdWriteAdapterStep::new()), + bitwise_lu_chip.clone(), + ), + builder.system_base().memory_controller.helper(), ); inventory.add_executor( jal_lui_chip, Rv32JalLuiOpcode::iter().map(|x| x.global_opcode()), )?; - let jalr_chip = Rv32JalrChip::new( - Rv32JalrAdapterChip::new(execution_bus, program_bus, memory_bridge), - Rv32JalrCoreChip::new(bitwise_lu_chip.clone(), range_checker.clone()), - offline_memory.clone(), + let jalr_chip = Rv32JalrChip::::new( + VmAirWrapper::new( + Rv32JalrAdapterAir::new( + memory_bridge, + ExecutionBridge::new(execution_bus, program_bus), + ), + Rv32JalrCoreAir::new(bitwise_lu_chip.bus(), range_checker.bus()), + ), + Rv32JalrStep::new( + Rv32JalrAdapterStep::new(), + bitwise_lu_chip.clone(), + range_checker.clone(), + ), + builder.system_base().memory_controller.helper(), ); inventory.add_executor(jalr_chip, Rv32JalrOpcode::iter().map(|x| x.global_opcode()))?; - let auipc_chip = Rv32AuipcChip::new( - Rv32RdWriteAdapterChip::new(execution_bus, program_bus, memory_bridge), - Rv32AuipcCoreChip::new(bitwise_lu_chip.clone()), - offline_memory.clone(), + let auipc_chip = Rv32AuipcChip::::new( + VmAirWrapper::new( + Rv32RdWriteAdapterAir::new( + memory_bridge, + ExecutionBridge::new(execution_bus, program_bus), + ), + Rv32AuipcCoreAir::new(bitwise_lu_chip.bus()), + ), + Rv32AuipcStep::new(Rv32RdWriteAdapterStep::new(), bitwise_lu_chip.clone()), + builder.system_base().memory_controller.helper(), ); inventory.add_executor( auipc_chip, @@ -352,7 +431,7 @@ impl VmExtension for Rv32I { PhantomDiscriminant(Rv32Phantom::HintInput as u16), )?; builder.add_phantom_sub_executor( - phantom::Rv32HintRandomSubEx::new(), + phantom::Rv32HintRandomSubEx, PhantomDiscriminant(Rv32Phantom::HintRandom as u16), )?; builder.add_phantom_sub_executor( @@ -382,7 +461,6 @@ impl VmExtension for Rv32M { program_bus, memory_bridge, } = builder.system_port(); - let offline_memory = builder.system_base().offline_memory(); let bitwise_lu_chip = if let Some(&chip) = builder .find_chip::>() @@ -412,28 +490,59 @@ impl VmExtension for Rv32M { chip }; - let mul_chip = Rv32MultiplicationChip::new( - Rv32MultAdapterChip::new(execution_bus, program_bus, memory_bridge), - MultiplicationCoreChip::new(range_tuple_checker.clone(), MulOpcode::CLASS_OFFSET), - offline_memory.clone(), + let mul_chip = Rv32MultiplicationChip::::new( + VmAirWrapper::new( + Rv32MultAdapterAir::new( + ExecutionBridge::new(execution_bus, program_bus), + memory_bridge, + ), + MultiplicationCoreAir::new(*range_tuple_checker.bus(), MulOpcode::CLASS_OFFSET), + ), + MultiplicationStep::new( + Rv32MultAdapterStep::new(), + range_tuple_checker.clone(), + MulOpcode::CLASS_OFFSET, + ), + builder.system_base().memory_controller.helper(), ); inventory.add_executor(mul_chip, MulOpcode::iter().map(|x| x.global_opcode()))?; - let mul_h_chip = Rv32MulHChip::new( - Rv32MultAdapterChip::new(execution_bus, program_bus, memory_bridge), - MulHCoreChip::new(bitwise_lu_chip.clone(), range_tuple_checker.clone()), - offline_memory.clone(), + let mul_h_chip = Rv32MulHChip::::new( + VmAirWrapper::new( + Rv32MultAdapterAir::new( + ExecutionBridge::new(execution_bus, program_bus), + memory_bridge, + ), + MulHCoreAir::new(bitwise_lu_chip.bus(), *range_tuple_checker.bus()), + ), + MulHStep::new( + Rv32MultAdapterStep::new(), + bitwise_lu_chip.clone(), + range_tuple_checker.clone(), + ), + builder.system_base().memory_controller.helper(), ); inventory.add_executor(mul_h_chip, MulHOpcode::iter().map(|x| x.global_opcode()))?; - let div_rem_chip = Rv32DivRemChip::new( - Rv32MultAdapterChip::new(execution_bus, program_bus, memory_bridge), - DivRemCoreChip::new( + let div_rem_chip = Rv32DivRemChip::::new( + VmAirWrapper::new( + Rv32MultAdapterAir::new( + ExecutionBridge::new(execution_bus, program_bus), + memory_bridge, + ), + DivRemCoreAir::new( + bitwise_lu_chip.bus(), + *range_tuple_checker.bus(), + DivRemOpcode::CLASS_OFFSET, + ), + ), + DivRemStep::new( + Rv32MultAdapterStep::new(), bitwise_lu_chip.clone(), range_tuple_checker.clone(), DivRemOpcode::CLASS_OFFSET, ), - offline_memory.clone(), + builder.system_base().memory_controller.helper(), ); inventory.add_executor( div_rem_chip, @@ -458,7 +567,6 @@ impl VmExtension for Rv32Io { program_bus, memory_bridge, } = builder.system_port(); - let offline_memory = builder.system_base().offline_memory(); let bitwise_lu_chip = if let Some(&chip) = builder .find_chip::>() @@ -472,16 +580,21 @@ impl VmExtension for Rv32Io { chip }; - let mut hintstore_chip = Rv32HintStoreChip::new( - execution_bus, - program_bus, - bitwise_lu_chip.clone(), - memory_bridge, - offline_memory.clone(), - builder.system_config().memory_config.pointer_max_bits, - Rv32HintStoreOpcode::CLASS_OFFSET, + let hintstore_chip = Rv32HintStoreChip::::new( + Rv32HintStoreAir::new( + ExecutionBridge::new(execution_bus, program_bus), + memory_bridge, + bitwise_lu_chip.bus(), + Rv32HintStoreOpcode::CLASS_OFFSET, + builder.system_config().memory_config.pointer_max_bits, + ), + Rv32HintStoreStep::new( + bitwise_lu_chip, + builder.system_config().memory_config.pointer_max_bits, + Rv32HintStoreOpcode::CLASS_OFFSET, + ), + builder.system_base().memory_controller.helper(), ); - hintstore_chip.set_streams(builder.streams().clone()); inventory.add_executor( hintstore_chip, @@ -497,34 +610,28 @@ mod phantom { use eyre::bail; use openvm_circuit::{ arch::{PhantomSubExecutor, Streams}, - system::memory::MemoryController, + system::memory::online::GuestMemory, }; use openvm_instructions::PhantomDiscriminant; use openvm_stark_backend::p3_field::{Field, PrimeField32}; - use rand::{rngs::OsRng, Rng}; + use rand::{rngs::StdRng, Rng}; - use crate::adapters::unsafe_read_rv32_register; + use crate::adapters::{memory_read, read_rv32_register}; pub struct Rv32HintInputSubEx; - pub struct Rv32HintRandomSubEx { - rng: OsRng, - } - impl Rv32HintRandomSubEx { - pub fn new() -> Self { - Self { rng: OsRng } - } - } + pub struct Rv32HintRandomSubEx; pub struct Rv32PrintStrSubEx; pub struct Rv32HintLoadByKeySubEx; impl PhantomSubExecutor for Rv32HintInputSubEx { fn phantom_execute( - &mut self, - _: &MemoryController, + &self, + _: &GuestMemory, streams: &mut Streams, + _: &mut StdRng, _: PhantomDiscriminant, - _: F, - _: F, + _: u32, + _: u32, _: u16, ) -> eyre::Result<()> { let mut hint = match streams.input_stream.pop_front() { @@ -550,18 +657,19 @@ mod phantom { impl PhantomSubExecutor for Rv32HintRandomSubEx { fn phantom_execute( - &mut self, - memory: &MemoryController, + &self, + memory: &GuestMemory, streams: &mut Streams, + rng: &mut StdRng, _: PhantomDiscriminant, - a: F, - _: F, + a: u32, + _: u32, _: u16, ) -> eyre::Result<()> { - let len = unsafe_read_rv32_register(memory, a) as usize; + let len = read_rv32_register(memory, a) as usize; streams.hint_stream.clear(); streams.hint_stream.extend( - std::iter::repeat_with(|| F::from_canonical_u8(self.rng.gen::())).take(len * 4), + std::iter::repeat_with(|| F::from_canonical_u8(rng.gen::())).take(len * 4), ); Ok(()) } @@ -569,23 +677,20 @@ mod phantom { impl PhantomSubExecutor for Rv32PrintStrSubEx { fn phantom_execute( - &mut self, - memory: &MemoryController, + &self, + memory: &GuestMemory, _: &mut Streams, + _: &mut StdRng, _: PhantomDiscriminant, - a: F, - b: F, + a: u32, + b: u32, _: u16, ) -> eyre::Result<()> { - let rd = unsafe_read_rv32_register(memory, a); - let rs1 = unsafe_read_rv32_register(memory, b); + let rd = read_rv32_register(memory, a); + let rs1 = read_rv32_register(memory, b); let bytes = (0..rs1) - .map(|i| -> eyre::Result { - let val = memory.unsafe_read_cell(F::TWO, F::from_canonical_u32(rd + i)); - let byte: u8 = val.as_canonical_u32().try_into()?; - Ok(byte) - }) - .collect::>>()?; + .map(|i| memory_read::<1>(memory, 2, rd + i)[0]) + .collect::>(); let peeked_str = String::from_utf8(bytes)?; print!("{peeked_str}"); Ok(()) @@ -594,22 +699,19 @@ mod phantom { impl PhantomSubExecutor for Rv32HintLoadByKeySubEx { fn phantom_execute( - &mut self, - memory: &MemoryController, + &self, + memory: &GuestMemory, streams: &mut Streams, + _: &mut StdRng, _: PhantomDiscriminant, - a: F, - b: F, + a: u32, + b: u32, _: u16, ) -> eyre::Result<()> { - let ptr = unsafe_read_rv32_register(memory, a); - let len = unsafe_read_rv32_register(memory, b); + let ptr = read_rv32_register(memory, a); + let len = read_rv32_register(memory, b); let key: Vec = (0..len) - .map(|i| { - memory - .unsafe_read_cell(F::TWO, F::from_canonical_u32(ptr + i)) - .as_canonical_u32() as u8 - }) + .map(|i| memory_read::<1>(memory, 2, ptr + i)[0]) .collect(); if let Some(val) = streams.kv_store.get(&key) { let to_push = hint_load_by_key_decode::(val); diff --git a/extensions/rv32im/circuit/src/hintstore/mod.rs b/extensions/rv32im/circuit/src/hintstore/mod.rs index d566292207..3872ea8dd2 100644 --- a/extensions/rv32im/circuit/src/hintstore/mod.rs +++ b/extensions/rv32im/circuit/src/hintstore/mod.rs @@ -1,25 +1,29 @@ -use std::{ - borrow::{Borrow, BorrowMut}, - sync::{Arc, Mutex, OnceLock}, -}; +use std::borrow::{Borrow, BorrowMut}; use openvm_circuit::{ arch::{ - ExecutionBridge, ExecutionBus, ExecutionError, ExecutionState, InstructionExecutor, Streams, + execution_mode::{E1ExecutionCtx, E2ExecutionCtx}, + get_record_from_slice, CustomBorrow, E2PreCompute, ExecuteFunc, ExecutionBridge, + ExecutionError, + ExecutionError::InvalidInstruction, + ExecutionState, MatrixRecordArena, MultiRowLayout, MultiRowMetadata, NewVmChipWrapper, + RecordArena, Result, SizedRecord, StepExecutorE1, StepExecutorE2, TraceFiller, TraceStep, + VmSegmentState, VmStateMut, }, - system::{ - memory::{ - offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols}, - MemoryAddress, MemoryAuxColsFactory, MemoryController, OfflineMemory, RecordId, + system::memory::{ + offline_checker::{ + MemoryBridge, MemoryReadAuxCols, MemoryReadAuxRecord, MemoryWriteAuxCols, + MemoryWriteBytesAuxRecord, }, - program::ProgramBus, + online::TracingMemory, + MemoryAddress, MemoryAuxColsFactory, }, }; use openvm_circuit_primitives::{ bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, - utils::{next_power_of_two_or_zero, not}, + utils::not, }; -use openvm_circuit_primitives_derive::AlignedBorrow; +use openvm_circuit_primitives_derive::{AlignedBorrow, AlignedBytesBorrow}; use openvm_instructions::{ instruction::Instruction, program::DEFAULT_PC_STEP, @@ -31,18 +35,15 @@ use openvm_rv32im_transpiler::{ Rv32HintStoreOpcode::{HINT_BUFFER, HINT_STOREW}, }; use openvm_stark_backend::{ - config::{StarkGenericConfig, Val}, interaction::InteractionBuilder, p3_air::{Air, AirBuilder, BaseAir}, p3_field::{Field, FieldAlgebra, PrimeField32}, p3_matrix::{dense::RowMajorMatrix, Matrix}, - prover::types::AirProofInput, - rap::{AnyRap, BaseAirWithPublicValues, PartitionedBaseAir}, - Chip, ChipUsageGetter, + p3_maybe_rayon::prelude::*, + rap::{BaseAirWithPublicValues, PartitionedBaseAir}, }; -use serde::{Deserialize, Serialize}; -use crate::adapters::{compose, decompose}; +use crate::adapters::{read_rv32_register, tracing_read, tracing_write}; #[cfg(test)] mod tests; @@ -70,7 +71,7 @@ pub struct Rv32HintStoreCols { pub num_words_aux_cols: MemoryReadAuxCols, } -#[derive(Copy, Clone, Debug)] +#[derive(Copy, Clone, Debug, derive_new::new)] pub struct Rv32HintStoreAir { pub execution_bridge: ExecutionBridge, pub memory_bridge: MemoryBridge, @@ -182,7 +183,6 @@ impl Air for Rv32HintStoreAir { &local_cols.write_aux, ) .eval(builder, is_valid.clone()); - let expected_opcode = (local_cols.is_single * AB::F::from_canonical_usize(HINT_STOREW as usize + self.offset)) + (local_cols.is_buffer @@ -264,265 +264,498 @@ impl Air for Rv32HintStoreAir { } } -#[derive(Serialize, Deserialize)] -#[serde(bound = "F: Field")] -pub struct Rv32HintStoreRecord { - pub from_state: ExecutionState, - pub instruction: Instruction, - pub mem_ptr_read: RecordId, - pub mem_ptr: u32, +#[derive(Copy, Clone, Debug)] +pub struct Rv32HintStoreMetadata { + num_words: usize, +} + +impl MultiRowMetadata for Rv32HintStoreMetadata { + #[inline(always)] + fn get_num_rows(&self) -> usize { + self.num_words + } +} + +pub type Rv32HintStoreLayout = MultiRowLayout; + +// This is the part of the record that we keep only once per instruction +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug)] +pub struct Rv32HintStoreRecordHeader { pub num_words: u32, - pub num_words_read: Option, - pub hints: Vec<([F; RV32_REGISTER_NUM_LIMBS], RecordId)>, + pub from_pc: u32, + pub timestamp: u32, + + pub mem_ptr_ptr: u32, + pub mem_ptr: u32, + pub mem_ptr_aux_record: MemoryReadAuxRecord, + + // will set `num_words_ptr` to `u32::MAX` in case of single hint + pub num_words_ptr: u32, + pub num_words_read: MemoryReadAuxRecord, +} + +// This is the part of the record that we keep `num_words` times per instruction +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug)] +pub struct Rv32HintStoreVar { + pub data_write_aux: MemoryWriteBytesAuxRecord, + pub data: [u8; RV32_REGISTER_NUM_LIMBS], +} + +/// **SAFETY**: the order of the fields in `Rv32HintStoreRecord` and `Rv32HintStoreVar` is +/// important. The chip also assumes that the offset of the fields `write_aux` and `data` in +/// `Rv32HintStoreCols` is bigger than `size_of::()` +#[derive(Debug)] +pub struct Rv32HintStoreRecordMut<'a> { + pub inner: &'a mut Rv32HintStoreRecordHeader, + pub var: &'a mut [Rv32HintStoreVar], +} + +/// Custom borrowing that splits the buffer into a fixed `Rv32HintStoreRecord` header +/// followed by a slice of `Rv32HintStoreVar`'s of length `num_words` provided at runtime. +/// Uses `align_to_mut()` to make sure the slice is properly aligned to `Rv32HintStoreVar`. +/// Has debug assertions to make sure the above works as expected. +impl<'a> CustomBorrow<'a, Rv32HintStoreRecordMut<'a>, Rv32HintStoreLayout> for [u8] { + fn custom_borrow(&'a mut self, layout: Rv32HintStoreLayout) -> Rv32HintStoreRecordMut<'a> { + let (header_buf, rest) = + unsafe { self.split_at_mut_unchecked(size_of::()) }; + + let (_, vars, _) = unsafe { rest.align_to_mut::() }; + Rv32HintStoreRecordMut { + inner: header_buf.borrow_mut(), + var: &mut vars[..layout.metadata.num_words], + } + } + + unsafe fn extract_layout(&self) -> Rv32HintStoreLayout { + let header: &Rv32HintStoreRecordHeader = self.borrow(); + MultiRowLayout::new(Rv32HintStoreMetadata { + num_words: header.num_words as usize, + }) + } +} + +impl SizedRecord for Rv32HintStoreRecordMut<'_> { + fn size(layout: &Rv32HintStoreLayout) -> usize { + let mut total_len = size_of::(); + // Align the pointer to the alignment of `Rv32HintStoreVar` + total_len = total_len.next_multiple_of(align_of::()); + total_len += size_of::() * layout.metadata.num_words; + total_len + } + + fn alignment(_layout: &Rv32HintStoreLayout) -> usize { + align_of::() + } } -pub struct Rv32HintStoreChip { - air: Rv32HintStoreAir, - pub records: Vec>, - pub height: usize, - offline_memory: Arc>>, - pub streams: OnceLock>>>, +pub struct Rv32HintStoreStep { + pointer_max_bits: usize, + offset: usize, bitwise_lookup_chip: SharedBitwiseOperationLookupChip, } -impl Rv32HintStoreChip { +impl Rv32HintStoreStep { pub fn new( - execution_bus: ExecutionBus, - program_bus: ProgramBus, bitwise_lookup_chip: SharedBitwiseOperationLookupChip, - memory_bridge: MemoryBridge, - offline_memory: Arc>>, pointer_max_bits: usize, offset: usize, ) -> Self { - let air = Rv32HintStoreAir { - execution_bridge: ExecutionBridge::new(execution_bus, program_bus), - memory_bridge, - bitwise_operation_lookup_bus: bitwise_lookup_chip.bus(), - offset, - pointer_max_bits, - }; Self { - records: vec![], - air, - height: 0, - offline_memory, - streams: OnceLock::new(), + pointer_max_bits, + offset, bitwise_lookup_chip, } } - pub fn set_streams(&mut self, streams: Arc>>) { - self.streams - .set(streams) - .map_err(|_| "streams have already been set.") - .unwrap(); - } } -impl InstructionExecutor for Rv32HintStoreChip { - fn execute( +impl TraceStep for Rv32HintStoreStep +where + F: PrimeField32, +{ + type RecordLayout = MultiRowLayout; + type RecordMut<'a> = Rv32HintStoreRecordMut<'a>; + + fn get_opcode_name(&self, opcode: usize) -> String { + if opcode == HINT_STOREW.global_opcode().as_usize() { + String::from("HINT_STOREW") + } else if opcode == HINT_BUFFER.global_opcode().as_usize() { + String::from("HINT_BUFFER") + } else { + unreachable!("unsupported opcode: {}", opcode) + } + } + + fn execute<'buf, RA>( &mut self, - memory: &mut MemoryController, + state: VmStateMut, CTX>, instruction: &Instruction, - from_state: ExecutionState, - ) -> Result, ExecutionError> { + arena: &'buf mut RA, + ) -> Result<()> + where + RA: RecordArena<'buf, Self::RecordLayout, Self::RecordMut<'buf>>, + { let &Instruction { - opcode, - a: num_words_ptr, - b: mem_ptr_ptr, - d, - e, - .. + opcode, a, b, d, e, .. } = instruction; + + let a = a.as_canonical_u32(); + let b = b.as_canonical_u32(); debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); debug_assert_eq!(e.as_canonical_u32(), RV32_MEMORY_AS); - let local_opcode = - Rv32HintStoreOpcode::from_usize(opcode.local_opcode_idx(self.air.offset)); - let (mem_ptr_read, mem_ptr_limbs) = memory.read::(d, mem_ptr_ptr); - let (num_words, num_words_read) = if local_opcode == HINT_STOREW { - memory.increment_timestamp(); - (1, None) + let local_opcode = Rv32HintStoreOpcode::from_usize(opcode.local_opcode_idx(self.offset)); + + // We do untraced read of `num_words` in order to allocate the record first + let num_words = if local_opcode == HINT_STOREW { + 1 } else { - let (num_words_read, num_words_limbs) = - memory.read::(d, num_words_ptr); - (compose(num_words_limbs), Some(num_words_read)) + read_rv32_register(state.memory.data(), a) }; - debug_assert_ne!(num_words, 0); - debug_assert!(num_words <= (1 << self.air.pointer_max_bits)); - let mem_ptr = compose(mem_ptr_limbs); + let record = arena.alloc(MultiRowLayout::new(Rv32HintStoreMetadata { + num_words: num_words as usize, + })); - debug_assert!(mem_ptr <= (1 << self.air.pointer_max_bits)); + record.inner.from_pc = *state.pc; + record.inner.timestamp = state.memory.timestamp; + record.inner.mem_ptr_ptr = b; - let mut streams = self.streams.get().unwrap().lock().unwrap(); - if streams.hint_stream.len() < RV32_REGISTER_NUM_LIMBS * num_words as usize { - return Err(ExecutionError::HintOutOfBounds { pc: from_state.pc }); - } + record.inner.mem_ptr = u32::from_le_bytes(tracing_read( + state.memory, + RV32_REGISTER_AS, + b, + &mut record.inner.mem_ptr_aux_record.prev_timestamp, + )); + + debug_assert!(record.inner.mem_ptr <= (1 << self.pointer_max_bits)); + debug_assert_ne!(num_words, 0); + debug_assert!(num_words <= (1 << self.pointer_max_bits)); - let mut record = Rv32HintStoreRecord { - from_state, - instruction: instruction.clone(), - mem_ptr_read, - mem_ptr, - num_words, - num_words_read, - hints: vec![], + record.inner.num_words = num_words; + if local_opcode == HINT_STOREW { + state.memory.increment_timestamp(); + record.inner.num_words_ptr = u32::MAX; + } else { + record.inner.num_words_ptr = a; + tracing_read::<_, RV32_REGISTER_NUM_LIMBS>( + state.memory, + RV32_REGISTER_AS, + record.inner.num_words_ptr, + &mut record.inner.num_words_read.prev_timestamp, + ); }; - for word_index in 0..num_words { - if word_index != 0 { - memory.increment_timestamp(); - memory.increment_timestamp(); + if state.streams.hint_stream.len() < RV32_REGISTER_NUM_LIMBS * num_words as usize { + return Err(ExecutionError::HintOutOfBounds { pc: *state.pc }); + } + + for idx in 0..(num_words as usize) { + if idx != 0 { + state.memory.increment_timestamp(); + state.memory.increment_timestamp(); } - let data: [F; RV32_REGISTER_NUM_LIMBS] = - std::array::from_fn(|_| streams.hint_stream.pop_front().unwrap()); - let (write, _) = memory.write( - e, - F::from_canonical_u32(mem_ptr + (RV32_REGISTER_NUM_LIMBS as u32 * word_index)), + let data_f: [F; RV32_REGISTER_NUM_LIMBS] = + std::array::from_fn(|_| state.streams.hint_stream.pop_front().unwrap()); + let data: [u8; RV32_REGISTER_NUM_LIMBS] = + data_f.map(|byte| byte.as_canonical_u32() as u8); + + record.var[idx].data = data; + + tracing_write( + state.memory, + RV32_MEMORY_AS, + record.inner.mem_ptr + (RV32_REGISTER_NUM_LIMBS * idx) as u32, data, + &mut record.var[idx].data_write_aux.prev_timestamp, + &mut record.var[idx].data_write_aux.prev_data, ); - record.hints.push((data, write)); } + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); - self.height += record.hints.len(); - self.records.push(record); - - let next_state = ExecutionState { - pc: from_state.pc + DEFAULT_PC_STEP, - timestamp: memory.timestamp(), - }; - Ok(next_state) + Ok(()) } +} - fn get_opcode_name(&self, opcode: usize) -> String { - if opcode == HINT_STOREW.global_opcode().as_usize() { - String::from("HINT_STOREW") - } else if opcode == HINT_BUFFER.global_opcode().as_usize() { - String::from("HINT_BUFFER") - } else { - unreachable!("unsupported opcode: {}", opcode) +impl TraceFiller for Rv32HintStoreStep { + fn fill_trace( + &self, + mem_helper: &MemoryAuxColsFactory, + trace: &mut RowMajorMatrix, + rows_used: usize, + ) { + if rows_used == 0 { + return; } + + let width = trace.width; + let mut trace = &mut trace.values[..width * rows_used]; + let mut sizes = Vec::with_capacity(rows_used); + let mut chunks = Vec::with_capacity(rows_used); + + while !trace.is_empty() { + let record: &Rv32HintStoreRecordHeader = + unsafe { get_record_from_slice(&mut trace, ()) }; + let (chunk, rest) = trace.split_at_mut(width * record.num_words as usize); + sizes.push(record.num_words); + chunks.push(chunk); + trace = rest; + } + + let msl_rshift: u32 = ((RV32_REGISTER_NUM_LIMBS - 1) * RV32_CELL_BITS) as u32; + let msl_lshift: u32 = + (RV32_REGISTER_NUM_LIMBS * RV32_CELL_BITS - self.pointer_max_bits) as u32; + + chunks + .par_iter_mut() + .zip(sizes.par_iter()) + .for_each(|(chunk, &num_words)| { + let record: Rv32HintStoreRecordMut = unsafe { + get_record_from_slice( + chunk, + MultiRowLayout::new(Rv32HintStoreMetadata { + num_words: num_words as usize, + }), + ) + }; + self.bitwise_lookup_chip.request_range( + (record.inner.mem_ptr >> msl_rshift) << msl_lshift, + (num_words >> msl_rshift) << msl_lshift, + ); + + let mut timestamp = record.inner.timestamp + num_words * 3; + let mut mem_ptr = record.inner.mem_ptr + num_words * RV32_REGISTER_NUM_LIMBS as u32; + + // Assuming that `num_words` is usually small (e.g. 1 for `HINT_STOREW`) + // it is better to do a serial pass of the rows per instruction (going from the last + // row to the first row) instead of a parallel pass, since need to + // copy the record to a new buffer in parallel case. + chunk + .rchunks_exact_mut(width) + .zip(record.var.iter().enumerate().rev()) + .for_each(|(row, (idx, var))| { + for pair in var.data.chunks_exact(2) { + self.bitwise_lookup_chip + .request_range(pair[0] as u32, pair[1] as u32); + } + + let cols: &mut Rv32HintStoreCols = row.borrow_mut(); + let is_single = record.inner.num_words_ptr == u32::MAX; + timestamp -= 3; + if idx == 0 && !is_single { + mem_helper.fill( + record.inner.num_words_read.prev_timestamp, + timestamp + 1, + cols.num_words_aux_cols.as_mut(), + ); + cols.num_words_ptr = F::from_canonical_u32(record.inner.num_words_ptr); + } else { + mem_helper.fill_zero(cols.num_words_aux_cols.as_mut()); + cols.num_words_ptr = F::ZERO; + } + + cols.is_buffer_start = F::from_bool(idx == 0 && !is_single); + + // Note: writing in reverse + cols.data = var.data.map(|x| F::from_canonical_u8(x)); + + cols.write_aux.set_prev_data( + var.data_write_aux + .prev_data + .map(|x| F::from_canonical_u8(x)), + ); + mem_helper.fill( + var.data_write_aux.prev_timestamp, + timestamp + 2, + cols.write_aux.as_mut(), + ); + + if idx == 0 { + mem_helper.fill( + record.inner.mem_ptr_aux_record.prev_timestamp, + timestamp, + cols.mem_ptr_aux_cols.as_mut(), + ); + } else { + mem_helper.fill_zero(cols.mem_ptr_aux_cols.as_mut()); + } + + mem_ptr -= RV32_REGISTER_NUM_LIMBS as u32; + cols.mem_ptr_limbs = mem_ptr.to_le_bytes().map(|x| F::from_canonical_u8(x)); + cols.mem_ptr_ptr = F::from_canonical_u32(record.inner.mem_ptr_ptr); + + cols.from_state.timestamp = F::from_canonical_u32(timestamp); + cols.from_state.pc = F::from_canonical_u32(record.inner.from_pc); + + cols.rem_words_limbs = (num_words - idx as u32) + .to_le_bytes() + .map(|x| F::from_canonical_u8(x)); + cols.is_buffer = F::from_bool(!is_single); + cols.is_single = F::from_bool(is_single); + }); + }) } } -impl ChipUsageGetter for Rv32HintStoreChip { - fn air_name(&self) -> String { - "Rv32HintStoreAir".to_string() - } +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct HintStorePreCompute { + c: u32, + a: u8, + b: u8, +} - fn current_trace_height(&self) -> usize { - self.height +impl StepExecutorE1 for Rv32HintStoreStep +where + F: PrimeField32, +{ + #[inline(always)] + fn pre_compute_size(&self) -> usize { + size_of::() } - fn trace_width(&self) -> usize { - Rv32HintStoreCols::::width() + fn pre_compute_e1( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result> { + let pre_compute: &mut HintStorePreCompute = data.borrow_mut(); + let local_opcode = self.pre_compute_impl(pc, inst, pre_compute)?; + let fn_ptr = match local_opcode { + HINT_STOREW => execute_e1_impl::<_, _, true>, + HINT_BUFFER => execute_e1_impl::<_, _, false>, + }; + Ok(fn_ptr) } } -impl Rv32HintStoreChip { - // returns number of used u32s - fn record_to_rows( - record: Rv32HintStoreRecord, - aux_cols_factory: &MemoryAuxColsFactory, - slice: &mut [F], - memory: &OfflineMemory, - bitwise_lookup_chip: &SharedBitwiseOperationLookupChip, - pointer_max_bits: usize, - ) -> usize { - let width = Rv32HintStoreCols::::width(); - let cols: &mut Rv32HintStoreCols = slice[..width].borrow_mut(); - - cols.is_single = F::from_bool(record.num_words_read.is_none()); - cols.is_buffer = F::from_bool(record.num_words_read.is_some()); - cols.is_buffer_start = cols.is_buffer; - - cols.from_state = record.from_state.map(F::from_canonical_u32); - cols.mem_ptr_ptr = record.instruction.b; - aux_cols_factory.generate_read_aux( - memory.record_by_id(record.mem_ptr_read), - &mut cols.mem_ptr_aux_cols, - ); +impl StepExecutorE2 for Rv32HintStoreStep +where + F: PrimeField32, +{ + fn e2_pre_compute_size(&self) -> usize { + size_of::>() + } - cols.num_words_ptr = record.instruction.a; - if let Some(num_words_read) = record.num_words_read { - aux_cols_factory.generate_read_aux( - memory.record_by_id(num_words_read), - &mut cols.num_words_aux_cols, - ); - } + fn pre_compute_e2( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result> + where + Ctx: E2ExecutionCtx, + { + let pre_compute: &mut E2PreCompute = data.borrow_mut(); + pre_compute.chip_idx = chip_idx as u32; + let local_opcode = self.pre_compute_impl(pc, inst, &mut pre_compute.data)?; + let fn_ptr = match local_opcode { + HINT_STOREW => execute_e2_impl::<_, _, true>, + HINT_BUFFER => execute_e2_impl::<_, _, false>, + }; + Ok(fn_ptr) + } +} - let mut mem_ptr = record.mem_ptr; - let mut rem_words = record.num_words; - let mut used_u32s = 0; +/// Return the number of used rows. +#[inline(always)] +unsafe fn execute_e12_impl( + pre_compute: &HintStorePreCompute, + vm_state: &mut VmSegmentState, +) -> u32 { + let mem_ptr_limbs = vm_state.vm_read::(RV32_REGISTER_AS, pre_compute.b as u32); + let mem_ptr = u32::from_le_bytes(mem_ptr_limbs); + + let num_words = if IS_HINT_STOREW { + 1 + } else { + let num_words_limbs = vm_state.vm_read::(RV32_REGISTER_AS, pre_compute.a as u32); + u32::from_le_bytes(num_words_limbs) + }; + debug_assert_ne!(num_words, 0); + + if vm_state.streams.hint_stream.len() < RV32_REGISTER_NUM_LIMBS * num_words as usize { + vm_state.exit_code = Err(ExecutionError::HintOutOfBounds { pc: vm_state.pc }); + return 0; + } - let mem_ptr_msl = mem_ptr >> ((RV32_REGISTER_NUM_LIMBS - 1) * RV32_CELL_BITS); - let rem_words_msl = rem_words >> ((RV32_REGISTER_NUM_LIMBS - 1) * RV32_CELL_BITS); - bitwise_lookup_chip.request_range( - mem_ptr_msl << (RV32_REGISTER_NUM_LIMBS * RV32_CELL_BITS - pointer_max_bits), - rem_words_msl << (RV32_REGISTER_NUM_LIMBS * RV32_CELL_BITS - pointer_max_bits), + for word_index in 0..num_words { + let data: [u8; RV32_REGISTER_NUM_LIMBS] = std::array::from_fn(|_| { + vm_state + .streams + .hint_stream + .pop_front() + .unwrap() + .as_canonical_u32() as u8 + }); + vm_state.vm_write( + RV32_MEMORY_AS, + mem_ptr + (RV32_REGISTER_NUM_LIMBS as u32 * word_index), + &data, ); - for (i, &(data, write)) in record.hints.iter().enumerate() { - for half in 0..(RV32_REGISTER_NUM_LIMBS / 2) { - bitwise_lookup_chip.request_range( - data[2 * half].as_canonical_u32(), - data[2 * half + 1].as_canonical_u32(), - ); - } - - let cols: &mut Rv32HintStoreCols = slice[used_u32s..used_u32s + width].borrow_mut(); - cols.from_state.timestamp = - F::from_canonical_u32(record.from_state.timestamp + (3 * i as u32)); - cols.data = data; - aux_cols_factory.generate_write_aux(memory.record_by_id(write), &mut cols.write_aux); - cols.rem_words_limbs = decompose(rem_words); - cols.mem_ptr_limbs = decompose(mem_ptr); - if i != 0 { - cols.is_buffer = F::ONE; - } - used_u32s += width; - mem_ptr += RV32_REGISTER_NUM_LIMBS as u32; - rem_words -= 1; - } - - used_u32s } - fn generate_trace(self) -> RowMajorMatrix { - let width = self.trace_width(); - let height = next_power_of_two_or_zero(self.height); - let mut flat_trace = F::zero_vec(width * height); + vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP); + vm_state.instret += 1; + num_words +} - let memory = self.offline_memory.lock().unwrap(); +unsafe fn execute_e1_impl( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &HintStorePreCompute = pre_compute.borrow(); + execute_e12_impl::(pre_compute, vm_state); +} - let aux_cols_factory = memory.aux_cols_factory(); +unsafe fn execute_e2_impl( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &E2PreCompute = pre_compute.borrow(); + let height_delta = execute_e12_impl::(&pre_compute.data, vm_state); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, height_delta); +} - let mut used_u32s = 0; - for record in self.records { - used_u32s += Self::record_to_rows( - record, - &aux_cols_factory, - &mut flat_trace[used_u32s..], - &memory, - &self.bitwise_lookup_chip, - self.air.pointer_max_bits, - ); +impl Rv32HintStoreStep { + #[inline(always)] + fn pre_compute_impl( + &self, + pc: u32, + inst: &Instruction, + data: &mut HintStorePreCompute, + ) -> Result { + let &Instruction { + opcode, + a, + b, + c, + d, + e, + .. + } = inst; + if d.as_canonical_u32() != RV32_REGISTER_AS || e.as_canonical_u32() != RV32_MEMORY_AS { + return Err(InvalidInstruction(pc)); } - // padding rows can just be all zeros - RowMajorMatrix::new(flat_trace, width) + *data = { + HintStorePreCompute { + c: c.as_canonical_u32(), + a: a.as_canonical_u32() as u8, + b: b.as_canonical_u32() as u8, + } + }; + Ok(Rv32HintStoreOpcode::from_usize( + opcode.local_opcode_idx(self.offset), + )) } } -impl Chip for Rv32HintStoreChip> -where - Val: PrimeField32, -{ - fn air(&self) -> Arc> { - Arc::new(self.air) - } - fn generate_air_proof_input(self) -> AirProofInput { - AirProofInput::simple_no_pis(self.generate_trace()) - } -} +pub type Rv32HintStoreChip = + NewVmChipWrapper>; diff --git a/extensions/rv32im/circuit/src/hintstore/tests.rs b/extensions/rv32im/circuit/src/hintstore/tests.rs index 204070762c..0a8afa05a0 100644 --- a/extensions/rv32im/circuit/src/hintstore/tests.rs +++ b/extensions/rv32im/circuit/src/hintstore/tests.rs @@ -1,12 +1,8 @@ -use std::{ - array, - borrow::BorrowMut, - sync::{Arc, Mutex}, -}; +use std::{array, borrow::BorrowMut}; use openvm_circuit::arch::{ testing::{memory::gen_pointer, VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}, - Streams, + DenseRecordArena, ExecutionBridge, InstructionExecutor, NewVmChipWrapper, }; use openvm_circuit_primitives::bitwise_op_lookup::{ BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, @@ -24,31 +20,52 @@ use openvm_stark_backend::{ Matrix, }, utils::disable_debug_builder, - verifier::VerificationError, }; -use openvm_stark_sdk::{config::setup_tracing, p3_baby_bear::BabyBear, utils::create_seeded_rng}; +use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; use rand::{rngs::StdRng, Rng}; -use super::{Rv32HintStoreChip, Rv32HintStoreCols}; -use crate::adapters::decompose; +use super::{Rv32HintStoreAir, Rv32HintStoreChip, Rv32HintStoreCols, Rv32HintStoreStep}; +use crate::{ + adapters::decompose, hintstore::Rv32HintStoreLayout, test_utils::get_verification_error, +}; type F = BabyBear; +const MAX_INS_CAPACITY: usize = 4096; -fn set_and_execute( +fn create_test_chip( tester: &mut VmChipTestBuilder, - chip: &mut Rv32HintStoreChip, +) -> ( + Rv32HintStoreChip, + SharedBitwiseOperationLookupChip, +) { + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); + + let mut chip = Rv32HintStoreChip::::new( + Rv32HintStoreAir::new( + ExecutionBridge::new(tester.execution_bus(), tester.program_bus()), + tester.memory_bridge(), + bitwise_chip.bus(), + 0, + tester.address_bits(), + ), + Rv32HintStoreStep::new(bitwise_chip.clone(), tester.address_bits(), 0), + tester.memory_helper(), + ); + chip.set_trace_buffer_height(MAX_INS_CAPACITY); + + (chip, bitwise_chip) +} + +fn set_and_execute>( + tester: &mut VmChipTestBuilder, + chip: &mut E, rng: &mut StdRng, opcode: Rv32HintStoreOpcode, ) { - let mem_ptr = rng.gen_range( - 0..(1 - << (tester - .memory_controller() - .borrow() - .mem_config() - .pointer_max_bits - - 2)), - ) << 2; + let mem_ptr = rng + .gen_range(0..(1 << (tester.memory_controller().mem_config().pointer_max_bits - 2))) + << 2; let b = gen_pointer(rng, 4); tester.write(1, b, decompose(mem_ptr)); @@ -56,13 +73,7 @@ fn set_and_execute( let read_data: [F; RV32_REGISTER_NUM_LIMBS] = array::from_fn(|_| F::from_canonical_u32(rng.gen_range(0..(1 << RV32_CELL_BITS)))); for data in read_data { - chip.streams - .get() - .unwrap() - .lock() - .unwrap() - .hint_stream - .push_back(data); + tester.streams.hint_stream.push_back(data); } tester.execute( @@ -80,20 +91,14 @@ fn set_and_execute_buffer( rng: &mut StdRng, opcode: Rv32HintStoreOpcode, ) { - let mem_ptr = rng.gen_range( - 0..(1 - << (tester - .memory_controller() - .borrow() - .mem_config() - .pointer_max_bits - - 2)), - ) << 2; + let mem_ptr = rng + .gen_range(0..(1 << (tester.memory_controller().mem_config().pointer_max_bits - 2))) + << 2; let b = gen_pointer(rng, 4); tester.write(1, b, decompose(mem_ptr)); - let num_words = rng.gen_range(1..20); + let num_words = rng.gen_range(1..28); let a = gen_pointer(rng, 4); tester.write(1, a, decompose(num_words)); @@ -102,13 +107,7 @@ fn set_and_execute_buffer( .collect(); for i in 0..num_words { for datum in data[i as usize] { - chip.streams - .get() - .unwrap() - .lock() - .unwrap() - .hint_stream - .push_back(datum); + tester.streams.hint_stream.push_back(datum); } } @@ -131,30 +130,15 @@ fn set_and_execute_buffer( /// Randomly generate computations and execute, ensuring that the generated trace /// passes all constraints. /////////////////////////////////////////////////////////////////////////////////////// + #[test] fn rand_hintstore_test() { - setup_tracing(); let mut rng = create_seeded_rng(); let mut tester = VmChipTestBuilder::default(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - - let range_checker_chip = tester.memory_controller().borrow().range_checker.clone(); - - let mut chip = Rv32HintStoreChip::::new( - tester.execution_bus(), - tester.program_bus(), - bitwise_chip.clone(), - tester.memory_bridge(), - tester.offline_memory_mutex_arc(), - tester.address_bits(), - 0, - ); - chip.set_streams(Arc::new(Mutex::new(Streams::default()))); - - let num_tests: usize = 8; - for _ in 0..num_tests { + let (mut chip, bitwise_chip) = create_test_chip(&mut tester); + let num_ops: usize = 100; + for _ in 0..num_ops { if rng.gen_bool(0.5) { set_and_execute(&mut tester, &mut chip, &mut rng, HINT_STOREW); } else { @@ -162,7 +146,6 @@ fn rand_hintstore_test() { } } - drop(range_checker_chip); let tester = tester.build().load(chip).load(bitwise_chip).finalize(); tester.simple_test().expect("Verification failed"); } @@ -171,64 +154,44 @@ fn rand_hintstore_test() { // NEGATIVE TESTS // // Given a fake trace of a single operation, setup a chip and run the test. We replace -// the write part of the trace and check that the core chip throws the expected error. -// A dummy adaptor is used so memory interactions don't indirectly cause false passes. +// part of the trace and check that the chip throws the expected error. ////////////////////////////////////////////////////////////////////////////////////// #[allow(clippy::too_many_arguments)] fn run_negative_hintstore_test( opcode: Rv32HintStoreOpcode, - data: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, - expected_error: VerificationError, + prank_data: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, + interaction_error: bool, ) { let mut rng = create_seeded_rng(); let mut tester = VmChipTestBuilder::default(); - - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - - let range_checker_chip = tester.memory_controller().borrow().range_checker.clone(); - - let mut chip = Rv32HintStoreChip::::new( - tester.execution_bus(), - tester.program_bus(), - bitwise_chip.clone(), - tester.memory_bridge(), - tester.offline_memory_mutex_arc(), - tester.address_bits(), - 0, - ); - chip.set_streams(Arc::new(Mutex::new(Streams::default()))); + let (mut chip, bitwise_chip) = create_test_chip(&mut tester); set_and_execute(&mut tester, &mut chip, &mut rng, opcode); let modify_trace = |trace: &mut DenseMatrix| { let mut trace_row = trace.row_slice(0).to_vec(); let cols: &mut Rv32HintStoreCols = trace_row.as_mut_slice().borrow_mut(); - if let Some(data) = data { + if let Some(data) = prank_data { cols.data = data.map(F::from_canonical_u32); } *trace = RowMajorMatrix::new(trace_row, trace.width()); }; - drop(range_checker_chip); disable_debug_builder(); let tester = tester .build() .load_and_prank_trace(chip, modify_trace) .load(bitwise_chip) .finalize(); - tester.simple_test_with_expected_error(expected_error); + tester.simple_test_with_expected_error(get_verification_error(interaction_error)); } #[test] fn negative_hintstore_tests() { - run_negative_hintstore_test( - HINT_STOREW, - Some([92, 187, 45, 280]), - VerificationError::ChallengePhaseError, - ); + run_negative_hintstore_test(HINT_STOREW, Some([92, 187, 45, 280]), true); } + /////////////////////////////////////////////////////////////////////////////////////// /// SANITY TESTS /// @@ -239,22 +202,69 @@ fn execute_roundtrip_sanity_test() { let mut rng = create_seeded_rng(); let mut tester = VmChipTestBuilder::default(); + let (mut chip, _) = create_test_chip(&mut tester); + + let num_ops: usize = 10; + for _ in 0..num_ops { + set_and_execute(&mut tester, &mut chip, &mut rng, HINT_STOREW); + } +} + +/////////////////////////////////////////////////////////////////////////////////////// +/// DENSE TESTS +/// +/// Ensure that the chip works as expected with dense records. +/// We first execute some instructions with a [DenseRecordArena] and transfer the records +/// to a [MatrixRecordArena]. After transferring we generate the trace and make sure that +/// all the constraints pass. +/////////////////////////////////////////////////////////////////////////////////////// +type Rv32HintStoreChipDense = + NewVmChipWrapper; + +fn create_test_chip_dense(tester: &mut VmChipTestBuilder) -> Rv32HintStoreChipDense { let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let mut chip = Rv32HintStoreChip::::new( - tester.execution_bus(), - tester.program_bus(), - bitwise_chip.clone(), - tester.memory_bridge(), - tester.offline_memory_mutex_arc(), - tester.address_bits(), - 0, + let mut chip = Rv32HintStoreChipDense::new( + Rv32HintStoreAir::new( + ExecutionBridge::new(tester.execution_bus(), tester.program_bus()), + tester.memory_bridge(), + bitwise_chip.bus(), + 0, + tester.address_bits(), + ), + Rv32HintStoreStep::new(bitwise_chip.clone(), tester.address_bits(), 0), + tester.memory_helper(), ); - chip.set_streams(Arc::new(Mutex::new(Streams::default()))); - let num_tests: usize = 100; - for _ in 0..num_tests { - set_and_execute(&mut tester, &mut chip, &mut rng, HINT_STOREW); + chip.set_trace_buffer_height(MAX_INS_CAPACITY); + chip +} + +#[test] +fn dense_record_arena_test() { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default(); + let (mut sparse_chip, bitwise_chip) = create_test_chip(&mut tester); + + { + let mut dense_chip = create_test_chip_dense(&mut tester); + + let num_ops: usize = 100; + for _ in 0..num_ops { + set_and_execute(&mut tester, &mut dense_chip, &mut rng, HINT_STOREW); + } + + let mut record_interpreter = dense_chip + .arena + .get_record_seeker::<_, Rv32HintStoreLayout>(); + record_interpreter.transfer_to_matrix_arena(&mut sparse_chip.arena); } + + let tester = tester + .build() + .load(sparse_chip) + .load(bitwise_chip) + .finalize(); + tester.simple_test().expect("Verification failed"); } diff --git a/extensions/rv32im/circuit/src/jal_lui/core.rs b/extensions/rv32im/circuit/src/jal_lui/core.rs index 2ba10e615e..11725f3257 100644 --- a/extensions/rv32im/circuit/src/jal_lui/core.rs +++ b/extensions/rv32im/circuit/src/jal_lui/core.rs @@ -1,19 +1,24 @@ -use std::{ - array, - borrow::{Borrow, BorrowMut}, +use std::borrow::{Borrow, BorrowMut}; + +use openvm_circuit::{ + arch::{ + execution_mode::{E1ExecutionCtx, E2ExecutionCtx}, + get_record_from_slice, AdapterAirContext, AdapterTraceFiller, AdapterTraceStep, + E2PreCompute, EmptyAdapterCoreLayout, ExecuteFunc, ImmInstruction, RecordArena, Result, + StepExecutorE1, StepExecutorE2, TraceFiller, TraceStep, VmAdapterInterface, VmCoreAir, + VmSegmentState, VmStateMut, + }, + system::memory::{online::TracingMemory, MemoryAuxColsFactory}, }; - -use openvm_circuit::arch::{ - AdapterAirContext, AdapterRuntimeContext, ImmInstruction, Result, VmAdapterInterface, - VmCoreAir, VmCoreChip, -}; -use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, +use openvm_circuit_primitives::{ + bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, + AlignedBytesBorrow, }; use openvm_circuit_primitives_derive::AlignedBorrow; use openvm_instructions::{ instruction::Instruction, program::{DEFAULT_PC_STEP, PC_BITS}, + riscv::RV32_REGISTER_AS, LocalOpcode, }; use openvm_rv32im_transpiler::Rv32JalLuiOpcode::{self, *}; @@ -23,10 +28,11 @@ use openvm_stark_backend::{ p3_field::{Field, FieldAlgebra, PrimeField32}, rap::BaseAirWithPublicValues, }; -use serde::{Deserialize, Serialize}; use crate::adapters::{RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS, RV_J_TYPE_IMM_BITS}; +pub(super) const ADDITIONAL_BITS: u32 = 0b11000000; + #[repr(C)] #[derive(Debug, Clone, AlignedBorrow)] pub struct Rv32JalLuiCoreCols { @@ -36,7 +42,7 @@ pub struct Rv32JalLuiCoreCols { pub is_lui: T, } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Copy, derive_new::new)] pub struct Rv32JalLuiCoreAir { pub bus: BitwiseOperationLookupBus, } @@ -141,134 +147,283 @@ where } #[repr(C)] -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(bound = "F: Field")] -pub struct Rv32JalLuiCoreRecord { - pub rd_data: [F; RV32_REGISTER_NUM_LIMBS], - pub imm: F, +#[derive(AlignedBytesBorrow, Debug)] +pub struct Rv32JalLuiStepRecord { + pub imm: u32, + pub rd_data: [u8; RV32_REGISTER_NUM_LIMBS], pub is_jal: bool, - pub is_lui: bool, } -pub struct Rv32JalLuiCoreChip { - pub air: Rv32JalLuiCoreAir, +#[derive(derive_new::new)] +pub struct Rv32JalLuiStep { + adapter: A, pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, } -impl Rv32JalLuiCoreChip { - pub fn new(bitwise_lookup_chip: SharedBitwiseOperationLookupChip) -> Self { - Self { - air: Rv32JalLuiCoreAir { - bus: bitwise_lookup_chip.bus(), - }, - bitwise_lookup_chip, - } +impl TraceStep for Rv32JalLuiStep +where + F: PrimeField32, + A: 'static + + for<'a> AdapterTraceStep, +{ + type RecordLayout = EmptyAdapterCoreLayout; + type RecordMut<'a> = (A::RecordMut<'a>, &'a mut Rv32JalLuiStepRecord); + + fn get_opcode_name(&self, opcode: usize) -> String { + format!( + "{:?}", + Rv32JalLuiOpcode::from_usize(opcode - Rv32JalLuiOpcode::CLASS_OFFSET) + ) + } + + fn execute<'buf, RA>( + &mut self, + state: VmStateMut, CTX>, + instruction: &Instruction, + arena: &'buf mut RA, + ) -> Result<()> + where + RA: RecordArena<'buf, Self::RecordLayout, Self::RecordMut<'buf>>, + { + let &Instruction { opcode, c: imm, .. } = instruction; + + let (mut adapter_record, core_record) = arena.alloc(EmptyAdapterCoreLayout::new()); + + A::start(*state.pc, state.memory, &mut adapter_record); + + let is_jal = opcode.local_opcode_idx(Rv32JalLuiOpcode::CLASS_OFFSET) == JAL as usize; + let signed_imm = get_signed_imm(is_jal, imm); + + let (to_pc, rd_data) = run_jal_lui(is_jal, *state.pc, signed_imm); + + core_record.imm = imm.as_canonical_u32(); + core_record.rd_data = rd_data; + core_record.is_jal = is_jal; + + self.adapter + .write(state.memory, instruction, rd_data, &mut adapter_record); + + *state.pc = to_pc; + + Ok(()) } } -impl> VmCoreChip for Rv32JalLuiCoreChip +impl TraceFiller for Rv32JalLuiStep where - I::Writes: From<[[F; RV32_REGISTER_NUM_LIMBS]; 1]>, + F: PrimeField32, + A: 'static + AdapterTraceFiller, { - type Record = Rv32JalLuiCoreRecord; - type Air = Rv32JalLuiCoreAir; - - #[allow(clippy::type_complexity)] - fn execute_instruction( - &self, - instruction: &Instruction, - from_pc: u32, - _reads: I::Reads, - ) -> Result<(AdapterRuntimeContext, Self::Record)> { - let local_opcode = Rv32JalLuiOpcode::from_usize( - instruction - .opcode - .local_opcode_idx(Rv32JalLuiOpcode::CLASS_OFFSET), - ); - let imm = instruction.c; - - let signed_imm = match local_opcode { - JAL => { - // Note: signed_imm is a signed integer and imm is a field element - (imm + F::from_canonical_u32(1 << (RV_J_TYPE_IMM_BITS - 1))).as_canonical_u32() - as i32 - - (1 << (RV_J_TYPE_IMM_BITS - 1)) - } - LUI => imm.as_canonical_u32() as i32, - }; - let (to_pc, rd_data) = run_jal_lui(local_opcode, from_pc, signed_imm); + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + self.adapter.fill_trace_row(mem_helper, adapter_row); + let record: &Rv32JalLuiStepRecord = unsafe { get_record_from_slice(&mut core_row, ()) }; + let core_row: &mut Rv32JalLuiCoreCols = core_row.borrow_mut(); - for i in 0..(RV32_REGISTER_NUM_LIMBS / 2) { + for pair in record.rd_data.chunks_exact(2) { self.bitwise_lookup_chip - .request_range(rd_data[i * 2], rd_data[i * 2 + 1]); + .request_range(pair[0] as u32, pair[1] as u32); } - - if local_opcode == JAL { - let last_limb_bits = PC_BITS - RV32_CELL_BITS * (RV32_REGISTER_NUM_LIMBS - 1); - let additional_bits = (last_limb_bits..RV32_CELL_BITS).fold(0, |acc, x| acc + (1 << x)); + if record.is_jal { self.bitwise_lookup_chip - .request_xor(rd_data[3], additional_bits); + .request_xor(record.rd_data[3] as u32, ADDITIONAL_BITS); } - let rd_data = rd_data.map(F::from_canonical_u32); + // Writing in reverse order + core_row.is_lui = F::from_bool(!record.is_jal); + core_row.is_jal = F::from_bool(record.is_jal); + core_row.rd_data = record.rd_data.map(F::from_canonical_u8); + core_row.imm = F::from_canonical_u32(record.imm); + } +} - let output = AdapterRuntimeContext { - to_pc: Some(to_pc), - writes: [rd_data].into(), +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct JalLuiPreCompute { + signed_imm: i32, + a: u8, +} + +impl StepExecutorE1 for Rv32JalLuiStep +where + F: PrimeField32, +{ + #[inline(always)] + fn pre_compute_size(&self) -> usize { + size_of::() + } + + fn pre_compute_e1( + &self, + _pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result> { + let data: &mut JalLuiPreCompute = data.borrow_mut(); + let (is_jal, enabled) = self.pre_compute_impl(inst, data)?; + let fn_ptr = match (is_jal, enabled) { + (true, true) => execute_e1_impl::<_, _, true, true>, + (true, false) => execute_e1_impl::<_, _, true, false>, + (false, true) => execute_e1_impl::<_, _, false, true>, + (false, false) => execute_e1_impl::<_, _, false, false>, }; + Ok(fn_ptr) + } +} - Ok(( - output, - Rv32JalLuiCoreRecord { - rd_data, - imm, - is_jal: local_opcode == JAL, - is_lui: local_opcode == LUI, - }, - )) +impl StepExecutorE2 for Rv32JalLuiStep +where + F: PrimeField32, +{ + fn e2_pre_compute_size(&self) -> usize { + size_of::>() } - fn get_opcode_name(&self, opcode: usize) -> String { - format!( - "{:?}", - Rv32JalLuiOpcode::from_usize(opcode - Rv32JalLuiOpcode::CLASS_OFFSET) - ) + fn pre_compute_e2( + &self, + chip_idx: usize, + _pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result> + where + Ctx: E2ExecutionCtx, + { + let data: &mut E2PreCompute = data.borrow_mut(); + data.chip_idx = chip_idx as u32; + let (is_jal, enabled) = self.pre_compute_impl(inst, &mut data.data)?; + let fn_ptr = match (is_jal, enabled) { + (true, true) => execute_e2_impl::<_, _, true, true>, + (true, false) => execute_e2_impl::<_, _, true, false>, + (false, true) => execute_e2_impl::<_, _, false, true>, + (false, false) => execute_e2_impl::<_, _, false, false>, + }; + Ok(fn_ptr) } +} - fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) { - let core_cols: &mut Rv32JalLuiCoreCols = row_slice.borrow_mut(); - core_cols.rd_data = record.rd_data; - core_cols.imm = record.imm; - core_cols.is_jal = F::from_bool(record.is_jal); - core_cols.is_lui = F::from_bool(record.is_lui); +unsafe fn execute_e12_impl< + F: PrimeField32, + CTX: E1ExecutionCtx, + const IS_JAL: bool, + const ENABLED: bool, +>( + pre_compute: &JalLuiPreCompute, + vm_state: &mut VmSegmentState, +) { + let JalLuiPreCompute { a, signed_imm } = *pre_compute; + + let rd = if IS_JAL { + let rd_data = (vm_state.pc + DEFAULT_PC_STEP).to_le_bytes(); + let next_pc = vm_state.pc as i32 + signed_imm; + debug_assert!(next_pc >= 0); + vm_state.pc = next_pc as u32; + rd_data + } else { + let imm = signed_imm as u32; + let rd = imm << 12; + vm_state.pc += DEFAULT_PC_STEP; + rd.to_le_bytes() + }; + + if ENABLED { + vm_state.vm_write(RV32_REGISTER_AS, a as u32, &rd); } - fn air(&self) -> &Self::Air { - &self.air + vm_state.instret += 1; +} + +unsafe fn execute_e1_impl< + F: PrimeField32, + CTX: E1ExecutionCtx, + const IS_JAL: bool, + const ENABLED: bool, +>( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &JalLuiPreCompute = pre_compute.borrow(); + execute_e12_impl::(pre_compute, vm_state); +} + +unsafe fn execute_e2_impl< + F: PrimeField32, + CTX: E2ExecutionCtx, + const IS_JAL: bool, + const ENABLED: bool, +>( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &E2PreCompute = pre_compute.borrow(); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, 1); + execute_e12_impl::(&pre_compute.data, vm_state); +} + +impl Rv32JalLuiStep { + /// Return (IS_JAL, ENABLED) + #[inline(always)] + fn pre_compute_impl( + &self, + inst: &Instruction, + data: &mut JalLuiPreCompute, + ) -> Result<(bool, bool)> { + let local_opcode = Rv32JalLuiOpcode::from_usize( + inst.opcode.local_opcode_idx(Rv32JalLuiOpcode::CLASS_OFFSET), + ); + let is_jal = local_opcode == JAL; + let imm_f = inst.c.as_canonical_u32(); + let signed_imm = if is_jal { + if imm_f < (1 << (RV_J_TYPE_IMM_BITS - 1)) { + imm_f as i32 + } else { + let neg_imm_f = F::ORDER_U32 - imm_f; + assert!(neg_imm_f < (1 << (RV_J_TYPE_IMM_BITS - 1))); + -(neg_imm_f as i32) + } + } else { + imm_f as i32 + }; + + *data = JalLuiPreCompute { + signed_imm, + a: inst.a.as_canonical_u32() as u8, + }; + let enabled = !inst.f.is_zero(); + Ok((is_jal, enabled)) } } -// returns (to_pc, rd_data) -pub(super) fn run_jal_lui( - opcode: Rv32JalLuiOpcode, - pc: u32, - imm: i32, -) -> (u32, [u32; RV32_REGISTER_NUM_LIMBS]) { - match opcode { - JAL => { - let rd_data = array::from_fn(|i| { - ((pc + DEFAULT_PC_STEP) >> (8 * i)) & ((1 << RV32_CELL_BITS) - 1) - }); - let next_pc = pc as i32 + imm; - assert!(next_pc >= 0); - (next_pc as u32, rd_data) - } - LUI => { - let imm = imm as u32; - let rd = imm << 12; - let rd_data = - array::from_fn(|i| (rd >> (RV32_CELL_BITS * i)) & ((1 << RV32_CELL_BITS) - 1)); - (pc + DEFAULT_PC_STEP, rd_data) +// returns the canonical signed representation of the immediate +// `imm` can be "negative" as a field element +pub(super) fn get_signed_imm(is_jal: bool, imm: F) -> i32 { + let imm_f = imm.as_canonical_u32(); + if is_jal { + if imm_f < (1 << (RV_J_TYPE_IMM_BITS - 1)) { + imm_f as i32 + } else { + let neg_imm_f = F::ORDER_U32 - imm_f; + debug_assert!(neg_imm_f < (1 << (RV_J_TYPE_IMM_BITS - 1))); + -(neg_imm_f as i32) } + } else { + imm_f as i32 + } +} + +// returns (to_pc, rd_data) +#[inline(always)] +pub(super) fn run_jal_lui(is_jal: bool, pc: u32, imm: i32) -> (u32, [u8; RV32_REGISTER_NUM_LIMBS]) { + if is_jal { + let rd_data = (pc + DEFAULT_PC_STEP).to_le_bytes(); + let next_pc = pc as i32 + imm; + debug_assert!(next_pc >= 0); + (next_pc as u32, rd_data) + } else { + let imm = imm as u32; + let rd = imm << 12; + (pc + DEFAULT_PC_STEP, rd.to_le_bytes()) } } diff --git a/extensions/rv32im/circuit/src/jal_lui/mod.rs b/extensions/rv32im/circuit/src/jal_lui/mod.rs index 779b710bea..0df873c3e6 100644 --- a/extensions/rv32im/circuit/src/jal_lui/mod.rs +++ b/extensions/rv32im/circuit/src/jal_lui/mod.rs @@ -1,6 +1,6 @@ -use openvm_circuit::arch::VmChipWrapper; +use openvm_circuit::arch::{MatrixRecordArena, NewVmChipWrapper, VmAirWrapper}; -use crate::adapters::Rv32CondRdWriteAdapterChip; +use crate::adapters::{Rv32CondRdWriteAdapterAir, Rv32CondRdWriteAdapterStep}; mod core; pub use core::*; @@ -8,4 +8,7 @@ pub use core::*; #[cfg(test)] mod tests; -pub type Rv32JalLuiChip = VmChipWrapper, Rv32JalLuiCoreChip>; +pub type Rv32JalLuiAir = VmAirWrapper; +pub type Rv32JalLuiStepWithAdapter = Rv32JalLuiStep; +pub type Rv32JalLuiChip = + NewVmChipWrapper>; diff --git a/extensions/rv32im/circuit/src/jal_lui/tests.rs b/extensions/rv32im/circuit/src/jal_lui/tests.rs index 35e258cbfb..e0f2c378ea 100644 --- a/extensions/rv32im/circuit/src/jal_lui/tests.rs +++ b/extensions/rv32im/circuit/src/jal_lui/tests.rs @@ -2,7 +2,7 @@ use std::borrow::BorrowMut; use openvm_circuit::arch::{ testing::{VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}, - VmAdapterChip, + VmAirWrapper, }; use openvm_circuit_primitives::bitwise_op_lookup::{ BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, @@ -12,27 +12,61 @@ use openvm_rv32im_transpiler::Rv32JalLuiOpcode::{self, *}; use openvm_stark_backend::{ p3_air::BaseAir, p3_field::{FieldAlgebra, PrimeField32}, - p3_matrix::{dense::RowMajorMatrix, Matrix}, + p3_matrix::{ + dense::{DenseMatrix, RowMajorMatrix}, + Matrix, + }, utils::disable_debug_builder, - verifier::VerificationError, - Chip, ChipUsageGetter, }; use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; use rand::{rngs::StdRng, Rng}; +use test_case::test_case; -use super::{run_jal_lui, Rv32JalLuiChip, Rv32JalLuiCoreChip}; +use super::{run_jal_lui, Rv32JalLuiChip, Rv32JalLuiCoreAir, Rv32JalLuiStep}; use crate::{ adapters::{ - Rv32CondRdWriteAdapterChip, Rv32CondRdWriteAdapterCols, RV32_CELL_BITS, - RV32_REGISTER_NUM_LIMBS, RV_IS_TYPE_IMM_BITS, + Rv32CondRdWriteAdapterAir, Rv32CondRdWriteAdapterCols, Rv32CondRdWriteAdapterStep, + Rv32RdWriteAdapterAir, Rv32RdWriteAdapterStep, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS, + RV_IS_TYPE_IMM_BITS, }, - jal_lui::Rv32JalLuiCoreCols, + jal_lui::{Rv32JalLuiCoreCols, ADDITIONAL_BITS}, + test_utils::get_verification_error, }; const IMM_BITS: usize = 20; const LIMB_MAX: u32 = (1 << RV32_CELL_BITS) - 1; +const MAX_INS_CAPACITY: usize = 128; + type F = BabyBear; +fn create_test_chip( + tester: &VmChipTestBuilder, +) -> ( + Rv32JalLuiChip, + SharedBitwiseOperationLookupChip, +) { + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); + + let mut chip = Rv32JalLuiChip::::new( + VmAirWrapper::new( + Rv32CondRdWriteAdapterAir::new(Rv32RdWriteAdapterAir::new( + tester.memory_bridge(), + tester.execution_bridge(), + )), + Rv32JalLuiCoreAir::new(bitwise_bus), + ), + Rv32JalLuiStep::new( + Rv32CondRdWriteAdapterStep::new(Rv32RdWriteAdapterStep::new()), + bitwise_chip.clone(), + ), + tester.memory_helper(), + ); + chip.set_trace_buffer_height(MAX_INS_CAPACITY); + + (chip, bitwise_chip) +} + fn set_and_execute( tester: &mut VmChipTestBuilder, chip: &mut Rv32JalLuiChip, @@ -67,11 +101,11 @@ fn set_and_execute( let initial_pc = tester.execution.last_from_pc().as_canonical_u32(); let final_pc = tester.execution.last_to_pc().as_canonical_u32(); - let (next_pc, rd_data) = run_jal_lui(opcode, initial_pc, imm); + let (next_pc, rd_data) = run_jal_lui(opcode == JAL, initial_pc, imm); let rd_data = if needs_write { rd_data } else { [0; 4] }; assert_eq!(next_pc, final_pc); - assert_eq!(rd_data.map(F::from_canonical_u32), tester.read::<4>(1, a)); + assert_eq!(rd_data.map(F::from_canonical_u8), tester.read::<4>(1, a)); } /////////////////////////////////////////////////////////////////////////////////////// @@ -81,25 +115,15 @@ fn set_and_execute( /// passes all constraints. /////////////////////////////////////////////////////////////////////////////////////// -#[test] -fn rand_jal_lui_test() { +#[test_case(JAL, 100)] +#[test_case(LUI, 100)] +fn rand_jal_lui_test(opcode: Rv32JalLuiOpcode, num_ops: usize) { let mut rng = create_seeded_rng(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let mut tester = VmChipTestBuilder::default(); - let adapter = Rv32CondRdWriteAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - ); - let core = Rv32JalLuiCoreChip::new(bitwise_chip.clone()); - let mut chip = Rv32JalLuiChip::::new(adapter, core, tester.offline_memory_mutex_arc()); + let (mut chip, bitwise_chip) = create_test_chip(&tester); - let num_tests: usize = 100; - for _ in 0..num_tests { - set_and_execute(&mut tester, &mut chip, &mut rng, JAL, None, None); - set_and_execute(&mut tester, &mut chip, &mut rng, LUI, None, None); + for _ in 0..num_ops { + set_and_execute(&mut tester, &mut chip, &mut rng, opcode, None, None); } let tester = tester.build().load(chip).load(bitwise_chip).finalize(); @@ -109,35 +133,29 @@ fn rand_jal_lui_test() { // NEGATIVE TESTS // // Given a fake trace of a single operation, setup a chip and run the test. We replace -// the write part of the trace and check that the core chip throws the expected error. -// A dummy adaptor is used so memory interactions don't indirectly cause false passes. +// part of the trace and check that the chip throws the expected error. ////////////////////////////////////////////////////////////////////////////////////// +#[derive(Clone, Copy, Default, PartialEq)] +struct JalLuiPrankValues { + pub rd_data: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, + pub imm: Option, + pub is_jal: Option, + pub is_lui: Option, + pub needs_write: Option, +} + #[allow(clippy::too_many_arguments)] fn run_negative_jal_lui_test( opcode: Rv32JalLuiOpcode, initial_imm: Option, initial_pc: Option, - rd_data: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, - imm: Option, - is_jal: Option, - is_lui: Option, - needs_write: Option, - expected_error: VerificationError, + prank_vals: JalLuiPrankValues, + interaction_error: bool, ) { let mut rng = create_seeded_rng(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let mut tester = VmChipTestBuilder::default(); - let adapter = Rv32CondRdWriteAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - ); - let adapter_width = BaseAir::::width(adapter.air()); - let core = Rv32JalLuiCoreChip::new(bitwise_chip.clone()); - let mut chip = Rv32JalLuiChip::::new(adapter, core, tester.offline_memory_mutex_arc()); + let (mut chip, bitwise_chip) = create_test_chip(&tester); set_and_execute( &mut tester, @@ -148,51 +166,43 @@ fn run_negative_jal_lui_test( initial_pc, ); - let tester = tester.build(); - - let jal_lui_trace_width = chip.trace_width(); - let air = chip.air(); - let mut chip_input = chip.generate_air_proof_input(); - let jal_lui_trace = chip_input.raw.common_main.as_mut().unwrap(); - { - let mut trace_row = jal_lui_trace.row_slice(0).to_vec(); - + let adapter_width = BaseAir::::width(&chip.air.adapter); + let modify_trace = |trace: &mut DenseMatrix| { + let mut trace_row = trace.row_slice(0).to_vec(); let (adapter_row, core_row) = trace_row.split_at_mut(adapter_width); - let adapter_cols: &mut Rv32CondRdWriteAdapterCols = adapter_row.borrow_mut(); let core_cols: &mut Rv32JalLuiCoreCols = core_row.borrow_mut(); - if let Some(data) = rd_data { + if let Some(data) = prank_vals.rd_data { core_cols.rd_data = data.map(F::from_canonical_u32); } - - if let Some(imm) = imm { + if let Some(imm) = prank_vals.imm { core_cols.imm = if imm < 0 { F::NEG_ONE * F::from_canonical_u32((-imm) as u32) } else { F::from_canonical_u32(imm as u32) }; } - if let Some(is_jal) = is_jal { + if let Some(is_jal) = prank_vals.is_jal { core_cols.is_jal = F::from_bool(is_jal); } - if let Some(is_lui) = is_lui { + if let Some(is_lui) = prank_vals.is_lui { core_cols.is_lui = F::from_bool(is_lui); } - - if let Some(needs_write) = needs_write { + if let Some(needs_write) = prank_vals.needs_write { adapter_cols.needs_write = F::from_bool(needs_write); } - *jal_lui_trace = RowMajorMatrix::new(trace_row, jal_lui_trace_width); - } + *trace = RowMajorMatrix::new(trace_row, trace.width()); + }; disable_debug_builder(); let tester = tester - .load_air_proof_input((air, chip_input)) + .build() + .load_and_prank_trace(chip, modify_trace) .load(bitwise_chip) .finalize(); - tester.simple_test_with_expected_error(expected_error); + tester.simple_test_with_expected_error(get_verification_error(interaction_error)); } #[test] @@ -201,34 +211,35 @@ fn opcode_flag_negative_test() { JAL, None, None, - None, - None, - Some(false), - Some(true), - None, - VerificationError::OodEvaluationMismatch, + JalLuiPrankValues { + is_jal: Some(false), + is_lui: Some(true), + ..Default::default() + }, + false, ); run_negative_jal_lui_test( JAL, None, None, - None, - None, - Some(false), - Some(false), - Some(false), - VerificationError::ChallengePhaseError, + JalLuiPrankValues { + is_jal: Some(false), + is_lui: Some(false), + needs_write: Some(false), + ..Default::default() + }, + true, ); run_negative_jal_lui_test( LUI, None, None, - None, - None, - Some(true), - Some(false), - None, - VerificationError::OodEvaluationMismatch, + JalLuiPrankValues { + is_jal: Some(true), + is_lui: Some(false), + ..Default::default() + }, + false, ); } @@ -238,67 +249,61 @@ fn overflow_negative_tests() { JAL, None, None, - Some([LIMB_MAX, LIMB_MAX, LIMB_MAX, LIMB_MAX]), - None, - None, - None, - None, - VerificationError::OodEvaluationMismatch, + JalLuiPrankValues { + rd_data: Some([LIMB_MAX, LIMB_MAX, LIMB_MAX, LIMB_MAX]), + ..Default::default() + }, + false, ); run_negative_jal_lui_test( LUI, None, None, - Some([LIMB_MAX, LIMB_MAX, LIMB_MAX, LIMB_MAX]), - None, - None, - None, - None, - VerificationError::OodEvaluationMismatch, + JalLuiPrankValues { + rd_data: Some([LIMB_MAX, LIMB_MAX, LIMB_MAX, LIMB_MAX]), + ..Default::default() + }, + false, ); run_negative_jal_lui_test( LUI, None, None, - Some([0, LIMB_MAX, LIMB_MAX, LIMB_MAX + 1]), - None, - None, - None, - None, - VerificationError::OodEvaluationMismatch, + JalLuiPrankValues { + rd_data: Some([0, LIMB_MAX, LIMB_MAX, LIMB_MAX + 1]), + ..Default::default() + }, + false, ); run_negative_jal_lui_test( LUI, None, None, - None, - Some(-1), - None, - None, - None, - VerificationError::OodEvaluationMismatch, + JalLuiPrankValues { + imm: Some(-1), + ..Default::default() + }, + false, ); run_negative_jal_lui_test( LUI, None, None, - None, - Some(-28), - None, - None, - None, - VerificationError::OodEvaluationMismatch, + JalLuiPrankValues { + imm: Some(-28), + ..Default::default() + }, + false, ); run_negative_jal_lui_test( JAL, None, Some(251), - Some([F::NEG_ONE.as_canonical_u32(), 1, 0, 0]), - None, - None, - None, - None, - VerificationError::ChallengePhaseError, + JalLuiPrankValues { + rd_data: Some([F::NEG_ONE.as_canonical_u32(), 1, 0, 0]), + ..Default::default() + }, + true, ); } @@ -307,25 +312,12 @@ fn overflow_negative_tests() { /// /// Ensure that solve functions produce the correct results. /////////////////////////////////////////////////////////////////////////////////////// + #[test] fn execute_roundtrip_sanity_test() { let mut rng = create_seeded_rng(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let mut tester = VmChipTestBuilder::default(); - let adapter = Rv32CondRdWriteAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - ); - let core = Rv32JalLuiCoreChip::new(bitwise_chip); - let mut chip = Rv32JalLuiChip::::new(adapter, core, tester.offline_memory_mutex_arc()); - let num_tests: usize = 10; - for _ in 0..num_tests { - set_and_execute(&mut tester, &mut chip, &mut rng, JAL, None, None); - set_and_execute(&mut tester, &mut chip, &mut rng, LUI, None, None); - } + let (mut chip, _) = create_test_chip(&tester); set_and_execute( &mut tester, @@ -347,20 +339,25 @@ fn execute_roundtrip_sanity_test() { #[test] fn run_jal_sanity_test() { - let opcode = JAL; let initial_pc = 28120; let imm = -2048; - let (next_pc, rd_data) = run_jal_lui(opcode, initial_pc, imm); + let (next_pc, rd_data) = run_jal_lui(true, initial_pc, imm); assert_eq!(next_pc, 26072); assert_eq!(rd_data, [220, 109, 0, 0]); } #[test] fn run_lui_sanity_test() { - let opcode = LUI; let initial_pc = 456789120; let imm = 853679; - let (next_pc, rd_data) = run_jal_lui(opcode, initial_pc, imm); + let (next_pc, rd_data) = run_jal_lui(false, initial_pc, imm); assert_eq!(next_pc, 456789124); assert_eq!(rd_data, [0, 240, 106, 208]); } + +#[test] +fn test_additional_bits() { + let last_limb_bits = PC_BITS - RV32_CELL_BITS * (RV32_REGISTER_NUM_LIMBS - 1); + let additional_bits = (last_limb_bits..RV32_CELL_BITS).fold(0, |acc, x| acc + (1u32 << x)); + assert_eq!(additional_bits, ADDITIONAL_BITS); +} diff --git a/extensions/rv32im/circuit/src/jalr/core.rs b/extensions/rv32im/circuit/src/jalr/core.rs index fd89c1e317..616bfd2c9b 100644 --- a/extensions/rv32im/circuit/src/jalr/core.rs +++ b/extensions/rv32im/circuit/src/jalr/core.rs @@ -3,18 +3,27 @@ use std::{ borrow::{Borrow, BorrowMut}, }; -use openvm_circuit::arch::{ - AdapterAirContext, AdapterRuntimeContext, Result, SignedImmInstruction, VmAdapterInterface, - VmCoreAir, VmCoreChip, +use openvm_circuit::{ + arch::{ + execution_mode::{E1ExecutionCtx, E2ExecutionCtx}, + get_record_from_slice, AdapterAirContext, AdapterTraceFiller, AdapterTraceStep, + E2PreCompute, EmptyAdapterCoreLayout, ExecuteFunc, + ExecutionError::InvalidInstruction, + RecordArena, Result, SignedImmInstruction, StepExecutorE1, StepExecutorE2, TraceFiller, + TraceStep, VmAdapterInterface, VmCoreAir, VmSegmentState, VmStateMut, + }, + system::memory::{online::TracingMemory, MemoryAuxColsFactory}, }; use openvm_circuit_primitives::{ bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus}, + AlignedBytesBorrow, }; use openvm_circuit_primitives_derive::AlignedBorrow; use openvm_instructions::{ instruction::Instruction, program::{DEFAULT_PC_STEP, PC_BITS}, + riscv::RV32_REGISTER_AS, LocalOpcode, }; use openvm_rv32im_transpiler::Rv32JalrOpcode::{self, *}; @@ -24,11 +33,8 @@ use openvm_stark_backend::{ p3_field::{Field, FieldAlgebra, PrimeField32}, rap::BaseAirWithPublicValues, }; -use serde::{Deserialize, Serialize}; -use crate::adapters::{compose, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}; - -const RV32_LIMB_MAX: u32 = (1 << RV32_CELL_BITS) - 1; +use crate::adapters::{RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}; #[repr(C)] #[derive(Debug, Clone, AlignedBorrow)] @@ -46,18 +52,7 @@ pub struct Rv32JalrCoreCols { pub imm_sign: T, } -#[repr(C)] -#[derive(Serialize, Deserialize)] -pub struct Rv32JalrCoreRecord { - pub imm: F, - pub rs1_data: [F; RV32_REGISTER_NUM_LIMBS], - pub rd_data: [F; RV32_REGISTER_NUM_LIMBS - 1], - pub to_pc_least_sig_bit: F, - pub to_pc_limbs: [u32; 2], - pub imm_sign: F, -} - -#[derive(Debug, Clone)] +#[derive(Debug, Clone, derive_new::new)] pub struct Rv32JalrCoreAir { pub bitwise_lookup_bus: BitwiseOperationLookupBus, pub range_bus: VariableRangeCheckerBus, @@ -181,127 +176,278 @@ where } } -pub struct Rv32JalrCoreChip { - pub air: Rv32JalrCoreAir, +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug)] +pub struct Rv32JalrCoreRecord { + pub imm: u16, + pub from_pc: u32, + pub rs1_val: u32, + pub imm_sign: bool, +} + +pub struct Rv32JalrStep { + adapter: A, pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, pub range_checker_chip: SharedVariableRangeCheckerChip, } -impl Rv32JalrCoreChip { +impl Rv32JalrStep { pub fn new( + adapter: A, bitwise_lookup_chip: SharedBitwiseOperationLookupChip, range_checker_chip: SharedVariableRangeCheckerChip, ) -> Self { assert!(range_checker_chip.range_max_bits() >= 16); Self { - air: Rv32JalrCoreAir { - bitwise_lookup_bus: bitwise_lookup_chip.bus(), - range_bus: range_checker_chip.bus(), - }, + adapter, bitwise_lookup_chip, range_checker_chip, } } } -impl> VmCoreChip for Rv32JalrCoreChip +impl TraceStep for Rv32JalrStep where - I::Reads: Into<[[F; RV32_REGISTER_NUM_LIMBS]; 1]>, - I::Writes: From<[[F; RV32_REGISTER_NUM_LIMBS]; 1]>, + F: PrimeField32, + A: 'static + + for<'a> AdapterTraceStep< + F, + CTX, + ReadData = [u8; RV32_REGISTER_NUM_LIMBS], + WriteData = [u8; RV32_REGISTER_NUM_LIMBS], + >, { - type Record = Rv32JalrCoreRecord; - type Air = Rv32JalrCoreAir; + type RecordLayout = EmptyAdapterCoreLayout; + type RecordMut<'a> = (A::RecordMut<'a>, &'a mut Rv32JalrCoreRecord); - #[allow(clippy::type_complexity)] - fn execute_instruction( - &self, + fn get_opcode_name(&self, opcode: usize) -> String { + format!( + "{:?}", + Rv32JalrOpcode::from_usize(opcode - Rv32JalrOpcode::CLASS_OFFSET) + ) + } + + fn execute<'buf, RA>( + &mut self, + state: VmStateMut, CTX>, instruction: &Instruction, - from_pc: u32, - reads: I::Reads, - ) -> Result<(AdapterRuntimeContext, Self::Record)> { + arena: &'buf mut RA, + ) -> Result<()> + where + RA: RecordArena<'buf, Self::RecordLayout, Self::RecordMut<'buf>>, + { let Instruction { opcode, c, g, .. } = *instruction; - let local_opcode = - Rv32JalrOpcode::from_usize(opcode.local_opcode_idx(Rv32JalrOpcode::CLASS_OFFSET)); - let imm = c.as_canonical_u32(); - let imm_sign = g.as_canonical_u32(); - let imm_extended = imm + imm_sign * 0xffff0000; + debug_assert_eq!( + opcode.local_opcode_idx(Rv32JalrOpcode::CLASS_OFFSET), + JALR as usize + ); + + let (mut adapter_record, core_record) = arena.alloc(EmptyAdapterCoreLayout::new()); + + A::start(*state.pc, state.memory, &mut adapter_record); - let rs1 = reads.into()[0]; - let rs1_val = compose(rs1); + core_record.rs1_val = u32::from_le_bytes(self.adapter.read( + state.memory, + instruction, + &mut adapter_record, + )); - let (to_pc, rd_data) = run_jalr(local_opcode, from_pc, imm_extended, rs1_val); + core_record.imm = c.as_canonical_u32() as u16; + core_record.imm_sign = g.is_one(); + core_record.from_pc = *state.pc; + let (to_pc, rd_data) = run_jalr( + core_record.from_pc, + core_record.rs1_val, + core_record.imm, + core_record.imm_sign, + ); + + self.adapter + .write(state.memory, instruction, rd_data, &mut adapter_record); + + // RISC-V spec explicitly sets the least significant bit of `to_pc` to 0 + *state.pc = to_pc & !1; + + Ok(()) + } +} +impl TraceFiller for Rv32JalrStep +where + F: PrimeField32, + A: 'static + AdapterTraceFiller, +{ + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + + self.adapter.fill_trace_row(mem_helper, adapter_row); + let record: &Rv32JalrCoreRecord = unsafe { get_record_from_slice(&mut core_row, ()) }; + + let core_row: &mut Rv32JalrCoreCols = core_row.borrow_mut(); + + let (to_pc, rd_data) = + run_jalr(record.from_pc, record.rs1_val, record.imm, record.imm_sign); + let to_pc_limbs = [(to_pc & ((1 << 16) - 1)) >> 1, to_pc >> 16]; + self.range_checker_chip.add_count(to_pc_limbs[0], 15); + self.range_checker_chip + .add_count(to_pc_limbs[1], PC_BITS - 16); self.bitwise_lookup_chip - .request_range(rd_data[0], rd_data[1]); + .request_range(rd_data[0] as u32, rd_data[1] as u32); + self.range_checker_chip - .add_count(rd_data[2], RV32_CELL_BITS); + .add_count(rd_data[2] as u32, RV32_CELL_BITS); self.range_checker_chip - .add_count(rd_data[3], PC_BITS - RV32_CELL_BITS * 3); - - let mask = (1 << 15) - 1; - let to_pc_least_sig_bit = rs1_val.wrapping_add(imm_extended) & 1; - - let to_pc_limbs = array::from_fn(|i| ((to_pc >> (1 + i * 15)) & mask)); + .add_count(rd_data[3] as u32, PC_BITS - RV32_CELL_BITS * 3); + + // Write in reverse order + core_row.imm_sign = F::from_bool(record.imm_sign); + core_row.to_pc_limbs = to_pc_limbs.map(F::from_canonical_u32); + core_row.to_pc_least_sig_bit = F::from_bool(to_pc & 1 == 1); + // fill_trace_row is called only on valid rows + core_row.is_valid = F::ONE; + core_row.rs1_data = record.rs1_val.to_le_bytes().map(F::from_canonical_u8); + core_row + .rd_data + .iter_mut() + .rev() + .zip(rd_data.iter().skip(1).rev()) + .for_each(|(dst, src)| { + *dst = F::from_canonical_u8(*src); + }); + core_row.imm = F::from_canonical_u16(record.imm); + } +} - let rd_data = rd_data.map(F::from_canonical_u32); +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct JalrPreCompute { + imm_extended: u32, + a: u8, + b: u8, +} - let output = AdapterRuntimeContext { - to_pc: Some(to_pc), - writes: [rd_data].into(), +impl StepExecutorE1 for Rv32JalrStep +where + F: PrimeField32, +{ + #[inline(always)] + fn pre_compute_size(&self) -> usize { + size_of::() + } + #[inline(always)] + fn pre_compute_e1( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result> { + let data: &mut JalrPreCompute = data.borrow_mut(); + let enabled = self.pre_compute_impl(pc, inst, data)?; + let fn_ptr = if enabled { + execute_e1_impl::<_, _, true> + } else { + execute_e1_impl::<_, _, false> }; + Ok(fn_ptr) + } +} - Ok(( - output, - Rv32JalrCoreRecord { - imm: c, - rd_data: array::from_fn(|i| rd_data[i + 1]), - rs1_data: rs1, - to_pc_least_sig_bit: F::from_canonical_u32(to_pc_least_sig_bit), - to_pc_limbs, - imm_sign: g, - }, - )) +impl StepExecutorE2 for Rv32JalrStep +where + F: PrimeField32, +{ + fn e2_pre_compute_size(&self) -> usize { + size_of::>() } - fn get_opcode_name(&self, opcode: usize) -> String { - format!( - "{:?}", - Rv32JalrOpcode::from_usize(opcode - Rv32JalrOpcode::CLASS_OFFSET) - ) + fn pre_compute_e2( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result> + where + Ctx: E2ExecutionCtx, + { + let data: &mut E2PreCompute = data.borrow_mut(); + data.chip_idx = chip_idx as u32; + let enabled = self.pre_compute_impl(pc, inst, &mut data.data)?; + let fn_ptr = if enabled { + execute_e2_impl::<_, _, true> + } else { + execute_e2_impl::<_, _, false> + }; + Ok(fn_ptr) } +} - fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) { - self.range_checker_chip.add_count(record.to_pc_limbs[0], 15); - self.range_checker_chip.add_count(record.to_pc_limbs[1], 14); - - let core_cols: &mut Rv32JalrCoreCols = row_slice.borrow_mut(); - core_cols.imm = record.imm; - core_cols.rd_data = record.rd_data; - core_cols.rs1_data = record.rs1_data; - core_cols.to_pc_least_sig_bit = record.to_pc_least_sig_bit; - core_cols.to_pc_limbs = record.to_pc_limbs.map(F::from_canonical_u32); - core_cols.imm_sign = record.imm_sign; - core_cols.is_valid = F::ONE; +#[inline(always)] +unsafe fn execute_e12_impl( + pre_compute: &JalrPreCompute, + vm_state: &mut VmSegmentState, +) { + let rs1 = vm_state.vm_read::(RV32_REGISTER_AS, pre_compute.b as u32); + let rs1 = u32::from_le_bytes(rs1); + let to_pc = rs1.wrapping_add(pre_compute.imm_extended); + let to_pc = to_pc - (to_pc & 1); + debug_assert!(to_pc < (1 << PC_BITS)); + let rd = (vm_state.pc + DEFAULT_PC_STEP).to_le_bytes(); + + if ENABLED { + vm_state.vm_write(RV32_REGISTER_AS, pre_compute.a as u32, &rd); } - fn air(&self) -> &Self::Air { - &self.air + vm_state.pc = to_pc; + vm_state.instret += 1; +} + +unsafe fn execute_e1_impl( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &JalrPreCompute = pre_compute.borrow(); + execute_e12_impl::(pre_compute, vm_state); +} + +unsafe fn execute_e2_impl( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &E2PreCompute = pre_compute.borrow(); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, 1); + execute_e12_impl::(&pre_compute.data, vm_state); +} + +impl Rv32JalrStep { + /// Return true if enabled. + fn pre_compute_impl( + &self, + pc: u32, + inst: &Instruction, + data: &mut JalrPreCompute, + ) -> Result { + let imm_extended = inst.c.as_canonical_u32() + inst.g.as_canonical_u32() * 0xffff0000; + if inst.d.as_canonical_u32() != RV32_REGISTER_AS { + return Err(InvalidInstruction(pc)); + } + *data = JalrPreCompute { + imm_extended, + a: inst.a.as_canonical_u32() as u8, + b: inst.b.as_canonical_u32() as u8, + }; + let enabled = !inst.f.is_zero(); + Ok(enabled) } } // returns (to_pc, rd_data) -pub(super) fn run_jalr( - _opcode: Rv32JalrOpcode, - pc: u32, - imm: u32, - rs1: u32, -) -> (u32, [u32; RV32_REGISTER_NUM_LIMBS]) { - let to_pc = rs1.wrapping_add(imm); - let to_pc = to_pc - (to_pc & 1); +#[inline(always)] +pub(super) fn run_jalr(pc: u32, rs1: u32, imm: u16, imm_sign: bool) -> (u32, [u8; 4]) { + let to_pc = rs1.wrapping_add(imm as u32 + (imm_sign as u32 * 0xffff0000)); assert!(to_pc < (1 << PC_BITS)); - ( - to_pc, - array::from_fn(|i: usize| ((pc + DEFAULT_PC_STEP) >> (RV32_CELL_BITS * i)) & RV32_LIMB_MAX), - ) + (to_pc, pc.wrapping_add(DEFAULT_PC_STEP).to_le_bytes()) } diff --git a/extensions/rv32im/circuit/src/jalr/mod.rs b/extensions/rv32im/circuit/src/jalr/mod.rs index 1d85dcbe4a..458376e7bf 100644 --- a/extensions/rv32im/circuit/src/jalr/mod.rs +++ b/extensions/rv32im/circuit/src/jalr/mod.rs @@ -1,6 +1,6 @@ -use openvm_circuit::arch::VmChipWrapper; +use openvm_circuit::arch::{MatrixRecordArena, NewVmChipWrapper, VmAirWrapper}; -use crate::adapters::Rv32JalrAdapterChip; +use crate::adapters::{Rv32JalrAdapterAir, Rv32JalrAdapterStep}; mod core; pub use core::*; @@ -8,4 +8,7 @@ pub use core::*; #[cfg(test)] mod tests; -pub type Rv32JalrChip = VmChipWrapper, Rv32JalrCoreChip>; +pub type Rv32JalrAir = VmAirWrapper; +pub type Rv32JalrStepWithAdapter = Rv32JalrStep; +pub type Rv32JalrChip = + NewVmChipWrapper>; diff --git a/extensions/rv32im/circuit/src/jalr/tests.rs b/extensions/rv32im/circuit/src/jalr/tests.rs index e22d97967f..4efc3c8c69 100644 --- a/extensions/rv32im/circuit/src/jalr/tests.rs +++ b/extensions/rv32im/circuit/src/jalr/tests.rs @@ -2,7 +2,7 @@ use std::{array, borrow::BorrowMut}; use openvm_circuit::arch::{ testing::{VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}, - VmAdapterChip, + VmAirWrapper, }; use openvm_circuit_primitives::bitwise_op_lookup::{ BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, @@ -12,26 +12,60 @@ use openvm_rv32im_transpiler::Rv32JalrOpcode::{self, *}; use openvm_stark_backend::{ p3_air::BaseAir, p3_field::{FieldAlgebra, PrimeField32}, - p3_matrix::{dense::RowMajorMatrix, Matrix}, + p3_matrix::{ + dense::{DenseMatrix, RowMajorMatrix}, + Matrix, + }, utils::disable_debug_builder, - verifier::VerificationError, - Chip, ChipUsageGetter, }; use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; use rand::{rngs::StdRng, Rng}; +use super::Rv32JalrCoreAir; use crate::{ - adapters::{compose, Rv32JalrAdapterChip, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}, - jalr::{run_jalr, Rv32JalrChip, Rv32JalrCoreChip, Rv32JalrCoreCols}, + adapters::{ + compose, Rv32JalrAdapterAir, Rv32JalrAdapterStep, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS, + }, + jalr::{run_jalr, Rv32JalrChip, Rv32JalrCoreCols, Rv32JalrStep}, + test_utils::get_verification_error, }; const IMM_BITS: usize = 16; +const MAX_INS_CAPACITY: usize = 128; + type F = BabyBear; fn into_limbs(num: u32) -> [u32; 4] { array::from_fn(|i| (num >> (8 * i)) & 255) } +fn create_test_chip( + tester: &mut VmChipTestBuilder, +) -> ( + Rv32JalrChip, + SharedBitwiseOperationLookupChip, +) { + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); + let range_checker_chip = tester.memory_controller().range_checker.clone(); + + let mut chip = Rv32JalrChip::::new( + VmAirWrapper::new( + Rv32JalrAdapterAir::new(tester.memory_bridge(), tester.execution_bridge()), + Rv32JalrCoreAir::new(bitwise_bus, range_checker_chip.bus()), + ), + Rv32JalrStep::new( + Rv32JalrAdapterStep::new(), + bitwise_chip.clone(), + range_checker_chip.clone(), + ), + tester.memory_helper(), + ); + chip.set_trace_buffer_height(MAX_INS_CAPACITY); + + (chip, bitwise_chip) +} + #[allow(clippy::too_many_arguments)] fn set_and_execute( tester: &mut VmChipTestBuilder, @@ -45,7 +79,7 @@ fn set_and_execute( ) { let imm = initial_imm.unwrap_or(rng.gen_range(0..(1 << IMM_BITS))); let imm_sign = initial_imm_sign.unwrap_or(rng.gen_range(0..2)); - let imm_ext = imm + imm_sign * (0xffffffff ^ ((1 << IMM_BITS) - 1)); + let imm_ext = imm + (imm_sign * 0xffff0000); let a = rng.gen_range(0..32) << 2; let b = rng.gen_range(1..32) << 2; let to_pc = rng.gen_range(0..(1 << PC_BITS)); @@ -55,6 +89,7 @@ fn set_and_execute( tester.write(1, b, rs1); + let initial_pc = initial_pc.unwrap_or(rng.gen_range(0..(1 << PC_BITS))); tester.execute_with_pc( chip, &Instruction::from_usize( @@ -69,18 +104,17 @@ fn set_and_execute( imm_sign as usize, ], ), - initial_pc.unwrap_or(rng.gen_range(0..(1 << PC_BITS))), + initial_pc, ); - let initial_pc = tester.execution.last_from_pc().as_canonical_u32(); let final_pc = tester.execution.last_to_pc().as_canonical_u32(); let rs1 = compose(rs1); - let (next_pc, rd_data) = run_jalr(opcode, initial_pc, imm_ext, rs1); + let (next_pc, rd_data) = run_jalr(initial_pc, rs1, imm as u16, imm_sign == 1); let rd_data = if a == 0 { [0; 4] } else { rd_data }; - assert_eq!(next_pc, final_pc); - assert_eq!(rd_data.map(F::from_canonical_u32), tester.read::<4>(1, a)); + assert_eq!(next_pc & !1, final_pc); + assert_eq!(rd_data.map(F::from_canonical_u8), tester.read::<4>(1, a)); } /////////////////////////////////////////////////////////////////////////////////////// @@ -92,21 +126,11 @@ fn set_and_execute( #[test] fn rand_jalr_test() { let mut rng = create_seeded_rng(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); let mut tester = VmChipTestBuilder::default(); - let range_checker_chip = tester.memory_controller().borrow().range_checker.clone(); + let (mut chip, bitwise_chip) = create_test_chip(&mut tester); - let adapter = Rv32JalrAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - ); - let inner = Rv32JalrCoreChip::new(bitwise_chip.clone(), range_checker_chip.clone()); - let mut chip = Rv32JalrChip::::new(adapter, inner, tester.offline_memory_mutex_arc()); - - let num_tests: usize = 100; - for _ in 0..num_tests { + let num_ops = 100; + for _ in 0..num_ops { set_and_execute( &mut tester, &mut chip, @@ -119,7 +143,6 @@ fn rand_jalr_test() { ); } - drop(range_checker_chip); let tester = tester.build().load(chip).load(bitwise_chip).finalize(); tester.simple_test().expect("Verification failed"); } @@ -128,10 +151,18 @@ fn rand_jalr_test() { // NEGATIVE TESTS // // Given a fake trace of a single operation, setup a chip and run the test. We replace -// the write part of the trace and check that the core chip throws the expected error. -// A dummy adaptor is used so memory interactions don't indirectly cause false passes. +// part of the trace and check that the chip throws the expected error. ////////////////////////////////////////////////////////////////////////////////////// +#[derive(Clone, Copy, Default, PartialEq)] +struct JalrPrankValues { + pub rd_data: Option<[u32; RV32_REGISTER_NUM_LIMBS - 1]>, + pub rs1_data: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, + pub to_pc_least_sig_bit: Option, + pub to_pc_limbs: Option<[u32; 2]>, + pub imm_sign: Option, +} + #[allow(clippy::too_many_arguments)] fn run_negative_jalr_test( opcode: Rv32JalrOpcode, @@ -139,27 +170,13 @@ fn run_negative_jalr_test( initial_rs1: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, initial_imm: Option, initial_imm_sign: Option, - rd_data: Option<[u32; RV32_REGISTER_NUM_LIMBS - 1]>, - rs1_data: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, - to_pc_least_sig_bit: Option, - to_pc_limbs: Option<[u32; 2]>, - imm_sign: Option, - expected_error: VerificationError, + prank_vals: JalrPrankValues, + interaction_error: bool, ) { let mut rng = create_seeded_rng(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); let mut tester = VmChipTestBuilder::default(); - let range_checker_chip = tester.memory_controller().borrow().range_checker.clone(); - let adapter = Rv32JalrAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - ); - let adapter_width = BaseAir::::width(adapter.air()); - let inner = Rv32JalrCoreChip::new(bitwise_chip.clone(), range_checker_chip.clone()); - let mut chip = Rv32JalrChip::::new(adapter, inner, tester.offline_memory_mutex_arc()); + let (mut chip, bitwise_chip) = create_test_chip(&mut tester); set_and_execute( &mut tester, @@ -172,49 +189,38 @@ fn run_negative_jalr_test( initial_rs1, ); - let tester = tester.build(); - - let jalr_trace_width = chip.trace_width(); - let air = chip.air(); - let mut chip_input = chip.generate_air_proof_input(); - let jalr_trace = chip_input.raw.common_main.as_mut().unwrap(); - { - let mut trace_row = jalr_trace.row_slice(0).to_vec(); - + let adapter_width = BaseAir::::width(&chip.air.adapter); + let modify_trace = |trace: &mut DenseMatrix| { + let mut trace_row = trace.row_slice(0).to_vec(); let (_, core_row) = trace_row.split_at_mut(adapter_width); - let core_cols: &mut Rv32JalrCoreCols = core_row.borrow_mut(); - if let Some(data) = rd_data { + if let Some(data) = prank_vals.rd_data { core_cols.rd_data = data.map(F::from_canonical_u32); } - - if let Some(data) = rs1_data { + if let Some(data) = prank_vals.rs1_data { core_cols.rs1_data = data.map(F::from_canonical_u32); } - - if let Some(data) = to_pc_least_sig_bit { + if let Some(data) = prank_vals.to_pc_least_sig_bit { core_cols.to_pc_least_sig_bit = F::from_canonical_u32(data); } - - if let Some(data) = to_pc_limbs { + if let Some(data) = prank_vals.to_pc_limbs { core_cols.to_pc_limbs = data.map(F::from_canonical_u32); } - - if let Some(data) = imm_sign { + if let Some(data) = prank_vals.imm_sign { core_cols.imm_sign = F::from_canonical_u32(data); } - *jalr_trace = RowMajorMatrix::new(trace_row, jalr_trace_width); - } + *trace = RowMajorMatrix::new(trace_row, trace.width()); + }; - drop(range_checker_chip); disable_debug_builder(); let tester = tester - .load_air_proof_input((air, chip_input)) + .build() + .load_and_prank_trace(chip, modify_trace) .load(bitwise_chip) .finalize(); - tester.simple_test_with_expected_error(expected_error); + tester.simple_test_with_expected_error(get_verification_error(interaction_error)); } #[test] @@ -225,12 +231,11 @@ fn invalid_cols_negative_tests() { None, Some(15362), Some(0), - None, - None, - None, - None, - Some(1), - VerificationError::OodEvaluationMismatch, + JalrPrankValues { + imm_sign: Some(1), + ..Default::default() + }, + false, ); run_negative_jalr_test( @@ -239,12 +244,11 @@ fn invalid_cols_negative_tests() { None, Some(15362), Some(1), - None, - None, - None, - None, - Some(0), - VerificationError::OodEvaluationMismatch, + JalrPrankValues { + imm_sign: Some(0), + ..Default::default() + }, + false, ); run_negative_jalr_test( @@ -253,12 +257,11 @@ fn invalid_cols_negative_tests() { Some([23, 154, 67, 28]), Some(42512), Some(1), - None, - None, - Some(0), - None, - None, - VerificationError::OodEvaluationMismatch, + JalrPrankValues { + to_pc_least_sig_bit: Some(0), + ..Default::default() + }, + false, ); } @@ -270,12 +273,11 @@ fn overflow_negative_tests() { None, None, None, - Some([1, 0, 0]), - None, - None, - None, - None, - VerificationError::ChallengePhaseError, + JalrPrankValues { + rd_data: Some([1, 0, 0]), + ..Default::default() + }, + true, ); run_negative_jalr_test( @@ -284,15 +286,14 @@ fn overflow_negative_tests() { Some([0, 0, 0, 0]), Some((1 << 15) - 2), Some(0), - None, - None, - None, - Some([ - (F::NEG_ONE * F::from_canonical_u32((1 << 14) + 1)).as_canonical_u32(), - 1, - ]), - None, - VerificationError::ChallengePhaseError, + JalrPrankValues { + to_pc_limbs: Some([ + (F::NEG_ONE * F::from_canonical_u32((1 << 14) + 1)).as_canonical_u32(), + 1, + ]), + ..Default::default() + }, + true, ); } @@ -301,44 +302,13 @@ fn overflow_negative_tests() { /// /// Ensure that solve functions produce the correct results. /////////////////////////////////////////////////////////////////////////////////////// -#[test] -fn execute_roundtrip_sanity_test() { - let mut rng = create_seeded_rng(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let mut tester = VmChipTestBuilder::default(); - let range_checker_chip = tester.memory_controller().borrow().range_checker.clone(); - - let adapter = Rv32JalrAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - ); - let inner = Rv32JalrCoreChip::new(bitwise_chip, range_checker_chip); - let mut chip = Rv32JalrChip::::new(adapter, inner, tester.offline_memory_mutex_arc()); - - let num_tests: usize = 10; - for _ in 0..num_tests { - set_and_execute( - &mut tester, - &mut chip, - &mut rng, - JALR, - None, - None, - None, - None, - ); - } -} #[test] fn run_jalr_sanity_test() { - let opcode = JALR; let initial_pc = 789456120; let imm = -1235_i32 as u32; let rs1 = 736482910; - let (next_pc, rd_data) = run_jalr(opcode, initial_pc, imm, rs1); - assert_eq!(next_pc, 736481674); + let (next_pc, rd_data) = run_jalr(initial_pc, rs1, imm as u16, true); + assert_eq!(next_pc & !1, 736481674); assert_eq!(rd_data, [252, 36, 14, 47]); } diff --git a/extensions/rv32im/circuit/src/less_than/core.rs b/extensions/rv32im/circuit/src/less_than/core.rs index a605dc43de..b9e92205c7 100644 --- a/extensions/rv32im/circuit/src/less_than/core.rs +++ b/extensions/rv32im/circuit/src/less_than/core.rs @@ -3,16 +3,29 @@ use std::{ borrow::{Borrow, BorrowMut}, }; -use openvm_circuit::arch::{ - AdapterAirContext, AdapterRuntimeContext, MinimalInstruction, Result, VmAdapterInterface, - VmCoreAir, VmCoreChip, +use openvm_circuit::{ + arch::{ + execution_mode::{E1ExecutionCtx, E2ExecutionCtx}, + get_record_from_slice, AdapterAirContext, AdapterTraceFiller, AdapterTraceStep, + E2PreCompute, EmptyAdapterCoreLayout, ExecuteFunc, + ExecutionError::InvalidInstruction, + MinimalInstruction, RecordArena, Result, StepExecutorE1, StepExecutorE2, TraceFiller, + TraceStep, VmAdapterInterface, VmCoreAir, VmSegmentState, VmStateMut, + }, + system::memory::{online::TracingMemory, MemoryAuxColsFactory}, }; use openvm_circuit_primitives::{ bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, utils::not, + AlignedBytesBorrow, }; use openvm_circuit_primitives_derive::AlignedBorrow; -use openvm_instructions::{instruction::Instruction, LocalOpcode}; +use openvm_instructions::{ + instruction::Instruction, + program::DEFAULT_PC_STEP, + riscv::{RV32_IMM_AS, RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS}, + LocalOpcode, +}; use openvm_rv32im_transpiler::LessThanOpcode; use openvm_stark_backend::{ interaction::InteractionBuilder, @@ -20,12 +33,12 @@ use openvm_stark_backend::{ p3_field::{Field, FieldAlgebra, PrimeField32}, rap::BaseAirWithPublicValues, }; -use serde::{de::DeserializeOwned, Deserialize, Serialize}; -use serde_big_array::BigArray; use strum::IntoEnumIterator; +use crate::adapters::imm_to_bytes; + #[repr(C)] -#[derive(AlignedBorrow)] +#[derive(AlignedBorrow, Debug)] pub struct LessThanCoreCols { pub b: [T; NUM_LIMBS], pub c: [T; NUM_LIMBS], @@ -45,7 +58,7 @@ pub struct LessThanCoreCols { pub diff_val: T, } -#[derive(Copy, Clone, Debug)] +#[derive(Copy, Clone, Debug, derive_new::new)] pub struct LessThanCoreAir { pub bus: BitwiseOperationLookupBus, offset: usize, @@ -164,162 +177,339 @@ where } #[repr(C)] -#[derive(Clone, Debug, Serialize, Deserialize)] -#[serde(bound = "T: Serialize + DeserializeOwned")] -pub struct LessThanCoreRecord { - #[serde(with = "BigArray")] - pub b: [T; NUM_LIMBS], - #[serde(with = "BigArray")] - pub c: [T; NUM_LIMBS], - pub cmp_result: T, - pub b_msb_f: T, - pub c_msb_f: T, - pub diff_val: T, - pub diff_idx: usize, - pub opcode: LessThanOpcode, +#[derive(AlignedBytesBorrow, Debug)] +pub struct LessThanCoreRecord { + pub b: [u8; NUM_LIMBS], + pub c: [u8; NUM_LIMBS], + pub local_opcode: u8, } -pub struct LessThanCoreChip { - pub air: LessThanCoreAir, +#[derive(derive_new::new)] +pub struct LessThanStep { + adapter: A, pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, + pub offset: usize, } -impl LessThanCoreChip { - pub fn new( - bitwise_lookup_chip: SharedBitwiseOperationLookupChip, - offset: usize, - ) -> Self { - Self { - air: LessThanCoreAir { - bus: bitwise_lookup_chip.bus(), - offset, - }, - bitwise_lookup_chip, - } +impl TraceStep + for LessThanStep +where + F: PrimeField32, + A: 'static + + for<'a> AdapterTraceStep< + F, + CTX, + ReadData: Into<[[u8; NUM_LIMBS]; 2]>, + WriteData: From<[[u8; NUM_LIMBS]; 1]>, + >, +{ + type RecordLayout = EmptyAdapterCoreLayout; + type RecordMut<'a> = ( + A::RecordMut<'a>, + &'a mut LessThanCoreRecord, + ); + + fn get_opcode_name(&self, opcode: usize) -> String { + format!("{:?}", LessThanOpcode::from_usize(opcode - self.offset)) + } + + fn execute<'buf, RA>( + &mut self, + state: VmStateMut, CTX>, + instruction: &Instruction, + arena: &'buf mut RA, + ) -> Result<()> + where + RA: RecordArena<'buf, Self::RecordLayout, Self::RecordMut<'buf>>, + { + debug_assert!(LIMB_BITS <= 8); + let Instruction { opcode, .. } = instruction; + + let (mut adapter_record, core_record) = arena.alloc(EmptyAdapterCoreLayout::new()); + A::start(*state.pc, state.memory, &mut adapter_record); + + let [rs1, rs2] = self + .adapter + .read(state.memory, instruction, &mut adapter_record) + .into(); + + core_record.b = rs1; + core_record.c = rs2; + core_record.local_opcode = opcode.local_opcode_idx(self.offset) as u8; + + let (cmp_result, _, _, _) = run_less_than::( + core_record.local_opcode == LessThanOpcode::SLT as u8, + &rs1, + &rs2, + ); + + let mut output = [0u8; NUM_LIMBS]; + output[0] = cmp_result as u8; + + self.adapter.write( + state.memory, + instruction, + [output].into(), + &mut adapter_record, + ); + + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + + Ok(()) } } -impl, const NUM_LIMBS: usize, const LIMB_BITS: usize> - VmCoreChip for LessThanCoreChip +impl TraceFiller + for LessThanStep where - I::Reads: Into<[[F; NUM_LIMBS]; 2]>, - I::Writes: From<[[F; NUM_LIMBS]; 1]>, + F: PrimeField32, + A: 'static + AdapterTraceFiller, { - type Record = LessThanCoreRecord; - type Air = LessThanCoreAir; + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + self.adapter.fill_trace_row(mem_helper, adapter_row); + let record: &LessThanCoreRecord = + unsafe { get_record_from_slice(&mut core_row, ()) }; - #[allow(clippy::type_complexity)] - fn execute_instruction( - &self, - instruction: &Instruction, - _from_pc: u32, - reads: I::Reads, - ) -> Result<(AdapterRuntimeContext, Self::Record)> { - let Instruction { opcode, .. } = instruction; - let less_than_opcode = LessThanOpcode::from_usize(opcode.local_opcode_idx(self.air.offset)); + let core_row: &mut LessThanCoreCols = core_row.borrow_mut(); - let data: [[F; NUM_LIMBS]; 2] = reads.into(); - let b = data[0].map(|x| x.as_canonical_u32()); - let c = data[1].map(|y| y.as_canonical_u32()); + let is_slt = record.local_opcode == LessThanOpcode::SLT as u8; let (cmp_result, diff_idx, b_sign, c_sign) = - run_less_than::(less_than_opcode, &b, &c); + run_less_than::(is_slt, &record.b, &record.c); // We range check (b_msb_f + 128) and (c_msb_f + 128) if signed, // b_msb_f and c_msb_f if not let (b_msb_f, b_msb_range) = if b_sign { ( - -F::from_canonical_u32((1 << LIMB_BITS) - b[NUM_LIMBS - 1]), - b[NUM_LIMBS - 1] - (1 << (LIMB_BITS - 1)), + -F::from_canonical_u16((1u16 << LIMB_BITS) - record.b[NUM_LIMBS - 1] as u16), + record.b[NUM_LIMBS - 1] - (1u8 << (LIMB_BITS - 1)), ) } else { ( - F::from_canonical_u32(b[NUM_LIMBS - 1]), - b[NUM_LIMBS - 1] - + (((less_than_opcode == LessThanOpcode::SLT) as u32) << (LIMB_BITS - 1)), + F::from_canonical_u8(record.b[NUM_LIMBS - 1]), + record.b[NUM_LIMBS - 1] + ((is_slt as u8) << (LIMB_BITS - 1)), ) }; let (c_msb_f, c_msb_range) = if c_sign { ( - -F::from_canonical_u32((1 << LIMB_BITS) - c[NUM_LIMBS - 1]), - c[NUM_LIMBS - 1] - (1 << (LIMB_BITS - 1)), + -F::from_canonical_u16((1u16 << LIMB_BITS) - record.c[NUM_LIMBS - 1] as u16), + record.c[NUM_LIMBS - 1] - (1u8 << (LIMB_BITS - 1)), ) } else { ( - F::from_canonical_u32(c[NUM_LIMBS - 1]), - c[NUM_LIMBS - 1] - + (((less_than_opcode == LessThanOpcode::SLT) as u32) << (LIMB_BITS - 1)), + F::from_canonical_u8(record.c[NUM_LIMBS - 1]), + record.c[NUM_LIMBS - 1] + ((is_slt as u8) << (LIMB_BITS - 1)), ) }; - self.bitwise_lookup_chip - .request_range(b_msb_range, c_msb_range); - let diff_val = if diff_idx == NUM_LIMBS { - 0 + core_row.diff_val = if diff_idx == NUM_LIMBS { + F::ZERO } else if diff_idx == (NUM_LIMBS - 1) { if cmp_result { c_msb_f - b_msb_f } else { b_msb_f - c_msb_f } - .as_canonical_u32() } else if cmp_result { - c[diff_idx] - b[diff_idx] + F::from_canonical_u8(record.c[diff_idx] - record.b[diff_idx]) } else { - b[diff_idx] - c[diff_idx] + F::from_canonical_u8(record.b[diff_idx] - record.c[diff_idx]) }; + self.bitwise_lookup_chip + .request_range(b_msb_range as u32, c_msb_range as u32); + + core_row.diff_marker = [F::ZERO; NUM_LIMBS]; if diff_idx != NUM_LIMBS { - self.bitwise_lookup_chip.request_range(diff_val - 1, 0); + self.bitwise_lookup_chip + .request_range(core_row.diff_val.as_canonical_u32() - 1, 0); + core_row.diff_marker[diff_idx] = F::ONE; } - let mut writes = [0u32; NUM_LIMBS]; - writes[0] = cmp_result as u32; - - let output = AdapterRuntimeContext::without_pc([writes.map(F::from_canonical_u32)]); - let record = LessThanCoreRecord { - opcode: less_than_opcode, - b: data[0], - c: data[1], - cmp_result: F::from_bool(cmp_result), - b_msb_f, - c_msb_f, - diff_val: F::from_canonical_u32(diff_val), - diff_idx, - }; + core_row.c_msb_f = c_msb_f; + core_row.b_msb_f = b_msb_f; + core_row.opcode_sltu_flag = F::from_bool(!is_slt); + core_row.opcode_slt_flag = F::from_bool(is_slt); + core_row.cmp_result = F::from_bool(cmp_result); + core_row.c = record.c.map(F::from_canonical_u8); + core_row.b = record.b.map(F::from_canonical_u8); + } +} - Ok((output, record)) +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct LessThanPreCompute { + c: u32, + a: u8, + b: u8, +} + +impl StepExecutorE1 + for LessThanStep +where + F: PrimeField32, +{ + #[inline(always)] + fn pre_compute_size(&self) -> usize { + size_of::() } - fn get_opcode_name(&self, opcode: usize) -> String { - format!("{:?}", LessThanOpcode::from_usize(opcode - self.air.offset)) + #[inline(always)] + fn pre_compute_e1( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result> { + let pre_compute: &mut LessThanPreCompute = data.borrow_mut(); + let (is_imm, is_sltu) = self.pre_compute_impl(pc, inst, pre_compute)?; + let fn_ptr = match (is_imm, is_sltu) { + (true, true) => execute_e1_impl::<_, _, true, true>, + (true, false) => execute_e1_impl::<_, _, true, false>, + (false, true) => execute_e1_impl::<_, _, false, true>, + (false, false) => execute_e1_impl::<_, _, false, false>, + }; + Ok(fn_ptr) } +} - fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) { - let row_slice: &mut LessThanCoreCols<_, NUM_LIMBS, LIMB_BITS> = row_slice.borrow_mut(); - row_slice.b = record.b; - row_slice.c = record.c; - row_slice.cmp_result = record.cmp_result; - row_slice.b_msb_f = record.b_msb_f; - row_slice.c_msb_f = record.c_msb_f; - row_slice.diff_val = record.diff_val; - row_slice.opcode_slt_flag = F::from_bool(record.opcode == LessThanOpcode::SLT); - row_slice.opcode_sltu_flag = F::from_bool(record.opcode == LessThanOpcode::SLTU); - row_slice.diff_marker = array::from_fn(|i| F::from_bool(i == record.diff_idx)); +impl StepExecutorE2 + for LessThanStep +where + F: PrimeField32, +{ + fn e2_pre_compute_size(&self) -> usize { + size_of::>() } - fn air(&self) -> &Self::Air { - &self.air + fn pre_compute_e2( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result> + where + Ctx: E2ExecutionCtx, + { + let pre_compute: &mut E2PreCompute = data.borrow_mut(); + pre_compute.chip_idx = chip_idx as u32; + let (is_imm, is_sltu) = self.pre_compute_impl(pc, inst, &mut pre_compute.data)?; + let fn_ptr = match (is_imm, is_sltu) { + (true, true) => execute_e2_impl::<_, _, true, true>, + (true, false) => execute_e2_impl::<_, _, true, false>, + (false, true) => execute_e2_impl::<_, _, false, true>, + (false, false) => execute_e2_impl::<_, _, false, false>, + }; + Ok(fn_ptr) + } +} + +unsafe fn execute_e12_impl< + F: PrimeField32, + CTX: E1ExecutionCtx, + const E_IS_IMM: bool, + const IS_U32: bool, +>( + pre_compute: &LessThanPreCompute, + vm_state: &mut VmSegmentState, +) { + let rs1 = vm_state.vm_read::(RV32_REGISTER_AS, pre_compute.b as u32); + let rs2 = if E_IS_IMM { + pre_compute.c.to_le_bytes() + } else { + vm_state.vm_read::(RV32_REGISTER_AS, pre_compute.c) + }; + let cmp_result = if IS_U32 { + u32::from_le_bytes(rs1) < u32::from_le_bytes(rs2) + } else { + i32::from_le_bytes(rs1) < i32::from_le_bytes(rs2) + }; + let mut rd = [0u8; RV32_REGISTER_NUM_LIMBS]; + rd[0] = cmp_result as u8; + vm_state.vm_write(RV32_REGISTER_AS, pre_compute.a as u32, &rd); + + vm_state.pc += DEFAULT_PC_STEP; + vm_state.instret += 1; +} + +unsafe fn execute_e1_impl< + F: PrimeField32, + CTX: E1ExecutionCtx, + const E_IS_IMM: bool, + const IS_U32: bool, +>( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &LessThanPreCompute = pre_compute.borrow(); + execute_e12_impl::(pre_compute, vm_state); +} +unsafe fn execute_e2_impl< + F: PrimeField32, + CTX: E2ExecutionCtx, + const E_IS_IMM: bool, + const IS_U32: bool, +>( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &E2PreCompute = pre_compute.borrow(); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, 1); + execute_e12_impl::(&pre_compute.data, vm_state); +} + +impl LessThanStep { + #[inline(always)] + fn pre_compute_impl( + &self, + pc: u32, + inst: &Instruction, + data: &mut LessThanPreCompute, + ) -> Result<(bool, bool)> { + let Instruction { + opcode, + a, + b, + c, + d, + e, + .. + } = inst; + let e_u32 = e.as_canonical_u32(); + if d.as_canonical_u32() != RV32_REGISTER_AS + || !(e_u32 == RV32_IMM_AS || e_u32 == RV32_REGISTER_AS) + { + return Err(InvalidInstruction(pc)); + } + let local_opcode = LessThanOpcode::from_usize(opcode.local_opcode_idx(self.offset)); + let is_imm = e_u32 == RV32_IMM_AS; + let c_u32 = c.as_canonical_u32(); + + *data = LessThanPreCompute { + c: if is_imm { + u32::from_le_bytes(imm_to_bytes(c_u32)) + } else { + c_u32 + }, + a: a.as_canonical_u32() as u8, + b: b.as_canonical_u32() as u8, + }; + Ok((is_imm, local_opcode == LessThanOpcode::SLTU)) } } // Returns (cmp_result, diff_idx, x_sign, y_sign) +#[inline(always)] pub(super) fn run_less_than( - opcode: LessThanOpcode, - x: &[u32; NUM_LIMBS], - y: &[u32; NUM_LIMBS], + is_slt: bool, + x: &[u8; NUM_LIMBS], + y: &[u8; NUM_LIMBS], ) -> (bool, usize, bool, bool) { - let x_sign = (x[NUM_LIMBS - 1] >> (LIMB_BITS - 1) == 1) && opcode == LessThanOpcode::SLT; - let y_sign = (y[NUM_LIMBS - 1] >> (LIMB_BITS - 1) == 1) && opcode == LessThanOpcode::SLT; + let x_sign = (x[NUM_LIMBS - 1] >> (LIMB_BITS - 1) == 1) && is_slt; + let y_sign = (y[NUM_LIMBS - 1] >> (LIMB_BITS - 1) == 1) && is_slt; for i in (0..NUM_LIMBS).rev() { if x[i] != y[i] { return ((x[i] < y[i]) ^ x_sign ^ y_sign, i, x_sign, y_sign); diff --git a/extensions/rv32im/circuit/src/less_than/mod.rs b/extensions/rv32im/circuit/src/less_than/mod.rs index f8247d2d33..48a877527c 100644 --- a/extensions/rv32im/circuit/src/less_than/mod.rs +++ b/extensions/rv32im/circuit/src/less_than/mod.rs @@ -1,6 +1,8 @@ -use openvm_circuit::arch::VmChipWrapper; +use openvm_circuit::arch::{MatrixRecordArena, NewVmChipWrapper, VmAirWrapper}; -use super::adapters::{Rv32BaseAluAdapterChip, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}; +use super::adapters::{ + Rv32BaseAluAdapterAir, Rv32BaseAluAdapterStep, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS, +}; mod core; pub use core::*; @@ -8,8 +10,9 @@ pub use core::*; #[cfg(test)] mod tests; -pub type Rv32LessThanChip = VmChipWrapper< - F, - Rv32BaseAluAdapterChip, - LessThanCoreChip, ->; +pub type Rv32LessThanAir = + VmAirWrapper>; +pub type Rv32LessThanStep = + LessThanStep, RV32_REGISTER_NUM_LIMBS, RV32_CELL_BITS>; +pub type Rv32LessThanChip = + NewVmChipWrapper>; diff --git a/extensions/rv32im/circuit/src/less_than/tests.rs b/extensions/rv32im/circuit/src/less_than/tests.rs index 18d64bf5f6..50bcac3bd0 100644 --- a/extensions/rv32im/circuit/src/less_than/tests.rs +++ b/extensions/rv32im/circuit/src/less_than/tests.rs @@ -1,17 +1,17 @@ -use std::borrow::BorrowMut; +use std::{array, borrow::BorrowMut}; use openvm_circuit::{ arch::{ - testing::{TestAdapterChip, VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}, - ExecutionBridge, VmAdapterChip, VmChipWrapper, + testing::{VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}, + InstructionExecutor, VmAirWrapper, }, - utils::{generate_long_number, i32_to_f}, + utils::i32_to_f, }; use openvm_circuit_primitives::bitwise_op_lookup::{ BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, }; -use openvm_instructions::{instruction::Instruction, LocalOpcode}; -use openvm_rv32im_transpiler::LessThanOpcode; +use openvm_instructions::LocalOpcode; +use openvm_rv32im_transpiler::LessThanOpcode::{self, *}; use openvm_stark_backend::{ p3_air::BaseAir, p3_field::{FieldAlgebra, PrimeField32}, @@ -20,20 +20,96 @@ use openvm_stark_backend::{ Matrix, }, utils::disable_debug_builder, - verifier::VerificationError, - ChipUsageGetter, }; use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; -use rand::Rng; +use rand::{rngs::StdRng, Rng}; +use test_case::test_case; -use super::{core::run_less_than, LessThanCoreChip, Rv32LessThanChip}; +use super::{core::run_less_than, LessThanCoreAir, LessThanStep, Rv32LessThanChip}; use crate::{ - adapters::{Rv32BaseAluAdapterChip, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}, + adapters::{ + Rv32BaseAluAdapterAir, Rv32BaseAluAdapterStep, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS, + }, less_than::LessThanCoreCols, - test_utils::{generate_rv32_is_type_immediate, rv32_rand_write_register_or_imm}, + test_utils::{ + generate_rv32_is_type_immediate, get_verification_error, rv32_rand_write_register_or_imm, + }, }; type F = BabyBear; +const MAX_INS_CAPACITY: usize = 128; + +fn create_test_chip( + tester: &VmChipTestBuilder, +) -> ( + Rv32LessThanChip, + SharedBitwiseOperationLookupChip, +) { + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); + + let mut chip = Rv32LessThanChip::::new( + VmAirWrapper::new( + Rv32BaseAluAdapterAir::new( + tester.execution_bridge(), + tester.memory_bridge(), + bitwise_bus, + ), + LessThanCoreAir::new(bitwise_bus, LessThanOpcode::CLASS_OFFSET), + ), + LessThanStep::new( + Rv32BaseAluAdapterStep::new(bitwise_chip.clone()), + bitwise_chip.clone(), + LessThanOpcode::CLASS_OFFSET, + ), + tester.memory_helper(), + ); + chip.set_trace_buffer_height(MAX_INS_CAPACITY); + + (chip, bitwise_chip) +} + +#[allow(clippy::too_many_arguments)] +fn set_and_execute>( + tester: &mut VmChipTestBuilder, + chip: &mut E, + rng: &mut StdRng, + opcode: LessThanOpcode, + b: Option<[u8; RV32_REGISTER_NUM_LIMBS]>, + is_imm: Option, + c: Option<[u8; RV32_REGISTER_NUM_LIMBS]>, +) { + let b = b.unwrap_or(array::from_fn(|_| rng.gen_range(0..=u8::MAX))); + let (c_imm, c) = if is_imm.unwrap_or(rng.gen_bool(0.5)) { + let (imm, c) = if let Some(c) = c { + ((u32::from_le_bytes(c) & 0xFFFFFF) as usize, c) + } else { + generate_rv32_is_type_immediate(rng) + }; + (Some(imm), c) + } else { + ( + None, + c.unwrap_or(array::from_fn(|_| rng.gen_range(0..=u8::MAX))), + ) + }; + + let (instruction, rd) = rv32_rand_write_register_or_imm( + tester, + b, + c, + c_imm, + opcode.global_opcode().as_usize(), + rng, + ); + tester.execute(chip, &instruction); + + let (cmp, _, _, _) = + run_less_than::(opcode == SLT, &b, &c); + let mut a = [F::ZERO; RV32_REGISTER_NUM_LIMBS]; + a[0] = F::from_bool(cmp); + assert_eq!(a, tester.read::(1, rd)); +} ////////////////////////////////////////////////////////////////////////////////////// // POSITIVE TESTS @@ -42,100 +118,51 @@ type F = BabyBear; // passes all constraints. ////////////////////////////////////////////////////////////////////////////////////// +#[test_case(SLT, 100)] +#[test_case(SLTU, 100)] fn run_rv32_lt_rand_test(opcode: LessThanOpcode, num_ops: usize) { let mut rng = create_seeded_rng(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let mut tester = VmChipTestBuilder::default(); - let mut chip = Rv32LessThanChip::::new( - Rv32BaseAluAdapterChip::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - bitwise_chip.clone(), - ), - LessThanCoreChip::new(bitwise_chip.clone(), LessThanOpcode::CLASS_OFFSET), - tester.offline_memory_mutex_arc(), - ); + let (mut chip, bitwise_chip) = create_test_chip(&tester); for _ in 0..num_ops { - let b = generate_long_number::(&mut rng); - let (c_imm, c) = if rng.gen_bool(0.5) { - ( - None, - generate_long_number::(&mut rng), - ) - } else { - let (imm, c) = generate_rv32_is_type_immediate(&mut rng); - (Some(imm), c) - }; - - let (instruction, rd) = rv32_rand_write_register_or_imm( - &mut tester, - b, - c, - c_imm, - opcode.global_opcode().as_usize(), - &mut rng, - ); - tester.execute(&mut chip, &instruction); - - let (cmp, _, _, _) = - run_less_than::(opcode, &b, &c); - let mut a = [F::ZERO; RV32_REGISTER_NUM_LIMBS]; - a[0] = F::from_bool(cmp); - assert_eq!(a, tester.read::(1, rd)); + set_and_execute(&mut tester, &mut chip, &mut rng, opcode, None, None, None); } // Test special case where b = c let b = [101, 128, 202, 255]; - let (instruction, _) = rv32_rand_write_register_or_imm( + set_and_execute( &mut tester, - b, - b, - None, - opcode.global_opcode().as_usize(), + &mut chip, &mut rng, + opcode, + Some(b), + Some(false), + Some(b), ); - tester.execute(&mut chip, &instruction); let b = [36, 0, 0, 0]; - let (instruction, _) = rv32_rand_write_register_or_imm( + set_and_execute( &mut tester, - b, - b, - Some(36), - opcode.global_opcode().as_usize(), + &mut chip, &mut rng, + opcode, + Some(b), + Some(true), + Some(b), ); - tester.execute(&mut chip, &instruction); let tester = tester.build().load(chip).load(bitwise_chip).finalize(); tester.simple_test().expect("Verification failed"); } -#[test] -fn rv32_slt_rand_test() { - run_rv32_lt_rand_test(LessThanOpcode::SLT, 100); -} - -#[test] -fn rv32_sltu_rand_test() { - run_rv32_lt_rand_test(LessThanOpcode::SLTU, 100); -} - ////////////////////////////////////////////////////////////////////////////////////// // NEGATIVE TESTS // // Given a fake trace of a single operation, setup a chip and run the test. We replace -// the write part of the trace and check that the core chip throws the expected error. -// A dummy adapter is used so memory interactions don't indirectly cause false passes. +// part of the trace and check that the chip throws the expected error. ////////////////////////////////////////////////////////////////////////////////////// -type Rv32LessThanTestChip = - VmChipWrapper, LessThanCoreChip>; - #[derive(Clone, Copy, Default, PartialEq)] struct LessThanPrankValues { pub b_msb: Option, @@ -145,67 +172,29 @@ struct LessThanPrankValues { } #[allow(clippy::too_many_arguments)] -fn run_rv32_lt_negative_test( +fn run_negative_less_than_test( opcode: LessThanOpcode, - b: [u32; RV32_REGISTER_NUM_LIMBS], - c: [u32; RV32_REGISTER_NUM_LIMBS], - cmp_result: bool, + b: [u8; RV32_REGISTER_NUM_LIMBS], + c: [u8; RV32_REGISTER_NUM_LIMBS], + prank_cmp_result: bool, prank_vals: LessThanPrankValues, interaction_error: bool, ) { - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - + let mut rng = create_seeded_rng(); let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let mut chip = Rv32LessThanTestChip::::new( - TestAdapterChip::new( - vec![[b.map(F::from_canonical_u32), c.map(F::from_canonical_u32)].concat()], - vec![None], - ExecutionBridge::new(tester.execution_bus(), tester.program_bus()), - ), - LessThanCoreChip::new(bitwise_chip.clone(), LessThanOpcode::CLASS_OFFSET), - tester.offline_memory_mutex_arc(), - ); + let (mut chip, bitwise_chip) = create_test_chip(&tester); - tester.execute( + set_and_execute( + &mut tester, &mut chip, - &Instruction::from_usize(opcode.global_opcode(), [0, 0, 0, 1, 1]), + &mut rng, + opcode, + Some(b), + Some(false), + Some(c), ); - let trace_width = chip.trace_width(); - let adapter_width = BaseAir::::width(chip.adapter.air()); - let (_, _, b_sign, c_sign) = - run_less_than::(opcode, &b, &c); - - if prank_vals != LessThanPrankValues::default() { - debug_assert!(prank_vals.diff_val.is_some()); - let b_msb = prank_vals.b_msb.unwrap_or( - b[RV32_REGISTER_NUM_LIMBS - 1] as i32 - if b_sign { 1 << RV32_CELL_BITS } else { 0 }, - ); - let c_msb = prank_vals.c_msb.unwrap_or( - c[RV32_REGISTER_NUM_LIMBS - 1] as i32 - if c_sign { 1 << RV32_CELL_BITS } else { 0 }, - ); - let sign_offset = if opcode == LessThanOpcode::SLT { - 1 << (RV32_CELL_BITS - 1) - } else { - 0 - }; - - bitwise_chip.clear(); - bitwise_chip.request_range( - (b_msb + sign_offset) as u8 as u32, - (c_msb + sign_offset) as u8 as u32, - ); - - let diff_val = prank_vals - .diff_val - .unwrap() - .clamp(0, (1 << RV32_CELL_BITS) - 1); - if diff_val > 0 { - bitwise_chip.request_range(diff_val - 1, 0); - } - }; - + let adapter_width = BaseAir::::width(&chip.air.adapter); let modify_trace = |trace: &mut DenseMatrix| { let mut values = trace.row_slice(0).to_vec(); let cols: &mut LessThanCoreCols = @@ -223,9 +212,9 @@ fn run_rv32_lt_negative_test( if let Some(diff_val) = prank_vals.diff_val { cols.diff_val = F::from_canonical_u32(diff_val); } - cols.cmp_result = F::from_bool(cmp_result); + cols.cmp_result = F::from_bool(prank_cmp_result); - *trace = RowMajorMatrix::new(values, trace_width); + *trace = RowMajorMatrix::new(values, trace.width()); }; disable_debug_builder(); @@ -234,11 +223,7 @@ fn run_rv32_lt_negative_test( .load_and_prank_trace(chip, modify_trace) .load(bitwise_chip) .finalize(); - tester.simple_test_with_expected_error(if interaction_error { - VerificationError::ChallengePhaseError - } else { - VerificationError::OodEvaluationMismatch - }); + tester.simple_test_with_expected_error(get_verification_error(interaction_error)); } #[test] @@ -246,8 +231,8 @@ fn rv32_lt_wrong_false_cmp_negative_test() { let b = [145, 34, 25, 205]; let c = [73, 35, 25, 205]; let prank_vals = Default::default(); - run_rv32_lt_negative_test(LessThanOpcode::SLT, b, c, false, prank_vals, false); - run_rv32_lt_negative_test(LessThanOpcode::SLTU, b, c, false, prank_vals, false); + run_negative_less_than_test(SLT, b, c, false, prank_vals, false); + run_negative_less_than_test(SLTU, b, c, false, prank_vals, false); } #[test] @@ -255,8 +240,8 @@ fn rv32_lt_wrong_true_cmp_negative_test() { let b = [73, 35, 25, 205]; let c = [145, 34, 25, 205]; let prank_vals = Default::default(); - run_rv32_lt_negative_test(LessThanOpcode::SLT, b, c, true, prank_vals, false); - run_rv32_lt_negative_test(LessThanOpcode::SLTU, b, c, true, prank_vals, false); + run_negative_less_than_test(SLT, b, c, true, prank_vals, false); + run_negative_less_than_test(SLTU, b, c, true, prank_vals, false); } #[test] @@ -264,8 +249,8 @@ fn rv32_lt_wrong_eq_negative_test() { let b = [73, 35, 25, 205]; let c = [73, 35, 25, 205]; let prank_vals = Default::default(); - run_rv32_lt_negative_test(LessThanOpcode::SLT, b, c, true, prank_vals, false); - run_rv32_lt_negative_test(LessThanOpcode::SLTU, b, c, true, prank_vals, false); + run_negative_less_than_test(SLT, b, c, true, prank_vals, false); + run_negative_less_than_test(SLTU, b, c, true, prank_vals, false); } #[test] @@ -276,8 +261,8 @@ fn rv32_lt_fake_diff_val_negative_test() { diff_val: Some(F::NEG_ONE.as_canonical_u32()), ..Default::default() }; - run_rv32_lt_negative_test(LessThanOpcode::SLT, b, c, false, prank_vals, true); - run_rv32_lt_negative_test(LessThanOpcode::SLTU, b, c, false, prank_vals, true); + run_negative_less_than_test(SLT, b, c, false, prank_vals, true); + run_negative_less_than_test(SLTU, b, c, false, prank_vals, true); } #[test] @@ -289,8 +274,8 @@ fn rv32_lt_zero_diff_val_negative_test() { diff_val: Some(0), ..Default::default() }; - run_rv32_lt_negative_test(LessThanOpcode::SLT, b, c, false, prank_vals, true); - run_rv32_lt_negative_test(LessThanOpcode::SLTU, b, c, false, prank_vals, true); + run_negative_less_than_test(SLT, b, c, false, prank_vals, true); + run_negative_less_than_test(SLTU, b, c, false, prank_vals, true); } #[test] @@ -302,8 +287,8 @@ fn rv32_lt_fake_diff_marker_negative_test() { diff_val: Some(72), ..Default::default() }; - run_rv32_lt_negative_test(LessThanOpcode::SLT, b, c, false, prank_vals, false); - run_rv32_lt_negative_test(LessThanOpcode::SLTU, b, c, false, prank_vals, false); + run_negative_less_than_test(SLT, b, c, false, prank_vals, false); + run_negative_less_than_test(SLTU, b, c, false, prank_vals, false); } #[test] @@ -315,8 +300,8 @@ fn rv32_lt_zero_diff_marker_negative_test() { diff_val: Some(0), ..Default::default() }; - run_rv32_lt_negative_test(LessThanOpcode::SLT, b, c, false, prank_vals, false); - run_rv32_lt_negative_test(LessThanOpcode::SLTU, b, c, false, prank_vals, false); + run_negative_less_than_test(SLT, b, c, false, prank_vals, false); + run_negative_less_than_test(SLTU, b, c, false, prank_vals, false); } #[test] @@ -329,7 +314,7 @@ fn rv32_slt_wrong_b_msb_negative_test() { diff_val: Some(1), ..Default::default() }; - run_rv32_lt_negative_test(LessThanOpcode::SLT, b, c, false, prank_vals, false); + run_negative_less_than_test(SLT, b, c, false, prank_vals, false); } #[test] @@ -342,7 +327,7 @@ fn rv32_slt_wrong_b_msb_sign_negative_test() { diff_val: Some(256), ..Default::default() }; - run_rv32_lt_negative_test(LessThanOpcode::SLT, b, c, false, prank_vals, true); + run_negative_less_than_test(SLT, b, c, false, prank_vals, true); } #[test] @@ -355,7 +340,7 @@ fn rv32_slt_wrong_c_msb_negative_test() { diff_val: Some(1), ..Default::default() }; - run_rv32_lt_negative_test(LessThanOpcode::SLT, b, c, true, prank_vals, false); + run_negative_less_than_test(SLT, b, c, true, prank_vals, false); } #[test] @@ -368,7 +353,7 @@ fn rv32_slt_wrong_c_msb_sign_negative_test() { diff_val: Some(256), ..Default::default() }; - run_rv32_lt_negative_test(LessThanOpcode::SLT, b, c, true, prank_vals, true); + run_negative_less_than_test(SLT, b, c, true, prank_vals, true); } #[test] @@ -381,7 +366,7 @@ fn rv32_sltu_wrong_b_msb_negative_test() { diff_val: Some(1), ..Default::default() }; - run_rv32_lt_negative_test(LessThanOpcode::SLTU, b, c, true, prank_vals, false); + run_negative_less_than_test(SLTU, b, c, true, prank_vals, false); } #[test] @@ -394,7 +379,7 @@ fn rv32_sltu_wrong_b_msb_sign_negative_test() { diff_val: Some(256), ..Default::default() }; - run_rv32_lt_negative_test(LessThanOpcode::SLTU, b, c, true, prank_vals, true); + run_negative_less_than_test(SLTU, b, c, true, prank_vals, true); } #[test] @@ -407,7 +392,7 @@ fn rv32_sltu_wrong_c_msb_negative_test() { diff_val: Some(1), ..Default::default() }; - run_rv32_lt_negative_test(LessThanOpcode::SLTU, b, c, false, prank_vals, false); + run_negative_less_than_test(SLTU, b, c, false, prank_vals, false); } #[test] @@ -420,7 +405,7 @@ fn rv32_sltu_wrong_c_msb_sign_negative_test() { diff_val: Some(256), ..Default::default() }; - run_rv32_lt_negative_test(LessThanOpcode::SLTU, b, c, false, prank_vals, true); + run_negative_less_than_test(SLTU, b, c, false, prank_vals, true); } /////////////////////////////////////////////////////////////////////////////////////// @@ -431,10 +416,10 @@ fn rv32_sltu_wrong_c_msb_sign_negative_test() { #[test] fn run_sltu_sanity_test() { - let x: [u32; RV32_REGISTER_NUM_LIMBS] = [145, 34, 25, 205]; - let y: [u32; RV32_REGISTER_NUM_LIMBS] = [73, 35, 25, 205]; + let x: [u8; RV32_REGISTER_NUM_LIMBS] = [145, 34, 25, 205]; + let y: [u8; RV32_REGISTER_NUM_LIMBS] = [73, 35, 25, 205]; let (cmp_result, diff_idx, x_sign, y_sign) = - run_less_than::(LessThanOpcode::SLTU, &x, &y); + run_less_than::(false, &x, &y); assert!(cmp_result); assert_eq!(diff_idx, 1); assert!(!x_sign); // unsigned @@ -443,10 +428,10 @@ fn run_sltu_sanity_test() { #[test] fn run_slt_same_sign_sanity_test() { - let x: [u32; RV32_REGISTER_NUM_LIMBS] = [145, 34, 25, 205]; - let y: [u32; RV32_REGISTER_NUM_LIMBS] = [73, 35, 25, 205]; + let x: [u8; RV32_REGISTER_NUM_LIMBS] = [145, 34, 25, 205]; + let y: [u8; RV32_REGISTER_NUM_LIMBS] = [73, 35, 25, 205]; let (cmp_result, diff_idx, x_sign, y_sign) = - run_less_than::(LessThanOpcode::SLT, &x, &y); + run_less_than::(true, &x, &y); assert!(cmp_result); assert_eq!(diff_idx, 1); assert!(x_sign); // negative @@ -455,10 +440,10 @@ fn run_slt_same_sign_sanity_test() { #[test] fn run_slt_diff_sign_sanity_test() { - let x: [u32; RV32_REGISTER_NUM_LIMBS] = [45, 35, 25, 55]; - let y: [u32; RV32_REGISTER_NUM_LIMBS] = [173, 34, 25, 205]; + let x: [u8; RV32_REGISTER_NUM_LIMBS] = [45, 35, 25, 55]; + let y: [u8; RV32_REGISTER_NUM_LIMBS] = [173, 34, 25, 205]; let (cmp_result, diff_idx, x_sign, y_sign) = - run_less_than::(LessThanOpcode::SLT, &x, &y); + run_less_than::(true, &x, &y); assert!(!cmp_result); assert_eq!(diff_idx, 3); assert!(!x_sign); // positive @@ -467,9 +452,9 @@ fn run_slt_diff_sign_sanity_test() { #[test] fn run_less_than_equal_sanity_test() { - let x: [u32; RV32_REGISTER_NUM_LIMBS] = [45, 35, 25, 55]; + let x: [u8; RV32_REGISTER_NUM_LIMBS] = [45, 35, 25, 55]; let (cmp_result, diff_idx, x_sign, y_sign) = - run_less_than::(LessThanOpcode::SLT, &x, &x); + run_less_than::(true, &x, &x); assert!(!cmp_result); assert_eq!(diff_idx, RV32_REGISTER_NUM_LIMBS); assert!(!x_sign); // positive diff --git a/extensions/rv32im/circuit/src/load_sign_extend/core.rs b/extensions/rv32im/circuit/src/load_sign_extend/core.rs index 2284d6815c..088c1024bf 100644 --- a/extensions/rv32im/circuit/src/load_sign_extend/core.rs +++ b/extensions/rv32im/circuit/src/load_sign_extend/core.rs @@ -3,15 +3,29 @@ use std::{ borrow::{Borrow, BorrowMut}, }; -use openvm_circuit::arch::{ - AdapterAirContext, AdapterRuntimeContext, Result, VmAdapterInterface, VmCoreAir, VmCoreChip, +use openvm_circuit::{ + arch::{ + execution_mode::{E1ExecutionCtx, E2ExecutionCtx}, + get_record_from_slice, AdapterAirContext, AdapterTraceFiller, AdapterTraceStep, + E2PreCompute, EmptyAdapterCoreLayout, ExecuteFunc, + ExecutionError::{self, InvalidInstruction}, + RecordArena, Result, StepExecutorE1, StepExecutorE2, TraceFiller, TraceStep, + VmAdapterInterface, VmCoreAir, VmSegmentState, VmStateMut, + }, + system::memory::{online::TracingMemory, MemoryAuxColsFactory, POINTER_MAX_BITS}, }; use openvm_circuit_primitives::{ utils::select, var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus}, + AlignedBytesBorrow, }; use openvm_circuit_primitives_derive::AlignedBorrow; -use openvm_instructions::{instruction::Instruction, LocalOpcode}; +use openvm_instructions::{ + instruction::Instruction, + program::DEFAULT_PC_STEP, + riscv::{RV32_IMM_AS, RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS}, + LocalOpcode, +}; use openvm_rv32im_transpiler::Rv32LoadStoreOpcode::{self, *}; use openvm_stark_backend::{ interaction::InteractionBuilder, @@ -19,8 +33,6 @@ use openvm_stark_backend::{ p3_field::{Field, FieldAlgebra, PrimeField32}, rap::BaseAirWithPublicValues, }; -use serde::{de::DeserializeOwned, Deserialize, Serialize}; -use serde_big_array::BigArray; use crate::adapters::LoadStoreInstruction; @@ -46,20 +58,7 @@ pub struct LoadSignExtendCoreCols { pub prev_data: [T; NUM_CELLS], } -#[repr(C)] -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(bound = "F: Serialize + DeserializeOwned")] -pub struct LoadSignExtendCoreRecord { - #[serde(with = "BigArray")] - pub shifted_read_data: [F; NUM_CELLS], - #[serde(with = "BigArray")] - pub prev_data: [F; NUM_CELLS], - pub opcode: Rv32LoadStoreOpcode, - pub shift_amount: u32, - pub most_sig_bit: bool, -} - -#[derive(Debug, Clone)] +#[derive(Debug, Clone, derive_new::new)] pub struct LoadSignExtendCoreAir { pub range_bus: VariableRangeCheckerBus, } @@ -178,135 +177,345 @@ where } } -pub struct LoadSignExtendCoreChip { - pub air: LoadSignExtendCoreAir, - pub range_checker_chip: SharedVariableRangeCheckerChip, +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug)] +pub struct LoadSignExtendCoreRecord { + pub is_byte: bool, + pub shift_amount: u8, + pub read_data: [u8; NUM_CELLS], + pub prev_data: [u8; NUM_CELLS], } -impl LoadSignExtendCoreChip { - pub fn new(range_checker_chip: SharedVariableRangeCheckerChip) -> Self { - Self { - air: LoadSignExtendCoreAir:: { - range_bus: range_checker_chip.bus(), - }, - range_checker_chip, - } - } +#[derive(derive_new::new)] +pub struct LoadSignExtendStep { + adapter: A, + pub range_checker_chip: SharedVariableRangeCheckerChip, } -impl, const NUM_CELLS: usize, const LIMB_BITS: usize> - VmCoreChip for LoadSignExtendCoreChip +impl TraceStep + for LoadSignExtendStep where - I::Reads: Into<([[F; NUM_CELLS]; 2], F)>, - I::Writes: From<[[F; NUM_CELLS]; 1]>, + F: PrimeField32, + A: 'static + + for<'a> AdapterTraceStep< + F, + CTX, + ReadData = (([u32; NUM_CELLS], [u8; NUM_CELLS]), u8), + WriteData = [u32; NUM_CELLS], + >, { - type Record = LoadSignExtendCoreRecord; - type Air = LoadSignExtendCoreAir; + type RecordLayout = EmptyAdapterCoreLayout; + type RecordMut<'a> = ( + A::RecordMut<'a>, + &'a mut LoadSignExtendCoreRecord, + ); - #[allow(clippy::type_complexity)] - fn execute_instruction( - &self, + fn get_opcode_name(&self, opcode: usize) -> String { + format!( + "{:?}", + Rv32LoadStoreOpcode::from_usize(opcode - Rv32LoadStoreOpcode::CLASS_OFFSET) + ) + } + + fn execute<'buf, RA>( + &mut self, + state: VmStateMut, CTX>, instruction: &Instruction, - _from_pc: u32, - reads: I::Reads, - ) -> Result<(AdapterRuntimeContext, Self::Record)> { + arena: &'buf mut RA, + ) -> Result<()> + where + RA: RecordArena<'buf, Self::RecordLayout, Self::RecordMut<'buf>>, + { + let Instruction { opcode, .. } = instruction; + let local_opcode = Rv32LoadStoreOpcode::from_usize( - instruction - .opcode - .local_opcode_idx(Rv32LoadStoreOpcode::CLASS_OFFSET), + opcode.local_opcode_idx(Rv32LoadStoreOpcode::CLASS_OFFSET), ); - let (data, shift_amount) = reads.into(); - let shift_amount = shift_amount.as_canonical_u32(); - let write_data: [F; NUM_CELLS] = run_write_data_sign_extend::<_, NUM_CELLS, LIMB_BITS>( + let (mut adapter_record, core_record) = arena.alloc(EmptyAdapterCoreLayout::new()); + + A::start(*state.pc, state.memory, &mut adapter_record); + + let tmp = self + .adapter + .read(state.memory, instruction, &mut adapter_record); + + core_record.is_byte = local_opcode == LOADB; + core_record.prev_data = tmp.0 .0.map(|x| x as u8); + core_record.read_data = tmp.0 .1; + core_record.shift_amount = tmp.1; + + let write_data = run_write_data_sign_extend( local_opcode, - data[1], - data[0], - shift_amount, + core_record.read_data, + core_record.shift_amount as usize, ); - let output = AdapterRuntimeContext::without_pc([write_data]); - let most_sig_limb = match local_opcode { - LOADB => write_data[0], - LOADH => write_data[NUM_CELLS / 2 - 1], - _ => unreachable!(), - } - .as_canonical_u32(); + self.adapter.write( + state.memory, + instruction, + write_data.map(u32::from), + &mut adapter_record, + ); + + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + + Ok(()) + } +} + +impl TraceFiller + for LoadSignExtendStep +where + F: PrimeField32, + A: 'static + AdapterTraceFiller, +{ + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + self.adapter.fill_trace_row(mem_helper, adapter_row); + let record: &LoadSignExtendCoreRecord = + unsafe { get_record_from_slice(&mut core_row, ()) }; - let most_sig_bit = most_sig_limb & (1 << (LIMB_BITS - 1)); + let core_row: &mut LoadSignExtendCoreCols = core_row.borrow_mut(); + + let shift = record.shift_amount; + let most_sig_limb = if record.is_byte { + record.read_data[shift as usize] + } else { + record.read_data[NUM_CELLS / 2 - 1 + shift as usize] + }; + + let most_sig_bit = most_sig_limb & (1 << 7); self.range_checker_chip - .add_count(most_sig_limb - most_sig_bit, LIMB_BITS - 1); - - let read_shift = shift_amount & 2; - - Ok(( - output, - LoadSignExtendCoreRecord { - opcode: local_opcode, - most_sig_bit: most_sig_bit != 0, - prev_data: data[0], - shifted_read_data: array::from_fn(|i| { - data[1][(i + read_shift as usize) % NUM_CELLS] - }), - shift_amount, - }, - )) + .add_count((most_sig_limb - most_sig_bit) as u32, 7); + + core_row.prev_data = record.prev_data.map(F::from_canonical_u8); + core_row.shifted_read_data = record.read_data.map(F::from_canonical_u8); + core_row.shifted_read_data.rotate_left((shift & 2) as usize); + + core_row.data_most_sig_bit = F::from_bool(most_sig_bit != 0); + core_row.shift_most_sig_bit = F::from_bool(shift & 2 == 2); + core_row.opcode_loadh_flag = F::from_bool(!record.is_byte); + core_row.opcode_loadb_flag1 = F::from_bool(record.is_byte && ((shift & 1) == 1)); + core_row.opcode_loadb_flag0 = F::from_bool(record.is_byte && ((shift & 1) == 0)); } +} - fn get_opcode_name(&self, opcode: usize) -> String { - format!( - "{:?}", - Rv32LoadStoreOpcode::from_usize(opcode - Rv32LoadStoreOpcode::CLASS_OFFSET) - ) +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct LoadSignExtendPreCompute { + imm_extended: u32, + a: u8, + b: u8, + e: u8, +} + +impl StepExecutorE1 + for LoadSignExtendStep +where + F: PrimeField32, +{ + fn pre_compute_size(&self) -> usize { + size_of::() } - fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) { - let core_cols: &mut LoadSignExtendCoreCols = row_slice.borrow_mut(); - let opcode = record.opcode; - let shift = record.shift_amount; - core_cols.opcode_loadb_flag0 = F::from_bool(opcode == LOADB && (shift & 1) == 0); - core_cols.opcode_loadb_flag1 = F::from_bool(opcode == LOADB && (shift & 1) == 1); - core_cols.opcode_loadh_flag = F::from_bool(opcode == LOADH); - core_cols.shift_most_sig_bit = F::from_canonical_u32((shift & 2) >> 1); - core_cols.data_most_sig_bit = F::from_bool(record.most_sig_bit); - core_cols.prev_data = record.prev_data; - core_cols.shifted_read_data = record.shifted_read_data; + #[inline(always)] + fn pre_compute_e1( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result> { + let pre_compute: &mut LoadSignExtendPreCompute = data.borrow_mut(); + let (is_loadb, enabled) = self.pre_compute_impl(pc, inst, pre_compute)?; + let fn_ptr = match (is_loadb, enabled) { + (true, true) => execute_e1_impl::<_, _, true, true>, + (true, false) => execute_e1_impl::<_, _, true, false>, + (false, true) => execute_e1_impl::<_, _, false, true>, + (false, false) => execute_e1_impl::<_, _, false, false>, + }; + Ok(fn_ptr) + } +} + +impl StepExecutorE2 + for LoadSignExtendStep +where + F: PrimeField32, +{ + fn e2_pre_compute_size(&self) -> usize { + size_of::>() + } + + fn pre_compute_e2( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result> + where + Ctx: E2ExecutionCtx, + { + let pre_compute: &mut E2PreCompute = data.borrow_mut(); + pre_compute.chip_idx = chip_idx as u32; + let (is_loadb, enabled) = self.pre_compute_impl(pc, inst, &mut pre_compute.data)?; + let fn_ptr = match (is_loadb, enabled) { + (true, true) => execute_e2_impl::<_, _, true, true>, + (true, false) => execute_e2_impl::<_, _, true, false>, + (false, true) => execute_e2_impl::<_, _, false, true>, + (false, false) => execute_e2_impl::<_, _, false, false>, + }; + Ok(fn_ptr) } +} + +#[inline(always)] +unsafe fn execute_e12_impl< + F: PrimeField32, + CTX: E1ExecutionCtx, + const IS_LOADB: bool, + const ENABLED: bool, +>( + pre_compute: &LoadSignExtendPreCompute, + vm_state: &mut VmSegmentState, +) { + let rs1_bytes: [u8; RV32_REGISTER_NUM_LIMBS] = + vm_state.vm_read(RV32_REGISTER_AS, pre_compute.b as u32); + let rs1_val = u32::from_le_bytes(rs1_bytes); + let ptr_val = rs1_val.wrapping_add(pre_compute.imm_extended); + // sign_extend([r32{c,g}(b):2]_e)` + debug_assert!(ptr_val < (1 << POINTER_MAX_BITS)); + let shift_amount = ptr_val % 4; + let ptr_val = ptr_val - shift_amount; // aligned ptr + + let read_data: [u8; RV32_REGISTER_NUM_LIMBS] = vm_state.vm_read(pre_compute.e as u32, ptr_val); + + let write_data = if IS_LOADB { + let byte = read_data[shift_amount as usize]; + let sign_extended = (byte as i8) as i32; + sign_extended.to_le_bytes() + } else { + if shift_amount != 0 && shift_amount != 2 { + vm_state.exit_code = Err(ExecutionError::Fail { pc: vm_state.pc }); + return; + } + let half: [u8; 2] = array::from_fn(|i| read_data[shift_amount as usize + i]); + (i16::from_le_bytes(half) as i32).to_le_bytes() + }; - fn air(&self) -> &Self::Air { - &self.air + if ENABLED { + vm_state.vm_write(RV32_REGISTER_AS, pre_compute.a as u32, &write_data); } + + vm_state.pc += DEFAULT_PC_STEP; + vm_state.instret += 1; +} + +unsafe fn execute_e1_impl< + F: PrimeField32, + CTX: E1ExecutionCtx, + const IS_LOADB: bool, + const ENABLED: bool, +>( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &LoadSignExtendPreCompute = pre_compute.borrow(); + execute_e12_impl::(pre_compute, vm_state); } -pub(super) fn run_write_data_sign_extend< +unsafe fn execute_e2_impl< F: PrimeField32, - const NUM_CELLS: usize, - const LIMB_BITS: usize, + CTX: E2ExecutionCtx, + const IS_LOADB: bool, + const ENABLED: bool, >( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &E2PreCompute = pre_compute.borrow(); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, 1); + execute_e12_impl::(&pre_compute.data, vm_state); +} + +impl LoadSignExtendStep { + /// Return (is_loadb, enabled) + fn pre_compute_impl( + &self, + pc: u32, + inst: &Instruction, + data: &mut LoadSignExtendPreCompute, + ) -> Result<(bool, bool)> { + let Instruction { + opcode, + a, + b, + c, + d, + e, + f, + g, + .. + } = inst; + + let e_u32 = e.as_canonical_u32(); + if d.as_canonical_u32() != RV32_REGISTER_AS || e_u32 == RV32_IMM_AS { + return Err(InvalidInstruction(pc)); + } + + let local_opcode = Rv32LoadStoreOpcode::from_usize( + opcode.local_opcode_idx(Rv32LoadStoreOpcode::CLASS_OFFSET), + ); + match local_opcode { + LOADB | LOADH => {} + _ => unreachable!("LoadSignExtendStep should only handle LOADB/LOADH opcodes"), + } + + let imm = c.as_canonical_u32(); + let imm_sign = g.as_canonical_u32(); + let imm_extended = imm + imm_sign * 0xffff0000; + + *data = LoadSignExtendPreCompute { + imm_extended, + a: a.as_canonical_u32() as u8, + b: b.as_canonical_u32() as u8, + e: e_u32 as u8, + }; + let enabled = !f.is_zero(); + Ok((local_opcode == LOADB, enabled)) + } +} + +// Returns write_data +#[inline(always)] +pub(super) fn run_write_data_sign_extend( opcode: Rv32LoadStoreOpcode, - read_data: [F; NUM_CELLS], - _prev_data: [F; NUM_CELLS], - shift: u32, -) -> [F; NUM_CELLS] { - let shift = shift as usize; - let mut write_data = read_data; + read_data: [u8; NUM_CELLS], + shift: usize, +) -> [u8; NUM_CELLS] { match (opcode, shift) { (LOADH, 0) | (LOADH, 2) => { - let ext = read_data[NUM_CELLS / 2 - 1 + shift].as_canonical_u32(); - let ext = (ext >> (LIMB_BITS - 1)) * ((1 << LIMB_BITS) - 1); - for cell in write_data.iter_mut().take(NUM_CELLS).skip(NUM_CELLS / 2) { - *cell = F::from_canonical_u32(ext); - } - write_data[0..NUM_CELLS / 2] - .copy_from_slice(&read_data[shift..(NUM_CELLS / 2 + shift)]); + let ext = (read_data[NUM_CELLS / 2 - 1 + shift] >> 7) * u8::MAX; + array::from_fn(|i| { + if i < NUM_CELLS / 2 { + read_data[i + shift] + } else { + ext + } + }) } (LOADB, 0) | (LOADB, 1) | (LOADB, 2) | (LOADB, 3) => { - let ext = read_data[shift].as_canonical_u32(); - let ext = (ext >> (LIMB_BITS - 1)) * ((1 << LIMB_BITS) - 1); - for cell in write_data.iter_mut().take(NUM_CELLS).skip(1) { - *cell = F::from_canonical_u32(ext); - } - write_data[0] = read_data[shift]; + let ext = (read_data[shift] >> 7) * u8::MAX; + array::from_fn(|i| { + if i == 0 { + read_data[i + shift] + } else { + ext + } + }) } // Currently the adapter AIR requires `ptr_val` to be aligned to the data size in bytes. // The circuit requires that `shift = ptr_val % 4` so that `ptr_val - shift` is a multiple of 4. @@ -314,6 +523,5 @@ pub(super) fn run_write_data_sign_extend< _ => unreachable!( "unaligned memory access not supported by this execution environment: {opcode:?}, shift: {shift}" ), - }; - write_data + } } diff --git a/extensions/rv32im/circuit/src/load_sign_extend/mod.rs b/extensions/rv32im/circuit/src/load_sign_extend/mod.rs index 79efbe912e..3c02546e7c 100644 --- a/extensions/rv32im/circuit/src/load_sign_extend/mod.rs +++ b/extensions/rv32im/circuit/src/load_sign_extend/mod.rs @@ -1,7 +1,7 @@ -use openvm_circuit::arch::VmChipWrapper; +use openvm_circuit::arch::{MatrixRecordArena, NewVmChipWrapper, VmAirWrapper}; use super::adapters::{RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}; -use crate::adapters::Rv32LoadStoreAdapterChip; +use crate::adapters::{Rv32LoadStoreAdapterAir, Rv32LoadStoreAdapterStep}; mod core; pub use core::*; @@ -9,8 +9,11 @@ pub use core::*; #[cfg(test)] mod tests; -pub type Rv32LoadSignExtendChip = VmChipWrapper< - F, - Rv32LoadStoreAdapterChip, - LoadSignExtendCoreChip, +pub type Rv32LoadSignExtendAir = VmAirWrapper< + Rv32LoadStoreAdapterAir, + LoadSignExtendCoreAir, >; +pub type Rv32LoadSignExtendStep = + LoadSignExtendStep; +pub type Rv32LoadSignExtendChip = + NewVmChipWrapper>; diff --git a/extensions/rv32im/circuit/src/load_sign_extend/tests.rs b/extensions/rv32im/circuit/src/load_sign_extend/tests.rs index 0fe6d859d1..7057e2511b 100644 --- a/extensions/rv32im/circuit/src/load_sign_extend/tests.rs +++ b/extensions/rv32im/circuit/src/load_sign_extend/tests.rs @@ -2,7 +2,7 @@ use std::{array, borrow::BorrowMut}; use openvm_circuit::arch::{ testing::{memory::gen_pointer, VmChipTestBuilder}, - VmAdapterChip, + VmAirWrapper, }; use openvm_instructions::{instruction::Instruction, LocalOpcode}; use openvm_rv32im_transpiler::Rv32LoadStoreOpcode::{self, *}; @@ -14,24 +14,45 @@ use openvm_stark_backend::{ Matrix, }, utils::disable_debug_builder, - verifier::VerificationError, }; -use openvm_stark_sdk::{config::setup_tracing, p3_baby_bear::BabyBear, utils::create_seeded_rng}; +use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; use rand::{rngs::StdRng, Rng}; +use test_case::test_case; -use super::run_write_data_sign_extend; +use super::{run_write_data_sign_extend, LoadSignExtendCoreAir}; use crate::{ - adapters::{compose, Rv32LoadStoreAdapterChip, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}, + adapters::{Rv32LoadStoreAdapterAir, Rv32LoadStoreAdapterStep, RV32_REGISTER_NUM_LIMBS}, load_sign_extend::LoadSignExtendCoreCols, - LoadSignExtendCoreChip, Rv32LoadSignExtendChip, + test_utils::get_verification_error, + LoadSignExtendStep, Rv32LoadSignExtendChip, }; const IMM_BITS: usize = 16; +const MAX_INS_CAPACITY: usize = 128; type F = BabyBear; -fn into_limbs(num: u32) -> [u32; NUM_LIMBS] { - array::from_fn(|i| (num >> (LIMB_BITS * i)) & ((1 << LIMB_BITS) - 1)) +fn create_test_chip(tester: &mut VmChipTestBuilder) -> Rv32LoadSignExtendChip { + let range_checker_chip = tester.memory_controller().range_checker.clone(); + let mut chip = Rv32LoadSignExtendChip::::new( + VmAirWrapper::new( + Rv32LoadStoreAdapterAir::new( + tester.memory_bridge(), + tester.execution_bridge(), + range_checker_chip.bus(), + tester.address_bits(), + ), + LoadSignExtendCoreAir::new(range_checker_chip.bus()), + ), + LoadSignExtendStep::new( + Rv32LoadStoreAdapterStep::new(tester.address_bits(), range_checker_chip.clone()), + range_checker_chip.clone(), + ), + tester.memory_helper(), + ); + chip.set_trace_buffer_height(MAX_INS_CAPACITY); + + chip } #[allow(clippy::too_many_arguments)] @@ -40,53 +61,44 @@ fn set_and_execute( chip: &mut Rv32LoadSignExtendChip, rng: &mut StdRng, opcode: Rv32LoadStoreOpcode, - read_data: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, - rs1: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, + read_data: Option<[u8; RV32_REGISTER_NUM_LIMBS]>, + rs1: Option<[u8; RV32_REGISTER_NUM_LIMBS]>, imm: Option, imm_sign: Option, ) { let imm = imm.unwrap_or(rng.gen_range(0..(1 << IMM_BITS))); let imm_sign = imm_sign.unwrap_or(rng.gen_range(0..2)); - let imm_ext = imm + imm_sign * (0xffffffff ^ ((1 << IMM_BITS) - 1)); + let imm_ext = imm + imm_sign * (0xffff0000); let alignment = match opcode { LOADB => 0, LOADH => 1, _ => unreachable!(), }; - let ptr_val = rng.gen_range( - 0..(1 - << (tester - .memory_controller() - .borrow() - .mem_config() - .pointer_max_bits - - alignment)), - ) << alignment; - - let rs1 = rs1 - .unwrap_or(into_limbs::( - (ptr_val as u32).wrapping_sub(imm_ext), - )) - .map(F::from_canonical_u32); + + let ptr_val: u32 = rng.gen_range(0..(1 << (tester.address_bits() - alignment))) << alignment; + let rs1 = rs1.unwrap_or(ptr_val.wrapping_sub(imm_ext).to_le_bytes()); + let ptr_val = imm_ext.wrapping_add(u32::from_le_bytes(rs1)); let a = gen_pointer(rng, 4); let b = gen_pointer(rng, 4); - let ptr_val = imm_ext.wrapping_add(compose(rs1)); let shift_amount = ptr_val % 4; - tester.write(1, b, rs1); + tester.write(1, b, rs1.map(F::from_canonical_u8)); let some_prev_data: [F; RV32_REGISTER_NUM_LIMBS] = if a != 0 { - array::from_fn(|_| F::from_canonical_u32(rng.gen_range(0..(1 << RV32_CELL_BITS)))) + array::from_fn(|_| F::from_canonical_u8(rng.gen())) } else { [F::ZERO; RV32_REGISTER_NUM_LIMBS] }; - let read_data: [F; RV32_REGISTER_NUM_LIMBS] = read_data - .unwrap_or(array::from_fn(|_| rng.gen_range(0..(1 << RV32_CELL_BITS)))) - .map(F::from_canonical_u32); + let read_data: [u8; RV32_REGISTER_NUM_LIMBS] = + read_data.unwrap_or(array::from_fn(|_| rng.gen())); tester.write(1, a, some_prev_data); - tester.write(2, (ptr_val - shift_amount) as usize, read_data); + tester.write( + 2, + (ptr_val - shift_amount) as usize, + read_data.map(F::from_canonical_u8), + ); tester.execute( chip, @@ -104,16 +116,11 @@ fn set_and_execute( ), ); - let write_data = run_write_data_sign_extend::<_, RV32_REGISTER_NUM_LIMBS, RV32_CELL_BITS>( - opcode, - read_data, - some_prev_data, - shift_amount, - ); + let write_data = run_write_data_sign_extend(opcode, read_data, shift_amount as usize); if a != 0 { - assert_eq!(write_data, tester.read::<4>(1, a)); + assert_eq!(write_data.map(F::from_canonical_u8), tester.read::<4>(1, a)); } else { - assert_eq!([F::ZERO; RV32_REGISTER_NUM_LIMBS], tester.read::<4>(1, a)); + assert_eq!([F::ZERO; 4], tester.read::<4>(1, a)); } } @@ -123,40 +130,19 @@ fn set_and_execute( /// Randomly generate computations and execute, ensuring that the generated trace /// passes all constraints. /////////////////////////////////////////////////////////////////////////////////////// -#[test] -fn rand_load_sign_extend_test() { - setup_tracing(); +#[test_case(LOADB, 100)] +#[test_case(LOADH, 100)] +fn rand_load_sign_extend_test(opcode: Rv32LoadStoreOpcode, num_ops: usize) { let mut rng = create_seeded_rng(); let mut tester = VmChipTestBuilder::default(); - let range_checker_chip = tester.memory_controller().borrow().range_checker.clone(); - let adapter = Rv32LoadStoreAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - range_checker_chip.clone(), - ); - let core = LoadSignExtendCoreChip::new(range_checker_chip); - let mut chip = - Rv32LoadSignExtendChip::::new(adapter, core, tester.offline_memory_mutex_arc()); - let num_tests: usize = 100; - for _ in 0..num_tests { - set_and_execute( - &mut tester, - &mut chip, - &mut rng, - LOADB, - None, - None, - None, - None, - ); + let mut chip = create_test_chip(&mut tester); + for _ in 0..num_ops { set_and_execute( &mut tester, &mut chip, &mut rng, - LOADH, + opcode, None, None, None, @@ -172,36 +158,29 @@ fn rand_load_sign_extend_test() { // NEGATIVE TESTS // // Given a fake trace of a single operation, setup a chip and run the test. We replace -// the write part of the trace and check that the core chip throws the expected error. -// A dummy adaptor is used so memory interactions don't indirectly cause false passes. +// part of the trace and check that the chip throws the expected error. ////////////////////////////////////////////////////////////////////////////////////// -#[allow(clippy::too_many_arguments)] -fn run_negative_loadstore_test( - opcode: Rv32LoadStoreOpcode, - read_data: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, +#[derive(Clone, Copy, Default, PartialEq)] +struct LoadSignExtPrankValues { data_most_sig_bit: Option, shift_most_sig_bit: Option, opcode_flags: Option<[bool; 3]>, - rs1: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, +} + +#[allow(clippy::too_many_arguments)] +fn run_negative_load_sign_extend_test( + opcode: Rv32LoadStoreOpcode, + read_data: Option<[u8; RV32_REGISTER_NUM_LIMBS]>, + rs1: Option<[u8; RV32_REGISTER_NUM_LIMBS]>, imm: Option, imm_sign: Option, - expected_error: VerificationError, + prank_vals: LoadSignExtPrankValues, + interaction_error: bool, ) { let mut rng = create_seeded_rng(); let mut tester = VmChipTestBuilder::default(); - let range_checker_chip = tester.memory_controller().borrow().range_checker.clone(); - let adapter = Rv32LoadStoreAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - range_checker_chip.clone(), - ); - let core = LoadSignExtendCoreChip::new(range_checker_chip.clone()); - let adapter_width = BaseAir::::width(adapter.air()); - let mut chip = - Rv32LoadSignExtendChip::::new(adapter, core, tester.offline_memory_mutex_arc()); + let mut chip = create_test_chip(&mut tester); set_and_execute( &mut tester, @@ -214,78 +193,78 @@ fn run_negative_loadstore_test( imm_sign, ); + let adapter_width = BaseAir::::width(&chip.air.adapter); let modify_trace = |trace: &mut DenseMatrix| { let mut trace_row = trace.row_slice(0).to_vec(); - let (_, core_row) = trace_row.split_at_mut(adapter_width); let core_cols: &mut LoadSignExtendCoreCols = core_row.borrow_mut(); - if let Some(shifted_read_data) = read_data { - core_cols.shifted_read_data = shifted_read_data.map(F::from_canonical_u32); + core_cols.shifted_read_data = shifted_read_data.map(F::from_canonical_u8); } - - if let Some(data_most_sig_bit) = data_most_sig_bit { + if let Some(data_most_sig_bit) = prank_vals.data_most_sig_bit { core_cols.data_most_sig_bit = F::from_canonical_u32(data_most_sig_bit); } - if let Some(shift_most_sig_bit) = shift_most_sig_bit { + if let Some(shift_most_sig_bit) = prank_vals.shift_most_sig_bit { core_cols.shift_most_sig_bit = F::from_canonical_u32(shift_most_sig_bit); } - - if let Some(opcode_flags) = opcode_flags { + if let Some(opcode_flags) = prank_vals.opcode_flags { core_cols.opcode_loadb_flag0 = F::from_bool(opcode_flags[0]); core_cols.opcode_loadb_flag1 = F::from_bool(opcode_flags[1]); core_cols.opcode_loadh_flag = F::from_bool(opcode_flags[2]); } + *trace = RowMajorMatrix::new(trace_row, trace.width()); }; - drop(range_checker_chip); disable_debug_builder(); let tester = tester .build() .load_and_prank_trace(chip, modify_trace) .finalize(); - tester.simple_test_with_expected_error(expected_error); + tester.simple_test_with_expected_error(get_verification_error(interaction_error)); } #[test] fn loadstore_negative_tests() { - run_negative_loadstore_test( + run_negative_load_sign_extend_test( LOADB, Some([233, 187, 145, 238]), - Some(0), - None, None, None, None, - None, - VerificationError::ChallengePhaseError, + LoadSignExtPrankValues { + data_most_sig_bit: Some(0), + ..Default::default() + }, + true, ); - run_negative_loadstore_test( + run_negative_load_sign_extend_test( LOADH, None, - None, - Some(0), - None, Some([202, 109, 183, 26]), Some(31212), None, - VerificationError::ChallengePhaseError, + LoadSignExtPrankValues { + shift_most_sig_bit: Some(0), + ..Default::default() + }, + true, ); - run_negative_loadstore_test( + run_negative_load_sign_extend_test( LOADB, None, - None, - None, - Some([true, false, false]), Some([250, 132, 77, 5]), Some(47741), None, - VerificationError::ChallengePhaseError, + LoadSignExtPrankValues { + opcode_flags: Some([true, false, false]), + ..Default::default() + }, + true, ); } @@ -294,119 +273,51 @@ fn loadstore_negative_tests() { /// /// Ensure that solve functions produce the correct results. /////////////////////////////////////////////////////////////////////////////////////// -#[test] -fn execute_roundtrip_sanity_test() { - let mut rng = create_seeded_rng(); - let mut tester = VmChipTestBuilder::default(); - let range_checker_chip = tester.memory_controller().borrow().range_checker.clone(); - let adapter = Rv32LoadStoreAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - range_checker_chip.clone(), - ); - let core = LoadSignExtendCoreChip::new(range_checker_chip); - let mut chip = - Rv32LoadSignExtendChip::::new(adapter, core, tester.offline_memory_mutex_arc()); - - let num_tests: usize = 10; - for _ in 0..num_tests { - set_and_execute( - &mut tester, - &mut chip, - &mut rng, - LOADB, - None, - None, - None, - None, - ); - set_and_execute( - &mut tester, - &mut chip, - &mut rng, - LOADH, - None, - None, - None, - None, - ); - } -} #[test] fn solve_loadh_extend_sign_sanity_test() { - let read_data = [34, 159, 237, 151].map(F::from_canonical_u32); - let prev_data = [94, 183, 56, 241].map(F::from_canonical_u32); - let write_data0 = run_write_data_sign_extend::<_, RV32_REGISTER_NUM_LIMBS, RV32_CELL_BITS>( - LOADH, read_data, prev_data, 0, - ); - let write_data2 = run_write_data_sign_extend::<_, RV32_REGISTER_NUM_LIMBS, RV32_CELL_BITS>( - LOADH, read_data, prev_data, 2, - ); + let read_data = [34, 159, 237, 151]; + let write_data0 = run_write_data_sign_extend::(LOADH, read_data, 0); + let write_data2 = run_write_data_sign_extend::(LOADH, read_data, 2); - assert_eq!(write_data0, [34, 159, 255, 255].map(F::from_canonical_u32)); - assert_eq!(write_data2, [237, 151, 255, 255].map(F::from_canonical_u32)); + assert_eq!(write_data0, [34, 159, 255, 255]); + assert_eq!(write_data2, [237, 151, 255, 255]); } #[test] fn solve_loadh_extend_zero_sanity_test() { - let read_data = [34, 121, 237, 97].map(F::from_canonical_u32); - let prev_data = [94, 183, 56, 241].map(F::from_canonical_u32); - let write_data0 = run_write_data_sign_extend::<_, RV32_REGISTER_NUM_LIMBS, RV32_CELL_BITS>( - LOADH, read_data, prev_data, 0, - ); - let write_data2 = run_write_data_sign_extend::<_, RV32_REGISTER_NUM_LIMBS, RV32_CELL_BITS>( - LOADH, read_data, prev_data, 2, - ); + let read_data = [34, 121, 237, 97]; + let write_data0 = run_write_data_sign_extend::(LOADH, read_data, 0); + let write_data2 = run_write_data_sign_extend::(LOADH, read_data, 2); - assert_eq!(write_data0, [34, 121, 0, 0].map(F::from_canonical_u32)); - assert_eq!(write_data2, [237, 97, 0, 0].map(F::from_canonical_u32)); + assert_eq!(write_data0, [34, 121, 0, 0]); + assert_eq!(write_data2, [237, 97, 0, 0]); } #[test] fn solve_loadb_extend_sign_sanity_test() { - let read_data = [45, 82, 99, 127].map(F::from_canonical_u32); - let prev_data = [53, 180, 29, 244].map(F::from_canonical_u32); - let write_data0 = run_write_data_sign_extend::<_, RV32_REGISTER_NUM_LIMBS, RV32_CELL_BITS>( - LOADB, read_data, prev_data, 0, - ); - let write_data1 = run_write_data_sign_extend::<_, RV32_REGISTER_NUM_LIMBS, RV32_CELL_BITS>( - LOADB, read_data, prev_data, 1, - ); - let write_data2 = run_write_data_sign_extend::<_, RV32_REGISTER_NUM_LIMBS, RV32_CELL_BITS>( - LOADB, read_data, prev_data, 2, - ); - let write_data3 = run_write_data_sign_extend::<_, RV32_REGISTER_NUM_LIMBS, RV32_CELL_BITS>( - LOADB, read_data, prev_data, 3, - ); - - assert_eq!(write_data0, [45, 0, 0, 0].map(F::from_canonical_u32)); - assert_eq!(write_data1, [82, 0, 0, 0].map(F::from_canonical_u32)); - assert_eq!(write_data2, [99, 0, 0, 0].map(F::from_canonical_u32)); - assert_eq!(write_data3, [127, 0, 0, 0].map(F::from_canonical_u32)); + let read_data = [45, 82, 99, 127]; + let write_data0 = run_write_data_sign_extend::(LOADB, read_data, 0); + let write_data1 = run_write_data_sign_extend::(LOADB, read_data, 1); + let write_data2 = run_write_data_sign_extend::(LOADB, read_data, 2); + let write_data3 = run_write_data_sign_extend::(LOADB, read_data, 3); + + assert_eq!(write_data0, [45, 0, 0, 0]); + assert_eq!(write_data1, [82, 0, 0, 0]); + assert_eq!(write_data2, [99, 0, 0, 0]); + assert_eq!(write_data3, [127, 0, 0, 0]); } #[test] fn solve_loadb_extend_zero_sanity_test() { - let read_data = [173, 210, 227, 255].map(F::from_canonical_u32); - let prev_data = [53, 180, 29, 244].map(F::from_canonical_u32); - let write_data0 = run_write_data_sign_extend::<_, RV32_REGISTER_NUM_LIMBS, RV32_CELL_BITS>( - LOADB, read_data, prev_data, 0, - ); - let write_data1 = run_write_data_sign_extend::<_, RV32_REGISTER_NUM_LIMBS, RV32_CELL_BITS>( - LOADB, read_data, prev_data, 1, - ); - let write_data2 = run_write_data_sign_extend::<_, RV32_REGISTER_NUM_LIMBS, RV32_CELL_BITS>( - LOADB, read_data, prev_data, 2, - ); - let write_data3 = run_write_data_sign_extend::<_, RV32_REGISTER_NUM_LIMBS, RV32_CELL_BITS>( - LOADB, read_data, prev_data, 3, - ); - - assert_eq!(write_data0, [173, 255, 255, 255].map(F::from_canonical_u32)); - assert_eq!(write_data1, [210, 255, 255, 255].map(F::from_canonical_u32)); - assert_eq!(write_data2, [227, 255, 255, 255].map(F::from_canonical_u32)); - assert_eq!(write_data3, [255, 255, 255, 255].map(F::from_canonical_u32)); + let read_data = [173, 210, 227, 255]; + let write_data0 = run_write_data_sign_extend::(LOADB, read_data, 0); + let write_data1 = run_write_data_sign_extend::(LOADB, read_data, 1); + let write_data2 = run_write_data_sign_extend::(LOADB, read_data, 2); + let write_data3 = run_write_data_sign_extend::(LOADB, read_data, 3); + + assert_eq!(write_data0, [173, 255, 255, 255]); + assert_eq!(write_data1, [210, 255, 255, 255]); + assert_eq!(write_data2, [227, 255, 255, 255]); + assert_eq!(write_data3, [255, 255, 255, 255]); } diff --git a/extensions/rv32im/circuit/src/loadstore/core.rs b/extensions/rv32im/circuit/src/loadstore/core.rs index 36beb10629..e8e9a19909 100644 --- a/extensions/rv32im/circuit/src/loadstore/core.rs +++ b/extensions/rv32im/circuit/src/loadstore/core.rs @@ -1,10 +1,20 @@ -use std::borrow::{Borrow, BorrowMut}; +use std::{ + array, + borrow::{Borrow, BorrowMut}, + fmt::Debug, +}; -use openvm_circuit::arch::{ - AdapterAirContext, AdapterRuntimeContext, Result, VmAdapterInterface, VmCoreAir, VmCoreChip, +use openvm_circuit::{ + arch::{ + execution_mode::E1ExecutionCtx, AdapterAirContext, AdapterTraceStep, Result, + StepExecutorE1, TraceStep, VmAdapterInterface, VmCoreAir, VmStateMut, + }, + system::memory::{online::TracingMemory, MemoryAuxColsFactory, POINTER_MAX_BITS}, +}; +use openvm_circuit_primitives_derive::{AlignedBorrow, AlignedBytesBorrow}; +use openvm_instructions::{ + instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode, NATIVE_AS, }; -use openvm_circuit_primitives_derive::AlignedBorrow; -use openvm_instructions::{instruction::Instruction, LocalOpcode}; use openvm_rv32im_transpiler::Rv32LoadStoreOpcode::{self, *}; use openvm_stark_backend::{ interaction::InteractionBuilder, @@ -12,8 +22,6 @@ use openvm_stark_backend::{ p3_field::{Field, FieldAlgebra, PrimeField32}, rap::BaseAirWithPublicValues, }; -use serde::{de::DeserializeOwned, Deserialize, Serialize}; -use serde_big_array::BigArray; use crate::adapters::LoadStoreInstruction; @@ -35,6 +43,12 @@ enum InstructionOpcode { StoreB3, } +use openvm_circuit::arch::{ + execution_mode::E2ExecutionCtx, get_record_from_slice, AdapterTraceFiller, E2PreCompute, + EmptyAdapterCoreLayout, ExecuteFunc, ExecutionError, ExecutionError::InvalidInstruction, + RecordArena, StepExecutorE2, TraceFiller, VmSegmentState, +}; +use openvm_instructions::riscv::{RV32_IMM_AS, RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS}; use InstructionOpcode::*; /// LoadStore Core Chip handles byte/halfword into word conversions and unsigned extends @@ -56,21 +70,7 @@ pub struct LoadStoreCoreCols { pub write_data: [T; NUM_CELLS], } -#[repr(C)] -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(bound = "F: Serialize + DeserializeOwned")] -pub struct LoadStoreCoreRecord { - pub opcode: Rv32LoadStoreOpcode, - pub shift: u32, - #[serde(with = "BigArray")] - pub read_data: [F; NUM_CELLS], - #[serde(with = "BigArray")] - pub prev_data: [F; NUM_CELLS], - #[serde(with = "BigArray")] - pub write_data: [F; NUM_CELLS], -} - -#[derive(Debug, Clone)] +#[derive(Debug, Clone, derive_new::new)] pub struct LoadStoreCoreAir { pub offset: usize, } @@ -246,70 +246,109 @@ where } } -#[derive(Debug)] -pub struct LoadStoreCoreChip { - pub air: LoadStoreCoreAir, +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug)] +pub struct LoadStoreCoreRecord { + pub local_opcode: u8, + pub shift_amount: u8, + pub read_data: [u8; NUM_CELLS], + // Note: `prev_data` can be from native address space, so we need to use u32 + pub prev_data: [u32; NUM_CELLS], } -impl LoadStoreCoreChip { - pub fn new(offset: usize) -> Self { - Self { - air: LoadStoreCoreAir { offset }, - } - } +#[derive(derive_new::new)] +pub struct LoadStoreStep { + adapter: A, + pub offset: usize, } -impl, const NUM_CELLS: usize> VmCoreChip - for LoadStoreCoreChip +impl TraceStep for LoadStoreStep where - I::Reads: Into<([[F; NUM_CELLS]; 2], F)>, - I::Writes: From<[[F; NUM_CELLS]; 1]>, + F: PrimeField32, + A: 'static + + for<'a> AdapterTraceStep< + F, + CTX, + ReadData = (([u32; NUM_CELLS], [u8; NUM_CELLS]), u8), + WriteData = [u32; NUM_CELLS], + >, { - type Record = LoadStoreCoreRecord; - type Air = LoadStoreCoreAir; - - #[allow(clippy::type_complexity)] - fn execute_instruction( - &self, - instruction: &Instruction, - _from_pc: u32, - reads: I::Reads, - ) -> Result<(AdapterRuntimeContext, Self::Record)> { - let local_opcode = - Rv32LoadStoreOpcode::from_usize(instruction.opcode.local_opcode_idx(self.air.offset)); - - let (reads, shift_amount) = reads.into(); - let shift = shift_amount.as_canonical_u32(); - let prev_data = reads[0]; - let read_data = reads[1]; - let write_data = run_write_data(local_opcode, read_data, prev_data, shift); - let output = AdapterRuntimeContext::without_pc([write_data]); - - Ok(( - output, - LoadStoreCoreRecord { - opcode: local_opcode, - shift, - prev_data, - read_data, - write_data, - }, - )) - } + type RecordLayout = EmptyAdapterCoreLayout; + type RecordMut<'a> = (A::RecordMut<'a>, &'a mut LoadStoreCoreRecord); fn get_opcode_name(&self, opcode: usize) -> String { format!( "{:?}", - Rv32LoadStoreOpcode::from_usize(opcode - self.air.offset) + Rv32LoadStoreOpcode::from_usize(opcode - self.offset) ) } - fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) { - let core_cols: &mut LoadStoreCoreCols = row_slice.borrow_mut(); - let opcode = record.opcode; - let flags = &mut core_cols.flags; + fn execute<'buf, RA>( + &mut self, + state: VmStateMut, CTX>, + instruction: &Instruction, + arena: &'buf mut RA, + ) -> Result<()> + where + RA: RecordArena<'buf, Self::RecordLayout, Self::RecordMut<'buf>>, + { + let Instruction { opcode, .. } = instruction; + + let (mut adapter_record, core_record) = arena.alloc(EmptyAdapterCoreLayout::new()); + + A::start(*state.pc, state.memory, &mut adapter_record); + + ( + (core_record.prev_data, core_record.read_data), + core_record.shift_amount, + ) = self + .adapter + .read(state.memory, instruction, &mut adapter_record); + + let local_opcode = Rv32LoadStoreOpcode::from_usize(opcode.local_opcode_idx(self.offset)); + core_record.local_opcode = local_opcode as u8; + + let write_data = run_write_data( + local_opcode, + core_record.read_data, + core_record.prev_data, + core_record.shift_amount as usize, + ); + self.adapter + .write(state.memory, instruction, write_data, &mut adapter_record); + + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + + Ok(()) + } +} + +impl TraceFiller for LoadStoreStep +where + F: PrimeField32, + A: 'static + AdapterTraceFiller, +{ + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + self.adapter.fill_trace_row(mem_helper, adapter_row); + + let record: &LoadStoreCoreRecord = + unsafe { get_record_from_slice(&mut core_row, ()) }; + let core_row: &mut LoadStoreCoreCols = core_row.borrow_mut(); + + let opcode = Rv32LoadStoreOpcode::from_usize(record.local_opcode as usize); + let shift = record.shift_amount; + + let write_data = run_write_data(opcode, record.read_data, record.prev_data, shift as usize); + // Writing in reverse order + core_row.write_data = write_data.map(F::from_canonical_u32); + core_row.prev_data = record.prev_data.map(F::from_canonical_u32); + core_row.read_data = record.read_data.map(F::from_canonical_u8); + core_row.is_load = F::from_bool([LOADW, LOADHU, LOADBU].contains(&opcode)); + core_row.is_valid = F::ONE; + let flags = &mut core_row.flags; *flags = [F::ZERO; 4]; - match (opcode, record.shift) { + match (opcode, shift) { (LOADW, 0) => flags[0] = F::TWO, (LOADHU, 0) => flags[1] = F::TWO, (LOADHU, 2) => flags[2] = F::TWO, @@ -328,51 +367,442 @@ where (STOREB, 3) => (flags[2], flags[3]) = (F::ONE, F::ONE), _ => unreachable!(), }; - core_cols.prev_data = record.prev_data; - core_cols.read_data = record.read_data; - core_cols.is_valid = F::ONE; - core_cols.is_load = F::from_bool([LOADW, LOADHU, LOADBU].contains(&opcode)); - core_cols.write_data = record.write_data; + } +} + +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct LoadStorePreCompute { + imm_extended: u32, + a: u8, + b: u8, + e: u8, +} + +impl StepExecutorE1 for LoadStoreStep +where + F: PrimeField32, +{ + #[inline(always)] + fn pre_compute_size(&self) -> usize { + size_of::() + } + + #[inline(always)] + fn pre_compute_e1( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result> { + let pre_compute: &mut LoadStorePreCompute = data.borrow_mut(); + let (local_opcode, enabled, is_native_store) = + self.pre_compute_impl(pc, inst, pre_compute)?; + let fn_ptr = match (local_opcode, enabled, is_native_store) { + (LOADW, true, _) => execute_e1_impl::<_, _, U8, LoadWOp, true>, + (LOADW, false, _) => execute_e1_impl::<_, _, U8, LoadWOp, false>, + (LOADHU, true, _) => execute_e1_impl::<_, _, U8, LoadHUOp, true>, + (LOADHU, false, _) => execute_e1_impl::<_, _, U8, LoadHUOp, false>, + (LOADBU, true, _) => execute_e1_impl::<_, _, U8, LoadBUOp, true>, + (LOADBU, false, _) => execute_e1_impl::<_, _, U8, LoadBUOp, false>, + (STOREW, true, false) => execute_e1_impl::<_, _, U8, StoreWOp, true>, + (STOREW, false, false) => execute_e1_impl::<_, _, U8, StoreWOp, false>, + (STOREW, true, true) => execute_e1_impl::<_, _, F, StoreWOp, true>, + (STOREW, false, true) => execute_e1_impl::<_, _, F, StoreWOp, false>, + (STOREH, true, false) => execute_e1_impl::<_, _, U8, StoreHOp, true>, + (STOREH, false, false) => execute_e1_impl::<_, _, U8, StoreHOp, false>, + (STOREH, true, true) => execute_e1_impl::<_, _, F, StoreHOp, true>, + (STOREH, false, true) => execute_e1_impl::<_, _, F, StoreHOp, false>, + (STOREB, true, false) => execute_e1_impl::<_, _, U8, StoreBOp, true>, + (STOREB, false, false) => execute_e1_impl::<_, _, U8, StoreBOp, false>, + (STOREB, true, true) => execute_e1_impl::<_, _, F, StoreBOp, true>, + (STOREB, false, true) => execute_e1_impl::<_, _, F, StoreBOp, false>, + (_, _, _) => unreachable!(), + }; + Ok(fn_ptr) + } +} + +impl StepExecutorE2 for LoadStoreStep +where + F: PrimeField32, +{ + fn e2_pre_compute_size(&self) -> usize { + size_of::>() } - fn air(&self) -> &Self::Air { - &self.air + fn pre_compute_e2( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result> + where + Ctx: E2ExecutionCtx, + { + let pre_compute: &mut E2PreCompute = data.borrow_mut(); + pre_compute.chip_idx = chip_idx as u32; + let (local_opcode, enabled, is_native_store) = + self.pre_compute_impl(pc, inst, &mut pre_compute.data)?; + let fn_ptr = match (local_opcode, enabled, is_native_store) { + (LOADW, true, _) => execute_e2_impl::<_, _, U8, LoadWOp, true>, + (LOADW, false, _) => execute_e2_impl::<_, _, U8, LoadWOp, false>, + (LOADHU, true, _) => execute_e2_impl::<_, _, U8, LoadHUOp, true>, + (LOADHU, false, _) => execute_e2_impl::<_, _, U8, LoadHUOp, false>, + (LOADBU, true, _) => execute_e2_impl::<_, _, U8, LoadBUOp, true>, + (LOADBU, false, _) => execute_e2_impl::<_, _, U8, LoadBUOp, false>, + (STOREW, true, false) => execute_e2_impl::<_, _, U8, StoreWOp, true>, + (STOREW, false, false) => execute_e2_impl::<_, _, U8, StoreWOp, false>, + (STOREW, true, true) => execute_e2_impl::<_, _, F, StoreWOp, true>, + (STOREW, false, true) => execute_e2_impl::<_, _, F, StoreWOp, false>, + (STOREH, true, false) => execute_e2_impl::<_, _, U8, StoreHOp, true>, + (STOREH, false, false) => execute_e2_impl::<_, _, U8, StoreHOp, false>, + (STOREH, true, true) => execute_e2_impl::<_, _, F, StoreHOp, true>, + (STOREH, false, true) => execute_e2_impl::<_, _, F, StoreHOp, false>, + (STOREB, true, false) => execute_e2_impl::<_, _, U8, StoreBOp, true>, + (STOREB, false, false) => execute_e2_impl::<_, _, U8, StoreBOp, false>, + (STOREB, true, true) => execute_e2_impl::<_, _, F, StoreBOp, true>, + (STOREB, false, true) => execute_e2_impl::<_, _, F, StoreBOp, false>, + (_, _, _) => unreachable!(), + }; + Ok(fn_ptr) } } -pub(super) fn run_write_data( +#[inline(always)] +unsafe fn execute_e12_impl< + F: PrimeField32, + CTX: E1ExecutionCtx, + T: Copy + Debug + Default, + OP: LoadStoreOp, + const ENABLED: bool, +>( + pre_compute: &LoadStorePreCompute, + vm_state: &mut VmSegmentState, +) { + let rs1_bytes: [u8; RV32_REGISTER_NUM_LIMBS] = + vm_state.vm_read(RV32_REGISTER_AS, pre_compute.b as u32); + let rs1_val = u32::from_le_bytes(rs1_bytes); + let ptr_val = rs1_val.wrapping_add(pre_compute.imm_extended); + // sign_extend([r32{c,g}(b):2]_e)` + debug_assert!(ptr_val < (1 << POINTER_MAX_BITS)); + let shift_amount = ptr_val % 4; + let ptr_val = ptr_val - shift_amount; // aligned ptr + + let read_data: [u8; RV32_REGISTER_NUM_LIMBS] = if OP::IS_LOAD { + vm_state.vm_read(pre_compute.e as u32, ptr_val) + } else { + vm_state.vm_read(RV32_REGISTER_AS, pre_compute.a as u32) + }; + + // We need to write 4 u32s for STORE. + let mut write_data: [T; RV32_REGISTER_NUM_LIMBS] = if OP::HOST_READ { + vm_state.host_read(pre_compute.e as u32, ptr_val) + } else { + [T::default(); RV32_REGISTER_NUM_LIMBS] + }; + + if !OP::compute_write_data(&mut write_data, read_data, shift_amount as usize) { + vm_state.exit_code = Err(ExecutionError::Fail { pc: vm_state.pc }); + return; + } + + if ENABLED { + if OP::IS_LOAD { + vm_state.vm_write(RV32_REGISTER_AS, pre_compute.a as u32, &write_data); + } else { + vm_state.vm_write(pre_compute.e as u32, ptr_val, &write_data); + } + } + + vm_state.pc += DEFAULT_PC_STEP; + vm_state.instret += 1; +} + +unsafe fn execute_e1_impl< + F: PrimeField32, + CTX: E1ExecutionCtx, + T: Copy + Debug + Default, + OP: LoadStoreOp, + const ENABLED: bool, +>( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &LoadStorePreCompute = pre_compute.borrow(); + execute_e12_impl::(pre_compute, vm_state); +} + +unsafe fn execute_e2_impl< + F: PrimeField32, + CTX: E2ExecutionCtx, + T: Copy + Debug + Default, + OP: LoadStoreOp, + const ENABLED: bool, +>( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &E2PreCompute = pre_compute.borrow(); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, 1); + execute_e12_impl::(&pre_compute.data, vm_state); +} + +impl LoadStoreStep { + /// Return (local_opcode, enabled, is_native_store) + fn pre_compute_impl( + &self, + pc: u32, + inst: &Instruction, + data: &mut LoadStorePreCompute, + ) -> Result<(Rv32LoadStoreOpcode, bool, bool)> { + let Instruction { + opcode, + a, + b, + c, + d, + e, + f, + g, + .. + } = inst; + let enabled = !f.is_zero(); + + let e_u32 = e.as_canonical_u32(); + if d.as_canonical_u32() != RV32_REGISTER_AS || e_u32 == RV32_IMM_AS { + return Err(InvalidInstruction(pc)); + } + + let local_opcode = Rv32LoadStoreOpcode::from_usize( + opcode.local_opcode_idx(Rv32LoadStoreOpcode::CLASS_OFFSET), + ); + match local_opcode { + LOADW | LOADBU | LOADHU => {} + STOREW | STOREH | STOREB => { + if !enabled { + return Err(InvalidInstruction(pc)); + } + } + _ => unreachable!("LoadStoreStep should not handle LOADB/LOADH opcodes"), + } + + let imm = c.as_canonical_u32(); + let imm_sign = g.as_canonical_u32(); + let imm_extended = imm + imm_sign * 0xffff0000; + let is_native_store = e_u32 == NATIVE_AS; + + *data = LoadStorePreCompute { + imm_extended, + a: a.as_canonical_u32() as u8, + b: b.as_canonical_u32() as u8, + e: e_u32 as u8, + }; + Ok((local_opcode, enabled, is_native_store)) + } +} + +trait LoadStoreOp { + const IS_LOAD: bool; + const HOST_READ: bool; + + /// Return if the operation is valid. + fn compute_write_data( + write_data: &mut [T; RV32_REGISTER_NUM_LIMBS], + read_data: [u8; RV32_REGISTER_NUM_LIMBS], + shift_amount: usize, + ) -> bool; +} +/// Wrapper type for u8 so we can implement `LoadStoreOp` for `F: PrimeField32`. +/// For memory read/write, this type behaves as same as `u8`. +#[allow(dead_code)] +#[derive(Copy, Clone, Debug, Default)] +struct U8(u8); +struct LoadWOp; +struct LoadHUOp; +struct LoadBUOp; +struct StoreWOp; +struct StoreHOp; +struct StoreBOp; +impl LoadStoreOp for LoadWOp { + const IS_LOAD: bool = true; + const HOST_READ: bool = false; + + #[inline(always)] + fn compute_write_data( + write_data: &mut [U8; RV32_REGISTER_NUM_LIMBS], + read_data: [u8; RV32_REGISTER_NUM_LIMBS], + _shift_amount: usize, + ) -> bool { + *write_data = read_data.map(U8); + true + } +} + +impl LoadStoreOp for LoadHUOp { + const IS_LOAD: bool = true; + const HOST_READ: bool = false; + #[inline(always)] + fn compute_write_data( + write_data: &mut [U8; RV32_REGISTER_NUM_LIMBS], + read_data: [u8; RV32_REGISTER_NUM_LIMBS], + shift_amount: usize, + ) -> bool { + if shift_amount != 0 && shift_amount != 2 { + return false; + } + write_data[0] = U8(read_data[shift_amount]); + write_data[1] = U8(read_data[shift_amount + 1]); + true + } +} +impl LoadStoreOp for LoadBUOp { + const IS_LOAD: bool = true; + const HOST_READ: bool = false; + #[inline(always)] + fn compute_write_data( + write_data: &mut [U8; RV32_REGISTER_NUM_LIMBS], + read_data: [u8; RV32_REGISTER_NUM_LIMBS], + shift_amount: usize, + ) -> bool { + write_data[0] = U8(read_data[shift_amount]); + true + } +} + +impl LoadStoreOp for StoreWOp { + const IS_LOAD: bool = false; + const HOST_READ: bool = false; + #[inline(always)] + fn compute_write_data( + write_data: &mut [U8; RV32_REGISTER_NUM_LIMBS], + read_data: [u8; RV32_REGISTER_NUM_LIMBS], + _shift_amount: usize, + ) -> bool { + *write_data = read_data.map(U8); + true + } +} +impl LoadStoreOp for StoreHOp { + const IS_LOAD: bool = false; + const HOST_READ: bool = true; + + #[inline(always)] + fn compute_write_data( + write_data: &mut [U8; RV32_REGISTER_NUM_LIMBS], + read_data: [u8; RV32_REGISTER_NUM_LIMBS], + shift_amount: usize, + ) -> bool { + if shift_amount != 0 && shift_amount != 2 { + return false; + } + write_data[shift_amount] = U8(read_data[0]); + write_data[shift_amount + 1] = U8(read_data[1]); + true + } +} +impl LoadStoreOp for StoreBOp { + const IS_LOAD: bool = false; + const HOST_READ: bool = true; + #[inline(always)] + fn compute_write_data( + write_data: &mut [U8; RV32_REGISTER_NUM_LIMBS], + read_data: [u8; RV32_REGISTER_NUM_LIMBS], + shift_amount: usize, + ) -> bool { + write_data[shift_amount] = U8(read_data[0]); + true + } +} + +impl LoadStoreOp for StoreWOp { + const IS_LOAD: bool = false; + const HOST_READ: bool = false; + #[inline(always)] + fn compute_write_data( + write_data: &mut [F; RV32_REGISTER_NUM_LIMBS], + read_data: [u8; RV32_REGISTER_NUM_LIMBS], + _shift_amount: usize, + ) -> bool { + *write_data = read_data.map(F::from_canonical_u8); + true + } +} +impl LoadStoreOp for StoreHOp { + const IS_LOAD: bool = false; + const HOST_READ: bool = true; + + #[inline(always)] + fn compute_write_data( + write_data: &mut [F; RV32_REGISTER_NUM_LIMBS], + read_data: [u8; RV32_REGISTER_NUM_LIMBS], + shift_amount: usize, + ) -> bool { + if shift_amount != 0 && shift_amount != 2 { + return false; + } + write_data[shift_amount] = F::from_canonical_u8(read_data[0]); + write_data[shift_amount + 1] = F::from_canonical_u8(read_data[1]); + true + } +} +impl LoadStoreOp for StoreBOp { + const IS_LOAD: bool = false; + const HOST_READ: bool = true; + #[inline(always)] + fn compute_write_data( + write_data: &mut [F; RV32_REGISTER_NUM_LIMBS], + read_data: [u8; RV32_REGISTER_NUM_LIMBS], + shift_amount: usize, + ) -> bool { + write_data[shift_amount] = F::from_canonical_u8(read_data[0]); + true + } +} + +// Returns the write data +#[inline(always)] +pub(super) fn run_write_data( opcode: Rv32LoadStoreOpcode, - read_data: [F; NUM_CELLS], - prev_data: [F; NUM_CELLS], - shift: u32, -) -> [F; NUM_CELLS] { - let shift = shift as usize; - let mut write_data = read_data; + read_data: [u8; NUM_CELLS], + prev_data: [u32; NUM_CELLS], + shift: usize, +) -> [u32; NUM_CELLS] { match (opcode, shift) { - (LOADW, 0) => (), + (LOADW, 0) => { + read_data.map(|x| x as u32) + }, (LOADBU, 0) | (LOADBU, 1) | (LOADBU, 2) | (LOADBU, 3) => { - for cell in write_data.iter_mut().take(NUM_CELLS).skip(1) { - *cell = F::ZERO; - } - write_data[0] = read_data[shift]; + let mut wrie_data = [0; NUM_CELLS]; + wrie_data[0] = read_data[shift] as u32; + wrie_data } (LOADHU, 0) | (LOADHU, 2) => { - for cell in write_data.iter_mut().take(NUM_CELLS).skip(NUM_CELLS / 2) { - *cell = F::ZERO; - } + let mut write_data = [0; NUM_CELLS]; for (i, cell) in write_data.iter_mut().take(NUM_CELLS / 2).enumerate() { - *cell = read_data[i + shift]; + *cell = read_data[i + shift] as u32; } + write_data } - (STOREW, 0) => (), + (STOREW, 0) => { + read_data.map(|x| x as u32) + }, (STOREB, 0) | (STOREB, 1) | (STOREB, 2) | (STOREB, 3) => { - write_data = prev_data; - write_data[shift] = read_data[0]; + let mut write_data = prev_data; + write_data[shift] = read_data[0] as u32; + write_data } (STOREH, 0) | (STOREH, 2) => { - write_data = prev_data; - write_data[shift..(NUM_CELLS / 2 + shift)] - .copy_from_slice(&read_data[..(NUM_CELLS / 2)]); + array::from_fn(|i| { + if i >= shift && i < (NUM_CELLS / 2 + shift){ + read_data[i - shift] as u32 + } else { + prev_data[i] + } + }) } // Currently the adapter AIR requires `ptr_val` to be aligned to the data size in bytes. // The circuit requires that `shift = ptr_val % 4` so that `ptr_val - shift` is a multiple of 4. @@ -380,6 +810,5 @@ pub(super) fn run_write_data( _ => unreachable!( "unaligned memory access not supported by this execution environment: {opcode:?}, shift: {shift}" ), - }; - write_data + } } diff --git a/extensions/rv32im/circuit/src/loadstore/mod.rs b/extensions/rv32im/circuit/src/loadstore/mod.rs index 825f82166c..eb439a1bd6 100644 --- a/extensions/rv32im/circuit/src/loadstore/mod.rs +++ b/extensions/rv32im/circuit/src/loadstore/mod.rs @@ -2,12 +2,16 @@ mod core; pub use core::*; -use openvm_circuit::arch::VmChipWrapper; +use openvm_circuit::arch::{MatrixRecordArena, NewVmChipWrapper, VmAirWrapper}; -use super::adapters::{Rv32LoadStoreAdapterChip, RV32_REGISTER_NUM_LIMBS}; +use super::adapters::RV32_REGISTER_NUM_LIMBS; +use crate::adapters::{Rv32LoadStoreAdapterAir, Rv32LoadStoreAdapterStep}; #[cfg(test)] mod tests; +pub type Rv32LoadStoreAir = + VmAirWrapper>; +pub type Rv32LoadStoreStep = LoadStoreStep; pub type Rv32LoadStoreChip = - VmChipWrapper, LoadStoreCoreChip>; + NewVmChipWrapper>; diff --git a/extensions/rv32im/circuit/src/loadstore/tests.rs b/extensions/rv32im/circuit/src/loadstore/tests.rs index 0fbfa137b9..f29bdda263 100644 --- a/extensions/rv32im/circuit/src/loadstore/tests.rs +++ b/extensions/rv32im/circuit/src/loadstore/tests.rs @@ -3,49 +3,78 @@ use std::{array, borrow::BorrowMut}; use openvm_circuit::{ arch::{ testing::{memory::gen_pointer, VmChipTestBuilder}, - VmAdapterChip, + MemoryConfig, VmAirWrapper, }, - utils::u32_into_limbs, + system::memory::merkle::public_values::PUBLIC_VALUES_AS, }; use openvm_instructions::{instruction::Instruction, LocalOpcode}; use openvm_rv32im_transpiler::Rv32LoadStoreOpcode::{self, *}; use openvm_stark_backend::{ p3_air::BaseAir, - p3_field::FieldAlgebra, + p3_field::{FieldAlgebra, PrimeField32}, p3_matrix::{ dense::{DenseMatrix, RowMajorMatrix}, Matrix, }, utils::disable_debug_builder, - verifier::VerificationError, }; -use openvm_stark_sdk::{config::setup_tracing, p3_baby_bear::BabyBear, utils::create_seeded_rng}; +use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; use rand::{rngs::StdRng, seq::SliceRandom, Rng}; +use test_case::test_case; -use super::{run_write_data, LoadStoreCoreChip, Rv32LoadStoreChip}; +use super::{run_write_data, LoadStoreCoreAir, LoadStoreStep, Rv32LoadStoreChip}; use crate::{ - adapters::{compose, Rv32LoadStoreAdapterChip, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}, + adapters::{ + Rv32LoadStoreAdapterAir, Rv32LoadStoreAdapterCols, Rv32LoadStoreAdapterStep, + RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS, + }, loadstore::LoadStoreCoreCols, + test_utils::get_verification_error, }; const IMM_BITS: usize = 16; +const MAX_INS_CAPACITY: usize = 128; type F = BabyBear; +fn create_test_chip(tester: &mut VmChipTestBuilder) -> Rv32LoadStoreChip { + let range_checker_chip = tester.range_checker(); + + let mut chip = Rv32LoadStoreChip::::new( + VmAirWrapper::new( + Rv32LoadStoreAdapterAir::new( + tester.memory_bridge(), + tester.execution_bridge(), + range_checker_chip.bus(), + tester.address_bits(), + ), + LoadStoreCoreAir::new(Rv32LoadStoreOpcode::CLASS_OFFSET), + ), + LoadStoreStep::new( + Rv32LoadStoreAdapterStep::new(tester.address_bits(), range_checker_chip.clone()), + Rv32LoadStoreOpcode::CLASS_OFFSET, + ), + tester.memory_helper(), + ); + chip.set_trace_buffer_height(MAX_INS_CAPACITY); + + chip +} + #[allow(clippy::too_many_arguments)] fn set_and_execute( tester: &mut VmChipTestBuilder, chip: &mut Rv32LoadStoreChip, rng: &mut StdRng, opcode: Rv32LoadStoreOpcode, - rs1: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, + rs1: Option<[u8; RV32_REGISTER_NUM_LIMBS]>, imm: Option, imm_sign: Option, mem_as: Option, ) { let imm = imm.unwrap_or(rng.gen_range(0..(1 << IMM_BITS))); let imm_sign = imm_sign.unwrap_or(rng.gen_range(0..2)); - let imm_ext = imm + imm_sign * (0xffffffff ^ ((1 << IMM_BITS) - 1)); + let imm_ext = imm + imm_sign * 0xffff0000; let alignment = match opcode { LOADW | STOREW => 2, @@ -54,33 +83,21 @@ fn set_and_execute( _ => unreachable!(), }; - let ptr_val = rng.gen_range( - 0..(1 - << (tester - .memory_controller() - .borrow() - .mem_config() - .pointer_max_bits - - alignment)), - ) << alignment; - - let rs1 = rs1 - .unwrap_or(u32_into_limbs::( - (ptr_val as u32).wrapping_sub(imm_ext), - )) - .map(F::from_canonical_u32); + let ptr_val: u32 = rng.gen_range(0..(1 << (tester.address_bits() - alignment))) << alignment; + let rs1 = rs1.unwrap_or(ptr_val.wrapping_sub(imm_ext).to_le_bytes()); + let ptr_val = imm_ext.wrapping_add(u32::from_le_bytes(rs1)); let a = gen_pointer(rng, 4); let b = gen_pointer(rng, 4); + let is_load = [LOADW, LOADHU, LOADBU].contains(&opcode); let mem_as = mem_as.unwrap_or(if is_load { - *[1, 2].choose(rng).unwrap() + 2 } else { *[2, 3, 4].choose(rng).unwrap() }); - let ptr_val = imm_ext.wrapping_add(compose(rs1)); let shift_amount = ptr_val % 4; - tester.write(1, b, rs1); + tester.write(1, b, rs1.map(F::from_canonical_u8)); let mut some_prev_data: [F; RV32_REGISTER_NUM_LIMBS] = array::from_fn(|_| F::from_canonical_u32(rng.gen_range(0..(1 << RV32_CELL_BITS)))); @@ -92,11 +109,11 @@ fn set_and_execute( some_prev_data = [F::ZERO; RV32_REGISTER_NUM_LIMBS]; } tester.write(1, a, some_prev_data); - if mem_as == 1 && ptr_val - shift_amount == 0 { - read_data = [F::ZERO; RV32_REGISTER_NUM_LIMBS]; - } tester.write(mem_as, (ptr_val - shift_amount) as usize, read_data); } else { + if mem_as == 4 { + some_prev_data = array::from_fn(|_| rng.gen()); + } if a == 0 { read_data = [F::ZERO; RV32_REGISTER_NUM_LIMBS]; } @@ -122,7 +139,13 @@ fn set_and_execute( ), ); - let write_data = run_write_data(opcode, read_data, some_prev_data, shift_amount); + let write_data = run_write_data( + opcode, + read_data.map(|x| x.as_canonical_u32() as u8), + some_prev_data.map(|x| x.as_canonical_u32()), + shift_amount as usize, + ) + .map(F::from_canonical_u32); if is_load { if enabled_write { assert_eq!(write_data, tester.read::<4>(1, a)); @@ -143,80 +166,27 @@ fn set_and_execute( /// Randomly generate computations and execute, ensuring that the generated trace /// passes all constraints. /////////////////////////////////////////////////////////////////////////////////////// -#[test] -fn rand_loadstore_test() { - setup_tracing(); +#[test_case(LOADW, 100)] +#[test_case(LOADBU, 100)] +#[test_case(LOADHU, 100)] +#[test_case(STOREW, 100)] +#[test_case(STOREB, 100)] +#[test_case(STOREH, 100)] +fn rand_loadstore_test(opcode: Rv32LoadStoreOpcode, num_ops: usize) { let mut rng = create_seeded_rng(); - let mut tester = VmChipTestBuilder::default(); - let range_checker_chip = tester.memory_controller().borrow().range_checker.clone(); - let adapter = Rv32LoadStoreAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - range_checker_chip.clone(), - ); - - let core = LoadStoreCoreChip::new(Rv32LoadStoreOpcode::CLASS_OFFSET); - let mut chip = Rv32LoadStoreChip::::new(adapter, core, tester.offline_memory_mutex_arc()); + let mut mem_config = MemoryConfig::default(); + if [STOREW, STOREB, STOREH].contains(&opcode) { + mem_config.addr_space_sizes[PUBLIC_VALUES_AS as usize] = 1 << 29; + } + let mut tester = VmChipTestBuilder::volatile(mem_config); + let mut chip = create_test_chip(&mut tester); - let num_tests: usize = 100; - for _ in 0..num_tests { - set_and_execute( - &mut tester, - &mut chip, - &mut rng, - LOADW, - None, - None, - None, - None, - ); - set_and_execute( - &mut tester, - &mut chip, - &mut rng, - LOADBU, - None, - None, - None, - None, - ); + for _ in 0..num_ops { set_and_execute( &mut tester, &mut chip, &mut rng, - LOADHU, - None, - None, - None, - None, - ); - set_and_execute( - &mut tester, - &mut chip, - &mut rng, - STOREW, - None, - None, - None, - None, - ); - set_and_execute( - &mut tester, - &mut chip, - &mut rng, - STOREB, - None, - None, - None, - None, - ); - set_and_execute( - &mut tester, - &mut chip, - &mut rng, - STOREH, + opcode, None, None, None, @@ -224,7 +194,6 @@ fn rand_loadstore_test() { ); } - drop(range_checker_chip); let tester = tester.build().load(chip).finalize(); tester.simple_test().expect("Verification failed"); } @@ -233,38 +202,35 @@ fn rand_loadstore_test() { // NEGATIVE TESTS // // Given a fake trace of a single operation, setup a chip and run the test. We replace -// the write part of the trace and check that the core chip throws the expected error. -// A dummy adaptor is used so memory interactions don't indirectly cause false passes. +// part of the trace and check that the chip throws the expected error. ////////////////////////////////////////////////////////////////////////////////////// -#[allow(clippy::too_many_arguments)] -fn run_negative_loadstore_test( - opcode: Rv32LoadStoreOpcode, +#[derive(Clone, Copy, Default, PartialEq)] +struct LoadStorePrankValues { read_data: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, prev_data: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, write_data: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, flags: Option<[u32; 4]>, is_load: Option, - rs1: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, + mem_as: Option, +} + +#[allow(clippy::too_many_arguments)] +fn run_negative_loadstore_test( + opcode: Rv32LoadStoreOpcode, + rs1: Option<[u8; RV32_REGISTER_NUM_LIMBS]>, imm: Option, imm_sign: Option, - mem_as: Option, - expected_error: VerificationError, + prank_vals: LoadStorePrankValues, + interaction_error: bool, ) { let mut rng = create_seeded_rng(); - let mut tester = VmChipTestBuilder::default(); - let range_checker_chip = tester.memory_controller().borrow().range_checker.clone(); - let adapter = Rv32LoadStoreAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - range_checker_chip.clone(), - ); - - let core = LoadStoreCoreChip::new(Rv32LoadStoreOpcode::CLASS_OFFSET); - let adapter_width = BaseAir::::width(adapter.air()); - let mut chip = Rv32LoadStoreChip::::new(adapter, core, tester.offline_memory_mutex_arc()); + let mut mem_config = MemoryConfig::default(); + if [STOREW, STOREB, STOREH].contains(&opcode) { + mem_config.addr_space_sizes[PUBLIC_VALUES_AS as usize] = 1 << 29; + } + let mut tester = VmChipTestBuilder::volatile(mem_config); + let mut chip = create_test_chip(&mut tester); set_and_execute( &mut tester, @@ -274,38 +240,45 @@ fn run_negative_loadstore_test( rs1, imm, imm_sign, - mem_as, + None, ); + let adapter_width = BaseAir::::width(&chip.air.adapter); + let modify_trace = |trace: &mut DenseMatrix| { let mut trace_row = trace.row_slice(0).to_vec(); - let (_, core_row) = trace_row.split_at_mut(adapter_width); + let (adapter_row, core_row) = trace_row.split_at_mut(adapter_width); + let adapter_cols: &mut Rv32LoadStoreAdapterCols = adapter_row.borrow_mut(); let core_cols: &mut LoadStoreCoreCols = core_row.borrow_mut(); - if let Some(read_data) = read_data { + + if let Some(read_data) = prank_vals.read_data { core_cols.read_data = read_data.map(F::from_canonical_u32); } - if let Some(prev_data) = prev_data { + if let Some(prev_data) = prank_vals.prev_data { core_cols.prev_data = prev_data.map(F::from_canonical_u32); } - if let Some(write_data) = write_data { + if let Some(write_data) = prank_vals.write_data { core_cols.write_data = write_data.map(F::from_canonical_u32); } - if let Some(flags) = flags { + if let Some(flags) = prank_vals.flags { core_cols.flags = flags.map(F::from_canonical_u32); } - if let Some(is_load) = is_load { + if let Some(is_load) = prank_vals.is_load { core_cols.is_load = F::from_bool(is_load); } + if let Some(mem_as) = prank_vals.mem_as { + adapter_cols.mem_as = F::from_canonical_u32(mem_as); + } + *trace = RowMajorMatrix::new(trace_row, trace.width()); }; - drop(range_checker_chip); disable_debug_builder(); let tester = tester .build() .load_and_prank_trace(chip, modify_trace) .finalize(); - tester.simple_test_with_expected_error(expected_error); + tester.simple_test_with_expected_error(get_verification_error(interaction_error)); } #[test] @@ -315,41 +288,36 @@ fn negative_wrong_opcode_tests() { None, None, None, - None, - Some(false), - None, - None, - None, - None, - VerificationError::OodEvaluationMismatch, + LoadStorePrankValues { + is_load: Some(false), + ..Default::default() + }, + false, ); run_negative_loadstore_test( LOADBU, - None, - None, - None, - Some([0, 0, 0, 2]), - None, Some([4, 0, 0, 0]), Some(1), None, - None, - VerificationError::OodEvaluationMismatch, + LoadStorePrankValues { + flags: Some([0, 0, 0, 2]), + ..Default::default() + }, + false, ); run_negative_loadstore_test( STOREH, - None, - None, - None, - Some([1, 0, 1, 0]), - Some(true), Some([11, 169, 76, 28]), Some(37121), None, - None, - VerificationError::OodEvaluationMismatch, + LoadStorePrankValues { + flags: Some([1, 0, 1, 0]), + is_load: Some(true), + ..Default::default() + }, + false, ); } @@ -357,30 +325,34 @@ fn negative_wrong_opcode_tests() { fn negative_write_data_tests() { run_negative_loadstore_test( LOADHU, - Some([175, 33, 198, 250]), - Some([90, 121, 64, 205]), - Some([175, 33, 0, 0]), - Some([0, 2, 0, 0]), - Some(true), Some([13, 11, 156, 23]), Some(43641), None, - None, - VerificationError::ChallengePhaseError, + LoadStorePrankValues { + read_data: Some([175, 33, 198, 250]), + prev_data: Some([90, 121, 64, 205]), + write_data: Some([175, 33, 0, 0]), + flags: Some([0, 2, 0, 0]), + is_load: Some(true), + mem_as: None, + }, + true, ); run_negative_loadstore_test( STOREB, - Some([175, 33, 198, 250]), - Some([90, 121, 64, 205]), - Some([175, 121, 64, 205]), - Some([0, 0, 1, 1]), - None, Some([45, 123, 87, 24]), Some(28122), Some(0), - None, - VerificationError::OodEvaluationMismatch, + LoadStorePrankValues { + read_data: Some([175, 33, 198, 250]), + prev_data: Some([90, 121, 64, 205]), + write_data: Some([175, 121, 64, 205]), + flags: Some([0, 0, 1, 1]), + is_load: None, + mem_as: None, + }, + false, ); } @@ -391,39 +363,35 @@ fn negative_wrong_address_space_tests() { None, None, None, - None, - None, - None, - None, - None, - Some(3), - VerificationError::OodEvaluationMismatch, + LoadStorePrankValues { + mem_as: Some(3), + ..Default::default() + }, + false, ); + run_negative_loadstore_test( LOADW, None, None, None, - None, - None, - None, - None, - None, - Some(4), - VerificationError::OodEvaluationMismatch, + LoadStorePrankValues { + mem_as: Some(4), + ..Default::default() + }, + false, ); + run_negative_loadstore_test( STOREW, None, None, None, - None, - None, - None, - None, - None, - Some(1), - VerificationError::OodEvaluationMismatch, + LoadStorePrankValues { + mem_as: Some(1), + ..Default::default() + }, + false, ); } @@ -432,140 +400,60 @@ fn negative_wrong_address_space_tests() { /// /// Ensure that solve functions produce the correct results. /////////////////////////////////////////////////////////////////////////////////////// -#[test] -fn execute_roundtrip_sanity_test() { - let mut rng = create_seeded_rng(); - let mut tester = VmChipTestBuilder::default(); - let range_checker_chip = tester.memory_controller().borrow().range_checker.clone(); - let adapter = Rv32LoadStoreAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - range_checker_chip.clone(), - ); - let core = LoadStoreCoreChip::new(Rv32LoadStoreOpcode::CLASS_OFFSET); - let mut chip = Rv32LoadStoreChip::::new(adapter, core, tester.offline_memory_mutex_arc()); - - let num_tests: usize = 100; - for _ in 0..num_tests { - set_and_execute( - &mut tester, - &mut chip, - &mut rng, - LOADW, - None, - None, - None, - None, - ); - set_and_execute( - &mut tester, - &mut chip, - &mut rng, - LOADBU, - None, - None, - None, - None, - ); - set_and_execute( - &mut tester, - &mut chip, - &mut rng, - LOADHU, - None, - None, - None, - None, - ); - set_and_execute( - &mut tester, - &mut chip, - &mut rng, - STOREW, - None, - None, - None, - None, - ); - set_and_execute( - &mut tester, - &mut chip, - &mut rng, - STOREB, - None, - None, - None, - None, - ); - set_and_execute( - &mut tester, - &mut chip, - &mut rng, - STOREH, - None, - None, - None, - None, - ); - } -} - #[test] fn run_loadw_storew_sanity_test() { - let read_data = [138, 45, 202, 76].map(F::from_canonical_u32); - let prev_data = [159, 213, 89, 34].map(F::from_canonical_u32); + let read_data = [138, 45, 202, 76]; + let prev_data = [159, 213, 89, 34]; let store_write_data = run_write_data(STOREW, read_data, prev_data, 0); let load_write_data = run_write_data(LOADW, read_data, prev_data, 0); - assert_eq!(store_write_data, read_data); - assert_eq!(load_write_data, read_data); + assert_eq!(store_write_data, read_data.map(u32::from)); + assert_eq!(load_write_data, read_data.map(u32::from)); } #[test] fn run_storeh_sanity_test() { - let read_data = [250, 123, 67, 198].map(F::from_canonical_u32); - let prev_data = [144, 56, 175, 92].map(F::from_canonical_u32); + let read_data = [250, 123, 67, 198]; + let prev_data = [144, 56, 175, 92]; let write_data = run_write_data(STOREH, read_data, prev_data, 0); let write_data2 = run_write_data(STOREH, read_data, prev_data, 2); - assert_eq!(write_data, [250, 123, 175, 92].map(F::from_canonical_u32)); - assert_eq!(write_data2, [144, 56, 250, 123].map(F::from_canonical_u32)); + assert_eq!(write_data, [250, 123, 175, 92]); + assert_eq!(write_data2, [144, 56, 250, 123]); } #[test] fn run_storeb_sanity_test() { - let read_data = [221, 104, 58, 147].map(F::from_canonical_u32); - let prev_data = [199, 83, 243, 12].map(F::from_canonical_u32); + let read_data = [221, 104, 58, 147]; + let prev_data = [199, 83, 243, 12]; let write_data = run_write_data(STOREB, read_data, prev_data, 0); let write_data1 = run_write_data(STOREB, read_data, prev_data, 1); let write_data2 = run_write_data(STOREB, read_data, prev_data, 2); let write_data3 = run_write_data(STOREB, read_data, prev_data, 3); - assert_eq!(write_data, [221, 83, 243, 12].map(F::from_canonical_u32)); - assert_eq!(write_data1, [199, 221, 243, 12].map(F::from_canonical_u32)); - assert_eq!(write_data2, [199, 83, 221, 12].map(F::from_canonical_u32)); - assert_eq!(write_data3, [199, 83, 243, 221].map(F::from_canonical_u32)); + assert_eq!(write_data, [221, 83, 243, 12]); + assert_eq!(write_data1, [199, 221, 243, 12]); + assert_eq!(write_data2, [199, 83, 221, 12]); + assert_eq!(write_data3, [199, 83, 243, 221]); } #[test] fn run_loadhu_sanity_test() { - let read_data = [175, 33, 198, 250].map(F::from_canonical_u32); - let prev_data = [90, 121, 64, 205].map(F::from_canonical_u32); + let read_data = [175, 33, 198, 250]; + let prev_data = [90, 121, 64, 205]; let write_data = run_write_data(LOADHU, read_data, prev_data, 0); let write_data2 = run_write_data(LOADHU, read_data, prev_data, 2); - assert_eq!(write_data, [175, 33, 0, 0].map(F::from_canonical_u32)); - assert_eq!(write_data2, [198, 250, 0, 0].map(F::from_canonical_u32)); + assert_eq!(write_data, [175, 33, 0, 0]); + assert_eq!(write_data2, [198, 250, 0, 0]); } #[test] fn run_loadbu_sanity_test() { - let read_data = [131, 74, 186, 29].map(F::from_canonical_u32); - let prev_data = [142, 67, 210, 88].map(F::from_canonical_u32); + let read_data = [131, 74, 186, 29]; + let prev_data = [142, 67, 210, 88]; let write_data = run_write_data(LOADBU, read_data, prev_data, 0); let write_data1 = run_write_data(LOADBU, read_data, prev_data, 1); let write_data2 = run_write_data(LOADBU, read_data, prev_data, 2); let write_data3 = run_write_data(LOADBU, read_data, prev_data, 3); - assert_eq!(write_data, [131, 0, 0, 0].map(F::from_canonical_u32)); - assert_eq!(write_data1, [74, 0, 0, 0].map(F::from_canonical_u32)); - assert_eq!(write_data2, [186, 0, 0, 0].map(F::from_canonical_u32)); - assert_eq!(write_data3, [29, 0, 0, 0].map(F::from_canonical_u32)); + assert_eq!(write_data, [131, 0, 0, 0]); + assert_eq!(write_data1, [74, 0, 0, 0]); + assert_eq!(write_data2, [186, 0, 0, 0]); + assert_eq!(write_data3, [29, 0, 0, 0]); } diff --git a/extensions/rv32im/circuit/src/mul/core.rs b/extensions/rv32im/circuit/src/mul/core.rs index fa65a6cf09..3d89574e64 100644 --- a/extensions/rv32im/circuit/src/mul/core.rs +++ b/extensions/rv32im/circuit/src/mul/core.rs @@ -3,13 +3,28 @@ use std::{ borrow::{Borrow, BorrowMut}, }; -use openvm_circuit::arch::{ - AdapterAirContext, AdapterRuntimeContext, MinimalInstruction, Result, VmAdapterInterface, - VmCoreAir, VmCoreChip, +use openvm_circuit::{ + arch::{ + execution_mode::{E1ExecutionCtx, E2ExecutionCtx}, + get_record_from_slice, AdapterAirContext, AdapterTraceFiller, AdapterTraceStep, + E2PreCompute, EmptyAdapterCoreLayout, ExecuteFunc, + ExecutionError::InvalidInstruction, + MinimalInstruction, RecordArena, Result, StepExecutorE1, StepExecutorE2, TraceFiller, + TraceStep, VmAdapterInterface, VmCoreAir, VmSegmentState, VmStateMut, + }, + system::memory::{online::TracingMemory, MemoryAuxColsFactory}, +}; +use openvm_circuit_primitives::{ + range_tuple::{RangeTupleCheckerBus, SharedRangeTupleCheckerChip}, + AlignedBytesBorrow, }; -use openvm_circuit_primitives::range_tuple::{RangeTupleCheckerBus, SharedRangeTupleCheckerChip}; use openvm_circuit_primitives_derive::AlignedBorrow; -use openvm_instructions::{instruction::Instruction, LocalOpcode}; +use openvm_instructions::{ + instruction::Instruction, + program::DEFAULT_PC_STEP, + riscv::{RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS}, + LocalOpcode, +}; use openvm_rv32im_transpiler::MulOpcode; use openvm_stark_backend::{ interaction::InteractionBuilder, @@ -17,8 +32,6 @@ use openvm_stark_backend::{ p3_field::{Field, FieldAlgebra, PrimeField32}, rap::BaseAirWithPublicValues, }; -use serde::{de::DeserializeOwned, Deserialize, Serialize}; -use serde_big_array::BigArray; #[repr(C)] #[derive(AlignedBorrow)] @@ -29,7 +42,7 @@ pub struct MultiplicationCoreCols { pub bus: RangeTupleCheckerBus<2>, pub offset: usize, @@ -109,14 +122,28 @@ where } } +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug)] +pub struct MultiplicationCoreRecord { + pub b: [u8; NUM_LIMBS], + pub c: [u8; NUM_LIMBS], +} + #[derive(Debug)] -pub struct MultiplicationCoreChip { - pub air: MultiplicationCoreAir, +pub struct MultiplicationStep { + adapter: A, + pub offset: usize, pub range_tuple_chip: SharedRangeTupleCheckerChip<2>, } -impl MultiplicationCoreChip { - pub fn new(range_tuple_chip: SharedRangeTupleCheckerChip<2>, offset: usize) -> Self { +impl + MultiplicationStep +{ + pub fn new( + adapter: A, + range_tuple_chip: SharedRangeTupleCheckerChip<2>, + offset: usize, + ) -> Self { // The RangeTupleChecker is used to range check (a[i], carry[i]) pairs where 0 <= i // < NUM_LIMBS. a[i] must have LIMB_BITS bits and carry[i] is the sum of i + 1 bytes // (with LIMB_BITS bits). @@ -132,102 +159,237 @@ impl MultiplicationCoreChip { - #[serde(with = "BigArray")] - pub a: [T; NUM_LIMBS], - #[serde(with = "BigArray")] - pub b: [T; NUM_LIMBS], - #[serde(with = "BigArray")] - pub c: [T; NUM_LIMBS], -} - -impl, const NUM_LIMBS: usize, const LIMB_BITS: usize> - VmCoreChip for MultiplicationCoreChip +impl TraceStep + for MultiplicationStep where - I::Reads: Into<[[F; NUM_LIMBS]; 2]>, - I::Writes: From<[[F; NUM_LIMBS]; 1]>, + F: PrimeField32, + A: 'static + + for<'a> AdapterTraceStep< + F, + CTX, + ReadData: Into<[[u8; NUM_LIMBS]; 2]>, + WriteData: From<[[u8; NUM_LIMBS]; 1]>, + >, { - type Record = MultiplicationCoreRecord; - type Air = MultiplicationCoreAir; + type RecordLayout = EmptyAdapterCoreLayout; + type RecordMut<'a> = ( + A::RecordMut<'a>, + &'a mut MultiplicationCoreRecord, + ); - #[allow(clippy::type_complexity)] - fn execute_instruction( - &self, + fn get_opcode_name(&self, opcode: usize) -> String { + format!("{:?}", MulOpcode::from_usize(opcode - self.offset)) + } + + fn execute<'buf, RA>( + &mut self, + state: VmStateMut, CTX>, instruction: &Instruction, - _from_pc: u32, - reads: I::Reads, - ) -> Result<(AdapterRuntimeContext, Self::Record)> { + arena: &'buf mut RA, + ) -> Result<()> + where + RA: RecordArena<'buf, Self::RecordLayout, Self::RecordMut<'buf>>, + { let Instruction { opcode, .. } = instruction; - assert_eq!( - MulOpcode::from_usize(opcode.local_opcode_idx(self.air.offset)), + + debug_assert_eq!( + MulOpcode::from_usize(opcode.local_opcode_idx(self.offset)), MulOpcode::MUL ); + let (mut adapter_record, core_record) = arena.alloc(EmptyAdapterCoreLayout::new()); + + A::start(*state.pc, state.memory, &mut adapter_record); + + let [rs1, rs2] = self + .adapter + .read(state.memory, instruction, &mut adapter_record) + .into(); + + let (a, _) = run_mul::(&rs1, &rs2); + + core_record.b = rs1; + core_record.c = rs2; + + self.adapter + .write(state.memory, instruction, [a].into(), &mut adapter_record); + + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + Ok(()) + } +} +impl TraceFiller + for MultiplicationStep +where + F: PrimeField32, + A: 'static + AdapterTraceFiller, +{ + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + self.adapter.fill_trace_row(mem_helper, adapter_row); + + let record: &MultiplicationCoreRecord = + unsafe { get_record_from_slice(&mut core_row, ()) }; + + let core_row: &mut MultiplicationCoreCols = core_row.borrow_mut(); - let data: [[F; NUM_LIMBS]; 2] = reads.into(); - let b = data[0].map(|x| x.as_canonical_u32()); - let c = data[1].map(|y| y.as_canonical_u32()); - let (a, carry) = run_mul::(&b, &c); + let (a, carry) = run_mul::(&record.b, &record.c); for (a, carry) in a.iter().zip(carry.iter()) { - self.range_tuple_chip.add_count(&[*a, *carry]); + self.range_tuple_chip.add_count(&[*a as u32, *carry]); } - let output = AdapterRuntimeContext::without_pc([a.map(F::from_canonical_u32)]); - let record = MultiplicationCoreRecord { - a: a.map(F::from_canonical_u32), - b: data[0], - c: data[1], - }; + // write in reverse order + core_row.is_valid = F::ONE; + core_row.c = record.c.map(F::from_canonical_u8); + core_row.b = record.b.map(F::from_canonical_u8); + core_row.a = a.map(F::from_canonical_u8); + } +} + +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct MultiPreCompute { + a: u8, + b: u8, + c: u8, +} - Ok((output, record)) +impl StepExecutorE1 + for MultiplicationStep +where + F: PrimeField32, +{ + fn pre_compute_size(&self) -> usize { + size_of::() } + fn pre_compute_e1( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result> + where + Ctx: E1ExecutionCtx, + { + let pre_compute: &mut MultiPreCompute = data.borrow_mut(); + self.pre_compute_impl(pc, inst, pre_compute)?; + Ok(execute_e1_impl) + } +} - fn get_opcode_name(&self, opcode: usize) -> String { - format!("{:?}", MulOpcode::from_usize(opcode - self.air.offset)) +impl StepExecutorE2 + for MultiplicationStep +where + F: PrimeField32, +{ + fn e2_pre_compute_size(&self) -> usize { + size_of::>() } - fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) { - let row_slice: &mut MultiplicationCoreCols<_, NUM_LIMBS, LIMB_BITS> = - row_slice.borrow_mut(); - row_slice.a = record.a; - row_slice.b = record.b; - row_slice.c = record.c; - row_slice.is_valid = F::ONE; + fn pre_compute_e2( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result> + where + Ctx: E2ExecutionCtx, + { + let pre_compute: &mut E2PreCompute = data.borrow_mut(); + pre_compute.chip_idx = chip_idx as u32; + self.pre_compute_impl(pc, inst, &mut pre_compute.data)?; + Ok(execute_e2_impl) } +} + +#[inline(always)] +unsafe fn execute_e12_impl( + pre_compute: &MultiPreCompute, + vm_state: &mut VmSegmentState, +) { + let rs1: [u8; RV32_REGISTER_NUM_LIMBS] = + vm_state.vm_read(RV32_REGISTER_AS, pre_compute.b as u32); + let rs2: [u8; RV32_REGISTER_NUM_LIMBS] = + vm_state.vm_read(RV32_REGISTER_AS, pre_compute.c as u32); + let rs1 = u32::from_le_bytes(rs1); + let rs2 = u32::from_le_bytes(rs2); + let rd = rs1.wrapping_mul(rs2); + vm_state.vm_write(RV32_REGISTER_AS, pre_compute.a as u32, &rd.to_le_bytes()); + + vm_state.pc += DEFAULT_PC_STEP; + vm_state.instret += 1; +} + +unsafe fn execute_e1_impl( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &MultiPreCompute = pre_compute.borrow(); + execute_e12_impl(pre_compute, vm_state); +} + +unsafe fn execute_e2_impl( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &E2PreCompute = pre_compute.borrow(); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, 1); + execute_e12_impl(&pre_compute.data, vm_state); +} - fn air(&self) -> &Self::Air { - &self.air +impl MultiplicationStep { + fn pre_compute_impl( + &self, + pc: u32, + inst: &Instruction, + data: &mut MultiPreCompute, + ) -> Result<()> { + assert_eq!( + MulOpcode::from_usize(inst.opcode.local_opcode_idx(self.offset)), + MulOpcode::MUL + ); + if inst.d.as_canonical_u32() != RV32_REGISTER_AS { + return Err(InvalidInstruction(pc)); + } + + *data = MultiPreCompute { + a: inst.a.as_canonical_u32() as u8, + b: inst.b.as_canonical_u32() as u8, + c: inst.c.as_canonical_u32() as u8, + }; + Ok(()) } } // returns mul, carry +#[inline(always)] pub(super) fn run_mul( - x: &[u32; NUM_LIMBS], - y: &[u32; NUM_LIMBS], -) -> ([u32; NUM_LIMBS], [u32; NUM_LIMBS]) { - let mut result = [0; NUM_LIMBS]; - let mut carry = [0; NUM_LIMBS]; + x: &[u8; NUM_LIMBS], + y: &[u8; NUM_LIMBS], +) -> ([u8; NUM_LIMBS], [u32; NUM_LIMBS]) { + let mut result = [0u8; NUM_LIMBS]; + let mut carry = [0u32; NUM_LIMBS]; for i in 0..NUM_LIMBS { + let mut res = 0u32; if i > 0 { - result[i] = carry[i - 1]; + res = carry[i - 1]; } for j in 0..=i { - result[i] += x[j] * y[i - j]; + res += (x[j] as u32) * (y[i - j] as u32); } - carry[i] = result[i] >> LIMB_BITS; - result[i] %= 1 << LIMB_BITS; + carry[i] = res >> LIMB_BITS; + res %= 1u32 << LIMB_BITS; + result[i] = res as u8; } (result, carry) } diff --git a/extensions/rv32im/circuit/src/mul/mod.rs b/extensions/rv32im/circuit/src/mul/mod.rs index 5f28439977..ec7b2b43cc 100644 --- a/extensions/rv32im/circuit/src/mul/mod.rs +++ b/extensions/rv32im/circuit/src/mul/mod.rs @@ -1,6 +1,7 @@ -use openvm_circuit::arch::VmChipWrapper; +use openvm_circuit::arch::{MatrixRecordArena, NewVmChipWrapper, VmAirWrapper}; -use super::adapters::{Rv32MultAdapterChip, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}; +use super::adapters::{RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}; +use crate::adapters::{Rv32MultAdapterAir, Rv32MultAdapterStep}; mod core; pub use core::*; @@ -8,8 +9,11 @@ pub use core::*; #[cfg(test)] mod tests; -pub type Rv32MultiplicationChip = VmChipWrapper< - F, - Rv32MultAdapterChip, - MultiplicationCoreChip, +pub type Rv32MultiplicationAir = VmAirWrapper< + Rv32MultAdapterAir, + MultiplicationCoreAir, >; +pub type Rv32MultiplicationStep = + MultiplicationStep; +pub type Rv32MultiplicationChip = + NewVmChipWrapper>; diff --git a/extensions/rv32im/circuit/src/mul/tests.rs b/extensions/rv32im/circuit/src/mul/tests.rs index b942c24cc3..320c7eef9f 100644 --- a/extensions/rv32im/circuit/src/mul/tests.rs +++ b/extensions/rv32im/circuit/src/mul/tests.rs @@ -1,15 +1,12 @@ -use std::borrow::BorrowMut; +use std::{array, borrow::BorrowMut}; -use openvm_circuit::{ - arch::{ - testing::{TestAdapterChip, VmChipTestBuilder, RANGE_TUPLE_CHECKER_BUS}, - ExecutionBridge, VmAdapterChip, VmChipWrapper, - }, - utils::generate_long_number, +use openvm_circuit::arch::{ + testing::{VmChipTestBuilder, RANGE_TUPLE_CHECKER_BUS}, + InstructionExecutor, VmAirWrapper, }; use openvm_circuit_primitives::range_tuple::{RangeTupleCheckerBus, SharedRangeTupleCheckerChip}; -use openvm_instructions::{instruction::Instruction, LocalOpcode}; -use openvm_rv32im_transpiler::MulOpcode; +use openvm_instructions::LocalOpcode; +use openvm_rv32im_transpiler::MulOpcode::{self, MUL}; use openvm_stark_backend::{ p3_air::BaseAir, p3_field::FieldAlgebra, @@ -18,69 +15,90 @@ use openvm_stark_backend::{ Matrix, }, utils::disable_debug_builder, - verifier::VerificationError, - ChipUsageGetter, }; use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; +use rand::{rngs::StdRng, Rng}; use super::core::run_mul; use crate::{ - adapters::{Rv32MultAdapterChip, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}, - mul::{MultiplicationCoreChip, MultiplicationCoreCols, Rv32MultiplicationChip}, - test_utils::rv32_rand_write_register_or_imm, + adapters::{Rv32MultAdapterAir, Rv32MultAdapterStep, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}, + mul::{MultiplicationCoreCols, MultiplicationStep, Rv32MultiplicationChip}, + test_utils::{get_verification_error, rv32_rand_write_register_or_imm}, + MultiplicationCoreAir, }; +const MAX_INS_CAPACITY: usize = 128; +// the max number of limbs we currently support MUL for is 32 (i.e. for U256s) +const MAX_NUM_LIMBS: u32 = 32; type F = BabyBear; -////////////////////////////////////////////////////////////////////////////////////// -// POSITIVE TESTS -// -// Randomly generate computations and execute, ensuring that the generated trace -// passes all constraints. -////////////////////////////////////////////////////////////////////////////////////// - -fn run_rv32_mul_rand_test(num_ops: usize) { - // the max number of limbs we currently support MUL for is 32 (i.e. for U256s) - const MAX_NUM_LIMBS: u32 = 32; - let mut rng = create_seeded_rng(); - +fn create_test_chip( + tester: &mut VmChipTestBuilder, +) -> (Rv32MultiplicationChip, SharedRangeTupleCheckerChip<2>) { let range_tuple_bus = RangeTupleCheckerBus::new( RANGE_TUPLE_CHECKER_BUS, [1 << RV32_CELL_BITS, MAX_NUM_LIMBS * (1 << RV32_CELL_BITS)], ); let range_tuple_checker = SharedRangeTupleCheckerChip::new(range_tuple_bus); - let mut tester = VmChipTestBuilder::default(); let mut chip = Rv32MultiplicationChip::::new( - Rv32MultAdapterChip::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), + VmAirWrapper::new( + Rv32MultAdapterAir::new(tester.execution_bridge(), tester.memory_bridge()), + MultiplicationCoreAir::new(range_tuple_bus, MulOpcode::CLASS_OFFSET), + ), + MultiplicationStep::new( + Rv32MultAdapterStep::new(), + range_tuple_checker.clone(), + MulOpcode::CLASS_OFFSET, ), - MultiplicationCoreChip::new(range_tuple_checker.clone(), MulOpcode::CLASS_OFFSET), - tester.offline_memory_mutex_arc(), + tester.memory_helper(), ); + chip.set_trace_buffer_height(MAX_INS_CAPACITY); + + (chip, range_tuple_checker) +} + +#[allow(clippy::too_many_arguments)] +fn set_and_execute>( + tester: &mut VmChipTestBuilder, + chip: &mut E, + rng: &mut StdRng, + opcode: MulOpcode, + b: Option<[u8; RV32_REGISTER_NUM_LIMBS]>, + c: Option<[u8; RV32_REGISTER_NUM_LIMBS]>, +) { + let b = b.unwrap_or(array::from_fn(|_| rng.gen_range(0..=u8::MAX))); + let c = c.unwrap_or(array::from_fn(|_| rng.gen_range(0..=u8::MAX))); + + let (mut instruction, rd) = + rv32_rand_write_register_or_imm(tester, b, c, None, opcode.global_opcode().as_usize(), rng); + + instruction.e = F::ZERO; + tester.execute(chip, &instruction); + + let (a, _) = run_mul::(&b, &c); + assert_eq!( + a.map(F::from_canonical_u8), + tester.read::(1, rd) + ) +} + +////////////////////////////////////////////////////////////////////////////////////// +// POSITIVE TESTS +// +// Randomly generate computations and execute, ensuring that the generated trace +// passes all constraints. +////////////////////////////////////////////////////////////////////////////////////// + +#[test] +fn run_rv32_mul_rand_test() { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default(); + let (mut chip, range_tuple_checker) = create_test_chip(&mut tester); + let num_ops = 100; for _ in 0..num_ops { - let b = generate_long_number::(&mut rng); - let c = generate_long_number::(&mut rng); - - let (mut instruction, rd) = rv32_rand_write_register_or_imm( - &mut tester, - b, - c, - None, - MulOpcode::MUL.global_opcode().as_usize(), - &mut rng, - ); - instruction.e = F::ZERO; - tester.execute(&mut chip, &instruction); - - let (a, _) = run_mul::(&b, &c); - assert_eq!( - a.map(F::from_canonical_u32), - tester.read::(1, rd) - ) + set_and_execute(&mut tester, &mut chip, &mut rng, MUL, None, None); } let tester = tester @@ -91,74 +109,36 @@ fn run_rv32_mul_rand_test(num_ops: usize) { tester.simple_test().expect("Verification failed"); } -#[test] -fn rv32_mul_rand_test() { - run_rv32_mul_rand_test(1); -} - ////////////////////////////////////////////////////////////////////////////////////// // NEGATIVE TESTS // // Given a fake trace of a single operation, setup a chip and run the test. We replace -// the write part of the trace and check that the core chip throws the expected error. -// A dummy adapter is used so memory interactions don't indirectly cause false passes. +// part of the trace and check that the chip throws the expected error. ////////////////////////////////////////////////////////////////////////////////////// -type Rv32MultiplicationTestChip = VmChipWrapper< - F, - TestAdapterChip, - MultiplicationCoreChip, ->; - #[allow(clippy::too_many_arguments)] -fn run_rv32_mul_negative_test( - a: [u32; RV32_REGISTER_NUM_LIMBS], - b: [u32; RV32_REGISTER_NUM_LIMBS], - c: [u32; RV32_REGISTER_NUM_LIMBS], - is_valid: bool, +fn run_negative_mul_test( + opcode: MulOpcode, + prank_a: [u32; RV32_REGISTER_NUM_LIMBS], + b: [u8; RV32_REGISTER_NUM_LIMBS], + c: [u8; RV32_REGISTER_NUM_LIMBS], + prank_is_valid: bool, interaction_error: bool, ) { - const MAX_NUM_LIMBS: u32 = 32; - let range_tuple_bus = RangeTupleCheckerBus::new( - RANGE_TUPLE_CHECKER_BUS, - [1 << RV32_CELL_BITS, MAX_NUM_LIMBS * (1 << RV32_CELL_BITS)], - ); - let range_tuple_chip = SharedRangeTupleCheckerChip::new(range_tuple_bus); - + let mut rng = create_seeded_rng(); let mut tester = VmChipTestBuilder::default(); - let mut chip = Rv32MultiplicationTestChip::::new( - TestAdapterChip::new( - vec![[b.map(F::from_canonical_u32), c.map(F::from_canonical_u32)].concat()], - vec![None], - ExecutionBridge::new(tester.execution_bus(), tester.program_bus()), - ), - MultiplicationCoreChip::new(range_tuple_chip.clone(), MulOpcode::CLASS_OFFSET), - tester.offline_memory_mutex_arc(), - ); - - tester.execute( - &mut chip, - &Instruction::from_usize(MulOpcode::MUL.global_opcode(), [0, 0, 0, 1, 0]), - ); - - let trace_width = chip.trace_width(); - let adapter_width = BaseAir::::width(chip.adapter.air()); - let (_, carry) = run_mul::(&b, &c); + let (mut chip, range_tuple_chip) = create_test_chip(&mut tester); - range_tuple_chip.clear(); - if is_valid { - for (a, carry) in a.iter().zip(carry.iter()) { - range_tuple_chip.add_count(&[*a, *carry]); - } - } + set_and_execute(&mut tester, &mut chip, &mut rng, opcode, Some(b), Some(c)); + let adapter_width = BaseAir::::width(&chip.air.adapter); let modify_trace = |trace: &mut DenseMatrix| { let mut values = trace.row_slice(0).to_vec(); let cols: &mut MultiplicationCoreCols = values.split_at_mut(adapter_width).1.borrow_mut(); - cols.a = a.map(F::from_canonical_u32); - cols.is_valid = F::from_bool(is_valid); - *trace = RowMajorMatrix::new(values, trace_width); + cols.a = prank_a.map(F::from_canonical_u32); + cols.is_valid = F::from_bool(prank_is_valid); + *trace = RowMajorMatrix::new(values, trace.width()); }; disable_debug_builder(); @@ -167,16 +147,13 @@ fn run_rv32_mul_negative_test( .load_and_prank_trace(chip, modify_trace) .load(range_tuple_chip) .finalize(); - tester.simple_test_with_expected_error(if interaction_error { - VerificationError::ChallengePhaseError - } else { - VerificationError::OodEvaluationMismatch - }); + tester.simple_test_with_expected_error(get_verification_error(interaction_error)); } #[test] fn rv32_mul_wrong_negative_test() { - run_rv32_mul_negative_test( + run_negative_mul_test( + MUL, [63, 247, 125, 234], [51, 109, 78, 142], [197, 85, 150, 32], @@ -187,7 +164,8 @@ fn rv32_mul_wrong_negative_test() { #[test] fn rv32_mul_is_valid_false_negative_test() { - run_rv32_mul_negative_test( + run_negative_mul_test( + MUL, [63, 247, 125, 234], [51, 109, 78, 142], [197, 85, 150, 32], @@ -204,9 +182,9 @@ fn rv32_mul_is_valid_false_negative_test() { #[test] fn run_mul_sanity_test() { - let x: [u32; RV32_REGISTER_NUM_LIMBS] = [197, 85, 150, 32]; - let y: [u32; RV32_REGISTER_NUM_LIMBS] = [51, 109, 78, 142]; - let z: [u32; RV32_REGISTER_NUM_LIMBS] = [63, 247, 125, 232]; + let x: [u8; RV32_REGISTER_NUM_LIMBS] = [197, 85, 150, 32]; + let y: [u8; RV32_REGISTER_NUM_LIMBS] = [51, 109, 78, 142]; + let z: [u8; RV32_REGISTER_NUM_LIMBS] = [63, 247, 125, 232]; let c: [u32; RV32_REGISTER_NUM_LIMBS] = [39, 100, 126, 205]; let (result, carry) = run_mul::(&x, &y); for i in 0..RV32_REGISTER_NUM_LIMBS { diff --git a/extensions/rv32im/circuit/src/mulh/core.rs b/extensions/rv32im/circuit/src/mulh/core.rs index 16aa8fd550..b5c61c188f 100644 --- a/extensions/rv32im/circuit/src/mulh/core.rs +++ b/extensions/rv32im/circuit/src/mulh/core.rs @@ -3,16 +3,28 @@ use std::{ borrow::{Borrow, BorrowMut}, }; -use openvm_circuit::arch::{ - AdapterAirContext, AdapterRuntimeContext, MinimalInstruction, Result, VmAdapterInterface, - VmCoreAir, VmCoreChip, +use openvm_circuit::{ + arch::{ + execution_mode::{E1ExecutionCtx, E2ExecutionCtx}, + get_record_from_slice, AdapterAirContext, AdapterTraceFiller, AdapterTraceStep, + E2PreCompute, EmptyAdapterCoreLayout, ExecuteFunc, MinimalInstruction, RecordArena, Result, + StepExecutorE1, StepExecutorE2, TraceFiller, TraceStep, VmAdapterInterface, VmCoreAir, + VmSegmentState, VmStateMut, + }, + system::memory::{online::TracingMemory, MemoryAuxColsFactory}, }; use openvm_circuit_primitives::{ bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, range_tuple::{RangeTupleCheckerBus, SharedRangeTupleCheckerChip}, + AlignedBytesBorrow, }; use openvm_circuit_primitives_derive::AlignedBorrow; -use openvm_instructions::{instruction::Instruction, LocalOpcode}; +use openvm_instructions::{ + instruction::Instruction, + program::DEFAULT_PC_STEP, + riscv::{RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS}, + LocalOpcode, +}; use openvm_rv32im_transpiler::MulHOpcode; use openvm_stark_backend::{ interaction::InteractionBuilder, @@ -20,8 +32,6 @@ use openvm_stark_backend::{ p3_field::{Field, FieldAlgebra, PrimeField32}, rap::BaseAirWithPublicValues, }; -use serde::{Deserialize, Serialize}; -use serde_big_array::BigArray; use strum::IntoEnumIterator; #[repr(C)] @@ -40,7 +50,7 @@ pub struct MulHCoreCols { pub opcode_mulhu_flag: T, } -#[derive(Copy, Clone, Debug)] +#[derive(Copy, Clone, Debug, derive_new::new)] pub struct MulHCoreAir { pub bitwise_lookup_bus: BitwiseOperationLookupBus, pub range_tuple_bus: RangeTupleCheckerBus<2>, @@ -183,14 +193,23 @@ where } } -pub struct MulHCoreChip { - pub air: MulHCoreAir, +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug)] +pub struct MulHCoreRecord { + pub b: [u8; NUM_LIMBS], + pub c: [u8; NUM_LIMBS], + pub local_opcode: u8, +} + +pub struct MulHStep { + adapter: A, pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, pub range_tuple_chip: SharedRangeTupleCheckerChip<2>, } -impl MulHCoreChip { +impl MulHStep { pub fn new( + adapter: A, bitwise_lookup_chip: SharedBitwiseOperationLookupChip, range_tuple_chip: SharedRangeTupleCheckerChip<2>, ) -> Self { @@ -209,55 +228,96 @@ impl MulHCoreChip { - pub opcode: MulHOpcode, - #[serde(with = "BigArray")] - pub a: [T; NUM_LIMBS], - #[serde(with = "BigArray")] - pub b: [T; NUM_LIMBS], - #[serde(with = "BigArray")] - pub c: [T; NUM_LIMBS], - #[serde(with = "BigArray")] - pub a_mul: [T; NUM_LIMBS], - pub b_ext: T, - pub c_ext: T, -} - -impl, const NUM_LIMBS: usize, const LIMB_BITS: usize> - VmCoreChip for MulHCoreChip +impl TraceStep + for MulHStep where - I::Reads: Into<[[F; NUM_LIMBS]; 2]>, - I::Writes: From<[[F; NUM_LIMBS]; 1]>, + F: PrimeField32, + A: 'static + + for<'a> AdapterTraceStep< + F, + CTX, + ReadData: Into<[[u8; NUM_LIMBS]; 2]>, + WriteData: From<[[u8; NUM_LIMBS]; 1]>, + >, { - type Record = MulHCoreRecord; - type Air = MulHCoreAir; + type RecordLayout = EmptyAdapterCoreLayout; + type RecordMut<'a> = ( + A::RecordMut<'a>, + &'a mut MulHCoreRecord, + ); - #[allow(clippy::type_complexity)] - fn execute_instruction( - &self, + fn get_opcode_name(&self, opcode: usize) -> String { + format!( + "{:?}", + MulHOpcode::from_usize(opcode - MulHOpcode::CLASS_OFFSET) + ) + } + + fn execute<'buf, RA>( + &mut self, + state: VmStateMut, CTX>, instruction: &Instruction, - _from_pc: u32, - reads: I::Reads, - ) -> Result<(AdapterRuntimeContext, Self::Record)> { + arena: &'buf mut RA, + ) -> Result<()> + where + RA: RecordArena<'buf, Self::RecordLayout, Self::RecordMut<'buf>>, + { let Instruction { opcode, .. } = instruction; - let mulh_opcode = MulHOpcode::from_usize(opcode.local_opcode_idx(MulHOpcode::CLASS_OFFSET)); - let data: [[F; NUM_LIMBS]; 2] = reads.into(); - let b = data[0].map(|x| x.as_canonical_u32()); - let c = data[1].map(|y| y.as_canonical_u32()); - let (a, a_mul, carry, b_ext, c_ext) = run_mulh::(mulh_opcode, &b, &c); + let (mut adapter_record, core_record) = arena.alloc(EmptyAdapterCoreLayout::new()); + + A::start(*state.pc, state.memory, &mut adapter_record); + + core_record.local_opcode = opcode.local_opcode_idx(MulHOpcode::CLASS_OFFSET) as u8; + let mulh_opcode = MulHOpcode::from_usize(core_record.local_opcode as usize); + + [core_record.b, core_record.c] = self + .adapter + .read(state.memory, instruction, &mut adapter_record) + .into(); + + let (a, _, _, _, _) = run_mulh::( + mulh_opcode, + &core_record.b.map(u32::from), + &core_record.c.map(u32::from), + ); + + let a = a.map(|x| x as u8); + self.adapter + .write(state.memory, instruction, [a].into(), &mut adapter_record); + + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + + Ok(()) + } +} + +impl TraceFiller + for MulHStep +where + F: PrimeField32, + A: 'static + AdapterTraceFiller, +{ + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + self.adapter.fill_trace_row(mem_helper, adapter_row); + let record: &MulHCoreRecord = + unsafe { get_record_from_slice(&mut core_row, ()) }; + let core_row: &mut MulHCoreCols = core_row.borrow_mut(); + + let opcode = MulHOpcode::from_usize(record.local_opcode as usize); + let (a, a_mul, carry, b_ext, c_ext) = run_mulh::( + opcode, + &record.b.map(u32::from), + &record.c.map(u32::from), + ); for i in 0..NUM_LIMBS { self.range_tuple_chip.add_count(&[a_mul[i], carry[i]]); @@ -265,55 +325,182 @@ where .add_count(&[a[i], carry[NUM_LIMBS + i]]); } - if mulh_opcode != MulHOpcode::MULHU { + if opcode != MulHOpcode::MULHU { let b_sign_mask = if b_ext == 0 { 0 } else { 1 << (LIMB_BITS - 1) }; let c_sign_mask = if c_ext == 0 { 0 } else { 1 << (LIMB_BITS - 1) }; self.bitwise_lookup_chip.request_range( - (b[NUM_LIMBS - 1] - b_sign_mask) << 1, - (c[NUM_LIMBS - 1] - c_sign_mask) << ((mulh_opcode == MulHOpcode::MULH) as u32), + (record.b[NUM_LIMBS - 1] as u32 - b_sign_mask) << 1, + (record.c[NUM_LIMBS - 1] as u32 - c_sign_mask) + << ((opcode == MulHOpcode::MULH) as u32), ); } - let output = AdapterRuntimeContext::without_pc([a.map(F::from_canonical_u32)]); - let record = MulHCoreRecord { - opcode: mulh_opcode, - a: a.map(F::from_canonical_u32), - b: data[0], - c: data[1], - a_mul: a_mul.map(F::from_canonical_u32), - b_ext: F::from_canonical_u32(b_ext), - c_ext: F::from_canonical_u32(c_ext), + // Write in reverse order + core_row.opcode_mulhu_flag = F::from_bool(opcode == MulHOpcode::MULHU); + core_row.opcode_mulhsu_flag = F::from_bool(opcode == MulHOpcode::MULHSU); + core_row.opcode_mulh_flag = F::from_bool(opcode == MulHOpcode::MULH); + core_row.c_ext = F::from_canonical_u32(c_ext); + core_row.b_ext = F::from_canonical_u32(b_ext); + core_row.a_mul = a_mul.map(F::from_canonical_u32); + core_row.c = record.c.map(F::from_canonical_u8); + core_row.b = record.b.map(F::from_canonical_u8); + core_row.a = a.map(F::from_canonical_u32); + } +} + +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct MulHPreCompute { + a: u8, + b: u8, + c: u8, +} + +impl StepExecutorE1 + for MulHStep +where + F: PrimeField32, +{ + #[inline(always)] + fn pre_compute_size(&self) -> usize { + size_of::() + } + + #[inline(always)] + fn pre_compute_e1( + &self, + _pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result> { + let pre_compute: &mut MulHPreCompute = data.borrow_mut(); + let local_opcode = self.pre_compute_e1(inst, pre_compute)?; + let fn_ptr = match local_opcode { + MulHOpcode::MULH => execute_e1_impl::<_, _, MulHOp>, + MulHOpcode::MULHSU => execute_e1_impl::<_, _, MulHSuOp>, + MulHOpcode::MULHU => execute_e1_impl::<_, _, MulHUOp>, }; + Ok(fn_ptr) + } +} - Ok((output, record)) +impl StepExecutorE2 + for MulHStep +where + F: PrimeField32, +{ + fn e2_pre_compute_size(&self) -> usize { + size_of::>() } - fn get_opcode_name(&self, opcode: usize) -> String { - format!( - "{:?}", - MulHOpcode::from_usize(opcode - MulHOpcode::CLASS_OFFSET) - ) + fn pre_compute_e2( + &self, + chip_idx: usize, + _pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result> + where + Ctx: E2ExecutionCtx, + { + let pre_compute: &mut E2PreCompute = data.borrow_mut(); + pre_compute.chip_idx = chip_idx as u32; + let local_opcode = self.pre_compute_e1(inst, &mut pre_compute.data)?; + let fn_ptr = match local_opcode { + MulHOpcode::MULH => execute_e2_impl::<_, _, MulHOp>, + MulHOpcode::MULHSU => execute_e2_impl::<_, _, MulHSuOp>, + MulHOpcode::MULHU => execute_e2_impl::<_, _, MulHUOp>, + }; + Ok(fn_ptr) } +} + +#[inline(always)] +unsafe fn execute_e12_impl( + pre_compute: &MulHPreCompute, + vm_state: &mut VmSegmentState, +) { + let rs1: [u8; RV32_REGISTER_NUM_LIMBS] = + vm_state.vm_read(RV32_REGISTER_AS, pre_compute.b as u32); + let rs2: [u8; RV32_REGISTER_NUM_LIMBS] = + vm_state.vm_read(RV32_REGISTER_AS, pre_compute.c as u32); + let rd = ::compute(rs1, rs2); + vm_state.vm_write(RV32_REGISTER_AS, pre_compute.a as u32, &rd); + + vm_state.pc += DEFAULT_PC_STEP; + vm_state.instret += 1; +} + +unsafe fn execute_e1_impl( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &MulHPreCompute = pre_compute.borrow(); + execute_e12_impl::(pre_compute, vm_state); +} + +unsafe fn execute_e2_impl( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &E2PreCompute = pre_compute.borrow(); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, 1); + execute_e12_impl::(&pre_compute.data, vm_state); +} - fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) { - let row_slice: &mut MulHCoreCols<_, NUM_LIMBS, LIMB_BITS> = row_slice.borrow_mut(); - row_slice.a = record.a; - row_slice.b = record.b; - row_slice.c = record.c; - row_slice.a_mul = record.a_mul; - row_slice.b_ext = record.b_ext; - row_slice.c_ext = record.c_ext; - row_slice.opcode_mulh_flag = F::from_bool(record.opcode == MulHOpcode::MULH); - row_slice.opcode_mulhsu_flag = F::from_bool(record.opcode == MulHOpcode::MULHSU); - row_slice.opcode_mulhu_flag = F::from_bool(record.opcode == MulHOpcode::MULHU); +impl MulHStep { + #[inline(always)] + fn pre_compute_e1( + &self, + inst: &Instruction, + data: &mut MulHPreCompute, + ) -> Result { + *data = MulHPreCompute { + a: inst.a.as_canonical_u32() as u8, + b: inst.b.as_canonical_u32() as u8, + c: inst.c.as_canonical_u32() as u8, + }; + Ok(MulHOpcode::from_usize( + inst.opcode.local_opcode_idx(MulHOpcode::CLASS_OFFSET), + )) } +} - fn air(&self) -> &Self::Air { - &self.air +trait MulHOperation { + fn compute(rs1: [u8; 4], rs1: [u8; 4]) -> [u8; 4]; +} +struct MulHOp; +struct MulHSuOp; +struct MulHUOp; +impl MulHOperation for MulHOp { + #[inline(always)] + fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> [u8; 4] { + let rs1 = i32::from_le_bytes(rs1) as i64; + let rs2 = i32::from_le_bytes(rs2) as i64; + ((rs1.wrapping_mul(rs2) >> 32) as u32).to_le_bytes() + } +} +impl MulHOperation for MulHSuOp { + #[inline(always)] + fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> [u8; 4] { + let rs1 = i32::from_le_bytes(rs1) as i64; + let rs2 = u32::from_le_bytes(rs2) as i64; + ((rs1.wrapping_mul(rs2) >> 32) as u32).to_le_bytes() + } +} +impl MulHOperation for MulHUOp { + #[inline(always)] + fn compute(rs1: [u8; 4], rs2: [u8; 4]) -> [u8; 4] { + let rs1 = u32::from_le_bytes(rs1) as i64; + let rs2 = u32::from_le_bytes(rs2) as i64; + ((rs1.wrapping_mul(rs2) >> 32) as u32).to_le_bytes() } } // returns mulh[[s]u], mul, carry, x_ext, y_ext +#[inline(always)] pub(super) fn run_mulh( opcode: MulHOpcode, x: &[u32; NUM_LIMBS], diff --git a/extensions/rv32im/circuit/src/mulh/mod.rs b/extensions/rv32im/circuit/src/mulh/mod.rs index 284b77191a..1e7df27cb4 100644 --- a/extensions/rv32im/circuit/src/mulh/mod.rs +++ b/extensions/rv32im/circuit/src/mulh/mod.rs @@ -1,6 +1,7 @@ -use openvm_circuit::arch::VmChipWrapper; +use openvm_circuit::arch::{MatrixRecordArena, NewVmChipWrapper, VmAirWrapper}; -use super::adapters::{Rv32MultAdapterChip, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}; +use super::adapters::{RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}; +use crate::adapters::{Rv32MultAdapterAir, Rv32MultAdapterStep}; mod core; pub use core::*; @@ -8,5 +9,7 @@ pub use core::*; #[cfg(test)] mod tests; -pub type Rv32MulHChip = - VmChipWrapper, MulHCoreChip>; +pub type Rv32MulHAir = + VmAirWrapper>; +pub type Rv32MulHStep = MulHStep; +pub type Rv32MulHChip = NewVmChipWrapper>; diff --git a/extensions/rv32im/circuit/src/mulh/tests.rs b/extensions/rv32im/circuit/src/mulh/tests.rs index 1c7cf5b5cb..43f4504244 100644 --- a/extensions/rv32im/circuit/src/mulh/tests.rs +++ b/extensions/rv32im/circuit/src/mulh/tests.rs @@ -3,10 +3,9 @@ use std::borrow::BorrowMut; use openvm_circuit::{ arch::{ testing::{ - memory::gen_pointer, TestAdapterChip, VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS, - RANGE_TUPLE_CHECKER_BUS, + memory::gen_pointer, VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS, RANGE_TUPLE_CHECKER_BUS, }, - ExecutionBridge, InstructionExecutor, VmAdapterChip, VmChipWrapper, + InstructionExecutor, VmAirWrapper, }, utils::generate_long_number, }; @@ -15,7 +14,7 @@ use openvm_circuit_primitives::{ range_tuple::{RangeTupleCheckerBus, SharedRangeTupleCheckerChip}, }; use openvm_instructions::{instruction::Instruction, LocalOpcode}; -use openvm_rv32im_transpiler::MulHOpcode; +use openvm_rv32im_transpiler::MulHOpcode::{self, *}; use openvm_stark_backend::{ p3_air::BaseAir, p3_field::FieldAlgebra, @@ -24,36 +23,75 @@ use openvm_stark_backend::{ Matrix, }, utils::disable_debug_builder, - verifier::VerificationError, - ChipUsageGetter, }; use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; use rand::rngs::StdRng; +use test_case::test_case; use super::core::run_mulh; use crate::{ - adapters::{Rv32MultAdapterChip, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}, - mulh::{MulHCoreChip, MulHCoreCols, Rv32MulHChip}, + adapters::{Rv32MultAdapterAir, Rv32MultAdapterStep, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}, + mulh::{MulHCoreCols, MulHStep, Rv32MulHChip}, + test_utils::get_verification_error, + MulHCoreAir, }; +const MAX_INS_CAPACITY: usize = 128; +// the max number of limbs we currently support MUL for is 32 (i.e. for U256s) +const MAX_NUM_LIMBS: u32 = 32; type F = BabyBear; -////////////////////////////////////////////////////////////////////////////////////// -// POSITIVE TESTS -// -// Randomly generate computations and execute, ensuring that the generated trace -// passes all constraints. -////////////////////////////////////////////////////////////////////////////////////// +fn create_test_chip( + tester: &mut VmChipTestBuilder, +) -> ( + Rv32MulHChip, + SharedBitwiseOperationLookupChip, + SharedRangeTupleCheckerChip<2>, +) { + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); + let range_tuple_bus = RangeTupleCheckerBus::new( + RANGE_TUPLE_CHECKER_BUS, + [1 << RV32_CELL_BITS, MAX_NUM_LIMBS * (1 << RV32_CELL_BITS)], + ); + + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); + let range_tuple_checker = SharedRangeTupleCheckerChip::new(range_tuple_bus); + + let mut chip = Rv32MulHChip::::new( + VmAirWrapper::new( + Rv32MultAdapterAir::new(tester.execution_bridge(), tester.memory_bridge()), + MulHCoreAir::new(bitwise_bus, range_tuple_bus), + ), + MulHStep::new( + Rv32MultAdapterStep::new(), + bitwise_chip.clone(), + range_tuple_checker.clone(), + ), + tester.memory_helper(), + ); + chip.set_trace_buffer_height(MAX_INS_CAPACITY); + + (chip, bitwise_chip, range_tuple_checker) +} #[allow(clippy::too_many_arguments)] -fn run_rv32_mulh_rand_write_execute>( - opcode: MulHOpcode, +fn set_and_execute>( tester: &mut VmChipTestBuilder, chip: &mut E, - b: [u32; RV32_REGISTER_NUM_LIMBS], - c: [u32; RV32_REGISTER_NUM_LIMBS], rng: &mut StdRng, + opcode: MulHOpcode, + b: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, + c: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, ) { + let b = b.unwrap_or(generate_long_number::< + RV32_REGISTER_NUM_LIMBS, + RV32_CELL_BITS, + >(rng)); + let c = c.unwrap_or(generate_long_number::< + RV32_REGISTER_NUM_LIMBS, + RV32_CELL_BITS, + >(rng)); + let rs1 = gen_pointer(rng, 4); let rs2 = gen_pointer(rng, 4); let rd = gen_pointer(rng, 4); @@ -61,47 +99,35 @@ fn run_rv32_mulh_rand_write_execute>( tester.write::(1, rs1, b.map(F::from_canonical_u32)); tester.write::(1, rs2, c.map(F::from_canonical_u32)); - let (a, _, _, _, _) = run_mulh::(opcode, &b, &c); tester.execute( chip, &Instruction::from_usize(opcode.global_opcode(), [rd, rs1, rs2, 1, 0]), ); + let (a, _, _, _, _) = run_mulh::(opcode, &b, &c); assert_eq!( a.map(F::from_canonical_u32), tester.read::(1, rd) ); } +////////////////////////////////////////////////////////////////////////////////////// +// POSITIVE TESTS +// +// Randomly generate computations and execute, ensuring that the generated trace +// passes all constraints. +////////////////////////////////////////////////////////////////////////////////////// + +#[test_case(MULH, 100)] +#[test_case(MULHSU, 100)] +#[test_case(MULHU, 100)] fn run_rv32_mulh_rand_test(opcode: MulHOpcode, num_ops: usize) { - // the max number of limbs we currently support MUL for is 32 (i.e. for U256s) - const MAX_NUM_LIMBS: u32 = 32; let mut rng = create_seeded_rng(); - - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let range_tuple_bus = RangeTupleCheckerBus::new( - RANGE_TUPLE_CHECKER_BUS, - [1 << RV32_CELL_BITS, MAX_NUM_LIMBS * (1 << RV32_CELL_BITS)], - ); - - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let range_tuple_checker = SharedRangeTupleCheckerChip::new(range_tuple_bus); - let mut tester = VmChipTestBuilder::default(); - let mut chip = Rv32MulHChip::::new( - Rv32MultAdapterChip::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - ), - MulHCoreChip::new(bitwise_chip.clone(), range_tuple_checker.clone()), - tester.offline_memory_mutex_arc(), - ); + let (mut chip, bitwise_chip, range_tuple_checker) = create_test_chip(&mut tester); for _ in 0..num_ops { - let b = generate_long_number::(&mut rng); - let c = generate_long_number::(&mut rng); - run_rv32_mulh_rand_write_execute(opcode, &mut tester, &mut chip, b, c, &mut rng); + set_and_execute(&mut tester, &mut chip, &mut rng, opcode, None, None); } let tester = tester @@ -113,88 +139,40 @@ fn run_rv32_mulh_rand_test(opcode: MulHOpcode, num_ops: usize) { tester.simple_test().expect("Verification failed"); } -#[test] -fn rv32_mulh_rand_test() { - run_rv32_mulh_rand_test(MulHOpcode::MULH, 100); -} - -#[test] -fn rv32_mulhsu_rand_test() { - run_rv32_mulh_rand_test(MulHOpcode::MULHSU, 100); -} - -#[test] -fn rv32_mulhu_rand_test() { - run_rv32_mulh_rand_test(MulHOpcode::MULHU, 100); -} - ////////////////////////////////////////////////////////////////////////////////////// // NEGATIVE TESTS // // Given a fake trace of a single operation, setup a chip and run the test. We replace -// the write part of the trace and check that the core chip throws the expected error. -// A dummy adapter is used so memory interactions don't indirectly cause false passes. +// part of the trace and check that the chip throws the expected error. ////////////////////////////////////////////////////////////////////////////////////// -type Rv32MulHTestChip = - VmChipWrapper, MulHCoreChip>; - #[allow(clippy::too_many_arguments)] -fn run_rv32_mulh_negative_test( +fn run_negative_mulh_test( opcode: MulHOpcode, - a: [u32; RV32_REGISTER_NUM_LIMBS], + prank_a: [u32; RV32_REGISTER_NUM_LIMBS], b: [u32; RV32_REGISTER_NUM_LIMBS], c: [u32; RV32_REGISTER_NUM_LIMBS], - a_mul: [u32; RV32_REGISTER_NUM_LIMBS], - b_ext: u32, - c_ext: u32, + prank_a_mul: [u32; RV32_REGISTER_NUM_LIMBS], + prank_b_ext: u32, + prank_c_ext: u32, interaction_error: bool, ) { - const MAX_NUM_LIMBS: u32 = 32; - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let range_tuple_bus = RangeTupleCheckerBus::new( - RANGE_TUPLE_CHECKER_BUS, - [1 << RV32_CELL_BITS, MAX_NUM_LIMBS * (1 << RV32_CELL_BITS)], - ); - - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let range_tuple_chip = SharedRangeTupleCheckerChip::new(range_tuple_bus); - + let mut rng = create_seeded_rng(); let mut tester = VmChipTestBuilder::default(); - let mut chip = Rv32MulHTestChip::::new( - TestAdapterChip::new( - vec![[b.map(F::from_canonical_u32), c.map(F::from_canonical_u32)].concat()], - vec![None], - ExecutionBridge::new(tester.execution_bus(), tester.program_bus()), - ), - MulHCoreChip::new(bitwise_chip.clone(), range_tuple_chip.clone()), - tester.offline_memory_mutex_arc(), - ); + let (mut chip, bitwise_chip, range_tuple_chip) = create_test_chip(&mut tester); - tester.execute( - &mut chip, - &Instruction::from_usize(opcode.global_opcode(), [0, 0, 0, 1, 0]), - ); - - let trace_width = chip.trace_width(); - let adapter_width = BaseAir::::width(chip.adapter.air()); - let (_, _, carry, _, _) = run_mulh::(opcode, &b, &c); - - range_tuple_chip.clear(); - for i in 0..RV32_REGISTER_NUM_LIMBS { - range_tuple_chip.add_count(&[a_mul[i], carry[i]]); - range_tuple_chip.add_count(&[a[i], carry[RV32_REGISTER_NUM_LIMBS + i]]); - } + set_and_execute(&mut tester, &mut chip, &mut rng, opcode, Some(b), Some(c)); + let adapter_width = BaseAir::::width(&chip.air.adapter); let modify_trace = |trace: &mut DenseMatrix| { let mut values = trace.row_slice(0).to_vec(); let cols: &mut MulHCoreCols = values.split_at_mut(adapter_width).1.borrow_mut(); - cols.a = a.map(F::from_canonical_u32); - cols.a_mul = a_mul.map(F::from_canonical_u32); - cols.b_ext = F::from_canonical_u32(b_ext); - cols.c_ext = F::from_canonical_u32(c_ext); - *trace = RowMajorMatrix::new(values, trace_width); + cols.a = prank_a.map(F::from_canonical_u32); + cols.a_mul = prank_a_mul.map(F::from_canonical_u32); + cols.b_ext = F::from_canonical_u32(prank_b_ext); + cols.c_ext = F::from_canonical_u32(prank_c_ext); + *trace = RowMajorMatrix::new(values, trace.width()); }; disable_debug_builder(); @@ -204,17 +182,13 @@ fn run_rv32_mulh_negative_test( .load(bitwise_chip) .load(range_tuple_chip) .finalize(); - tester.simple_test_with_expected_error(if interaction_error { - VerificationError::ChallengePhaseError - } else { - VerificationError::OodEvaluationMismatch - }); + tester.simple_test_with_expected_error(get_verification_error(interaction_error)); } #[test] fn rv32_mulh_wrong_a_mul_negative_test() { - run_rv32_mulh_negative_test( - MulHOpcode::MULH, + run_negative_mulh_test( + MULH, [130, 9, 135, 241], [197, 85, 150, 32], [51, 109, 78, 142], @@ -227,8 +201,8 @@ fn rv32_mulh_wrong_a_mul_negative_test() { #[test] fn rv32_mulh_wrong_a_negative_test() { - run_rv32_mulh_negative_test( - MulHOpcode::MULH, + run_negative_mulh_test( + MULH, [130, 9, 135, 242], [197, 85, 150, 32], [51, 109, 78, 142], @@ -241,8 +215,8 @@ fn rv32_mulh_wrong_a_negative_test() { #[test] fn rv32_mulh_wrong_ext_negative_test() { - run_rv32_mulh_negative_test( - MulHOpcode::MULH, + run_negative_mulh_test( + MULH, [1, 0, 0, 0], [0, 0, 0, 128], [2, 0, 0, 0], @@ -255,8 +229,8 @@ fn rv32_mulh_wrong_ext_negative_test() { #[test] fn rv32_mulh_invalid_ext_negative_test() { - run_rv32_mulh_negative_test( - MulHOpcode::MULH, + run_negative_mulh_test( + MULH, [3, 2, 2, 2], [0, 0, 0, 128], [2, 0, 0, 0], @@ -269,8 +243,8 @@ fn rv32_mulh_invalid_ext_negative_test() { #[test] fn rv32_mulhsu_wrong_a_mul_negative_test() { - run_rv32_mulh_negative_test( - MulHOpcode::MULHSU, + run_negative_mulh_test( + MULHSU, [174, 40, 246, 202], [197, 85, 150, 160], [51, 109, 78, 142], @@ -283,8 +257,8 @@ fn rv32_mulhsu_wrong_a_mul_negative_test() { #[test] fn rv32_mulhsu_wrong_a_negative_test() { - run_rv32_mulh_negative_test( - MulHOpcode::MULHSU, + run_negative_mulh_test( + MULHSU, [174, 40, 246, 201], [197, 85, 150, 160], [51, 109, 78, 142], @@ -297,8 +271,8 @@ fn rv32_mulhsu_wrong_a_negative_test() { #[test] fn rv32_mulhsu_wrong_b_ext_negative_test() { - run_rv32_mulh_negative_test( - MulHOpcode::MULHSU, + run_negative_mulh_test( + MULHSU, [1, 0, 0, 0], [0, 0, 0, 128], [2, 0, 0, 0], @@ -311,8 +285,8 @@ fn rv32_mulhsu_wrong_b_ext_negative_test() { #[test] fn rv32_mulhsu_wrong_c_ext_negative_test() { - run_rv32_mulh_negative_test( - MulHOpcode::MULHSU, + run_negative_mulh_test( + MULHSU, [0, 0, 0, 64], [0, 0, 0, 128], [0, 0, 0, 128], @@ -325,8 +299,8 @@ fn rv32_mulhsu_wrong_c_ext_negative_test() { #[test] fn rv32_mulhu_wrong_a_mul_negative_test() { - run_rv32_mulh_negative_test( - MulHOpcode::MULHU, + run_negative_mulh_test( + MULHU, [130, 9, 135, 241], [197, 85, 150, 32], [51, 109, 78, 142], @@ -339,8 +313,8 @@ fn rv32_mulhu_wrong_a_mul_negative_test() { #[test] fn rv32_mulhu_wrong_a_negative_test() { - run_rv32_mulh_negative_test( - MulHOpcode::MULHU, + run_negative_mulh_test( + MULHU, [130, 9, 135, 240], [197, 85, 150, 32], [51, 109, 78, 142], @@ -353,8 +327,8 @@ fn rv32_mulhu_wrong_a_negative_test() { #[test] fn rv32_mulhu_wrong_ext_negative_test() { - run_rv32_mulh_negative_test( - MulHOpcode::MULHU, + run_negative_mulh_test( + MULHU, [255, 255, 255, 255], [0, 0, 0, 128], [2, 0, 0, 0], @@ -380,7 +354,7 @@ fn run_mulh_sanity_test() { let c: [u32; RV32_REGISTER_NUM_LIMBS] = [303, 375, 449, 463]; let c_mul: [u32; RV32_REGISTER_NUM_LIMBS] = [39, 100, 126, 205]; let (res, res_mul, carry, x_ext, y_ext) = - run_mulh::(MulHOpcode::MULH, &x, &y); + run_mulh::(MULH, &x, &y); for i in 0..RV32_REGISTER_NUM_LIMBS { assert_eq!(z[i], res[i]); assert_eq!(z_mul[i], res_mul[i]); @@ -400,7 +374,7 @@ fn run_mulhu_sanity_test() { let c: [u32; RV32_REGISTER_NUM_LIMBS] = [107, 93, 18, 0]; let c_mul: [u32; RV32_REGISTER_NUM_LIMBS] = [39, 100, 126, 205]; let (res, res_mul, carry, x_ext, y_ext) = - run_mulh::(MulHOpcode::MULHU, &x, &y); + run_mulh::(MULHU, &x, &y); for i in 0..RV32_REGISTER_NUM_LIMBS { assert_eq!(z[i], res[i]); assert_eq!(z_mul[i], res_mul[i]); @@ -420,7 +394,7 @@ fn run_mulhsu_pos_sanity_test() { let c: [u32; RV32_REGISTER_NUM_LIMBS] = [107, 93, 18, 0]; let c_mul: [u32; RV32_REGISTER_NUM_LIMBS] = [39, 100, 126, 205]; let (res, res_mul, carry, x_ext, y_ext) = - run_mulh::(MulHOpcode::MULHSU, &x, &y); + run_mulh::(MULHSU, &x, &y); for i in 0..RV32_REGISTER_NUM_LIMBS { assert_eq!(z[i], res[i]); assert_eq!(z_mul[i], res_mul[i]); @@ -440,7 +414,7 @@ fn run_mulhsu_neg_sanity_test() { let c: [u32; RV32_REGISTER_NUM_LIMBS] = [212, 292, 326, 379]; let c_mul: [u32; RV32_REGISTER_NUM_LIMBS] = [39, 100, 126, 231]; let (res, res_mul, carry, x_ext, y_ext) = - run_mulh::(MulHOpcode::MULHSU, &x, &y); + run_mulh::(MULHSU, &x, &y); for i in 0..RV32_REGISTER_NUM_LIMBS { assert_eq!(z[i], res[i]); assert_eq!(z_mul[i], res_mul[i]); diff --git a/extensions/rv32im/circuit/src/shift/core.rs b/extensions/rv32im/circuit/src/shift/core.rs index cada97685e..1c9c9f59dc 100644 --- a/extensions/rv32im/circuit/src/shift/core.rs +++ b/extensions/rv32im/circuit/src/shift/core.rs @@ -3,17 +3,30 @@ use std::{ borrow::{Borrow, BorrowMut}, }; -use openvm_circuit::arch::{ - AdapterAirContext, AdapterRuntimeContext, MinimalInstruction, Result, VmAdapterInterface, - VmCoreAir, VmCoreChip, +use openvm_circuit::{ + arch::{ + execution_mode::{E1ExecutionCtx, E2ExecutionCtx}, + get_record_from_slice, AdapterAirContext, AdapterTraceFiller, AdapterTraceStep, + E2PreCompute, EmptyAdapterCoreLayout, ExecuteFunc, + ExecutionError::InvalidInstruction, + MinimalInstruction, RecordArena, Result, StepExecutorE1, StepExecutorE2, TraceFiller, + TraceStep, VmAdapterInterface, VmCoreAir, VmSegmentState, VmStateMut, + }, + system::memory::{online::TracingMemory, MemoryAuxColsFactory}, }; use openvm_circuit_primitives::{ bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, utils::not, var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus}, + AlignedBytesBorrow, }; use openvm_circuit_primitives_derive::AlignedBorrow; -use openvm_instructions::{instruction::Instruction, LocalOpcode}; +use openvm_instructions::{ + instruction::Instruction, + program::DEFAULT_PC_STEP, + riscv::{RV32_IMM_AS, RV32_REGISTER_AS}, + LocalOpcode, +}; use openvm_rv32im_transpiler::ShiftOpcode; use openvm_stark_backend::{ interaction::InteractionBuilder, @@ -21,10 +34,10 @@ use openvm_stark_backend::{ p3_field::{Field, FieldAlgebra, PrimeField32}, rap::BaseAirWithPublicValues, }; -use serde::{de::DeserializeOwned, Deserialize, Serialize}; -use serde_big_array::BigArray; use strum::IntoEnumIterator; +use crate::adapters::imm_to_bytes; + #[repr(C)] #[derive(AlignedBorrow, Clone, Copy, Debug)] pub struct ShiftCoreCols { @@ -51,7 +64,10 @@ pub struct ShiftCoreCols { pub bit_shift_carry: [T; NUM_LIMBS], } -#[derive(Copy, Clone, Debug)] +/// RV32 shift AIR. +/// Note: when the shift amount from operand is greater than the number of bits, only shift +/// `shift_amount % num_bits` bits. This matches the RV32 specs for SLL/SRL/SRA. +#[derive(Copy, Clone, Debug, derive_new::new)] pub struct ShiftCoreAir { pub bitwise_lookup_bus: BitwiseOperationLookupBus, pub range_bus: VariableRangeCheckerBus, @@ -238,154 +254,360 @@ where } #[repr(C)] -#[derive(Clone, Debug, Serialize, Deserialize)] -#[serde(bound = "T: Serialize + DeserializeOwned")] -pub struct ShiftCoreRecord { - #[serde(with = "BigArray")] - pub a: [T; NUM_LIMBS], - #[serde(with = "BigArray")] - pub b: [T; NUM_LIMBS], - #[serde(with = "BigArray")] - pub c: [T; NUM_LIMBS], - pub b_sign: T, - #[serde(with = "BigArray")] - pub bit_shift_carry: [u32; NUM_LIMBS], - pub bit_shift: usize, - pub limb_shift: usize, - pub opcode: ShiftOpcode, +#[derive(AlignedBytesBorrow, Debug)] +pub struct ShiftCoreRecord { + pub b: [u8; NUM_LIMBS], + pub c: [u8; NUM_LIMBS], + pub local_opcode: u8, } -pub struct ShiftCoreChip { - pub air: ShiftCoreAir, +pub struct ShiftStep { + adapter: A, + pub offset: usize, pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, pub range_checker_chip: SharedVariableRangeCheckerChip, } -impl ShiftCoreChip { +impl ShiftStep { pub fn new( + adapter: A, bitwise_lookup_chip: SharedBitwiseOperationLookupChip, range_checker_chip: SharedVariableRangeCheckerChip, offset: usize, ) -> Self { assert_eq!(NUM_LIMBS % 2, 0, "Number of limbs must be divisible by 2"); Self { - air: ShiftCoreAir { - bitwise_lookup_bus: bitwise_lookup_chip.bus(), - range_bus: range_checker_chip.bus(), - offset, - }, + adapter, + offset, bitwise_lookup_chip, range_checker_chip, } } } -impl, const NUM_LIMBS: usize, const LIMB_BITS: usize> - VmCoreChip for ShiftCoreChip +impl TraceStep + for ShiftStep where - I::Reads: Into<[[F; NUM_LIMBS]; 2]>, - I::Writes: From<[[F; NUM_LIMBS]; 1]>, + F: PrimeField32, + A: 'static + + for<'a> AdapterTraceStep< + F, + CTX, + ReadData: Into<[[u8; NUM_LIMBS]; 2]>, + WriteData: From<[[u8; NUM_LIMBS]; 1]>, + >, { - type Record = ShiftCoreRecord; - type Air = ShiftCoreAir; + type RecordLayout = EmptyAdapterCoreLayout; + type RecordMut<'a> = ( + A::RecordMut<'a>, + &'a mut ShiftCoreRecord, + ); - #[allow(clippy::type_complexity)] - fn execute_instruction( - &self, + fn get_opcode_name(&self, opcode: usize) -> String { + format!("{:?}", ShiftOpcode::from_usize(opcode - self.offset)) + } + + fn execute<'buf, RA>( + &mut self, + state: VmStateMut, CTX>, instruction: &Instruction, - _from_pc: u32, - reads: I::Reads, - ) -> Result<(AdapterRuntimeContext, Self::Record)> { + arena: &'buf mut RA, + ) -> Result<()> + where + RA: RecordArena<'buf, Self::RecordLayout, Self::RecordMut<'buf>>, + { let Instruction { opcode, .. } = instruction; - let shift_opcode = ShiftOpcode::from_usize(opcode.local_opcode_idx(self.air.offset)); - let data: [[F; NUM_LIMBS]; 2] = reads.into(); - let b = data[0].map(|x| x.as_canonical_u32()); - let c = data[1].map(|y| y.as_canonical_u32()); - let (a, limb_shift, bit_shift) = run_shift::(shift_opcode, &b, &c); + let local_opcode = ShiftOpcode::from_usize(opcode.local_opcode_idx(self.offset)); - let bit_shift_carry = array::from_fn(|i| match shift_opcode { - ShiftOpcode::SLL => b[i] >> (LIMB_BITS - bit_shift), - _ => b[i] % (1 << bit_shift), - }); + let (mut adapter_record, core_record) = arena.alloc(EmptyAdapterCoreLayout::new()); - let mut b_sign = 0; - if shift_opcode == ShiftOpcode::SRA { - b_sign = b[NUM_LIMBS - 1] >> (LIMB_BITS - 1); - self.bitwise_lookup_chip - .request_xor(b[NUM_LIMBS - 1], 1 << (LIMB_BITS - 1)); - } + A::start(*state.pc, state.memory, &mut adapter_record); - for i in 0..(NUM_LIMBS / 2) { - self.bitwise_lookup_chip - .request_range(a[i * 2], a[i * 2 + 1]); - } + let [rs1, rs2] = self + .adapter + .read(state.memory, instruction, &mut adapter_record) + .into(); - let output = AdapterRuntimeContext::without_pc([a.map(F::from_canonical_u32)]); - let record = ShiftCoreRecord { - opcode: shift_opcode, - a: a.map(F::from_canonical_u32), - b: data[0], - c: data[1], - bit_shift_carry, - bit_shift, - limb_shift, - b_sign: F::from_canonical_u32(b_sign), - }; + let (output, _, _) = run_shift::(local_opcode, &rs1, &rs2); - Ok((output, record)) - } + core_record.b = rs1; + core_record.c = rs2; + core_record.local_opcode = local_opcode as u8; - fn get_opcode_name(&self, opcode: usize) -> String { - format!("{:?}", ShiftOpcode::from_usize(opcode - self.air.offset)) + self.adapter.write( + state.memory, + instruction, + [output].into(), + &mut adapter_record, + ); + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + + Ok(()) } +} + +impl TraceFiller + for ShiftStep +where + F: PrimeField32, + A: 'static + AdapterTraceFiller, +{ + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + self.adapter.fill_trace_row(mem_helper, adapter_row); + + let record: &ShiftCoreRecord = + unsafe { get_record_from_slice(&mut core_row, ()) }; + + let core_row: &mut ShiftCoreCols = core_row.borrow_mut(); - fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) { - for carry_val in record.bit_shift_carry { - self.range_checker_chip - .add_count(carry_val, record.bit_shift); + let opcode = ShiftOpcode::from_usize(record.local_opcode as usize); + let (a, limb_shift, bit_shift) = + run_shift::(opcode, &record.b, &record.c); + + for pair in a.chunks_exact(2) { + self.bitwise_lookup_chip + .request_range(pair[0] as u32, pair[1] as u32); } let num_bits_log = (NUM_LIMBS * LIMB_BITS).ilog2(); self.range_checker_chip.add_count( - (((record.c[0].as_canonical_u32() as usize) - - record.bit_shift - - record.limb_shift * LIMB_BITS) - >> num_bits_log) as u32, + ((record.c[0] as usize - bit_shift - limb_shift * LIMB_BITS) >> num_bits_log) as u32, LIMB_BITS - num_bits_log as usize, ); - let row_slice: &mut ShiftCoreCols<_, NUM_LIMBS, LIMB_BITS> = row_slice.borrow_mut(); - row_slice.a = record.a; - row_slice.b = record.b; - row_slice.c = record.c; - row_slice.bit_multiplier_left = match record.opcode { - ShiftOpcode::SLL => F::from_canonical_usize(1 << record.bit_shift), - _ => F::ZERO, + core_row.bit_shift_carry = if bit_shift == 0 { + for _ in 0..NUM_LIMBS { + self.range_checker_chip.add_count(0, 0); + } + [F::ZERO; NUM_LIMBS] + } else { + array::from_fn(|i| { + let carry = match opcode { + ShiftOpcode::SLL => record.b[i] >> (LIMB_BITS - bit_shift), + _ => record.b[i] % (1 << bit_shift), + }; + self.range_checker_chip.add_count(carry as u32, bit_shift); + F::from_canonical_u8(carry) + }) }; - row_slice.bit_multiplier_right = match record.opcode { + + core_row.limb_shift_marker = [F::ZERO; NUM_LIMBS]; + core_row.limb_shift_marker[limb_shift] = F::ONE; + core_row.bit_shift_marker = [F::ZERO; LIMB_BITS]; + core_row.bit_shift_marker[bit_shift] = F::ONE; + + core_row.b_sign = F::ZERO; + if opcode == ShiftOpcode::SRA { + core_row.b_sign = F::from_canonical_u8(record.b[NUM_LIMBS - 1] >> (LIMB_BITS - 1)); + self.bitwise_lookup_chip + .request_xor(record.b[NUM_LIMBS - 1] as u32, 1 << (LIMB_BITS - 1)); + } + + core_row.bit_multiplier_right = match opcode { ShiftOpcode::SLL => F::ZERO, - _ => F::from_canonical_usize(1 << record.bit_shift), + _ => F::from_canonical_usize(1 << bit_shift), + }; + core_row.bit_multiplier_left = match opcode { + ShiftOpcode::SLL => F::from_canonical_usize(1 << bit_shift), + _ => F::ZERO, + }; + + core_row.opcode_sra_flag = F::from_bool(opcode == ShiftOpcode::SRA); + core_row.opcode_srl_flag = F::from_bool(opcode == ShiftOpcode::SRL); + core_row.opcode_sll_flag = F::from_bool(opcode == ShiftOpcode::SLL); + + core_row.c = record.c.map(F::from_canonical_u8); + core_row.b = record.b.map(F::from_canonical_u8); + core_row.a = a.map(F::from_canonical_u8); + } +} + +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct ShiftPreCompute { + c: u32, + a: u8, + b: u8, +} + +impl StepExecutorE1 + for ShiftStep +where + F: PrimeField32, +{ + fn pre_compute_size(&self) -> usize { + size_of::() + } + + #[inline(always)] + fn pre_compute_e1( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result> { + let data: &mut ShiftPreCompute = data.borrow_mut(); + let (is_imm, shift_opcode) = self.pre_compute_impl(pc, inst, data)?; + // `d` is always expected to be RV32_REGISTER_AS. + let fn_ptr = match (is_imm, shift_opcode) { + (true, ShiftOpcode::SLL) => execute_e1_impl::<_, _, true, SllOp>, + (false, ShiftOpcode::SLL) => execute_e1_impl::<_, _, false, SllOp>, + (true, ShiftOpcode::SRL) => execute_e1_impl::<_, _, true, SrlOp>, + (false, ShiftOpcode::SRL) => execute_e1_impl::<_, _, false, SrlOp>, + (true, ShiftOpcode::SRA) => execute_e1_impl::<_, _, true, SraOp>, + (false, ShiftOpcode::SRA) => execute_e1_impl::<_, _, false, SraOp>, }; - row_slice.b_sign = record.b_sign; - row_slice.bit_shift_marker = array::from_fn(|i| F::from_bool(i == record.bit_shift)); - row_slice.limb_shift_marker = array::from_fn(|i| F::from_bool(i == record.limb_shift)); - row_slice.bit_shift_carry = record.bit_shift_carry.map(F::from_canonical_u32); - row_slice.opcode_sll_flag = F::from_bool(record.opcode == ShiftOpcode::SLL); - row_slice.opcode_srl_flag = F::from_bool(record.opcode == ShiftOpcode::SRL); - row_slice.opcode_sra_flag = F::from_bool(record.opcode == ShiftOpcode::SRA); + Ok(fn_ptr) } +} - fn air(&self) -> &Self::Air { - &self.air +impl StepExecutorE2 + for ShiftStep +where + F: PrimeField32, +{ + fn e2_pre_compute_size(&self) -> usize { + size_of::>() + } + + #[inline(always)] + fn pre_compute_e2( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result> { + let data: &mut E2PreCompute = data.borrow_mut(); + data.chip_idx = chip_idx as u32; + let (is_imm, shift_opcode) = self.pre_compute_impl(pc, inst, &mut data.data)?; + // `d` is always expected to be RV32_REGISTER_AS. + let fn_ptr = match (is_imm, shift_opcode) { + (true, ShiftOpcode::SLL) => execute_e2_impl::<_, _, true, SllOp>, + (false, ShiftOpcode::SLL) => execute_e2_impl::<_, _, false, SllOp>, + (true, ShiftOpcode::SRL) => execute_e2_impl::<_, _, true, SrlOp>, + (false, ShiftOpcode::SRL) => execute_e2_impl::<_, _, false, SrlOp>, + (true, ShiftOpcode::SRA) => execute_e2_impl::<_, _, true, SraOp>, + (false, ShiftOpcode::SRA) => execute_e2_impl::<_, _, false, SraOp>, + }; + Ok(fn_ptr) } } +unsafe fn execute_e12_impl< + F: PrimeField32, + CTX: E1ExecutionCtx, + const IS_IMM: bool, + OP: ShiftOp, +>( + pre_compute: &ShiftPreCompute, + state: &mut VmSegmentState, +) { + let rs1 = state.vm_read::(RV32_REGISTER_AS, pre_compute.b as u32); + let rs2 = if IS_IMM { + pre_compute.c.to_le_bytes() + } else { + state.vm_read::(RV32_REGISTER_AS, pre_compute.c) + }; + let rs2 = u32::from_le_bytes(rs2); + + // Execute the shift operation + let rd = ::compute(rs1, rs2); + // Write the result back to memory + state.vm_write(RV32_REGISTER_AS, pre_compute.a as u32, &rd); + + state.instret += 1; + state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); +} + +unsafe fn execute_e1_impl( + pre_compute: &[u8], + state: &mut VmSegmentState, +) { + let pre_compute: &ShiftPreCompute = pre_compute.borrow(); + execute_e12_impl::(pre_compute, state); +} + +unsafe fn execute_e2_impl( + pre_compute: &[u8], + state: &mut VmSegmentState, +) { + let pre_compute: &E2PreCompute = pre_compute.borrow(); + state.ctx.on_height_change(pre_compute.chip_idx as usize, 1); + execute_e12_impl::(&pre_compute.data, state); +} + +impl ShiftStep { + #[inline(always)] + fn pre_compute_impl( + &self, + pc: u32, + inst: &Instruction, + data: &mut ShiftPreCompute, + ) -> Result<(bool, ShiftOpcode)> { + let Instruction { + opcode, a, b, c, e, .. + } = inst; + let shift_opcode = ShiftOpcode::from_usize(opcode.local_opcode_idx(self.offset)); + let e_u32 = e.as_canonical_u32(); + if inst.d.as_canonical_u32() != RV32_REGISTER_AS + || !(e_u32 == RV32_IMM_AS || e_u32 == RV32_REGISTER_AS) + { + return Err(InvalidInstruction(pc)); + } + let is_imm = e_u32 == RV32_IMM_AS; + let c_u32 = c.as_canonical_u32(); + *data = ShiftPreCompute { + c: if is_imm { + u32::from_le_bytes(imm_to_bytes(c_u32)) + } else { + c_u32 + }, + a: a.as_canonical_u32() as u8, + b: b.as_canonical_u32() as u8, + }; + // `d` is always expected to be RV32_REGISTER_AS. + Ok((is_imm, shift_opcode)) + } +} + +trait ShiftOp { + fn compute(rs1: [u8; 4], rs2: u32) -> [u8; 4]; +} +struct SllOp; +struct SrlOp; +struct SraOp; +impl ShiftOp for SllOp { + fn compute(rs1: [u8; 4], rs2: u32) -> [u8; 4] { + let rs1 = u32::from_le_bytes(rs1); + // `rs2`'s other bits are ignored. + (rs1 << (rs2 & 0x1F)).to_le_bytes() + } +} +impl ShiftOp for SrlOp { + fn compute(rs1: [u8; 4], rs2: u32) -> [u8; 4] { + let rs1 = u32::from_le_bytes(rs1); + // `rs2`'s other bits are ignored. + (rs1 >> (rs2 & 0x1F)).to_le_bytes() + } +} +impl ShiftOp for SraOp { + fn compute(rs1: [u8; 4], rs2: u32) -> [u8; 4] { + let rs1 = i32::from_le_bytes(rs1); + // `rs2`'s other bits are ignored. + (rs1 >> (rs2 & 0x1F)).to_le_bytes() + } +} + +// Returns (result, limb_shift, bit_shift) +#[inline(always)] pub(super) fn run_shift( opcode: ShiftOpcode, - x: &[u32; NUM_LIMBS], - y: &[u32; NUM_LIMBS], -) -> ([u32; NUM_LIMBS], usize, usize) { + x: &[u8; NUM_LIMBS], + y: &[u8; NUM_LIMBS], +) -> ([u8; NUM_LIMBS], usize, usize) { match opcode { ShiftOpcode::SLL => run_shift_left::(x, y), ShiftOpcode::SRL => run_shift_right::(x, y, true), @@ -393,53 +615,60 @@ pub(super) fn run_shift( } } +#[inline(always)] fn run_shift_left( - x: &[u32; NUM_LIMBS], - y: &[u32; NUM_LIMBS], -) -> ([u32; NUM_LIMBS], usize, usize) { - let mut result = [0u32; NUM_LIMBS]; + x: &[u8; NUM_LIMBS], + y: &[u8; NUM_LIMBS], +) -> ([u8; NUM_LIMBS], usize, usize) { + let mut result = [0u8; NUM_LIMBS]; let (limb_shift, bit_shift) = get_shift::(y); for i in limb_shift..NUM_LIMBS { result[i] = if i > limb_shift { - ((x[i - limb_shift] << bit_shift) + (x[i - limb_shift - 1] >> (LIMB_BITS - bit_shift))) - % (1 << LIMB_BITS) + (((x[i - limb_shift] as u16) << bit_shift) + | ((x[i - limb_shift - 1] as u16) >> (LIMB_BITS - bit_shift))) + % (1u16 << LIMB_BITS) } else { - (x[i - limb_shift] << bit_shift) % (1 << LIMB_BITS) - }; + ((x[i - limb_shift] as u16) << bit_shift) % (1u16 << LIMB_BITS) + } as u8; } (result, limb_shift, bit_shift) } +#[inline(always)] fn run_shift_right( - x: &[u32; NUM_LIMBS], - y: &[u32; NUM_LIMBS], + x: &[u8; NUM_LIMBS], + y: &[u8; NUM_LIMBS], logical: bool, -) -> ([u32; NUM_LIMBS], usize, usize) { +) -> ([u8; NUM_LIMBS], usize, usize) { let fill = if logical { 0 } else { - ((1 << LIMB_BITS) - 1) * (x[NUM_LIMBS - 1] >> (LIMB_BITS - 1)) + (((1u16 << LIMB_BITS) - 1) as u8) * (x[NUM_LIMBS - 1] >> (LIMB_BITS - 1)) }; let mut result = [fill; NUM_LIMBS]; let (limb_shift, bit_shift) = get_shift::(y); for i in 0..(NUM_LIMBS - limb_shift) { - result[i] = if i + limb_shift + 1 < NUM_LIMBS { - ((x[i + limb_shift] >> bit_shift) + (x[i + limb_shift + 1] << (LIMB_BITS - bit_shift))) - % (1 << LIMB_BITS) + let res = if i + limb_shift + 1 < NUM_LIMBS { + (((x[i + limb_shift] >> bit_shift) as u16) + | ((x[i + limb_shift + 1] as u16) << (LIMB_BITS - bit_shift))) + % (1u16 << LIMB_BITS) } else { - ((x[i + limb_shift] >> bit_shift) + (fill << (LIMB_BITS - bit_shift))) - % (1 << LIMB_BITS) - } + (((x[i + limb_shift] >> bit_shift) as u16) | ((fill as u16) << (LIMB_BITS - bit_shift))) + % (1u16 << LIMB_BITS) + }; + result[i] = res as u8; } (result, limb_shift, bit_shift) } -fn get_shift(y: &[u32]) -> (usize, usize) { - // We assume `NUM_LIMBS * LIMB_BITS <= 2^LIMB_BITS` so so the shift is defined +#[inline(always)] +fn get_shift(y: &[u8]) -> (usize, usize) { + debug_assert!(NUM_LIMBS * LIMB_BITS <= (1 << LIMB_BITS)); + // We assume `NUM_LIMBS * LIMB_BITS <= 2^LIMB_BITS` so the shift is defined // entirely in y[0]. let shift = (y[0] as usize) % (NUM_LIMBS * LIMB_BITS); (shift / LIMB_BITS, shift % LIMB_BITS) diff --git a/extensions/rv32im/circuit/src/shift/mod.rs b/extensions/rv32im/circuit/src/shift/mod.rs index 58d5ad022b..daf3c8b73b 100644 --- a/extensions/rv32im/circuit/src/shift/mod.rs +++ b/extensions/rv32im/circuit/src/shift/mod.rs @@ -1,6 +1,8 @@ -use openvm_circuit::arch::VmChipWrapper; +use openvm_circuit::arch::{MatrixRecordArena, NewVmChipWrapper, VmAirWrapper}; -use super::adapters::{Rv32BaseAluAdapterChip, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}; +use super::adapters::{ + Rv32BaseAluAdapterAir, Rv32BaseAluAdapterStep, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS, +}; mod core; pub use core::*; @@ -8,8 +10,8 @@ pub use core::*; #[cfg(test)] mod tests; -pub type Rv32ShiftChip = VmChipWrapper< - F, - Rv32BaseAluAdapterChip, - ShiftCoreChip, ->; +pub type Rv32ShiftAir = + VmAirWrapper>; +pub type Rv32ShiftStep = + ShiftStep, RV32_REGISTER_NUM_LIMBS, RV32_CELL_BITS>; +pub type Rv32ShiftChip = NewVmChipWrapper>; diff --git a/extensions/rv32im/circuit/src/shift/tests.rs b/extensions/rv32im/circuit/src/shift/tests.rs index 7a3ef6e72c..3ca5818d61 100644 --- a/extensions/rv32im/circuit/src/shift/tests.rs +++ b/extensions/rv32im/circuit/src/shift/tests.rs @@ -1,17 +1,14 @@ use std::{array, borrow::BorrowMut}; -use openvm_circuit::{ - arch::{ - testing::{TestAdapterChip, VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}, - ExecutionBridge, VmAdapterChip, VmChipWrapper, - }, - utils::generate_long_number, +use openvm_circuit::arch::{ + testing::{VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}, + InstructionExecutor, VmAirWrapper, }; use openvm_circuit_primitives::bitwise_op_lookup::{ BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, }; -use openvm_instructions::{instruction::Instruction, LocalOpcode}; -use openvm_rv32im_transpiler::ShiftOpcode; +use openvm_instructions::LocalOpcode; +use openvm_rv32im_transpiler::ShiftOpcode::{self, *}; use openvm_stark_backend::{ p3_air::BaseAir, p3_field::FieldAlgebra, @@ -20,108 +17,129 @@ use openvm_stark_backend::{ Matrix, }, utils::disable_debug_builder, - verifier::VerificationError, - ChipUsageGetter, }; use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; -use rand::Rng; +use rand::{rngs::StdRng, Rng}; +use test_case::test_case; -use super::{core::run_shift, Rv32ShiftChip, ShiftCoreChip}; +use super::{core::run_shift, Rv32ShiftChip, ShiftCoreAir, ShiftCoreCols, ShiftStep}; use crate::{ - adapters::{Rv32BaseAluAdapterChip, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}, - shift::ShiftCoreCols, - test_utils::{generate_rv32_is_type_immediate, rv32_rand_write_register_or_imm}, + adapters::{ + Rv32BaseAluAdapterAir, Rv32BaseAluAdapterStep, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS, + }, + test_utils::{ + generate_rv32_is_type_immediate, get_verification_error, rv32_rand_write_register_or_imm, + }, }; type F = BabyBear; +const MAX_INS_CAPACITY: usize = 128; -////////////////////////////////////////////////////////////////////////////////////// -// POSITIVE TESTS -// -// Randomly generate computations and execute, ensuring that the generated trace -// passes all constraints. -////////////////////////////////////////////////////////////////////////////////////// - -fn run_rv32_shift_rand_test(opcode: ShiftOpcode, num_ops: usize) { - let mut rng = create_seeded_rng(); +fn create_test_chip( + tester: &VmChipTestBuilder, +) -> ( + Rv32ShiftChip, + SharedBitwiseOperationLookupChip, +) { let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let mut tester = VmChipTestBuilder::default(); let mut chip = Rv32ShiftChip::::new( - Rv32BaseAluAdapterChip::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - bitwise_chip.clone(), + VmAirWrapper::new( + Rv32BaseAluAdapterAir::new( + tester.execution_bridge(), + tester.memory_bridge(), + bitwise_bus, + ), + ShiftCoreAir::new( + bitwise_bus, + tester.range_checker().bus(), + ShiftOpcode::CLASS_OFFSET, + ), ), - ShiftCoreChip::new( + ShiftStep::new( + Rv32BaseAluAdapterStep::new(bitwise_chip.clone()), bitwise_chip.clone(), - tester.memory_controller().borrow().range_checker.clone(), + tester.range_checker().clone(), ShiftOpcode::CLASS_OFFSET, ), - tester.offline_memory_mutex_arc(), + tester.memory_helper(), ); + chip.set_trace_buffer_height(MAX_INS_CAPACITY); - for _ in 0..num_ops { - let b = generate_long_number::(&mut rng); - let (c_imm, c) = if rng.gen_bool(0.5) { - ( - None, - generate_long_number::(&mut rng), - ) + (chip, bitwise_chip) +} + +#[allow(clippy::too_many_arguments)] +fn set_and_execute>( + tester: &mut VmChipTestBuilder, + chip: &mut E, + rng: &mut StdRng, + opcode: ShiftOpcode, + b: Option<[u8; RV32_REGISTER_NUM_LIMBS]>, + is_imm: Option, + c: Option<[u8; RV32_REGISTER_NUM_LIMBS]>, +) { + let b = b.unwrap_or(array::from_fn(|_| rng.gen_range(0..=u8::MAX))); + let (c_imm, c) = if is_imm.unwrap_or(rng.gen_bool(0.5)) { + let (imm, c) = if let Some(c) = c { + ((u32::from_le_bytes(c) & 0xFFFFFF) as usize, c) } else { - let (imm, c) = generate_rv32_is_type_immediate(&mut rng); - (Some(imm), c) + generate_rv32_is_type_immediate(rng) }; - - let (instruction, rd) = rv32_rand_write_register_or_imm( - &mut tester, - b, - c, - c_imm, - opcode.global_opcode().as_usize(), - &mut rng, - ); - tester.execute(&mut chip, &instruction); - - let (a, _, _) = run_shift::(opcode, &b, &c); - assert_eq!( - a.map(F::from_canonical_u32), - tester.read::(1, rd) + (Some(imm), c) + } else { + ( + None, + c.unwrap_or(array::from_fn(|_| rng.gen_range(0..=u8::MAX))), ) - } + }; + let (instruction, rd) = rv32_rand_write_register_or_imm( + tester, + b, + c, + c_imm, + opcode.global_opcode().as_usize(), + rng, + ); + tester.execute(chip, &instruction); - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); - tester.simple_test().expect("Verification failed"); + let (a, _, _) = run_shift::(opcode, &b, &c); + assert_eq!( + a.map(F::from_canonical_u8), + tester.read::(1, rd) + ) } -#[test] -fn rv32_shift_sll_rand_test() { - run_rv32_shift_rand_test(ShiftOpcode::SLL, 100); -} +////////////////////////////////////////////////////////////////////////////////////// +// POSITIVE TESTS +// +// Randomly generate computations and execute, ensuring that the generated trace +// passes all constraints. +////////////////////////////////////////////////////////////////////////////////////// +#[test_case(SLL, 100)] +#[test_case(SRL, 100)] +#[test_case(SRA, 100)] +fn run_rv32_shift_rand_test(opcode: ShiftOpcode, num_ops: usize) { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default(); + let (mut chip, bitwise_chip) = create_test_chip(&tester); -#[test] -fn rv32_shift_srl_rand_test() { - run_rv32_shift_rand_test(ShiftOpcode::SRL, 100); -} + for _ in 0..num_ops { + set_and_execute(&mut tester, &mut chip, &mut rng, opcode, None, None, None); + } -#[test] -fn rv32_shift_sra_rand_test() { - run_rv32_shift_rand_test(ShiftOpcode::SRA, 100); + let tester = tester.build().load(chip).load(bitwise_chip).finalize(); + tester.simple_test().expect("Verification failed"); } ////////////////////////////////////////////////////////////////////////////////////// // NEGATIVE TESTS // // Given a fake trace of a single operation, setup a chip and run the test. We replace -// the write part of the trace and check that the core chip throws the expected error. -// A dummy adapter is used so memory interactions don't indirectly cause false passes. +// part of the trace and check that the chip throws the expected error. ////////////////////////////////////////////////////////////////////////////////////// -type Rv32ShiftTestChip = - VmChipWrapper, ShiftCoreChip>; - #[derive(Clone, Copy, Default, PartialEq)] struct ShiftPrankValues { pub bit_shift: Option, @@ -134,63 +152,35 @@ struct ShiftPrankValues { } #[allow(clippy::too_many_arguments)] -fn run_rv32_shift_negative_test( +fn run_negative_shift_test( opcode: ShiftOpcode, - a: [u32; RV32_REGISTER_NUM_LIMBS], - b: [u32; RV32_REGISTER_NUM_LIMBS], - c: [u32; RV32_REGISTER_NUM_LIMBS], + prank_a: [u32; RV32_REGISTER_NUM_LIMBS], + b: [u8; RV32_REGISTER_NUM_LIMBS], + c: [u8; RV32_REGISTER_NUM_LIMBS], prank_vals: ShiftPrankValues, interaction_error: bool, ) { - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); + let mut rng = create_seeded_rng(); let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let range_checker_chip = tester.memory_controller().borrow().range_checker.clone(); - let mut chip = Rv32ShiftTestChip::::new( - TestAdapterChip::new( - vec![[b.map(F::from_canonical_u32), c.map(F::from_canonical_u32)].concat()], - vec![None], - ExecutionBridge::new(tester.execution_bus(), tester.program_bus()), - ), - ShiftCoreChip::new( - bitwise_chip.clone(), - range_checker_chip.clone(), - ShiftOpcode::CLASS_OFFSET, - ), - tester.offline_memory_mutex_arc(), - ); + let (mut chip, bitwise_chip) = create_test_chip(&tester); - tester.execute( + set_and_execute( + &mut tester, &mut chip, - &Instruction::from_usize(opcode.global_opcode(), [0, 0, 0, 1, 1]), + &mut rng, + opcode, + Some(b), + Some(false), + Some(c), ); - let bit_shift = prank_vals - .bit_shift - .unwrap_or(c[0] % (RV32_CELL_BITS as u32)); - let bit_shift_carry = prank_vals - .bit_shift_carry - .unwrap_or(array::from_fn(|i| match opcode { - ShiftOpcode::SLL => b[i] >> ((RV32_CELL_BITS as u32) - bit_shift), - _ => b[i] % (1 << bit_shift), - })); - - range_checker_chip.clear(); - range_checker_chip.add_count(bit_shift, RV32_CELL_BITS.ilog2() as usize); - for (a_val, carry_val) in a.iter().zip(bit_shift_carry.iter()) { - range_checker_chip.add_count(*a_val, RV32_CELL_BITS); - range_checker_chip.add_count(*carry_val, bit_shift as usize); - } - - let trace_width = chip.trace_width(); - let adapter_width = BaseAir::::width(chip.adapter.air()); - + let adapter_width = BaseAir::::width(&chip.air.adapter); let modify_trace = |trace: &mut DenseMatrix| { let mut values = trace.row_slice(0).to_vec(); let cols: &mut ShiftCoreCols = values.split_at_mut(adapter_width).1.borrow_mut(); - cols.a = a.map(F::from_canonical_u32); + cols.a = prank_a.map(F::from_canonical_u32); if let Some(bit_multiplier_left) = prank_vals.bit_multiplier_left { cols.bit_multiplier_left = F::from_canonical_u32(bit_multiplier_left); } @@ -210,21 +200,16 @@ fn run_rv32_shift_negative_test( cols.bit_shift_carry = bit_shift_carry.map(F::from_canonical_u32); } - *trace = RowMajorMatrix::new(values, trace_width); + *trace = RowMajorMatrix::new(values, trace.width()); }; - drop(range_checker_chip); disable_debug_builder(); let tester = tester .build() .load_and_prank_trace(chip, modify_trace) .load(bitwise_chip) .finalize(); - tester.simple_test_with_expected_error(if interaction_error { - VerificationError::ChallengePhaseError - } else { - VerificationError::OodEvaluationMismatch - }); + tester.simple_test_with_expected_error(get_verification_error(interaction_error)); } #[test] @@ -233,9 +218,9 @@ fn rv32_shift_wrong_negative_test() { let b = [1, 0, 0, 0]; let c = [1, 0, 0, 0]; let prank_vals = Default::default(); - run_rv32_shift_negative_test(ShiftOpcode::SLL, a, b, c, prank_vals, false); - run_rv32_shift_negative_test(ShiftOpcode::SRL, a, b, c, prank_vals, false); - run_rv32_shift_negative_test(ShiftOpcode::SRA, a, b, c, prank_vals, false); + run_negative_shift_test(SLL, a, b, c, prank_vals, false); + run_negative_shift_test(SRL, a, b, c, prank_vals, false); + run_negative_shift_test(SRA, a, b, c, prank_vals, false); } #[test] @@ -249,7 +234,7 @@ fn rv32_sll_wrong_bit_shift_negative_test() { bit_shift_marker: Some([0, 0, 1, 0, 0, 0, 0, 0]), ..Default::default() }; - run_rv32_shift_negative_test(ShiftOpcode::SLL, a, b, c, prank_vals, true); + run_negative_shift_test(SLL, a, b, c, prank_vals, true); } #[test] @@ -261,7 +246,7 @@ fn rv32_sll_wrong_limb_shift_negative_test() { limb_shift_marker: Some([0, 0, 1, 0]), ..Default::default() }; - run_rv32_shift_negative_test(ShiftOpcode::SLL, a, b, c, prank_vals, true); + run_negative_shift_test(SLL, a, b, c, prank_vals, true); } #[test] @@ -273,7 +258,7 @@ fn rv32_sll_wrong_bit_carry_negative_test() { bit_shift_carry: Some([0, 0, 0, 0]), ..Default::default() }; - run_rv32_shift_negative_test(ShiftOpcode::SLL, a, b, c, prank_vals, true); + run_negative_shift_test(SLL, a, b, c, prank_vals, true); } #[test] @@ -286,7 +271,7 @@ fn rv32_sll_wrong_bit_mult_side_negative_test() { bit_multiplier_right: Some(1), ..Default::default() }; - run_rv32_shift_negative_test(ShiftOpcode::SLL, a, b, c, prank_vals, false); + run_negative_shift_test(SLL, a, b, c, prank_vals, false); } #[test] @@ -300,7 +285,7 @@ fn rv32_srl_wrong_bit_shift_negative_test() { bit_shift_marker: Some([0, 0, 1, 0, 0, 0, 0, 0]), ..Default::default() }; - run_rv32_shift_negative_test(ShiftOpcode::SRL, a, b, c, prank_vals, false); + run_negative_shift_test(SRL, a, b, c, prank_vals, false); } #[test] @@ -312,7 +297,7 @@ fn rv32_srl_wrong_limb_shift_negative_test() { limb_shift_marker: Some([0, 1, 0, 0]), ..Default::default() }; - run_rv32_shift_negative_test(ShiftOpcode::SRL, a, b, c, prank_vals, false); + run_negative_shift_test(SRL, a, b, c, prank_vals, false); } #[test] @@ -325,8 +310,8 @@ fn rv32_srx_wrong_bit_mult_side_negative_test() { bit_multiplier_right: Some(0), ..Default::default() }; - run_rv32_shift_negative_test(ShiftOpcode::SRL, a, b, c, prank_vals, false); - run_rv32_shift_negative_test(ShiftOpcode::SRA, a, b, c, prank_vals, false); + run_negative_shift_test(SRL, a, b, c, prank_vals, false); + run_negative_shift_test(SRA, a, b, c, prank_vals, false); } #[test] @@ -340,7 +325,7 @@ fn rv32_sra_wrong_bit_shift_negative_test() { bit_shift_marker: Some([0, 0, 1, 0, 0, 0, 0, 0]), ..Default::default() }; - run_rv32_shift_negative_test(ShiftOpcode::SRA, a, b, c, prank_vals, false); + run_negative_shift_test(SRA, a, b, c, prank_vals, false); } #[test] @@ -352,7 +337,7 @@ fn rv32_sra_wrong_limb_shift_negative_test() { limb_shift_marker: Some([0, 1, 0, 0]), ..Default::default() }; - run_rv32_shift_negative_test(ShiftOpcode::SRA, a, b, c, prank_vals, false); + run_negative_shift_test(SRA, a, b, c, prank_vals, false); } #[test] @@ -364,7 +349,7 @@ fn rv32_sra_wrong_sign_negative_test() { b_sign: Some(0), ..Default::default() }; - run_rv32_shift_negative_test(ShiftOpcode::SRA, a, b, c, prank_vals, true); + run_negative_shift_test(SRA, a, b, c, prank_vals, true); } /////////////////////////////////////////////////////////////////////////////////////// @@ -375,11 +360,11 @@ fn rv32_sra_wrong_sign_negative_test() { #[test] fn run_sll_sanity_test() { - let x: [u32; RV32_REGISTER_NUM_LIMBS] = [45, 7, 61, 186]; - let y: [u32; RV32_REGISTER_NUM_LIMBS] = [91, 0, 100, 0]; - let z: [u32; RV32_REGISTER_NUM_LIMBS] = [0, 0, 0, 104]; + let x: [u8; RV32_REGISTER_NUM_LIMBS] = [45, 7, 61, 186]; + let y: [u8; RV32_REGISTER_NUM_LIMBS] = [91, 0, 100, 0]; + let z: [u8; RV32_REGISTER_NUM_LIMBS] = [0, 0, 0, 104]; let (result, limb_shift, bit_shift) = - run_shift::(ShiftOpcode::SLL, &x, &y); + run_shift::(SLL, &x, &y); for i in 0..RV32_REGISTER_NUM_LIMBS { assert_eq!(z[i], result[i]) } @@ -390,11 +375,11 @@ fn run_sll_sanity_test() { #[test] fn run_srl_sanity_test() { - let x: [u32; RV32_REGISTER_NUM_LIMBS] = [31, 190, 221, 200]; - let y: [u32; RV32_REGISTER_NUM_LIMBS] = [49, 190, 190, 190]; - let z: [u32; RV32_REGISTER_NUM_LIMBS] = [110, 100, 0, 0]; + let x: [u8; RV32_REGISTER_NUM_LIMBS] = [31, 190, 221, 200]; + let y: [u8; RV32_REGISTER_NUM_LIMBS] = [49, 190, 190, 190]; + let z: [u8; RV32_REGISTER_NUM_LIMBS] = [110, 100, 0, 0]; let (result, limb_shift, bit_shift) = - run_shift::(ShiftOpcode::SRL, &x, &y); + run_shift::(SRL, &x, &y); for i in 0..RV32_REGISTER_NUM_LIMBS { assert_eq!(z[i], result[i]) } @@ -405,11 +390,11 @@ fn run_srl_sanity_test() { #[test] fn run_sra_sanity_test() { - let x: [u32; RV32_REGISTER_NUM_LIMBS] = [31, 190, 221, 200]; - let y: [u32; RV32_REGISTER_NUM_LIMBS] = [113, 20, 50, 80]; - let z: [u32; RV32_REGISTER_NUM_LIMBS] = [110, 228, 255, 255]; + let x: [u8; RV32_REGISTER_NUM_LIMBS] = [31, 190, 221, 200]; + let y: [u8; RV32_REGISTER_NUM_LIMBS] = [113, 20, 50, 80]; + let z: [u8; RV32_REGISTER_NUM_LIMBS] = [110, 228, 255, 255]; let (result, limb_shift, bit_shift) = - run_shift::(ShiftOpcode::SRA, &x, &y); + run_shift::(SRA, &x, &y); for i in 0..RV32_REGISTER_NUM_LIMBS { assert_eq!(z[i], result[i]) } diff --git a/extensions/rv32im/circuit/src/test_utils.rs b/extensions/rv32im/circuit/src/test_utils.rs index 8a105ff990..f018b0d845 100644 --- a/extensions/rv32im/circuit/src/test_utils.rs +++ b/extensions/rv32im/circuit/src/test_utils.rs @@ -1,6 +1,6 @@ use openvm_circuit::arch::testing::{memory::gen_pointer, VmChipTestBuilder}; use openvm_instructions::{instruction::Instruction, VmOpcode}; -use openvm_stark_backend::p3_field::FieldAlgebra; +use openvm_stark_backend::{p3_field::FieldAlgebra, verifier::VerificationError}; use openvm_stark_sdk::p3_baby_bear::BabyBear; use rand::{rngs::StdRng, Rng}; @@ -10,8 +10,8 @@ use super::adapters::{RV32_REGISTER_NUM_LIMBS, RV_IS_TYPE_IMM_BITS}; #[cfg_attr(all(feature = "test-utils", not(test)), allow(dead_code))] pub fn rv32_rand_write_register_or_imm( tester: &mut VmChipTestBuilder, - rs1_writes: [u32; NUM_LIMBS], - rs2_writes: [u32; NUM_LIMBS], + rs1_writes: [u8; NUM_LIMBS], + rs2_writes: [u8; NUM_LIMBS], imm: Option, opcode_with_offset: usize, rng: &mut StdRng, @@ -22,9 +22,9 @@ pub fn rv32_rand_write_register_or_imm( let rs2 = imm.unwrap_or_else(|| gen_pointer(rng, NUM_LIMBS)); let rd = gen_pointer(rng, NUM_LIMBS); - tester.write::(1, rs1, rs1_writes.map(BabyBear::from_canonical_u32)); + tester.write::(1, rs1, rs1_writes.map(BabyBear::from_canonical_u8)); if !rs2_is_imm { - tester.write::(1, rs2, rs2_writes.map(BabyBear::from_canonical_u32)); + tester.write::(1, rs2, rs2_writes.map(BabyBear::from_canonical_u8)); } ( @@ -37,9 +37,7 @@ pub fn rv32_rand_write_register_or_imm( } #[cfg_attr(all(feature = "test-utils", not(test)), allow(dead_code))] -pub fn generate_rv32_is_type_immediate( - rng: &mut StdRng, -) -> (usize, [u32; RV32_REGISTER_NUM_LIMBS]) { +pub fn generate_rv32_is_type_immediate(rng: &mut StdRng) -> (usize, [u8; RV32_REGISTER_NUM_LIMBS]) { let mut imm: u32 = rng.gen_range(0..(1 << RV_IS_TYPE_IMM_BITS)); if (imm & 0x800) != 0 { imm |= !0xFFF @@ -51,7 +49,17 @@ pub fn generate_rv32_is_type_immediate( (imm >> 8) as u8, (imm >> 16) as u8, (imm >> 16) as u8, - ] - .map(|x| x as u32), + ], ) } + +/// Returns the corresponding verification error based on whether +/// an interaction error or a constraint error is expected +#[cfg_attr(all(feature = "test-utils", not(test)), allow(dead_code))] +pub fn get_verification_error(is_interaction_error: bool) -> VerificationError { + if is_interaction_error { + VerificationError::ChallengePhaseError + } else { + VerificationError::OodEvaluationMismatch + } +} diff --git a/extensions/rv32im/tests/src/lib.rs b/extensions/rv32im/tests/src/lib.rs index a4de516462..38ecc9f345 100644 --- a/extensions/rv32im/tests/src/lib.rs +++ b/extensions/rv32im/tests/src/lib.rs @@ -4,9 +4,13 @@ mod tests { use eyre::Result; use openvm_circuit::{ - arch::{hasher::poseidon2::vm_poseidon2_hasher, ExecutionError, Streams, VmExecutor}, - system::memory::tree::public_values::UserPublicValuesProof, - utils::{air_test, air_test_with_min_segments}, + arch::{ + execution_mode::e1::E1Ctx, hasher::poseidon2::vm_poseidon2_hasher, + interpreter::InterpretedInstance, ExecutionError, Streams, VirtualMachine, VmConfig, + VmExecutor, + }, + system::memory::merkle::public_values::UserPublicValuesProof, + utils::{air_test, air_test_with_min_segments, test_system_config_with_continuations}, }; use openvm_instructions::exe::VmExe; use openvm_rv32im_circuit::{Rv32IConfig, Rv32ImConfig}; @@ -14,7 +18,10 @@ mod tests { use openvm_rv32im_transpiler::{ Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, }; - use openvm_stark_sdk::{openvm_stark_backend::p3_field::FieldAlgebra, p3_baby_bear::BabyBear}; + use openvm_stark_sdk::{ + config::baby_bear_poseidon2::default_engine, openvm_stark_backend::p3_field::FieldAlgebra, + p3_baby_bear::BabyBear, + }; use openvm_toolchain_tests::{ build_example_program_at_path, build_example_program_at_path_with_features, get_programs_dir, @@ -24,6 +31,17 @@ mod tests { type F = BabyBear; + #[cfg(test)] + fn test_rv32im_config() -> Rv32ImConfig { + Rv32ImConfig { + rv32i: Rv32IConfig { + system: test_system_config_with_continuations(), + ..Default::default() + }, + ..Default::default() + } + } + #[test_case("fibonacci", 1)] fn test_rv32i(example_name: &str, min_segments: usize) -> Result<()> { let config = Rv32IConfig::default(); @@ -35,13 +53,14 @@ mod tests { .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension), )?; + let config = test_rv32im_config(); air_test_with_min_segments(config, exe, vec![], min_segments); Ok(()) } #[test_case("collatz", 1)] fn test_rv32im(example_name: &str, min_segments: usize) -> Result<()> { - let config = Rv32ImConfig::default(); + let config = test_rv32im_config(); let elf = build_example_program_at_path(get_programs_dir!(), example_name, &config)?; let exe = VmExe::from_elf( elf, @@ -54,10 +73,10 @@ mod tests { Ok(()) } - // #[test_case("fibonacci", 1)] + #[test_case("fibonacci", 1)] #[test_case("collatz", 1)] fn test_rv32im_std(example_name: &str, min_segments: usize) -> Result<()> { - let config = Rv32ImConfig::default(); + let config = test_rv32im_config(); let elf = build_example_program_at_path_with_features( get_programs_dir!(), example_name, @@ -77,7 +96,7 @@ mod tests { #[test] fn test_read_vec() -> Result<()> { - let config = Rv32IConfig::default(); + let config = test_rv32im_config(); let elf = build_example_program_at_path(get_programs_dir!(), "hint", &config)?; let exe = VmExe::from_elf( elf, @@ -93,7 +112,7 @@ mod tests { #[test] fn test_hint_load_by_key() -> Result<()> { - let config = Rv32IConfig::default(); + let config = test_rv32im_config(); let elf = build_example_program_at_path(get_programs_dir!(), "hint_load_by_key", &config)?; let exe = VmExe::from_elf( elf, @@ -116,7 +135,7 @@ mod tests { #[test] fn test_read() -> Result<()> { - let config = Rv32IConfig::default(); + let config = test_rv32im_config(); let elf = build_example_program_at_path(get_programs_dir!(), "read", &config)?; let exe = VmExe::from_elf( elf, @@ -147,7 +166,7 @@ mod tests { #[test] fn test_reveal() -> Result<()> { - let config = Rv32IConfig::default(); + let config = test_rv32im_config(); let elf = build_example_program_at_path(get_programs_dir!(), "reveal", &config)?; let exe = VmExe::from_elf( elf, @@ -156,11 +175,21 @@ mod tests { .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension), )?; - let executor = VmExecutor::::new(config.clone()); - let final_memory = executor.execute(exe, vec![])?.unwrap(); - let hasher = vm_poseidon2_hasher(); + + let vm = VirtualMachine::new(default_engine(), config.clone()); + let pk = vm.keygen(); + let vk = pk.get_vk(); + let segments = vm + .executor + .execute_metered(exe.clone(), vec![], &vk.num_interactions()) + .unwrap(); + + let final_memory = vm.executor.execute(exe, vec![], &segments)?.unwrap(); + let hasher = vm_poseidon2_hasher::(); let pv_proof = UserPublicValuesProof::compute( - config.system.memory_config.memory_dimensions(), + VmConfig::::system(&config) + .memory_config + .memory_dimensions(), 64, &hasher, &final_memory, @@ -186,7 +215,7 @@ mod tests { #[test] fn test_print() -> Result<()> { - let config = Rv32IConfig::default(); + let config = test_rv32im_config(); let elf = build_example_program_at_path(get_programs_dir!(), "print", &config)?; let exe = VmExe::from_elf( elf, @@ -201,7 +230,7 @@ mod tests { #[test] fn test_heap_overflow() -> Result<()> { - let config = Rv32ImConfig::default(); + let config = test_rv32im_config(); let elf = build_example_program_at_path(get_programs_dir!(), "heap_overflow", &config)?; let exe = VmExe::from_elf( elf, @@ -211,8 +240,9 @@ mod tests { .with_extension(Rv32IoTranspilerExtension), )?; - let executor = VmExecutor::::new(config.clone()); - match executor.execute(exe, vec![[0, 0, 0, 1].map(F::from_canonical_u8).to_vec()]) { + let executor = VmExecutor::new(config); + let input = vec![[0, 0, 0, 1].map(F::from_canonical_u8).to_vec()]; + match executor.execute_e1(exe.clone(), input.clone(), None) { Err(ExecutionError::FailedWithExitCode(_)) => Ok(()), Err(_) => panic!("should fail with `FailedWithExitCode`"), Ok(_) => panic!("should fail"), @@ -221,7 +251,7 @@ mod tests { #[test] fn test_hashmap() -> Result<()> { - let config = Rv32ImConfig::default(); + let config = test_rv32im_config(); let elf = build_example_program_at_path_with_features( get_programs_dir!(), "hashmap", @@ -241,7 +271,7 @@ mod tests { #[test] fn test_tiny_mem_test() -> Result<()> { - let config = Rv32ImConfig::default(); + let config = test_rv32im_config(); let elf = build_example_program_at_path_with_features( get_programs_dir!(), "tiny-mem-test", @@ -262,7 +292,7 @@ mod tests { #[test] #[should_panic] fn test_load_x0() { - let config = Rv32ImConfig::default(); + let config = test_rv32im_config(); let elf = build_example_program_at_path(get_programs_dir!(), "load_x0", &config).unwrap(); let exe = VmExe::from_elf( elf, @@ -272,8 +302,8 @@ mod tests { .with_extension(Rv32IoTranspilerExtension), ) .unwrap(); - let executor = VmExecutor::::new(config.clone()); - executor.execute(exe, vec![]).unwrap(); + let interpreter = InterpretedInstance::new(config, exe); + interpreter.execute(E1Ctx::new(None), vec![]).unwrap(); } #[test_case("getrandom", vec!["getrandom", "getrandom-unsupported"])] @@ -281,7 +311,7 @@ mod tests { #[test_case("getrandom_v02", vec!["getrandom-v02", "getrandom-unsupported"])] #[test_case("getrandom_v02", vec!["getrandom-v02/custom"])] fn test_getrandom_unsupported(program: &str, features: Vec<&str>) { - let config = Rv32ImConfig::default(); + let config = test_rv32im_config(); let elf = build_example_program_at_path_with_features( get_programs_dir!(), program, diff --git a/extensions/sha256/circuit/Cargo.toml b/extensions/sha2/circuit/Cargo.toml similarity index 80% rename from extensions/sha256/circuit/Cargo.toml rename to extensions/sha2/circuit/Cargo.toml index 95c87b0871..213965c0cb 100644 --- a/extensions/sha256/circuit/Cargo.toml +++ b/extensions/sha2/circuit/Cargo.toml @@ -1,9 +1,9 @@ [package] -name = "openvm-sha256-circuit" +name = "openvm-sha2-circuit" version.workspace = true authors.workspace = true edition.workspace = true -description = "OpenVM circuit extension for sha256" +description = "OpenVM circuit extension for SHA-2" [dependencies] openvm-stark-backend = { workspace = true } @@ -13,16 +13,16 @@ openvm-circuit-primitives-derive = { workspace = true } openvm-circuit-derive = { workspace = true } openvm-circuit = { workspace = true } openvm-instructions = { workspace = true } -openvm-sha256-transpiler = { workspace = true } +openvm-sha2-transpiler = { workspace = true } openvm-rv32im-circuit = { workspace = true } -openvm-sha256-air = { workspace = true } +openvm-sha2-air = { workspace = true } derive-new.workspace = true derive_more = { workspace = true, features = ["from"] } rand.workspace = true serde.workspace = true sha2 = { version = "0.10", default-features = false } -strum = { workspace = true } +ndarray = { workspace = true, default-features = false } [dev-dependencies] openvm-stark-sdk = { workspace = true } @@ -37,3 +37,6 @@ mimalloc = ["openvm-circuit/mimalloc"] jemalloc = ["openvm-circuit/jemalloc"] jemalloc-prof = ["openvm-circuit/jemalloc-prof"] nightly-features = ["openvm-circuit/nightly-features"] + +[package.metadata.cargo-shear] +ignored = ["ndarray"] \ No newline at end of file diff --git a/extensions/sha256/circuit/README.md b/extensions/sha2/circuit/README.md similarity index 56% rename from extensions/sha256/circuit/README.md rename to extensions/sha2/circuit/README.md index 1e794cd35c..de2100b261 100644 --- a/extensions/sha256/circuit/README.md +++ b/extensions/sha2/circuit/README.md @@ -1,28 +1,43 @@ -# SHA256 VM Extension +# SHA-2 VM Extension -This crate contains the circuit for the SHA256 VM extension. +This crate contains circuits for the SHA-2 family of hash functions. +We support SHA-256, SHA-512, and SHA-384. -## SHA-256 Algorithm Summary +## SHA-2 Algorithms Summary -See the [FIPS standard](https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.180-4.pdf), in particular, section 6.2 for reference. +The SHA-256, SHA-512, and SHA-384 algorithms are similar in structure. +We will first describe the SHA-256 algorithm, and then describe the differences between the three algorithms. + +See the [FIPS standard](https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.180-4.pdf) for reference. In particular, sections 6.2, 6.4, and 6.5. In short the SHA-256 algorithm works as follows. 1. Pad the message to 512 bits and split it into 512-bit 'blocks'. -2. Initialize a hash state consisting of eight 32-bit words. +2. Initialize a hash state consisting of eight 32-bit words to a specific constant value. 3. For each block, - 1. split the message into 16 32-bit words and produce 48 more 'message schedule' words based on them. - 2. apply 64 'rounds' to update the hash state based on the message schedule. - 3. add the previous block's final hash state to the current hash state (modulo `2^32`). + 1. split the message into 16 32-bit words and produce 48 more words based on them. The 16 message words together with the 48 additional words are called the 'message schedule'. + 2. apply a scrambling function 64 times to the hash state to update it based on the message schedule. We call each update a 'round'. + 3. add the previous block's final hash state to the current hash state (modulo $2^{32}$). 4. The output is the final hash state +The differences with the SHA-512 algorithm are that: +- SHA-512 uses 64-bit words, 1024-bit blocks, performs 80 rounds, and produces a 512-bit output. +- all the arithmetic is done modulo $2^{64}$. +- the initial hash state is different. + +The SHA-384 algorithm is a truncation of the SHA-512 output to 384 bits, and the only difference is that the initial hash state is different. + ## Design Overview -This chip produces an AIR that consists of 17 rows for each block (512 bits) in the message, and no more rows. -The first 16 rows of each block are called 'round rows', and each of them represents four rounds of the SHA-256 algorithm. -Each row constrains updates to the working variables on each round, and it also constrains the message schedule words based on previous rounds. -The final row is called a 'digest row' and it produces a final hash for the block, computed as the sum of the working variables and the previous block's final hash. +We reuse the same AIR code to produce circuits for all three algorithms. +To achieve this, we parameterize the AIR by constants (such as the word size, number of rounds, and block size) that are specific to each algorithm. + +This chip produces an AIR that consists of $R+1$ rows for each block of the message, and no more rows +(for SHA-256, $R = 16$ and for SHA-512 and SHA-384, $R = 20$). +The first $R$ rows of each block are called 'round rows', and each of them constrains four rounds of the hash algorithm. +Each row constrains updates to the working variables on each round, and also constrains the message schedule words based on previous rounds. +The final row of each block is called a 'digest row' and it produces a final hash for the block, computed as the sum of the working variables and the previous block's final hash. -Note that this chip only supports messages of length less than `2^29` bytes. +Note that this chip only supports messages of length less than $2^{29}$ bytes. ### Storing working variables @@ -50,7 +65,7 @@ Since we can reliably constrain values from four rounds ago, we can build up `in The last block of every message should have the `is_last_block` flag set to `1`. Note that `is_last_block` is not constrained to be true for the last block of every message, instead it *defines* what the last block of a message is. -For instance, if we produce an air with 10 blocks and only the last block has `is_last_block = 1` then the constraints will interpret it as a single message of length 10 blocks. +For instance, if we produce a trace with 10 blocks and only the last block has `is_last_block = 1` then the constraints will interpret it as a single message of length 10 blocks. If, however, we set `is_last_block` to true for the 6th block, the trace will be interpreted as hashing two messages, each of length 5 blocks. Note that we do constrain, however, that the very last block of the trace has `is_last_block = 1`. @@ -63,11 +78,11 @@ We use this trick in several places in this chip. ### Block index counter variables -There are two "block index" counter variables in each row of the air named `global_block_idx` and `local_block_idx`. -Both of these variables take on the same value on all 17 rows in a block. +There are two "block index" counter variables in each row named `global_block_idx` and `local_block_idx`. +Both of these variables take on the same value on all $R+1$ rows in a block. The `global_block_idx` is the index of the block in the entire trace. -The very first 17 rows in the trace will have `global_block_idx = 1` and the counter will increment by 1 between blocks. +The very first block in the trace will have `global_block_idx = 1` on each row and the counter will increment by 1 between blocks. The padding rows will all have `global_block_idx = 0`. The `global_block_idx` is used in interaction constraints to constrain the value of `hash` between blocks. @@ -79,15 +94,16 @@ The `local_block_idx` is used to calculate the length of the message processed s ### VM air vs SubAir -The SHA-256 VM extension chip uses the `Sha256Air` SubAir to help constrain the SHA-256 hash. -The VM extension air constrains the correctness of the SHA message padding, while the SubAir adds all other constraints related to the hash algorithm. -The VM extension air also constrains memory reads and writes. +The SHA-2 VM extension chip uses the `Sha2Air` SubAir to help constrain the appropriate SHA-2 hash algorithm. +The SubAir is also parameterized by the specific SHA-2 variant's constants. +The VM extension AIR constrains the correctness of the message padding, while the SubAir adds all other constraints related to the hash algorithm. +The VM extension AIR also constrains memory reads and writes. ### A gotcha about padding rows There are two senses of the word padding used in the context of this chip and this can be confusing. -First, we use padding to refer to the extra bits added to the message that is input to the SHA-256 algorithm in order to make the input's length a multiple of 512 bits. -So, we may use the term 'padding rows' to refer to round rows that correspond to the padded bits of a message (as in `Sha256VmAir::eval_padding_row`). +First, we use padding to refer to the extra bits added to the message that is input to the hash algorithm in order to make the input's length a multiple of the block size. +So, we may use the term 'padding rows' to refer to round rows that correspond to the padded bits of a message (as in `Sha2VmAir::eval_padding_row`). Second, the dummy rows that are added to the trace to make the trace height a power of 2 are also called padding rows (see the `is_padding_row` flag). In the SubAir, padding row probably means dummy row. -In the VM air, it probably refers to SHA-256 padding. \ No newline at end of file +In the VM air, it probably refers to the message padding. \ No newline at end of file diff --git a/extensions/sha2/circuit/src/extension.rs b/extensions/sha2/circuit/src/extension.rs new file mode 100644 index 0000000000..b05e4412e0 --- /dev/null +++ b/extensions/sha2/circuit/src/extension.rs @@ -0,0 +1,148 @@ +use derive_more::derive::From; +use openvm_circuit::{ + arch::{ + InitFileGenerator, SystemConfig, VmExtension, VmInventory, VmInventoryBuilder, + VmInventoryError, + }, + system::phantom::PhantomChip, +}; +use openvm_circuit_derive::{AnyEnum, InsExecutorE1, InsExecutorE2, InstructionExecutor, VmConfig}; +use openvm_circuit_primitives::bitwise_op_lookup::{ + BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, +}; +use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; +use openvm_instructions::*; +use openvm_rv32im_circuit::{ + Rv32I, Rv32IExecutor, Rv32IPeriphery, Rv32Io, Rv32IoExecutor, Rv32IoPeriphery, Rv32M, + Rv32MExecutor, Rv32MPeriphery, +}; +use openvm_sha2_air::{Sha256Config, Sha384Config, Sha512Config}; +use openvm_sha2_transpiler::Rv32Sha2Opcode; +use openvm_stark_backend::p3_field::PrimeField32; +use serde::{Deserialize, Serialize}; + +use crate::*; + +// TODO: this should be decided after e2 execution + +#[derive(Clone, Debug, VmConfig, derive_new::new, Serialize, Deserialize)] +pub struct Sha2Rv32Config { + #[system] + pub system: SystemConfig, + #[extension] + pub rv32i: Rv32I, + #[extension] + pub rv32m: Rv32M, + #[extension] + pub io: Rv32Io, + #[extension] + pub sha2: Sha2, +} + +impl Default for Sha2Rv32Config { + fn default() -> Self { + Self { + system: SystemConfig::default().with_continuations(), + rv32i: Rv32I, + rv32m: Rv32M::default(), + io: Rv32Io, + sha2: Sha2, + } + } +} + +// Default implementation uses no init file +impl InitFileGenerator for Sha2Rv32Config {} + +#[derive(Clone, Copy, Debug, Default, Serialize, Deserialize)] +pub struct Sha2; + +#[derive( + ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum, InsExecutorE1, InsExecutorE2, +)] +pub enum Sha2Executor { + Sha256(Sha2VmChip), + Sha512(Sha2VmChip), + Sha384(Sha2VmChip), +} + +#[derive(From, ChipUsageGetter, Chip, AnyEnum)] +pub enum Sha2Periphery { + BitwiseOperationLookup(SharedBitwiseOperationLookupChip<8>), + Phantom(PhantomChip), +} + +impl VmExtension for Sha2 { + type Executor = Sha2Executor; + type Periphery = Sha2Periphery; + + fn build( + &self, + builder: &mut VmInventoryBuilder, + ) -> Result, VmInventoryError> { + let mut inventory = VmInventory::new(); + let pointer_max_bits = builder.system_config().memory_config.pointer_max_bits; + + let bitwise_lu_chip = if let Some(&chip) = builder + .find_chip::>() + .first() + { + chip.clone() + } else { + let bitwise_lu_bus = BitwiseOperationLookupBus::new(builder.new_bus_idx()); + let chip = SharedBitwiseOperationLookupChip::new(bitwise_lu_bus); + inventory.add_periphery_chip(chip.clone()); + chip + }; + + let sha256_chip = Sha2VmChip::::new( + Sha2VmAir::new( + builder.system_port(), + bitwise_lu_chip.bus(), + pointer_max_bits, + builder.new_bus_idx(), + ), + Sha2VmStep::new( + bitwise_lu_chip.clone(), + Rv32Sha2Opcode::CLASS_OFFSET, + pointer_max_bits, + ), + builder.system_base().memory_controller.helper(), + ); + inventory.add_executor(sha256_chip, vec![Rv32Sha2Opcode::SHA256.global_opcode()])?; + + let sha512_chip = Sha2VmChip::::new( + Sha2VmAir::new( + builder.system_port(), + bitwise_lu_chip.bus(), + pointer_max_bits, + builder.new_bus_idx(), + ), + Sha2VmStep::new( + bitwise_lu_chip.clone(), + Rv32Sha2Opcode::CLASS_OFFSET, + pointer_max_bits, + ), + builder.system_base().memory_controller.helper(), + ); + inventory.add_executor(sha512_chip, vec![Rv32Sha2Opcode::SHA512.global_opcode()])?; + + let sha384_chip = Sha2VmChip::::new( + Sha2VmAir::new( + builder.system_port(), + bitwise_lu_chip.bus(), + pointer_max_bits, + builder.new_bus_idx(), + ), + Sha2VmStep::new( + bitwise_lu_chip.clone(), + Rv32Sha2Opcode::CLASS_OFFSET, + pointer_max_bits, + ), + builder.system_base().memory_controller.helper(), + ); + inventory.add_executor(sha384_chip, vec![Rv32Sha2Opcode::SHA384.global_opcode()])?; + + Ok(inventory) + } +} diff --git a/extensions/sha2/circuit/src/lib.rs b/extensions/sha2/circuit/src/lib.rs new file mode 100644 index 0000000000..cc51aaaf20 --- /dev/null +++ b/extensions/sha2/circuit/src/lib.rs @@ -0,0 +1,5 @@ +mod sha2_chip; +pub use sha2_chip::*; + +mod extension; +pub use extension::*; diff --git a/extensions/sha2/circuit/src/sha2_chip/air.rs b/extensions/sha2/circuit/src/sha2_chip/air.rs new file mode 100644 index 0000000000..600d483e63 --- /dev/null +++ b/extensions/sha2/circuit/src/sha2_chip/air.rs @@ -0,0 +1,777 @@ +use std::{cmp::min, convert::TryInto}; + +use openvm_circuit::{ + arch::{ExecutionBridge, SystemPort}, + system::memory::{ + offline_checker::{MemoryBridge, MemoryWriteAuxCols}, + MemoryAddress, + }, +}; +use openvm_circuit_primitives::{ + bitwise_op_lookup::BitwiseOperationLookupBus, encoder::Encoder, utils::not, SubAir, +}; +use openvm_instructions::{ + riscv::{RV32_CELL_BITS, RV32_MEMORY_AS, RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS}, + LocalOpcode, +}; +use openvm_sha2_air::{compose, Sha256Config, Sha2Air, Sha2Variant, Sha512Config}; +use openvm_stark_backend::{ + interaction::{BusIndex, InteractionBuilder}, + p3_air::{Air, AirBuilder, BaseAir}, + p3_field::{Field, FieldAlgebra}, + p3_matrix::Matrix, + rap::{BaseAirWithPublicValues, PartitionedBaseAir}, +}; + +use super::{Sha2ChipConfig, Sha2VmDigestColsRef, Sha2VmRoundColsRef}; + +/// Sha2VmAir does all constraints related to message padding and +/// the Sha2Air subair constrains the actual hash +#[derive(Clone, Debug)] +pub struct Sha2VmAir { + pub execution_bridge: ExecutionBridge, + pub memory_bridge: MemoryBridge, + /// Bus to send byte checks to + pub bitwise_lookup_bus: BitwiseOperationLookupBus, + /// Maximum number of bits allowed for an address pointer + /// Must be at least 24 + pub ptr_max_bits: usize, + pub(super) sha_subair: Sha2Air, + pub(super) padding_encoder: Encoder, +} + +impl Sha2VmAir { + pub fn new( + SystemPort { + execution_bus, + program_bus, + memory_bridge, + }: SystemPort, + bitwise_lookup_bus: BitwiseOperationLookupBus, + ptr_max_bits: usize, + self_bus_idx: BusIndex, + ) -> Self { + Self { + execution_bridge: ExecutionBridge::new(execution_bus, program_bus), + memory_bridge, + bitwise_lookup_bus, + ptr_max_bits, + sha_subair: Sha2Air::::new(bitwise_lookup_bus, self_bus_idx), + // optimization opportunity: we use fewer encoder cells for sha256 than sha512 or sha384 + padding_encoder: Encoder::new(PaddingFlags::COUNT, 2, false), + } + } +} + +impl BaseAirWithPublicValues for Sha2VmAir {} +impl PartitionedBaseAir for Sha2VmAir {} +impl BaseAir for Sha2VmAir { + fn width(&self) -> usize { + C::VM_WIDTH + } +} + +impl Air for Sha2VmAir { + fn eval(&self, builder: &mut AB) { + self.eval_padding(builder); + self.eval_transitions(builder); + self.eval_reads(builder); + self.eval_last_row(builder); + + self.sha_subair.eval(builder, C::VM_CONTROL_WIDTH); + } +} + +#[allow(dead_code, non_camel_case_types)] +pub(super) enum PaddingFlags { + /// Not considered for padding - W's are not constrained + NotConsidered, + /// Not padding - W's should be equal to the message + NotPadding, + /// FIRST_PADDING_i: it is the first row with padding and there are i cells of non-padding + FirstPadding0, + FirstPadding1, + FirstPadding2, + FirstPadding3, + FirstPadding4, + FirstPadding5, + FirstPadding6, + FirstPadding7, + FirstPadding8, + FirstPadding9, + FirstPadding10, + FirstPadding11, + FirstPadding12, + FirstPadding13, + FirstPadding14, + FirstPadding15, + FirstPadding16, + FirstPadding17, + FirstPadding18, + FirstPadding19, + FirstPadding20, + FirstPadding21, + FirstPadding22, + FirstPadding23, + FirstPadding24, + FirstPadding25, + FirstPadding26, + FirstPadding27, + FirstPadding28, + FirstPadding29, + FirstPadding30, + FirstPadding31, + /// FIRST_PADDING_i_LastRow: it is the first row with padding and there are i cells of + /// non-padding AND it is the last reading row of the message + /// NOTE: if the Last row has padding it has to be at least: + /// - 9 cells since the last 8 cells are padded with the message length (for SHA-256) + /// - 17 cells since the last 16 cells are padded with the message length (for SHA-512) + FirstPadding0_LastRow, + FirstPadding1_LastRow, + FirstPadding2_LastRow, + FirstPadding3_LastRow, + FirstPadding4_LastRow, + FirstPadding5_LastRow, + FirstPadding6_LastRow, + FirstPadding7_LastRow, + FirstPadding8_LastRow, + FirstPadding9_LastRow, + FirstPadding10_LastRow, + FirstPadding11_LastRow, + FirstPadding12_LastRow, + FirstPadding13_LastRow, + FirstPadding14_LastRow, + FirstPadding15_LastRow, + + /// The entire row is padding AND it is not the first row with padding + /// AND it is the 4th row of the last block of the message + EntirePaddingLastRow, + /// The entire row is padding AND it is not the first row with padding + EntirePadding, +} + +impl PaddingFlags { + /// The number of padding flags (including NotConsidered) + pub const COUNT: usize = EntirePadding as usize + 1; +} + +use PaddingFlags::*; +impl Sha2VmAir { + /// Implement all necessary constraints for the padding + fn eval_padding(&self, builder: &mut AB) { + let main = builder.main(); + let (local, next) = (main.row_slice(0), main.row_slice(1)); + let local_cols = Sha2VmRoundColsRef::::from::(&local[..C::VM_ROUND_WIDTH]); + let next_cols = Sha2VmRoundColsRef::::from::(&next[..C::VM_ROUND_WIDTH]); + + // Constrain the sanity of the padding flags + self.padding_encoder + .eval(builder, local_cols.control.pad_flags.as_slice().unwrap()); + + builder.assert_one(self.padding_encoder.contains_flag_range::( + local_cols.control.pad_flags.as_slice().unwrap(), + NotConsidered as usize..=EntirePadding as usize, + )); + + Self::eval_padding_transitions(self, builder, &local_cols, &next_cols); + Self::eval_padding_row(self, builder, &local_cols); + } + + fn eval_padding_transitions( + &self, + builder: &mut AB, + local: &Sha2VmRoundColsRef, + next: &Sha2VmRoundColsRef, + ) { + let next_is_last_row = *next.inner.flags.is_digest_row * *next.inner.flags.is_last_block; + + // Constrain that `padding_occured` is 1 on a suffix of rows in each message, excluding the + // last digest row, and 0 everywhere else. Furthermore, the suffix starts in the + // first 4 rows of some block. + + builder.assert_bool(*local.control.padding_occurred); + // Last round row in the last block has padding_occurred = 1 + // This is the end of the suffix + builder + .when(next_is_last_row.clone()) + .assert_one(*local.control.padding_occurred); + + // Digest row in the last block has padding_occurred = 0 + builder + .when(next_is_last_row.clone()) + .assert_zero(*next.control.padding_occurred); + + // If padding_occurred = 1 in the current row, then padding_occurred = 1 in the next row, + // unless next is the last digest row + builder + .when(*local.control.padding_occurred - next_is_last_row.clone()) + .assert_one(*next.control.padding_occurred); + + // If next row is not first 4 rows of a block, then next.padding_occurred = + // local.padding_occurred. So padding_occurred only changes in the first 4 rows of a + // block. + builder + .when_transition() + .when(not(*next.inner.flags.is_first_4_rows) - next_is_last_row) + .assert_eq( + *next.control.padding_occurred, + *local.control.padding_occurred, + ); + + // Constrain the that the start of the padding is correct + let next_is_first_padding_row = + *next.control.padding_occurred - *local.control.padding_occurred; + // Row index if its between 0..4, else 0 + let next_row_idx = self.sha_subair.row_idx_encoder.flag_with_val::( + next.inner.flags.row_idx.as_slice().unwrap(), + &(0..C::MESSAGE_ROWS).map(|x| (x, x)).collect::>(), + ); + // How many non-padding cells there are in the next row. + // Will be 0 on non-padding rows. + let next_padding_offset = self.padding_encoder.flag_with_val::( + next.control.pad_flags.as_slice().unwrap(), + &(0..C::MAX_FIRST_PADDING + 1) + .map(|i| (FirstPadding0 as usize + i, i)) + .collect::>(), + ) + self.padding_encoder.flag_with_val::( + next.control.pad_flags.as_slice().unwrap(), + &(0..C::MAX_FIRST_PADDING_LAST_ROW + 1) + .map(|i| (FirstPadding0_LastRow as usize + i, i)) + .collect::>(), + ); + + // Will be 0 on last digest row since: + // - padding_occurred = 0 is constrained above + // - next_row_idx = 0 since row_idx is not in 0..4 + // - and next_padding_offset = 0 since `pad_flags = NotConsidered` + let expected_len = *next.inner.flags.local_block_idx + * *next.control.padding_occurred + * AB::Expr::from_canonical_usize(C::BLOCK_U8S) + + next_row_idx * AB::Expr::from_canonical_usize(C::READ_SIZE) + + next_padding_offset; + + // Note: `next_is_first_padding_row` is either -1,0,1 + // If 1, then this constrains the length of message + // If -1, then `next` must be the last digest row and so this constraint will be 0 == 0 + builder.when(next_is_first_padding_row).assert_eq( + expected_len, + *next.control.len * *next.control.padding_occurred, + ); + + // Constrain the padding flags are of correct type (eg is not padding or first padding) + let is_next_first_padding = self.padding_encoder.contains_flag_range::( + next.control.pad_flags.as_slice().unwrap(), + FirstPadding0 as usize..=(FirstPadding15_LastRow as usize), + ); + + let is_next_last_padding = self.padding_encoder.contains_flag_range::( + next.control.pad_flags.as_slice().unwrap(), + FirstPadding0_LastRow as usize..=EntirePaddingLastRow as usize, + ); + + let is_next_entire_padding = self.padding_encoder.contains_flag_range::( + next.control.pad_flags.as_slice().unwrap(), + EntirePaddingLastRow as usize..=EntirePadding as usize, + ); + + let is_next_not_considered = self.padding_encoder.contains_flag::( + next.control.pad_flags.as_slice().unwrap(), + &[NotConsidered as usize], + ); + + let is_next_not_padding = self.padding_encoder.contains_flag::( + next.control.pad_flags.as_slice().unwrap(), + &[NotPadding as usize], + ); + + let is_next_4th_row = self + .sha_subair + .row_idx_encoder + .contains_flag::(next.inner.flags.row_idx.as_slice().unwrap(), &[3]); + + // `pad_flags` is `NotConsidered` on all rows except the first 4 rows of a block + builder.assert_eq( + not(*next.inner.flags.is_first_4_rows), + is_next_not_considered, + ); + + // `pad_flags` is `EntirePadding` if the previous row is padding + builder.when(*next.inner.flags.is_first_4_rows).assert_eq( + *local.control.padding_occurred * *next.control.padding_occurred, + is_next_entire_padding, + ); + + // `pad_flags` is `FirstPadding*` if current row is padding and the previous row is not + // padding + builder.when(*next.inner.flags.is_first_4_rows).assert_eq( + not(*local.control.padding_occurred) * *next.control.padding_occurred, + is_next_first_padding, + ); + + // `pad_flags` is `NotPadding` if current row is not padding + builder + .when(*next.inner.flags.is_first_4_rows) + .assert_eq(not(*next.control.padding_occurred), is_next_not_padding); + + // `pad_flags` is `*LastRow` on the row that contains the last four words of the message + builder + .when(*next.inner.flags.is_last_block) + .assert_eq(is_next_4th_row, is_next_last_padding); + } + + fn eval_padding_row( + &self, + builder: &mut AB, + local: &Sha2VmRoundColsRef, + ) { + let message = (0..C::READ_SIZE) + .map(|i| { + local.inner.message_schedule.carry_or_buffer[[i / (C::WORD_U8S), i % (C::WORD_U8S)]] + }) + .collect::>(); + + let get_ith_byte = |i: usize| { + let word_idx = i / C::WORD_U8S; + let word = local + .inner + .message_schedule + .w + .row(word_idx) + .mapv(|x| x.into()); + // Need to reverse the byte order to match the endianness of the memory + let byte_idx = C::WORD_U8S - i % C::WORD_U8S - 1; + compose::( + &word.as_slice().unwrap()[byte_idx * 8..(byte_idx + 1) * 8], + 1, + ) + }; + + let is_not_padding = self.padding_encoder.contains_flag::( + local.control.pad_flags.as_slice().unwrap(), + &[NotPadding as usize], + ); + + // Check the `w`s on case by case basis + for (i, message_byte) in message.iter().enumerate() { + let w = get_ith_byte(i); + let should_be_message = is_not_padding.clone() + + if i < C::MAX_FIRST_PADDING { + self.padding_encoder.contains_flag_range::( + local.control.pad_flags.as_slice().unwrap(), + FirstPadding0 as usize + i + 1 + ..=FirstPadding0 as usize + C::MAX_FIRST_PADDING, + ) + } else { + AB::Expr::ZERO + } + + if i < C::MAX_FIRST_PADDING_LAST_ROW { + self.padding_encoder.contains_flag_range::( + local.control.pad_flags.as_slice().unwrap(), + FirstPadding0_LastRow as usize + i + 1 + ..=FirstPadding0_LastRow as usize + C::MAX_FIRST_PADDING_LAST_ROW, + ) + } else { + AB::Expr::ZERO + }; + + builder + .when(should_be_message) + .assert_eq(w.clone(), *message_byte); + + let should_be_zero = self.padding_encoder.contains_flag::( + local.control.pad_flags.as_slice().unwrap(), + &[EntirePadding as usize], + ) + + // - 4 because the last 4 bytes are the padded length + if i < C::CELLS_PER_ROW - 4 { + self.padding_encoder.contains_flag::( + local.control.pad_flags.as_slice().unwrap(), + &[EntirePaddingLastRow as usize], + ) + if i > 0 { + self.padding_encoder.contains_flag_range::( + local.control.pad_flags.as_slice().unwrap(), + FirstPadding0_LastRow as usize + ..=min( + FirstPadding0_LastRow as usize + i - 1, + FirstPadding0_LastRow as usize + C::MAX_FIRST_PADDING_LAST_ROW, + ), + ) + } else { + AB::Expr::ZERO + } + } else { + AB::Expr::ZERO + } + if i > 0 { + self.padding_encoder.contains_flag_range::( + local.control.pad_flags.as_slice().unwrap(), + FirstPadding0 as usize..=FirstPadding0 as usize + i - 1, + ) + } else { + AB::Expr::ZERO + }; + builder.when(should_be_zero).assert_zero(w.clone()); + + // Assumes bit-length of message is a multiple of 8 (message is bytes) + // This is true because the message is given as &[u8] + let should_be_128 = self.padding_encoder.contains_flag::( + local.control.pad_flags.as_slice().unwrap(), + &[FirstPadding0 as usize + i], + ) + if i < 8 { + self.padding_encoder.contains_flag::( + local.control.pad_flags.as_slice().unwrap(), + &[FirstPadding0_LastRow as usize + i], + ) + } else { + AB::Expr::ZERO + }; + + builder + .when(should_be_128) + .assert_eq(AB::Expr::from_canonical_u32(1 << 7), w); + + // should be len is handled outside of the loop + } + let appended_len = compose::( + &[ + get_ith_byte(C::CELLS_PER_ROW - 1), + get_ith_byte(C::CELLS_PER_ROW - 2), + get_ith_byte(C::CELLS_PER_ROW - 3), + get_ith_byte(C::CELLS_PER_ROW - 4), + ], + RV32_CELL_BITS, + ); + + let actual_len = *local.control.len; + + let is_last_padding_row = self.padding_encoder.contains_flag_range::( + local.control.pad_flags.as_slice().unwrap(), + FirstPadding0_LastRow as usize..=EntirePaddingLastRow as usize, + ); + + builder.when(is_last_padding_row.clone()).assert_eq( + appended_len * AB::F::from_canonical_usize(RV32_CELL_BITS).inverse(), // bit to byte conversion + actual_len, + ); + + // We constrain that the appended length is in bytes + builder.when(is_last_padding_row.clone()).assert_zero( + local.inner.message_schedule.w[[3, 0]] + + local.inner.message_schedule.w[[3, 1]] + + local.inner.message_schedule.w[[3, 2]], + ); + + // We can't support messages longer than 2^29 bytes because the length has to fit in a + // field element. So, constrain that the first few bytes of the length are 0 (so only the + // last 4 bytes of the length can be nonzero). Thus, the bit-length is < 2^32 so the message + // is < 2^29 bytes. + // For SHA-256, assert bytes 8..12 are 0, because the message length is 8 bytes, and each + // row has 16 bytes. + // For SHA-512 and SHA-384, assert bytes 16..28 are 0, because the + // message length is 16 bytes and each row has 32 bytes. + for i in C::CELLS_PER_ROW - C::MESSAGE_LENGTH_BITS / 8..C::CELLS_PER_ROW - 4 { + builder + .when(is_last_padding_row.clone()) + .assert_zero(get_ith_byte(i)); + } + } + /// Implement constraints on `len`, `read_ptr` and `cur_timestamp` + fn eval_transitions(&self, builder: &mut AB) { + let main = builder.main(); + let (local, next) = (main.row_slice(0), main.row_slice(1)); + let local_cols = Sha2VmRoundColsRef::::from::(&local[..C::VM_ROUND_WIDTH]); + let next_cols = Sha2VmRoundColsRef::::from::(&next[..C::VM_ROUND_WIDTH]); + + let is_last_row = + *local_cols.inner.flags.is_last_block * *local_cols.inner.flags.is_digest_row; + // Len should be the same for the entire message + builder + .when_transition() + .when(not::(is_last_row.clone())) + .assert_eq(*next_cols.control.len, *local_cols.control.len); + + // Read ptr should increment by [C::READ_SIZE] for the first 4 rows and stay the same + // otherwise + let read_ptr_delta = + *local_cols.inner.flags.is_first_4_rows * AB::Expr::from_canonical_usize(C::READ_SIZE); + builder + .when_transition() + .when(not::(is_last_row.clone())) + .assert_eq( + *next_cols.control.read_ptr, + *local_cols.control.read_ptr + read_ptr_delta, + ); + + // Timestamp should increment by 1 for the first 4 rows and stay the same otherwise + let timestamp_delta = *local_cols.inner.flags.is_first_4_rows * AB::Expr::ONE; + builder + .when_transition() + .when(not::(is_last_row.clone())) + .assert_eq( + *next_cols.control.cur_timestamp, + *local_cols.control.cur_timestamp + timestamp_delta, + ); + } + + /// Implement the reads for the first 4 rows of a block + fn eval_reads(&self, builder: &mut AB) { + let main = builder.main(); + let local = main.row_slice(0); + let local_cols = Sha2VmRoundColsRef::::from::(&local[..C::VM_ROUND_WIDTH]); + + let message: Vec = (0..C::READ_SIZE) + .map(|i| { + local_cols.inner.message_schedule.carry_or_buffer + [[i / (C::WORD_U16S * 2), i % (C::WORD_U16S * 2)]] + }) + .collect(); + + match C::VARIANT { + Sha2Variant::Sha256 => { + let message: [AB::Var; Sha256Config::READ_SIZE] = + message.try_into().unwrap_or_else(|_| { + panic!("message is not the correct size"); + }); + self.memory_bridge + .read( + MemoryAddress::new( + AB::Expr::from_canonical_u32(RV32_MEMORY_AS), + *local_cols.control.read_ptr, + ), + message, + *local_cols.control.cur_timestamp, + local_cols.read_aux, + ) + .eval(builder, *local_cols.inner.flags.is_first_4_rows); + } + // Sha512 and Sha384 have the same read size so we put them together + Sha2Variant::Sha512 | Sha2Variant::Sha384 => { + let message: [AB::Var; Sha512Config::READ_SIZE] = + message.try_into().unwrap_or_else(|_| { + panic!("message is not the correct size"); + }); + self.memory_bridge + .read( + MemoryAddress::new( + AB::Expr::from_canonical_u32(RV32_MEMORY_AS), + *local_cols.control.read_ptr, + ), + message, + *local_cols.control.cur_timestamp, + local_cols.read_aux, + ) + .eval(builder, *local_cols.inner.flags.is_first_4_rows); + } + } + } + /// Implement the constraints for the last row of a message + fn eval_last_row(&self, builder: &mut AB) { + let main = builder.main(); + let local = main.row_slice(0); + let local_cols = Sha2VmDigestColsRef::::from::(&local[..C::VM_DIGEST_WIDTH]); + + let timestamp: AB::Var = local_cols.from_state.timestamp; + let mut timestamp_delta: usize = 0; + let mut timestamp_pp = || { + timestamp_delta += 1; + timestamp + AB::Expr::from_canonical_usize(timestamp_delta - 1) + }; + + let is_last_row = + *local_cols.inner.flags.is_last_block * *local_cols.inner.flags.is_digest_row; + + let dst_ptr: [AB::Var; RV32_REGISTER_NUM_LIMBS] = + local_cols.dst_ptr.to_vec().try_into().unwrap_or_else(|_| { + panic!("dst_ptr is not the correct size"); + }); + self.memory_bridge + .read( + MemoryAddress::new( + AB::Expr::from_canonical_u32(RV32_REGISTER_AS), + *local_cols.rd_ptr, + ), + dst_ptr, + timestamp_pp(), + &local_cols.register_reads_aux[0], + ) + .eval(builder, is_last_row.clone()); + + let src_ptr: [AB::Var; RV32_REGISTER_NUM_LIMBS] = + local_cols.src_ptr.to_vec().try_into().unwrap_or_else(|_| { + panic!("src_ptr is not the correct size"); + }); + self.memory_bridge + .read( + MemoryAddress::new( + AB::Expr::from_canonical_u32(RV32_REGISTER_AS), + *local_cols.rs1_ptr, + ), + src_ptr, + timestamp_pp(), + &local_cols.register_reads_aux[1], + ) + .eval(builder, is_last_row.clone()); + + let len_data: [AB::Var; RV32_REGISTER_NUM_LIMBS] = + local_cols.len_data.to_vec().try_into().unwrap_or_else(|_| { + panic!("len_data is not the correct size"); + }); + self.memory_bridge + .read( + MemoryAddress::new( + AB::Expr::from_canonical_u32(RV32_REGISTER_AS), + *local_cols.rs2_ptr, + ), + len_data, + timestamp_pp(), + &local_cols.register_reads_aux[2], + ) + .eval(builder, is_last_row.clone()); + // range check that the memory pointers don't overflow + // Note: no need to range check the length since we read from memory step by step and + // the memory bus will catch any memory accesses beyond ptr_max_bits + let shift = AB::Expr::from_canonical_usize( + 1 << (RV32_REGISTER_NUM_LIMBS * RV32_CELL_BITS - self.ptr_max_bits), + ); + // This only works if self.ptr_max_bits >= 24 which is typically the case + self.bitwise_lookup_bus + .send_range( + // It is fine to shift like this since we already know that dst_ptr and src_ptr + // have [RV32_CELL_BITS] bits + local_cols.dst_ptr[RV32_REGISTER_NUM_LIMBS - 1] * shift.clone(), + local_cols.src_ptr[RV32_REGISTER_NUM_LIMBS - 1] * shift.clone(), + ) + .eval(builder, is_last_row.clone()); + + // the number of reads that happened to read the entire message: we do 4 reads per block + let time_delta = (*local_cols.inner.flags.local_block_idx + AB::Expr::ONE) + * AB::Expr::from_canonical_usize(4); + // Every time we read the message we increment the read pointer by C::READ_SIZE + let read_ptr_delta = time_delta.clone() * AB::Expr::from_canonical_usize(C::READ_SIZE); + + let result: Vec = (0..C::HASH_SIZE) + .map(|i| { + // The limbs are written in big endian order to the memory so need to be reversed + local_cols.inner.final_hash[[i / C::WORD_U8S, C::WORD_U8S - i % C::WORD_U8S - 1]] + }) + .collect(); + + let dst_ptr_val = compose::( + local_cols.dst_ptr.mapv(|x| x.into()).as_slice().unwrap(), + RV32_CELL_BITS, + ); + + match C::VARIANT { + Sha2Variant::Sha256 => { + debug_assert_eq!(C::NUM_WRITES, 1); + debug_assert_eq!(local_cols.writes_aux_base.len(), 1); + debug_assert_eq!(local_cols.writes_aux_prev_data.nrows(), 1); + let prev_data: [AB::Var; Sha256Config::HASH_SIZE] = local_cols + .writes_aux_prev_data + .row(0) + .to_vec() + .try_into() + .unwrap_or_else(|_| { + panic!("writes_aux_prev_data is not the correct size"); + }); + // Note: revisit in the future to do 2 block writes of 16 cells instead of 1 block + // write of 32 cells. This could be beneficial as the output is often an input for + // another hash + self.memory_bridge + .write( + MemoryAddress::new( + AB::Expr::from_canonical_u32(RV32_MEMORY_AS), + dst_ptr_val, + ), + result.try_into().unwrap_or_else(|_| { + panic!("result is not the correct size"); + }), + timestamp_pp() + time_delta.clone(), + &MemoryWriteAuxCols::from_base(local_cols.writes_aux_base[0], prev_data), + ) + .eval(builder, is_last_row.clone()); + } + Sha2Variant::Sha512 | Sha2Variant::Sha384 => { + debug_assert_eq!(C::NUM_WRITES, 2); + debug_assert_eq!(local_cols.writes_aux_base.len(), 2); + debug_assert_eq!(local_cols.writes_aux_prev_data.nrows(), 2); + + // For Sha384, set the last 16 cells to 0 + let mut truncated_result: Vec = + result.iter().map(|x| (*x).into()).collect(); + for x in truncated_result.iter_mut().skip(C::DIGEST_SIZE) { + *x = AB::Expr::ZERO; + } + + // write the digest in two halves because we only support writes up to 32 bytes + for i in 0..Sha512Config::NUM_WRITES { + let prev_data: [AB::Var; Sha512Config::WRITE_SIZE] = local_cols + .writes_aux_prev_data + .row(i) + .to_vec() + .try_into() + .unwrap_or_else(|_| { + panic!("writes_aux_prev_data is not the correct size"); + }); + + self.memory_bridge + .write( + MemoryAddress::new( + AB::Expr::from_canonical_u32(RV32_MEMORY_AS), + dst_ptr_val.clone() + + AB::Expr::from_canonical_usize(i * Sha512Config::WRITE_SIZE), + ), + truncated_result + [i * Sha512Config::WRITE_SIZE..(i + 1) * Sha512Config::WRITE_SIZE] + .to_vec() + .try_into() + .unwrap_or_else(|_| { + panic!("result is not the correct size"); + }), + timestamp_pp() + time_delta.clone(), + &MemoryWriteAuxCols::from_base( + local_cols.writes_aux_base[i], + prev_data, + ), + ) + .eval(builder, is_last_row.clone()); + } + } + } + self.execution_bridge + .execute_and_increment_pc( + AB::Expr::from_canonical_usize(C::OPCODE.global_opcode().as_usize()), + [ + >::into(*local_cols.rd_ptr), + >::into(*local_cols.rs1_ptr), + >::into(*local_cols.rs2_ptr), + AB::Expr::from_canonical_u32(RV32_REGISTER_AS), + AB::Expr::from_canonical_u32(RV32_MEMORY_AS), + ], + *local_cols.from_state, + AB::Expr::from_canonical_usize(timestamp_delta) + time_delta.clone(), + ) + .eval(builder, is_last_row.clone()); + + // Assert that we read the correct length of the message + let len_val = compose::( + local_cols.len_data.mapv(|x| x.into()).as_slice().unwrap(), + RV32_CELL_BITS, + ); + builder + .when(is_last_row.clone()) + .assert_eq(*local_cols.control.len, len_val); + // Assert that we started reading from the correct pointer initially + let src_val = compose::( + local_cols.src_ptr.mapv(|x| x.into()).as_slice().unwrap(), + RV32_CELL_BITS, + ); + builder + .when(is_last_row.clone()) + .assert_eq(*local_cols.control.read_ptr, src_val + read_ptr_delta); + // Assert that we started reading from the correct timestamp + builder.when(is_last_row.clone()).assert_eq( + *local_cols.control.cur_timestamp, + local_cols.from_state.timestamp + AB::Expr::from_canonical_u32(3) + time_delta, + ); + } +} diff --git a/extensions/sha2/circuit/src/sha2_chip/columns.rs b/extensions/sha2/circuit/src/sha2_chip/columns.rs new file mode 100644 index 0000000000..20a2080860 --- /dev/null +++ b/extensions/sha2/circuit/src/sha2_chip/columns.rs @@ -0,0 +1,106 @@ +//! WARNING: the order of fields in the structs is important, do not change it + +use openvm_circuit::{ + arch::ExecutionState, + system::memory::offline_checker::{MemoryBaseAuxCols, MemoryReadAuxCols}, +}; +use openvm_circuit_primitives_derive::ColsRef; +use openvm_instructions::riscv::RV32_REGISTER_NUM_LIMBS; +use openvm_sha2_air::{ + ShaDigestCols, ShaDigestColsRef, ShaDigestColsRefMut, ShaRoundCols, ShaRoundColsRef, + ShaRoundColsRefMut, +}; + +use super::SHA_REGISTER_READS; +use crate::Sha2ChipConfig; + +/// the first C::ROUND_ROWS rows of every SHA block will be of type ShaVmRoundCols and the last row +/// will be of type ShaVmDigestCols +#[repr(C)] +#[derive(Clone, Copy, Debug, ColsRef)] +#[config(Sha2ChipConfig)] +pub struct Sha2VmRoundCols< + T, + const WORD_BITS: usize, + const WORD_U8S: usize, + const WORD_U16S: usize, + const ROUNDS_PER_ROW: usize, + const ROUNDS_PER_ROW_MINUS_ONE: usize, + const ROW_VAR_CNT: usize, +> { + pub control: Sha2VmControlCols, + pub inner: ShaRoundCols< + T, + WORD_BITS, + WORD_U8S, + WORD_U16S, + ROUNDS_PER_ROW, + ROUNDS_PER_ROW_MINUS_ONE, + ROW_VAR_CNT, + >, + #[aligned_borrow] + pub read_aux: MemoryReadAuxCols, +} + +#[repr(C)] +#[derive(Clone, Copy, Debug, ColsRef)] +#[config(Sha2ChipConfig)] +pub struct Sha2VmDigestCols< + T, + const WORD_BITS: usize, + const WORD_U8S: usize, + const WORD_U16S: usize, + const HASH_WORDS: usize, + const ROUNDS_PER_ROW: usize, + const ROUNDS_PER_ROW_MINUS_ONE: usize, + const ROW_VAR_CNT: usize, + const NUM_WRITES: usize, + const WRITE_SIZE: usize, +> { + pub control: Sha2VmControlCols, + pub inner: ShaDigestCols< + T, + WORD_BITS, + WORD_U8S, + WORD_U16S, + HASH_WORDS, + ROUNDS_PER_ROW, + ROUNDS_PER_ROW_MINUS_ONE, + ROW_VAR_CNT, + >, + #[aligned_borrow] + pub from_state: ExecutionState, + /// It is counter intuitive, but we will constrain the register reads on the very last row of + /// every message + pub rd_ptr: T, + pub rs1_ptr: T, + pub rs2_ptr: T, + pub dst_ptr: [T; RV32_REGISTER_NUM_LIMBS], + pub src_ptr: [T; RV32_REGISTER_NUM_LIMBS], + pub len_data: [T; RV32_REGISTER_NUM_LIMBS], + #[aligned_borrow] + pub register_reads_aux: [MemoryReadAuxCols; SHA_REGISTER_READS], + // We store the fields of MemoryWriteAuxCols here because the length of prev_data depends on + // the sha variant + #[aligned_borrow] + pub writes_aux_base: [MemoryBaseAuxCols; NUM_WRITES], + pub writes_aux_prev_data: [[T; WRITE_SIZE]; NUM_WRITES], +} + +/// These are the columns that are used on both round and digest rows +#[repr(C)] +#[derive(Clone, Copy, Debug, ColsRef)] +#[config(Sha2ChipConfig)] +pub struct Sha2VmControlCols { + /// Note: We will use the buffer in `inner.message_schedule` as the message data + /// This is the length of the entire message in bytes + pub len: T, + /// Need to keep timestamp and read_ptr since block reads don't have the necessary information + pub cur_timestamp: T, + pub read_ptr: T, + /// Padding flags which will be used to encode the the number of non-padding cells in the + /// current row + pub pad_flags: [T; 9], + /// A boolean flag that indicates whether a padding already occurred + pub padding_occurred: T, +} diff --git a/extensions/sha2/circuit/src/sha2_chip/config.rs b/extensions/sha2/circuit/src/sha2_chip/config.rs new file mode 100644 index 0000000000..7dfed9610a --- /dev/null +++ b/extensions/sha2/circuit/src/sha2_chip/config.rs @@ -0,0 +1,99 @@ +use openvm_instructions::riscv::RV32_CELL_BITS; +use openvm_sha2_air::{Sha256Config, Sha2Config, Sha384Config, Sha512Config}; +use openvm_sha2_transpiler::Rv32Sha2Opcode; + +use super::{Sha2VmControlColsRef, Sha2VmDigestColsRef, Sha2VmRoundColsRef}; + +pub trait Sha2ChipConfig: Sha2Config { + // Name of the opcode + const OPCODE_NAME: &'static str; + /// Width of the ShaVmControlCols + const VM_CONTROL_WIDTH: usize = Sha2VmControlColsRef::::width::(); + /// Width of the ShaVmRoundCols + const VM_ROUND_WIDTH: usize = Sha2VmRoundColsRef::::width::(); + /// Width of the ShaVmDigestCols + const VM_DIGEST_WIDTH: usize = Sha2VmDigestColsRef::::width::(); + /// Width of the ShaVmCols + const VM_WIDTH: usize = if Self::VM_ROUND_WIDTH > Self::VM_DIGEST_WIDTH { + Self::VM_ROUND_WIDTH + } else { + Self::VM_DIGEST_WIDTH + }; + /// Number of bits to use when padding the message length. Given by the SHA-2 spec. + const MESSAGE_LENGTH_BITS: usize; + /// Maximum i such that `FirstPadding_i` is a valid padding flag + const MAX_FIRST_PADDING: usize = Self::CELLS_PER_ROW - 1; + /// Maximum i such that `FirstPadding_i_LastRow` is a valid padding flag + const MAX_FIRST_PADDING_LAST_ROW: usize = + Self::CELLS_PER_ROW - Self::MESSAGE_LENGTH_BITS / 8 - 1; + /// OpenVM Opcode for the instruction + const OPCODE: Rv32Sha2Opcode; + + // ==== Constants for register/memory adapter ==== + /// Number of rv32 cells read in a block + const BLOCK_CELLS: usize = Self::BLOCK_BITS / RV32_CELL_BITS; + /// Number of rows we will do a read on for each block + const NUM_READ_ROWS: usize = Self::MESSAGE_ROWS; + + /// Number of cells to read in a single memory access + const READ_SIZE: usize = Self::WORD_U8S * Self::ROUNDS_PER_ROW; + /// Number of cells in the digest before truncation (Sha384 truncates the digest) + const HASH_SIZE: usize = Self::WORD_U8S * Self::HASH_WORDS; + /// Number of cells in the digest after truncation + const DIGEST_SIZE: usize; + + /// Number of parts to write the hash in + const NUM_WRITES: usize = Self::HASH_SIZE / Self::WRITE_SIZE; + /// Size of each write. Must divide Self::HASH_SIZE + const WRITE_SIZE: usize; +} + +/// Register reads to get dst, src, len +pub const SHA_REGISTER_READS: usize = 3; + +impl Sha2ChipConfig for Sha256Config { + const OPCODE_NAME: &'static str = "SHA256"; + const MESSAGE_LENGTH_BITS: usize = 64; + const WRITE_SIZE: usize = SHA_WRITE_SIZE; + const OPCODE: Rv32Sha2Opcode = Rv32Sha2Opcode::SHA256; + // no truncation + const DIGEST_SIZE: usize = Self::HASH_SIZE; +} + +impl Sha2ChipConfig for Sha512Config { + const OPCODE_NAME: &'static str = "SHA512"; + const MESSAGE_LENGTH_BITS: usize = 128; + const WRITE_SIZE: usize = SHA_WRITE_SIZE; + const OPCODE: Rv32Sha2Opcode = Rv32Sha2Opcode::SHA512; + // no truncation + const DIGEST_SIZE: usize = Self::HASH_SIZE; +} + +impl Sha2ChipConfig for Sha384Config { + const OPCODE_NAME: &'static str = "SHA384"; + const MESSAGE_LENGTH_BITS: usize = 128; + const WRITE_SIZE: usize = SHA_WRITE_SIZE; + const OPCODE: Rv32Sha2Opcode = Rv32Sha2Opcode::SHA384; + // Sha384 truncates the output to 48 cells + const DIGEST_SIZE: usize = 48; +} + +// We use the same write size for all variants to simplify tracegen record storage. +// In particular, each memory write aux record will have the same size, which is useful for +// defining Sha2VmRecordHeader in a repr(C) way. +pub const SHA_WRITE_SIZE: usize = 32; + +pub const MAX_SHA_NUM_WRITES: usize = if Sha256Config::NUM_WRITES > Sha512Config::NUM_WRITES { + if Sha256Config::NUM_WRITES > Sha384Config::NUM_WRITES { + Sha256Config::NUM_WRITES + } else { + Sha384Config::NUM_WRITES + } +} else if Sha512Config::NUM_WRITES > Sha384Config::NUM_WRITES { + Sha512Config::NUM_WRITES +} else { + Sha384Config::NUM_WRITES +}; + +/// Maximum message length that this chip supports in bytes +pub const SHA_MAX_MESSAGE_LEN: usize = 1 << 29; diff --git a/extensions/sha2/circuit/src/sha2_chip/mod.rs b/extensions/sha2/circuit/src/sha2_chip/mod.rs new file mode 100644 index 0000000000..7525ec8435 --- /dev/null +++ b/extensions/sha2/circuit/src/sha2_chip/mod.rs @@ -0,0 +1,284 @@ +//! Sha256 hasher. Handles full sha256 hashing with padding. +//! variable length inputs read from VM memory. +use std::{ + borrow::{Borrow, BorrowMut}, + iter, +}; + +use openvm_circuit::arch::{ + execution_mode::{E1ExecutionCtx, E2ExecutionCtx}, + E2PreCompute, ExecuteFunc, + ExecutionError::InvalidInstruction, + MatrixRecordArena, NewVmChipWrapper, Result, StepExecutorE1, StepExecutorE2, VmSegmentState, +}; +use openvm_circuit_primitives::{ + bitwise_op_lookup::SharedBitwiseOperationLookupChip, encoder::Encoder, +}; +use openvm_circuit_primitives_derive::AlignedBytesBorrow; +use openvm_instructions::{ + instruction::Instruction, + program::DEFAULT_PC_STEP, + riscv::{RV32_CELL_BITS, RV32_MEMORY_AS, RV32_REGISTER_AS}, + LocalOpcode, +}; +use openvm_sha2_air::{Sha256Config, Sha2StepHelper, Sha2Variant, Sha384Config, Sha512Config}; +use openvm_stark_backend::p3_field::PrimeField32; +use sha2::{Digest, Sha256, Sha384, Sha512}; + +mod air; +mod columns; +mod config; +mod trace; +mod utils; + +pub use air::*; +pub use columns::*; +pub use config::*; +pub use utils::get_sha2_num_blocks; + +#[cfg(test)] +mod tests; + +pub type Sha2VmChip = NewVmChipWrapper, Sha2VmStep, MatrixRecordArena>; + +pub struct Sha2VmStep { + pub inner: Sha2StepHelper, + pub padding_encoder: Encoder, + pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, + pub offset: usize, + pub pointer_max_bits: usize, +} + +impl Sha2VmStep { + pub fn new( + bitwise_lookup_chip: SharedBitwiseOperationLookupChip, + offset: usize, + pointer_max_bits: usize, + ) -> Self { + Self { + inner: Sha2StepHelper::::new(), + padding_encoder: Encoder::new(PaddingFlags::COUNT, 2, false), + bitwise_lookup_chip, + offset, + pointer_max_bits, + } + } +} + +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct Sha2PreCompute { + a: u8, + b: u8, + c: u8, +} + +impl StepExecutorE1 for Sha2VmStep { + fn pre_compute_size(&self) -> usize { + size_of::() + } + + fn pre_compute_e1( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result> + where + Ctx: E1ExecutionCtx, + { + let data: &mut Sha2PreCompute = data.borrow_mut(); + self.pre_compute_impl(pc, inst, data)?; + Ok(execute_e1_impl::<_, _, C>) + } +} +impl StepExecutorE2 for Sha2VmStep { + fn e2_pre_compute_size(&self) -> usize { + size_of::>() + } + + fn pre_compute_e2( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result> + where + Ctx: E2ExecutionCtx, + { + let data: &mut E2PreCompute = data.borrow_mut(); + data.chip_idx = chip_idx as u32; + self.pre_compute_impl(pc, inst, &mut data.data)?; + Ok(execute_e2_impl::<_, _, C>) + } +} + +unsafe fn execute_e12_impl< + F: PrimeField32, + CTX: E1ExecutionCtx, + C: Sha2ChipConfig, + const IS_E1: bool, +>( + pre_compute: &Sha2PreCompute, + vm_state: &mut VmSegmentState, +) -> u32 { + let dst = vm_state.vm_read(RV32_REGISTER_AS, pre_compute.a as u32); + let src = vm_state.vm_read(RV32_REGISTER_AS, pre_compute.b as u32); + let len = vm_state.vm_read(RV32_REGISTER_AS, pre_compute.c as u32); + let dst_u32 = u32::from_le_bytes(dst); + let src_u32 = u32::from_le_bytes(src); + let len_u32 = u32::from_le_bytes(len); + + let (output, height) = if IS_E1 { + // SAFETY: RV32_MEMORY_AS is memory address space of type u8 + let message = vm_state.vm_read_slice(RV32_MEMORY_AS, src_u32, len_u32 as usize); + let output = sha2_solve::(message); + (output, 0) + } else { + let num_blocks = get_sha2_num_blocks::(len_u32); + let mut message = Vec::with_capacity(len_u32 as usize); + for block_idx in 0..num_blocks as usize { + // Reads happen on the first 4 rows of each block + for row in 0..C::NUM_READ_ROWS { + let read_idx = block_idx * C::NUM_READ_ROWS + row; + match C::VARIANT { + Sha2Variant::Sha256 => { + let row_input: [u8; Sha256Config::READ_SIZE] = vm_state + .vm_read(RV32_MEMORY_AS, src_u32 + (read_idx * C::READ_SIZE) as u32); + message.extend_from_slice(&row_input); + } + Sha2Variant::Sha512 => { + let row_input: [u8; Sha512Config::READ_SIZE] = vm_state + .vm_read(RV32_MEMORY_AS, src_u32 + (read_idx * C::READ_SIZE) as u32); + message.extend_from_slice(&row_input); + } + Sha2Variant::Sha384 => { + let row_input: [u8; Sha384Config::READ_SIZE] = vm_state + .vm_read(RV32_MEMORY_AS, src_u32 + (read_idx * C::READ_SIZE) as u32); + message.extend_from_slice(&row_input); + } + } + } + } + let output = sha2_solve::(&message[..len_u32 as usize]); + let height = num_blocks * C::ROWS_PER_BLOCK as u32; + (output, height) + }; + match C::VARIANT { + Sha2Variant::Sha256 => { + let output: [u8; Sha256Config::WRITE_SIZE] = output.try_into().unwrap(); + vm_state.vm_write(RV32_MEMORY_AS, dst_u32, &output); + } + Sha2Variant::Sha512 => { + for i in 0..C::NUM_WRITES { + let output: [u8; Sha512Config::WRITE_SIZE] = output + [i * Sha512Config::WRITE_SIZE..(i + 1) * Sha512Config::WRITE_SIZE] + .try_into() + .unwrap(); + vm_state.vm_write( + RV32_MEMORY_AS, + dst_u32 + (i * Sha512Config::WRITE_SIZE) as u32, + &output, + ); + } + } + Sha2Variant::Sha384 => { + // Pad the output with zeros to 64 bytes + let output = output + .into_iter() + .chain(iter::repeat(0).take(16)) + .collect::>(); + for i in 0..C::NUM_WRITES { + let output: [u8; Sha384Config::WRITE_SIZE] = output + [i * Sha384Config::WRITE_SIZE..(i + 1) * Sha384Config::WRITE_SIZE] + .try_into() + .unwrap(); + vm_state.vm_write( + RV32_MEMORY_AS, + dst_u32 + (i * Sha384Config::WRITE_SIZE) as u32, + &output, + ); + } + } + } + + vm_state.pc = vm_state.pc.wrapping_add(DEFAULT_PC_STEP); + vm_state.instret += 1; + + height +} + +unsafe fn execute_e1_impl( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &Sha2PreCompute = pre_compute.borrow(); + execute_e12_impl::(pre_compute, vm_state); +} +unsafe fn execute_e2_impl( + pre_compute: &[u8], + vm_state: &mut VmSegmentState, +) { + let pre_compute: &E2PreCompute = pre_compute.borrow(); + let height = execute_e12_impl::(&pre_compute.data, vm_state); + vm_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, height); +} + +impl Sha2VmStep { + fn pre_compute_impl( + &self, + pc: u32, + inst: &Instruction, + data: &mut Sha2PreCompute, + ) -> Result<()> { + let Instruction { + opcode, + a, + b, + c, + d, + e, + .. + } = inst; + let e_u32 = e.as_canonical_u32(); + if d.as_canonical_u32() != RV32_REGISTER_AS || e_u32 != RV32_MEMORY_AS { + return Err(InvalidInstruction(pc)); + } + *data = Sha2PreCompute { + a: a.as_canonical_u32() as u8, + b: b.as_canonical_u32() as u8, + c: c.as_canonical_u32() as u8, + }; + assert_eq!(&C::OPCODE.global_opcode(), opcode); + Ok(()) + } +} + +pub fn sha2_solve(input_message: &[u8]) -> Vec { + match C::VARIANT { + Sha2Variant::Sha256 => { + let mut hasher = Sha256::new(); + hasher.update(input_message); + let mut output = vec![0u8; C::DIGEST_SIZE]; + output.copy_from_slice(hasher.finalize().as_ref()); + output + } + Sha2Variant::Sha512 => { + let mut hasher = Sha512::new(); + hasher.update(input_message); + let mut output = vec![0u8; C::DIGEST_SIZE]; + output.copy_from_slice(hasher.finalize().as_ref()); + output + } + Sha2Variant::Sha384 => { + let mut hasher = Sha384::new(); + hasher.update(input_message); + let mut output = vec![0u8; C::DIGEST_SIZE]; + output.copy_from_slice(hasher.finalize().as_ref()); + output + } + } +} diff --git a/extensions/sha2/circuit/src/sha2_chip/tests.rs b/extensions/sha2/circuit/src/sha2_chip/tests.rs new file mode 100644 index 0000000000..9cba30b34b --- /dev/null +++ b/extensions/sha2/circuit/src/sha2_chip/tests.rs @@ -0,0 +1,321 @@ +use std::array; + +use openvm_circuit::{ + arch::{ + testing::{memory::gen_pointer, VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}, + DenseRecordArena, InsExecutorE1, InstructionExecutor, NewVmChipWrapper, + }, + utils::get_random_message, +}; +use openvm_circuit_primitives::bitwise_op_lookup::{ + BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, +}; +use openvm_instructions::{instruction::Instruction, riscv::RV32_CELL_BITS, LocalOpcode}; +use openvm_sha2_air::{Sha256Config, Sha2Variant, Sha384Config, Sha512Config}; +use openvm_sha2_transpiler::Rv32Sha2Opcode; +use openvm_stark_backend::{interaction::BusIndex, p3_field::FieldAlgebra}; +use openvm_stark_sdk::{config::setup_tracing, p3_baby_bear::BabyBear, utils::create_seeded_rng}; +use rand::{rngs::StdRng, Rng}; + +use super::{Sha2ChipConfig, Sha2VmAir, Sha2VmChip, Sha2VmStep}; +use crate::{ + sha2_chip::trace::Sha2VmRecordLayout, sha2_solve, Sha2VmDigestColsRef, Sha2VmRoundColsRef, +}; + +type F = BabyBear; +const SELF_BUS_IDX: BusIndex = 28; +const MAX_INS_CAPACITY: usize = 8192; +type Sha2VmChipDense = NewVmChipWrapper, Sha2VmStep, DenseRecordArena>; + +fn create_test_chips( + tester: &mut VmChipTestBuilder, +) -> ( + Sha2VmChip, + SharedBitwiseOperationLookupChip, +) { + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); + let mut chip = Sha2VmChip::::new( + Sha2VmAir::new( + tester.system_port(), + bitwise_bus, + tester.address_bits(), + SELF_BUS_IDX, + ), + Sha2VmStep::new( + bitwise_chip.clone(), + Rv32Sha2Opcode::CLASS_OFFSET, + tester.address_bits(), + ), + tester.memory_helper(), + ); + chip.set_trace_height(MAX_INS_CAPACITY); + + (chip, bitwise_chip) +} + +fn set_and_execute, C: Sha2ChipConfig>( + tester: &mut VmChipTestBuilder, + chip: &mut E, + rng: &mut StdRng, + opcode: Rv32Sha2Opcode, + message: Option<&[u8]>, + len: Option, +) { + let len = len.unwrap_or(rng.gen_range(1..3000)); + let tmp = get_random_message(rng, len); + let message: &[u8] = message.unwrap_or(&tmp); + let len = message.len(); + + let rd = gen_pointer(rng, 4); + let rs1 = gen_pointer(rng, 4); + let rs2 = gen_pointer(rng, 4); + + let max_mem_ptr: u32 = 1 << tester.address_bits(); + let dst_ptr = rng.gen_range(0..max_mem_ptr - C::DIGEST_SIZE as u32); + let dst_ptr = dst_ptr ^ (dst_ptr & 3); + tester.write(1, rd, dst_ptr.to_le_bytes().map(F::from_canonical_u8)); + let src_ptr = rng.gen_range(0..(max_mem_ptr - len as u32)); + let src_ptr = src_ptr ^ (src_ptr & 3); + tester.write(1, rs1, src_ptr.to_le_bytes().map(F::from_canonical_u8)); + tester.write(1, rs2, len.to_le_bytes().map(F::from_canonical_u8)); + + message.chunks(4).enumerate().for_each(|(i, chunk)| { + let chunk: [&u8; 4] = array::from_fn(|i| chunk.get(i).unwrap_or(&0)); + tester.write( + 2, + src_ptr as usize + i * 4, + chunk.map(|&x| F::from_canonical_u8(x)), + ); + }); + + tester.execute( + chip, + &Instruction::from_usize(opcode.global_opcode(), [rd, rs1, rs2, 1, 2]), + ); + + let output = sha2_solve::(message); + match C::VARIANT { + Sha2Variant::Sha256 => { + assert_eq!( + output + .into_iter() + .map(F::from_canonical_u8) + .collect::>(), + tester.read::<{ Sha256Config::DIGEST_SIZE }>(2, dst_ptr as usize) + ); + } + Sha2Variant::Sha512 | Sha2Variant::Sha384 => { + let mut output = output; + output.extend(std::iter::repeat(0u8).take(C::HASH_SIZE)); + let output = output + .into_iter() + .map(F::from_canonical_u8) + .collect::>(); + for i in 0..C::NUM_WRITES { + assert_eq!( + output[i * C::WRITE_SIZE..(i + 1) * C::WRITE_SIZE], + tester.read::<{ Sha512Config::WRITE_SIZE }>( + 2, + dst_ptr as usize + i * C::WRITE_SIZE + ) + ); + } + } + } +} + +/////////////////////////////////////////////////////////////////////////////////////// +/// POSITIVE TESTS +/// +/// Randomly generate computations and execute, ensuring that the generated trace +/// passes all constraints. +/////////////////////////////////////////////////////////////////////////////////////// +fn rand_sha_test() { + setup_tracing(); + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default(); + let (mut chip, bitwise_chip) = create_test_chips::(&mut tester); + + let num_ops: usize = 10; + for _ in 0..num_ops { + set_and_execute::<_, C>(&mut tester, &mut chip, &mut rng, C::OPCODE, None, None); + } + + let tester = tester.build().load(chip).load(bitwise_chip).finalize(); + tester.simple_test().expect("Verification failed"); +} + +#[test] +fn rand_sha256_test() { + rand_sha_test::(); +} + +#[test] +fn rand_sha512_test() { + rand_sha_test::(); +} + +#[test] +fn rand_sha384_test() { + rand_sha_test::(); +} + +/////////////////////////////////////////////////////////////////////////////////////// +/// SANITY TESTS +/// +/// Ensure that solve functions produce the correct results. +/////////////////////////////////////////////////////////////////////////////////////// +fn execute_roundtrip_sanity_test() { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default(); + let (mut chip, _) = create_test_chips::(&mut tester); + + println!( + "Sha2VmDigestColsRef::::width::(): {}", + Sha2VmDigestColsRef::::width::() + ); + println!( + "Sha2VmRoundColsRef::::width::(): {}", + Sha2VmRoundColsRef::::width::() + ); + + let num_tests: usize = 1; + for _ in 0..num_tests { + set_and_execute::<_, C>(&mut tester, &mut chip, &mut rng, C::OPCODE, None, None); + } +} + +#[test] +fn sha256_roundtrip_sanity_test() { + execute_roundtrip_sanity_test::(); +} + +#[test] +fn sha512_roundtrip_sanity_test() { + execute_roundtrip_sanity_test::(); +} + +#[test] +fn sha384_roundtrip_sanity_test() { + execute_roundtrip_sanity_test::(); +} + +#[test] +fn sha256_solve_sanity_check() { + let input = b"Axiom is the best! Axiom is the best! Axiom is the best! Axiom is the best!"; + let output = sha2_solve::(input); + let expected: [u8; 32] = [ + 99, 196, 61, 185, 226, 212, 131, 80, 154, 248, 97, 108, 157, 55, 200, 226, 160, 73, 207, + 46, 245, 169, 94, 255, 42, 136, 193, 15, 40, 133, 173, 22, + ]; + assert_eq!(output, expected); +} + +#[test] +fn sha512_solve_sanity_check() { + let input = b"Axiom is the best! Axiom is the best! Axiom is the best! Axiom is the best!"; + let output = sha2_solve::(input); + // verified manually against the sha512 command line tool + let expected: [u8; 64] = [ + 0, 8, 195, 142, 70, 71, 97, 208, 132, 132, 243, 53, 179, 186, 8, 162, 71, 75, 126, 21, 130, + 203, 245, 126, 207, 65, 119, 60, 64, 79, 200, 2, 194, 17, 189, 137, 164, 213, 107, 197, + 152, 11, 242, 165, 146, 80, 96, 105, 249, 27, 139, 14, 244, 21, 118, 31, 94, 87, 32, 145, + 149, 98, 235, 75, + ]; + assert_eq!(output, expected); +} + +#[test] +fn sha384_solve_sanity_check() { + let input = b"Axiom is the best! Axiom is the best! Axiom is the best! Axiom is the best!"; + let output = sha2_solve::(input); + let expected: [u8; 48] = [ + 134, 227, 167, 229, 35, 110, 115, 174, 10, 27, 197, 116, 56, 144, 150, 36, 152, 190, 212, + 120, 26, 243, 125, 4, 2, 60, 164, 195, 218, 219, 255, 143, 240, 75, 158, 126, 102, 105, 8, + 202, 142, 240, 230, 161, 162, 152, 111, 71, + ]; + assert_eq!(output, expected); +} + +/////////////////////////////////////////////////////////////////////////////////////// +/// DENSE TESTS +/// +/// Ensure that the chip works as expected with dense records. +/// We first execute some instructions with a [DenseRecordArena] and transfer the records +/// to a [MatrixRecordArena]. After transferring we generate the trace and make sure that +/// all the constraints pass. +/////////////////////////////////////////////////////////////////////////////////////// +fn create_test_chip_dense( + tester: &mut VmChipTestBuilder, +) -> Sha2VmChipDense { + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); + + let mut chip = Sha2VmChipDense::::new( + Sha2VmAir::::new( + tester.system_port(), + bitwise_chip.bus(), + tester.address_bits(), + SELF_BUS_IDX, + ), + Sha2VmStep::::new( + bitwise_chip.clone(), + Rv32Sha2Opcode::CLASS_OFFSET, + tester.address_bits(), + ), + tester.memory_helper(), + ); + + chip.set_trace_buffer_height(MAX_INS_CAPACITY); + chip +} + +fn dense_record_arena_test() { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default(); + let (mut sparse_chip, bitwise_chip) = create_test_chips::(&mut tester); + + { + let mut dense_chip = create_test_chip_dense::(&mut tester); + + let num_ops: usize = 10; + for _ in 0..num_ops { + set_and_execute::<_, C>( + &mut tester, + &mut dense_chip, + &mut rng, + C::OPCODE, + None, + None, + ); + } + + let mut record_interpreter = dense_chip + .arena + .get_record_seeker::<_, Sha2VmRecordLayout>(); + record_interpreter.transfer_to_matrix_arena(&mut sparse_chip.arena); + } + + let tester = tester + .build() + .load(sparse_chip) + .load(bitwise_chip) + .finalize(); + tester.simple_test().expect("Verification failed"); +} + +#[test] +fn sha256_dense_record_arena_test() { + dense_record_arena_test::(); +} + +#[test] +fn sha512_dense_record_arena_test() { + dense_record_arena_test::(); +} + +#[test] +fn sha384_dense_record_arena_test() { + dense_record_arena_test::(); +} diff --git a/extensions/sha2/circuit/src/sha2_chip/trace.rs b/extensions/sha2/circuit/src/sha2_chip/trace.rs new file mode 100644 index 0000000000..b8a59413e5 --- /dev/null +++ b/extensions/sha2/circuit/src/sha2_chip/trace.rs @@ -0,0 +1,719 @@ +use std::{ + borrow::{Borrow, BorrowMut}, + cmp::min, + iter, + marker::PhantomData, +}; + +use openvm_circuit::{ + arch::{ + get_record_from_slice, CustomBorrow, MultiRowLayout, MultiRowMetadata, RecordArena, Result, + SizedRecord, TraceFiller, TraceStep, VmStateMut, + }, + system::memory::{ + offline_checker::{MemoryReadAuxRecord, MemoryWriteBytesAuxRecord}, + online::TracingMemory, + MemoryAuxColsFactory, + }, +}; +use openvm_circuit_primitives::AlignedBytesBorrow; +use openvm_instructions::{ + instruction::Instruction, + program::DEFAULT_PC_STEP, + riscv::{RV32_CELL_BITS, RV32_MEMORY_AS, RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS}, + LocalOpcode, +}; +use openvm_rv32im_circuit::adapters::{read_rv32_register, tracing_read, tracing_write}; +use openvm_sha2_air::{ + be_limbs_into_word, get_flag_pt_array, Sha256Config, Sha2StepHelper, Sha384Config, Sha512Config, +}; +use openvm_stark_backend::{ + p3_field::PrimeField32, + p3_matrix::{dense::RowMajorMatrix, Matrix}, + p3_maybe_rayon::prelude::*, +}; + +use super::{ + Sha2ChipConfig, Sha2Variant, Sha2VmDigestColsRefMut, Sha2VmRoundColsRefMut, Sha2VmStep, +}; +use crate::{ + get_sha2_num_blocks, sha2_chip::PaddingFlags, sha2_solve, Sha2VmControlColsRefMut, + MAX_SHA_NUM_WRITES, SHA_MAX_MESSAGE_LEN, SHA_REGISTER_READS, SHA_WRITE_SIZE, +}; + +#[derive(Clone, Copy)] +pub struct Sha2VmMetadata { + pub num_blocks: u32, + _phantom: PhantomData, +} + +impl MultiRowMetadata for Sha2VmMetadata { + #[inline(always)] + fn get_num_rows(&self) -> usize { + self.num_blocks as usize * C::ROWS_PER_BLOCK + } +} + +pub(crate) type Sha2VmRecordLayout = MultiRowLayout>; + +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug, Clone)] +pub struct Sha2VmRecordHeader { + pub from_pc: u32, + pub timestamp: u32, + pub rd_ptr: u32, + pub rs1_ptr: u32, + pub rs2_ptr: u32, + pub dst_ptr: u32, + pub src_ptr: u32, + pub len: u32, + + pub register_reads_aux: [MemoryReadAuxRecord; SHA_REGISTER_READS], + // Note: MAX_SHA_NUM_WRITES = 2 because SHA-256 uses 1 write, while SHA-512 and SHA-384 use 2 + // writes. We just use the same array for all variants to simplify record storage. + pub writes_aux: [MemoryWriteBytesAuxRecord; MAX_SHA_NUM_WRITES], +} + +pub struct Sha2VmRecordMut<'a> { + pub inner: &'a mut Sha2VmRecordHeader, + // Having a continuous slice of the input is useful for fast hashing in `execute` + pub input: &'a mut [u8], + pub read_aux: &'a mut [MemoryReadAuxRecord], +} + +/// Custom borrowing that splits the buffer into a fixed `Sha2VmRecord` header +/// followed by a slice of `u8`'s of length `C::BLOCK_CELLS * num_blocks` where `num_blocks` is +/// provided at runtime, followed by a slice of `MemoryReadAuxRecord`'s of length +/// `C::NUM_READ_ROWS * num_blocks`. Uses `align_to_mut()` to make sure the slice is properly +/// aligned to `MemoryReadAuxRecord`. Has debug assertions that check the size and alignment of the +/// slices. +impl<'a, C: Sha2ChipConfig> CustomBorrow<'a, Sha2VmRecordMut<'a>, Sha2VmRecordLayout> + for [u8] +{ + fn custom_borrow(&'a mut self, layout: Sha2VmRecordLayout) -> Sha2VmRecordMut<'a> { + let (header_buf, rest) = + unsafe { self.split_at_mut_unchecked(size_of::()) }; + let header: &mut Sha2VmRecordHeader = header_buf.borrow_mut(); + + // Using `split_at_mut_unchecked` for perf reasons + // input is a slice of `u8`'s of length `C::BLOCK_CELLS * num_blocks`, so the alignment + // is always satisfied + let (input, rest) = unsafe { + rest.split_at_mut_unchecked((layout.metadata.num_blocks as usize) * C::BLOCK_CELLS) + }; + + // Using `align_to_mut` to make sure the returned slice is properly aligned to + // `MemoryReadAuxRecord` Additionally, Rust's subslice operation (a few lines below) + // will verify that the buffer has enough capacity + let (_, read_aux_buf, _) = unsafe { rest.align_to_mut::() }; + Sha2VmRecordMut { + inner: header, + input, + read_aux: &mut read_aux_buf[..(layout.metadata.num_blocks as usize) * C::NUM_READ_ROWS], + } + } + + unsafe fn extract_layout(&self) -> Sha2VmRecordLayout { + let header: &Sha2VmRecordHeader = self.borrow(); + + Sha2VmRecordLayout { + metadata: Sha2VmMetadata { + num_blocks: get_sha2_num_blocks::(header.len), + _phantom: PhantomData::, + }, + } + } +} + +impl SizedRecord> for Sha2VmRecordMut<'_> { + fn size(layout: &Sha2VmRecordLayout) -> usize { + let mut total_len = size_of::(); + total_len += layout.metadata.num_blocks as usize * C::BLOCK_CELLS; + // Align the pointer to the alignment of `MemoryReadAuxRecord` + total_len = total_len.next_multiple_of(align_of::()); + total_len += layout.metadata.num_blocks as usize + * C::NUM_READ_ROWS + * size_of::(); + total_len + } + + fn alignment(_layout: &Sha2VmRecordLayout) -> usize { + align_of::() + } +} + +impl TraceStep for Sha2VmStep { + type RecordLayout = Sha2VmRecordLayout; + type RecordMut<'a> = Sha2VmRecordMut<'a>; + + fn get_opcode_name(&self, _: usize) -> String { + format!("{:?}", C::OPCODE) + } + + fn execute<'buf, RA>( + &mut self, + state: VmStateMut, CTX>, + instruction: &Instruction, + arena: &'buf mut RA, + ) -> Result<()> + where + RA: RecordArena<'buf, Self::RecordLayout, Self::RecordMut<'buf>>, + { + let Instruction { + opcode, + a, + b, + c, + d, + e, + .. + } = instruction; + debug_assert_eq!(*opcode, C::OPCODE.global_opcode()); + debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); + debug_assert_eq!(e.as_canonical_u32(), RV32_MEMORY_AS); + + // Reading the length first to allocate a record of correct size + let len = read_rv32_register(state.memory.data(), c.as_canonical_u32()); + + let num_blocks = get_sha2_num_blocks::(len); + let record = arena.alloc(MultiRowLayout { + metadata: Sha2VmMetadata { + num_blocks, + _phantom: PhantomData::, + }, + }); + + record.inner.from_pc = *state.pc; + record.inner.timestamp = state.memory.timestamp(); + record.inner.rd_ptr = a.as_canonical_u32(); + record.inner.rs1_ptr = b.as_canonical_u32(); + record.inner.rs2_ptr = c.as_canonical_u32(); + + record.inner.dst_ptr = u32::from_le_bytes(tracing_read( + state.memory, + RV32_REGISTER_AS, + record.inner.rd_ptr, + &mut record.inner.register_reads_aux[0].prev_timestamp, + )); + record.inner.src_ptr = u32::from_le_bytes(tracing_read( + state.memory, + RV32_REGISTER_AS, + record.inner.rs1_ptr, + &mut record.inner.register_reads_aux[1].prev_timestamp, + )); + record.inner.len = u32::from_le_bytes(tracing_read( + state.memory, + RV32_REGISTER_AS, + record.inner.rs2_ptr, + &mut record.inner.register_reads_aux[2].prev_timestamp, + )); + + // we will read [num_blocks] * [SHA256_BLOCK_CELLS] cells but only [len] cells will be used + debug_assert!( + record.inner.src_ptr as usize + num_blocks as usize * C::BLOCK_CELLS + <= (1 << self.pointer_max_bits) + ); + debug_assert!( + record.inner.dst_ptr as usize + C::WRITE_SIZE <= (1 << self.pointer_max_bits) + ); + // We don't support messages longer than 2^29 bytes + debug_assert!(record.inner.len < SHA_MAX_MESSAGE_LEN as u32); + + for block_idx in 0..num_blocks as usize { + // Reads happen on the first 4 rows of each block + for row in 0..C::NUM_READ_ROWS { + let read_idx = block_idx * C::NUM_READ_ROWS + row; + match C::VARIANT { + Sha2Variant::Sha256 => { + let row_input: [u8; Sha256Config::READ_SIZE] = tracing_read( + state.memory, + RV32_MEMORY_AS, + record.inner.src_ptr + (read_idx * C::READ_SIZE) as u32, + &mut record.read_aux[read_idx].prev_timestamp, + ); + record.input[read_idx * C::READ_SIZE..(read_idx + 1) * C::READ_SIZE] + .copy_from_slice(&row_input); + } + Sha2Variant::Sha512 => { + let row_input: [u8; Sha512Config::READ_SIZE] = tracing_read( + state.memory, + RV32_MEMORY_AS, + record.inner.src_ptr + (read_idx * C::READ_SIZE) as u32, + &mut record.read_aux[read_idx].prev_timestamp, + ); + record.input[read_idx * C::READ_SIZE..(read_idx + 1) * C::READ_SIZE] + .copy_from_slice(&row_input); + } + Sha2Variant::Sha384 => { + let row_input: [u8; Sha384Config::READ_SIZE] = tracing_read( + state.memory, + RV32_MEMORY_AS, + record.inner.src_ptr + (read_idx * C::READ_SIZE) as u32, + &mut record.read_aux[read_idx].prev_timestamp, + ); + record.input[read_idx * C::READ_SIZE..(read_idx + 1) * C::READ_SIZE] + .copy_from_slice(&row_input); + } + } + } + } + + let mut output = sha2_solve::(&record.input[..len as usize]); + match C::VARIANT { + Sha2Variant::Sha256 => { + tracing_write::( + state.memory, + RV32_MEMORY_AS, + record.inner.dst_ptr, + output.try_into().unwrap(), + &mut record.inner.writes_aux[0].prev_timestamp, + &mut record.inner.writes_aux[0].prev_data, + ); + } + Sha2Variant::Sha512 => { + debug_assert!(output.len() % Sha512Config::WRITE_SIZE == 0); + output + .chunks_exact(Sha512Config::WRITE_SIZE) + .enumerate() + .for_each(|(i, chunk)| { + tracing_write::( + state.memory, + RV32_MEMORY_AS, + record.inner.dst_ptr + (i * Sha512Config::WRITE_SIZE) as u32, + chunk.try_into().unwrap(), + &mut record.inner.writes_aux[i].prev_timestamp, + &mut record.inner.writes_aux[i].prev_data, + ); + }); + } + Sha2Variant::Sha384 => { + // output is a truncated 48-byte digest, so we will append 16 bytes of zeros + output.extend(iter::repeat(0).take(16)); + debug_assert!(output.len() % Sha384Config::WRITE_SIZE == 0); + output + .chunks_exact(Sha384Config::WRITE_SIZE) + .enumerate() + .for_each(|(i, chunk)| { + tracing_write::( + state.memory, + RV32_MEMORY_AS, + record.inner.dst_ptr + (i * Sha384Config::WRITE_SIZE) as u32, + chunk.try_into().unwrap(), + &mut record.inner.writes_aux[i].prev_timestamp, + &mut record.inner.writes_aux[i].prev_data, + ); + }); + } + } + + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + + Ok(()) + } +} + +impl TraceFiller for Sha2VmStep { + fn fill_trace( + &self, + mem_helper: &MemoryAuxColsFactory, + trace_matrix: &mut RowMajorMatrix, + rows_used: usize, + ) { + if rows_used == 0 { + return; + } + + let mut chunks = Vec::with_capacity(trace_matrix.height() / C::ROWS_PER_BLOCK); + let mut sizes = Vec::with_capacity(trace_matrix.height() / C::ROWS_PER_BLOCK); + let mut trace = &mut trace_matrix.values[..]; + let mut num_blocks_so_far = 0; + + // First pass over the trace to get the number of blocks for each instruction + // and divide the matrix into chunks of needed sizes + loop { + if num_blocks_so_far * C::ROWS_PER_BLOCK >= rows_used { + // Push all the padding rows as a single chunk and break + chunks.push(trace); + sizes.push((0, num_blocks_so_far)); + break; + } else { + let record: &Sha2VmRecordHeader = unsafe { get_record_from_slice(&mut trace, ()) }; + let num_blocks = get_sha2_num_blocks::(record.len) as usize; + let (chunk, rest) = + trace.split_at_mut(C::VM_WIDTH * C::ROWS_PER_BLOCK * num_blocks); + chunks.push(chunk); + sizes.push((num_blocks, num_blocks_so_far)); + num_blocks_so_far += num_blocks; + trace = rest; + } + } + + // During the first pass we will fill out most of the matrix + // But there are some cells that can't be generated by the first pass so we will do a second + // pass over the matrix later + chunks.par_iter_mut().zip(sizes.par_iter()).for_each( + |(slice, (num_blocks, global_block_offset))| { + if global_block_offset * C::ROWS_PER_BLOCK >= rows_used { + // Fill in the invalid rows + slice.par_chunks_mut(C::VM_WIDTH).for_each(|row| { + // Need to get rid of the accidental garbage data that might overflow the + // F's prime field. Unfortunately, there is no good way around this + unsafe { + std::ptr::write_bytes( + row.as_mut_ptr() as *mut u8, + 0, + C::VM_WIDTH * size_of::(), + ); + } + let cols = Sha2VmRoundColsRefMut::::from::( + row[..C::VM_ROUND_WIDTH].borrow_mut(), + ); + self.inner.generate_default_row(cols.inner); + }); + return; + } + + let record: Sha2VmRecordMut = unsafe { + get_record_from_slice( + slice, + Sha2VmRecordLayout { + metadata: Sha2VmMetadata { + num_blocks: *num_blocks as u32, + _phantom: PhantomData::, + }, + }, + ) + }; + + let mut input: Vec = Vec::with_capacity(C::BLOCK_CELLS * num_blocks); + input.extend_from_slice(record.input); + let mut padded_input = input.clone(); + let len = record.inner.len as usize; + let padded_input_len = padded_input.len(); + padded_input[len] = 1 << (RV32_CELL_BITS - 1); + padded_input[len + 1..padded_input_len - 4].fill(0); + padded_input[padded_input_len - 4..] + .copy_from_slice(&((len as u32) << 3).to_be_bytes()); + + let mut prev_hashes = Vec::with_capacity(*num_blocks); + prev_hashes.push(C::get_h().to_vec()); + for i in 0..*num_blocks - 1 { + prev_hashes.push(Sha2StepHelper::::get_block_hash( + &prev_hashes[i], + padded_input[i * C::BLOCK_CELLS..(i + 1) * C::BLOCK_CELLS].into(), + )); + } + // Copy the read aux records and input to another place to safely fill in the trace + // matrix without overwriting the record + let mut read_aux_records = Vec::with_capacity(C::NUM_READ_ROWS * num_blocks); + read_aux_records.extend_from_slice(record.read_aux); + let vm_record = record.inner.clone(); + + slice + .par_chunks_exact_mut(C::VM_WIDTH * C::ROWS_PER_BLOCK) + .enumerate() + .for_each(|(block_idx, block_slice)| { + // Need to get rid of the accidental garbage data that might overflow the + // F's prime field. Unfortunately, there is no good way around this + unsafe { + std::ptr::write_bytes( + block_slice.as_mut_ptr() as *mut u8, + 0, + C::ROWS_PER_BLOCK * C::VM_WIDTH * size_of::(), + ); + } + self.fill_block_trace::( + block_slice, + &vm_record, + &read_aux_records + [block_idx * C::NUM_READ_ROWS..(block_idx + 1) * C::NUM_READ_ROWS], + &input[block_idx * C::BLOCK_CELLS..(block_idx + 1) * C::BLOCK_CELLS], + &padded_input + [block_idx * C::BLOCK_CELLS..(block_idx + 1) * C::BLOCK_CELLS], + block_idx == *num_blocks - 1, + *global_block_offset + block_idx, + block_idx, + prev_hashes[block_idx].as_slice(), + mem_helper, + ); + }); + }, + ); + + // Do a second pass over the trace to fill in the missing values + // Note, we need to skip the very first row + trace_matrix.values[C::VM_WIDTH..] + .par_chunks_mut(C::VM_WIDTH * C::ROWS_PER_BLOCK) + .take(rows_used / C::ROWS_PER_BLOCK) + .for_each(|chunk| { + self.inner + .generate_missing_cells(chunk, C::VM_WIDTH, C::VM_CONTROL_WIDTH); + }); + } +} + +impl Sha2VmStep { + #[allow(clippy::too_many_arguments)] + fn fill_block_trace( + &self, + block_slice: &mut [F], + record: &Sha2VmRecordHeader, + read_aux_records: &[MemoryReadAuxRecord], + input: &[u8], + padded_input: &[u8], + is_last_block: bool, + global_block_idx: usize, + local_block_idx: usize, + prev_hash: &[C::Word], + mem_helper: &MemoryAuxColsFactory, + ) { + debug_assert_eq!(input.len(), C::BLOCK_CELLS); + debug_assert_eq!(padded_input.len(), C::BLOCK_CELLS); + debug_assert_eq!(read_aux_records.len(), C::NUM_READ_ROWS); + debug_assert_eq!(prev_hash.len(), C::HASH_WORDS); + + let padded_input = (0..C::BLOCK_WORDS) + .map(|i| { + be_limbs_into_word::( + &padded_input[i * C::WORD_U8S..(i + 1) * C::WORD_U8S] + .iter() + .map(|x| *x as u32) + .collect::>(), + ) + }) + .collect::>(); + + let block_start_timestamp = + record.timestamp + (SHA_REGISTER_READS + C::NUM_READ_ROWS * local_block_idx) as u32; + + let read_cells = (C::BLOCK_CELLS * local_block_idx) as u32; + let block_start_read_ptr = record.src_ptr + read_cells; + + let message_left = if record.len <= read_cells { + 0 + } else { + (record.len - read_cells) as usize + }; + + // -1 means that padding occurred before the start of the block + // C::ROWS_PER_BLOCK + 1 means that no padding occurred on this block + let first_padding_row = if record.len < read_cells { + -1 + } else if message_left < C::BLOCK_CELLS { + (message_left / C::READ_SIZE) as i32 + } else { + (C::ROWS_PER_BLOCK + 1) as i32 + }; + + // Fill in the VM columns first because the inner `carry_or_buffer` needs to be filled in + block_slice + .par_chunks_exact_mut(C::VM_WIDTH) + .enumerate() + .for_each(|(row_idx, row_slice)| { + // Handle round rows and digest row separately + if row_idx == C::ROWS_PER_BLOCK - 1 { + // This is a digest row + let mut digest_cols = Sha2VmDigestColsRefMut::::from::( + row_slice[..C::VM_DIGEST_WIDTH].borrow_mut(), + ); + digest_cols.from_state.timestamp = F::from_canonical_u32(record.timestamp); + digest_cols.from_state.pc = F::from_canonical_u32(record.from_pc); + *digest_cols.rd_ptr = F::from_canonical_u32(record.rd_ptr); + *digest_cols.rs1_ptr = F::from_canonical_u32(record.rs1_ptr); + *digest_cols.rs2_ptr = F::from_canonical_u32(record.rs2_ptr); + digest_cols + .dst_ptr + .iter_mut() + .zip(record.dst_ptr.to_le_bytes().map(F::from_canonical_u8)) + .for_each(|(x, y)| *x = y); + digest_cols + .src_ptr + .iter_mut() + .zip(record.src_ptr.to_le_bytes().map(F::from_canonical_u8)) + .for_each(|(x, y)| *x = y); + digest_cols + .len_data + .iter_mut() + .zip(record.len.to_le_bytes().map(F::from_canonical_u8)) + .for_each(|(x, y)| *x = y); + if is_last_block { + digest_cols + .register_reads_aux + .iter_mut() + .zip(record.register_reads_aux.iter()) + .enumerate() + .for_each(|(idx, (cols_read, record_read))| { + mem_helper.fill( + record_read.prev_timestamp, + record.timestamp + idx as u32, + cols_read.as_mut(), + ); + }); + for i in 0..C::NUM_WRITES { + digest_cols + .writes_aux_prev_data + .row_mut(i) + .iter_mut() + .zip(record.writes_aux[i].prev_data.map(F::from_canonical_u8)) + .for_each(|(x, y)| *x = y); + + // In the last block we do `C::NUM_READ_ROWS` reads and then write the + // result thus the timestamp of the write is + // `block_start_timestamp + C::NUM_READ_ROWS` + mem_helper.fill( + record.writes_aux[i].prev_timestamp, + block_start_timestamp + C::NUM_READ_ROWS as u32 + i as u32, + &mut digest_cols.writes_aux_base[i], + ); + } + // Need to range check the destination and source pointers + let msl_rshift: u32 = + ((RV32_REGISTER_NUM_LIMBS - 1) * RV32_CELL_BITS) as u32; + let msl_lshift: u32 = (RV32_REGISTER_NUM_LIMBS * RV32_CELL_BITS + - self.pointer_max_bits) + as u32; + self.bitwise_lookup_chip.request_range( + (record.dst_ptr >> msl_rshift) << msl_lshift, + (record.src_ptr >> msl_rshift) << msl_lshift, + ); + } else { + // Filling in zeros to make sure the accidental garbage data doesn't + // overflow the prime + digest_cols.register_reads_aux.iter_mut().for_each(|aux| { + mem_helper.fill_zero(aux.as_mut()); + }); + for i in 0..C::NUM_WRITES { + digest_cols.writes_aux_prev_data.row_mut(i).fill(F::ZERO); + mem_helper.fill_zero(&mut digest_cols.writes_aux_base[i]); + } + } + *digest_cols.inner.flags.is_last_block = F::from_bool(is_last_block); + *digest_cols.inner.flags.is_digest_row = F::from_bool(true); + } else { + // This is a round row + let mut round_cols = Sha2VmRoundColsRefMut::::from::( + row_slice[..C::VM_ROUND_WIDTH].borrow_mut(), + ); + // Take care of the first 4 round rows (aka read rows) + if row_idx < C::NUM_READ_ROWS { + round_cols + .inner + .message_schedule + .carry_or_buffer + .iter_mut() + .zip(input[row_idx * C::READ_SIZE..(row_idx + 1) * C::READ_SIZE].iter()) + .for_each(|(cell, data)| { + *cell = F::from_canonical_u8(*data); + }); + mem_helper.fill( + read_aux_records[row_idx].prev_timestamp, + block_start_timestamp + row_idx as u32, + round_cols.read_aux.as_mut(), + ); + } else { + mem_helper.fill_zero(round_cols.read_aux.as_mut()); + } + } + // Fill in the control cols, doesn't matter if it is a round or digest row + let mut control_cols = Sha2VmControlColsRefMut::::from::( + row_slice[..C::VM_CONTROL_WIDTH].borrow_mut(), + ); + *control_cols.len = F::from_canonical_u32(record.len); + // Only the first `SHA256_NUM_READ_ROWS` rows increment the timestamp and read ptr + *control_cols.cur_timestamp = F::from_canonical_u32( + block_start_timestamp + min(row_idx, C::NUM_READ_ROWS) as u32, + ); + *control_cols.read_ptr = F::from_canonical_u32( + block_start_read_ptr + (C::READ_SIZE * min(row_idx, C::NUM_READ_ROWS)) as u32, + ); + + // Fill in the padding flags + if row_idx < C::NUM_READ_ROWS { + #[allow(clippy::comparison_chain)] + if (row_idx as i32) < first_padding_row { + control_cols + .pad_flags + .iter_mut() + .zip( + get_flag_pt_array( + &self.padding_encoder, + PaddingFlags::NotPadding as usize, + ) + .into_iter() + .map(F::from_canonical_u32), + ) + .for_each(|(x, y)| *x = y); + } else if row_idx as i32 == first_padding_row { + let len = message_left - row_idx * C::READ_SIZE; + control_cols + .pad_flags + .iter_mut() + .zip( + get_flag_pt_array( + &self.padding_encoder, + if row_idx == 3 && is_last_block { + PaddingFlags::FirstPadding0_LastRow + } else { + PaddingFlags::FirstPadding0 + } as usize + + len, + ) + .into_iter() + .map(F::from_canonical_u32), + ) + .for_each(|(x, y)| *x = y); + } else { + control_cols + .pad_flags + .iter_mut() + .zip( + get_flag_pt_array( + &self.padding_encoder, + if row_idx == 3 && is_last_block { + PaddingFlags::EntirePaddingLastRow + } else { + PaddingFlags::EntirePadding + } as usize, + ) + .into_iter() + .map(F::from_canonical_u32), + ) + .for_each(|(x, y)| *x = y); + } + } else { + control_cols + .pad_flags + .iter_mut() + .zip( + get_flag_pt_array( + &self.padding_encoder, + PaddingFlags::NotConsidered as usize, + ) + .into_iter() + .map(F::from_canonical_u32), + ) + .for_each(|(x, y)| *x = y); + } + if is_last_block && row_idx == C::ROWS_PER_BLOCK - 1 { + // If last digest row, then we set padding_occurred = 0 + *control_cols.padding_occurred = F::ZERO; + } else { + *control_cols.padding_occurred = + F::from_bool((row_idx as i32) >= first_padding_row); + } + }); + + // Fill in the inner trace when the `carry_or_buffer` is filled in + self.inner.generate_block_trace::( + block_slice, + C::VM_WIDTH, + C::VM_CONTROL_WIDTH, + &padded_input, + self.bitwise_lookup_chip.clone(), + prev_hash, + is_last_block, + global_block_idx as u32 + 1, // global block index is 1-indexed + local_block_idx as u32, + ); + } +} diff --git a/extensions/sha2/circuit/src/sha2_chip/utils.rs b/extensions/sha2/circuit/src/sha2_chip/utils.rs new file mode 100644 index 0000000000..d3c78345ad --- /dev/null +++ b/extensions/sha2/circuit/src/sha2_chip/utils.rs @@ -0,0 +1,8 @@ +use crate::Sha2ChipConfig; + +/// Returns the number of blocks required to hash a message of length `len` +pub fn get_sha2_num_blocks(len: u32) -> u32 { + // need to pad with one 1 bit, 64 bits for the message length and then pad until the length + // is divisible by [C::BLOCK_BITS] + ((len << 3) as usize + 1 + C::MESSAGE_LENGTH_BITS).div_ceil(C::BLOCK_BITS) as u32 +} diff --git a/extensions/sha256/guest/Cargo.toml b/extensions/sha2/guest/Cargo.toml similarity index 69% rename from extensions/sha256/guest/Cargo.toml rename to extensions/sha2/guest/Cargo.toml index e9d28292b8..1c6503002e 100644 --- a/extensions/sha256/guest/Cargo.toml +++ b/extensions/sha2/guest/Cargo.toml @@ -1,9 +1,9 @@ [package] -name = "openvm-sha256-guest" +name = "openvm-sha2-guest" version.workspace = true authors.workspace = true edition.workspace = true -description = "Guest extension for Sha256" +description = "Guest extension for SHA-2" [dependencies] openvm-platform = { workspace = true } diff --git a/extensions/sha2/guest/src/lib.rs b/extensions/sha2/guest/src/lib.rs new file mode 100644 index 0000000000..567f60d4da --- /dev/null +++ b/extensions/sha2/guest/src/lib.rs @@ -0,0 +1,193 @@ +#![no_std] + +#[cfg(target_os = "zkvm")] +use openvm_platform::alloc::AlignedBuf; + +/// This is custom-0 defined in RISC-V spec document +pub const OPCODE: u8 = 0x0b; +pub const SHA2_FUNCT3: u8 = 0b100; + +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[repr(u8)] +pub enum Sha2BaseFunct7 { + Sha256 = 0x1, + Sha512 = 0x2, + Sha384 = 0x3, +} + +/// zkvm native implementation of sha256 +/// # Safety +/// +/// The VM accepts the preimage by pointer and length, and writes the +/// 32-byte hash. +/// - `bytes` must point to an input buffer at least `len` long. +/// - `output` must point to a buffer that is at least 32-bytes long. +/// +/// [`sha2-256`]: https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.180-4.pdf +#[cfg(target_os = "zkvm")] +#[inline(always)] +#[no_mangle] +pub extern "C" fn zkvm_sha256_impl(bytes: *const u8, len: usize, output: *mut u8) { + // SAFETY: assuming safety assumptions of the inputs, we handle all cases where `bytes` or + // `output` are not aligned to 4 bytes. + // The minimum alignment required for the input and output buffers + const MIN_ALIGN: usize = 4; + // The preferred alignment for the input buffer, since the input is read in chunks of 16 bytes + const INPUT_ALIGN: usize = 16; + // The preferred alignment for the output buffer, since the output is written in chunks of 32 + // bytes + const OUTPUT_ALIGN: usize = 32; + unsafe { + if bytes as usize % MIN_ALIGN != 0 { + let aligned_buff = AlignedBuf::new(bytes, len, INPUT_ALIGN); + if output as usize % MIN_ALIGN != 0 { + let aligned_out = AlignedBuf::uninit(32, OUTPUT_ALIGN); + __native_sha256(aligned_buff.ptr, len, aligned_out.ptr); + core::ptr::copy_nonoverlapping(aligned_out.ptr as *const u8, output, 32); + } else { + __native_sha256(aligned_buff.ptr, len, output); + } + } else { + if output as usize % MIN_ALIGN != 0 { + let aligned_out = AlignedBuf::uninit(32, OUTPUT_ALIGN); + __native_sha256(bytes, len, aligned_out.ptr); + core::ptr::copy_nonoverlapping(aligned_out.ptr as *const u8, output, 32); + } else { + __native_sha256(bytes, len, output); + } + }; + } +} + +/// zkvm native implementation of sha512 +/// # Safety +/// +/// The VM accepts the preimage by pointer and length, and writes the +/// 64-byte hash. +/// - `bytes` must point to an input buffer at least `len` long. +/// - `output` must point to a buffer that is at least 64-bytes long. +/// +/// [`sha2-512`]: https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.180-4.pdf +#[cfg(target_os = "zkvm")] +#[inline(always)] +#[no_mangle] +pub extern "C" fn zkvm_sha512_impl(bytes: *const u8, len: usize, output: *mut u8) { + // SAFETY: assuming safety assumptions of the inputs, we handle all cases where `bytes` or + // `output` are not aligned to 4 bytes. + // The minimum alignment required for the input and output buffers + const MIN_ALIGN: usize = 4; + // The preferred alignment for the input buffer, since the input is read in chunks of 32 bytes + const INPUT_ALIGN: usize = 32; + // The preferred alignment for the output buffer, since the output is written in chunks of 32 + // bytes + const OUTPUT_ALIGN: usize = 32; + unsafe { + if bytes as usize % MIN_ALIGN != 0 { + let aligned_buff = AlignedBuf::new(bytes, len, INPUT_ALIGN); + if output as usize % MIN_ALIGN != 0 { + let aligned_out = AlignedBuf::uninit(64, OUTPUT_ALIGN); + __native_sha512(aligned_buff.ptr, len, aligned_out.ptr); + core::ptr::copy_nonoverlapping(aligned_out.ptr as *const u8, output, 64); + } else { + __native_sha512(aligned_buff.ptr, len, output); + } + } else { + if output as usize % MIN_ALIGN != 0 { + let aligned_out = AlignedBuf::uninit(64, OUTPUT_ALIGN); + __native_sha512(bytes, len, aligned_out.ptr); + core::ptr::copy_nonoverlapping(aligned_out.ptr as *const u8, output, 64); + } else { + __native_sha512(bytes, len, output); + } + }; + } +} + +/// zkvm native implementation of sha384 +/// # Safety +/// +/// The VM accepts the preimage by pointer and length, and writes the +/// 48-byte hash followed by 16-bytes of zeros. +/// - `bytes` must point to an input buffer at least `len` long. +/// - `output` must point to a buffer that is at least 64-bytes long. +/// +/// [`sha2-512`]: https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.180-4.pdf +#[cfg(target_os = "zkvm")] +#[inline(always)] +#[no_mangle] +pub extern "C" fn zkvm_sha384_impl(bytes: *const u8, len: usize, output: *mut u8) { + // SAFETY: assuming safety assumptions of the inputs, we handle all cases where `bytes` or + // `output` are not aligned to 4 bytes. + // The minimum alignment required for the input and output buffers + const MIN_ALIGN: usize = 4; + // The preferred alignment for the input buffer, since the input is read in chunks of 32 bytes + const INPUT_ALIGN: usize = 32; + // The preferred alignment for the output buffer, since the output is written in chunks of 32 + // bytes + const OUTPUT_ALIGN: usize = 32; + unsafe { + if bytes as usize % MIN_ALIGN != 0 { + let aligned_buff = AlignedBuf::new(bytes, len, INPUT_ALIGN); + if output as usize % MIN_ALIGN != 0 { + let aligned_out = AlignedBuf::uninit(64, OUTPUT_ALIGN); + __native_sha384(aligned_buff.ptr, len, aligned_out.ptr); + core::ptr::copy_nonoverlapping(aligned_out.ptr as *const u8, output, 64); + } else { + __native_sha384(aligned_buff.ptr, len, output); + } + } else { + if output as usize % MIN_ALIGN != 0 { + let aligned_out = AlignedBuf::uninit(64, OUTPUT_ALIGN); + __native_sha384(bytes, len, aligned_out.ptr); + core::ptr::copy_nonoverlapping(aligned_out.ptr as *const u8, output, 64); + } else { + __native_sha384(bytes, len, output); + } + }; + } +} + +/// sha256 intrinsic binding +/// +/// # Safety +/// +/// The VM accepts the preimage by pointer and length, and writes the +/// 32-byte hash. +/// - `bytes` must point to an input buffer at least `len` long. +/// - `output` must point to a buffer that is at least 32-bytes long. +/// - `bytes` and `output` must be 4-byte aligned. +#[cfg(target_os = "zkvm")] +#[inline(always)] +fn __native_sha256(bytes: *const u8, len: usize, output: *mut u8) { + openvm_platform::custom_insn_r!(opcode = OPCODE, funct3 = SHA2_FUNCT3, funct7 = Sha2BaseFunct7::Sha256 as u8, rd = In output, rs1 = In bytes, rs2 = In len); +} + +/// sha512 intrinsic binding +/// +/// # Safety +/// +/// The VM accepts the preimage by pointer and length, and writes the +/// 64-byte hash. +/// - `bytes` must point to an input buffer at least `len` long. +/// - `output` must point to a buffer that is at least 64-bytes long. +/// - `bytes` and `output` must be 4-byte aligned. +#[cfg(target_os = "zkvm")] +#[inline(always)] +fn __native_sha512(bytes: *const u8, len: usize, output: *mut u8) { + openvm_platform::custom_insn_r!(opcode = OPCODE, funct3 = SHA2_FUNCT3, funct7 = Sha2BaseFunct7::Sha512 as u8, rd = In output, rs1 = In bytes, rs2 = In len); +} + +/// sha384 intrinsic binding +/// +/// # Safety +/// +/// The VM accepts the preimage by pointer and length, and writes the +/// 48-byte hash followed by 16-bytes of zeros. +/// - `bytes` must point to an input buffer at least `len` long. +/// - `output` must point to a buffer that is at least 64-bytes long. +/// - `bytes` and `output` must be 4-byte aligned. +#[cfg(target_os = "zkvm")] +#[inline(always)] +fn __native_sha384(bytes: *const u8, len: usize, output: *mut u8) { + openvm_platform::custom_insn_r!(opcode = OPCODE, funct3 = SHA2_FUNCT3, funct7 = Sha2BaseFunct7::Sha384 as u8, rd = In output, rs1 = In bytes, rs2 = In len); +} diff --git a/extensions/sha256/transpiler/Cargo.toml b/extensions/sha2/transpiler/Cargo.toml similarity index 73% rename from extensions/sha256/transpiler/Cargo.toml rename to extensions/sha2/transpiler/Cargo.toml index 933859f3a8..9eff76a3db 100644 --- a/extensions/sha256/transpiler/Cargo.toml +++ b/extensions/sha2/transpiler/Cargo.toml @@ -1,15 +1,15 @@ [package] -name = "openvm-sha256-transpiler" +name = "openvm-sha2-transpiler" version.workspace = true authors.workspace = true edition.workspace = true -description = "Transpiler extension for sha256" +description = "Transpiler extension for SHA-2" [dependencies] openvm-stark-backend = { workspace = true } openvm-instructions = { workspace = true } openvm-transpiler = { workspace = true } rrs-lib = { workspace = true } -openvm-sha256-guest = { workspace = true } +openvm-sha2-guest = { workspace = true } openvm-instructions-derive = { workspace = true } strum = { workspace = true } diff --git a/extensions/sha2/transpiler/src/lib.rs b/extensions/sha2/transpiler/src/lib.rs new file mode 100644 index 0000000000..89249ee026 --- /dev/null +++ b/extensions/sha2/transpiler/src/lib.rs @@ -0,0 +1,65 @@ +use openvm_instructions::{riscv::RV32_MEMORY_AS, LocalOpcode}; +use openvm_instructions_derive::LocalOpcode; +use openvm_sha2_guest::{Sha2BaseFunct7, OPCODE, SHA2_FUNCT3}; +use openvm_stark_backend::p3_field::PrimeField32; +use openvm_transpiler::{util::from_r_type, TranspilerExtension, TranspilerOutput}; +use rrs_lib::instruction_formats::RType; +use strum::{EnumCount, EnumIter, FromRepr}; + +#[derive( + Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, EnumCount, EnumIter, FromRepr, LocalOpcode, +)] +#[opcode_offset = 0x320] +#[repr(usize)] +pub enum Rv32Sha2Opcode { + SHA256, + SHA512, + SHA384, +} + +#[derive(Default)] +pub struct Sha2TranspilerExtension; + +impl TranspilerExtension for Sha2TranspilerExtension { + fn process_custom(&self, instruction_stream: &[u32]) -> Option> { + if instruction_stream.is_empty() { + return None; + } + let instruction_u32 = instruction_stream[0]; + let opcode = (instruction_u32 & 0x7f) as u8; + let funct3 = ((instruction_u32 >> 12) & 0b111) as u8; + + if (opcode, funct3) != (OPCODE, SHA2_FUNCT3) { + return None; + } + let dec_insn = RType::new(instruction_u32); + + if dec_insn.funct7 == Sha2BaseFunct7::Sha256 as u32 { + let instruction = from_r_type( + Rv32Sha2Opcode::SHA256.global_opcode().as_usize(), + RV32_MEMORY_AS as usize, + &dec_insn, + true, + ); + Some(TranspilerOutput::one_to_one(instruction)) + } else if dec_insn.funct7 == Sha2BaseFunct7::Sha512 as u32 { + let instruction = from_r_type( + Rv32Sha2Opcode::SHA512.global_opcode().as_usize(), + RV32_MEMORY_AS as usize, + &dec_insn, + true, + ); + Some(TranspilerOutput::one_to_one(instruction)) + } else if dec_insn.funct7 == Sha2BaseFunct7::Sha384 as u32 { + let instruction = from_r_type( + Rv32Sha2Opcode::SHA384.global_opcode().as_usize(), + RV32_MEMORY_AS as usize, + &dec_insn, + true, + ); + Some(TranspilerOutput::one_to_one(instruction)) + } else { + None + } + } +} diff --git a/extensions/sha256/circuit/src/extension.rs b/extensions/sha256/circuit/src/extension.rs deleted file mode 100644 index 783bc54f63..0000000000 --- a/extensions/sha256/circuit/src/extension.rs +++ /dev/null @@ -1,105 +0,0 @@ -use derive_more::derive::From; -use openvm_circuit::{ - arch::{ - InitFileGenerator, SystemConfig, VmExtension, VmInventory, VmInventoryBuilder, - VmInventoryError, - }, - system::phantom::PhantomChip, -}; -use openvm_circuit_derive::{AnyEnum, InstructionExecutor, VmConfig}; -use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, -}; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; -use openvm_instructions::*; -use openvm_rv32im_circuit::{ - Rv32I, Rv32IExecutor, Rv32IPeriphery, Rv32Io, Rv32IoExecutor, Rv32IoPeriphery, Rv32M, - Rv32MExecutor, Rv32MPeriphery, -}; -use openvm_sha256_transpiler::Rv32Sha256Opcode; -use openvm_stark_backend::p3_field::PrimeField32; -use serde::{Deserialize, Serialize}; -use strum::IntoEnumIterator; - -use crate::*; - -#[derive(Clone, Debug, VmConfig, derive_new::new, Serialize, Deserialize)] -pub struct Sha256Rv32Config { - #[system] - pub system: SystemConfig, - #[extension] - pub rv32i: Rv32I, - #[extension] - pub rv32m: Rv32M, - #[extension] - pub io: Rv32Io, - #[extension] - pub sha256: Sha256, -} - -impl Default for Sha256Rv32Config { - fn default() -> Self { - Self { - system: SystemConfig::default().with_continuations(), - rv32i: Rv32I, - rv32m: Rv32M::default(), - io: Rv32Io, - sha256: Sha256, - } - } -} - -// Default implementation uses no init file -impl InitFileGenerator for Sha256Rv32Config {} - -#[derive(Clone, Copy, Debug, Default, Serialize, Deserialize)] -pub struct Sha256; - -#[derive(ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum)] -pub enum Sha256Executor { - Sha256(Sha256VmChip), -} - -#[derive(From, ChipUsageGetter, Chip, AnyEnum)] -pub enum Sha256Periphery { - BitwiseOperationLookup(SharedBitwiseOperationLookupChip<8>), - Phantom(PhantomChip), -} - -impl VmExtension for Sha256 { - type Executor = Sha256Executor; - type Periphery = Sha256Periphery; - - fn build( - &self, - builder: &mut VmInventoryBuilder, - ) -> Result, VmInventoryError> { - let mut inventory = VmInventory::new(); - let bitwise_lu_chip = if let Some(&chip) = builder - .find_chip::>() - .first() - { - chip.clone() - } else { - let bitwise_lu_bus = BitwiseOperationLookupBus::new(builder.new_bus_idx()); - let chip = SharedBitwiseOperationLookupChip::new(bitwise_lu_bus); - inventory.add_periphery_chip(chip.clone()); - chip - }; - - let sha256_chip = Sha256VmChip::new( - builder.system_port(), - builder.system_config().memory_config.pointer_max_bits, - bitwise_lu_chip, - builder.new_bus_idx(), - Rv32Sha256Opcode::CLASS_OFFSET, - builder.system_base().offline_memory(), - ); - inventory.add_executor( - sha256_chip, - Rv32Sha256Opcode::iter().map(|x| x.global_opcode()), - )?; - - Ok(inventory) - } -} diff --git a/extensions/sha256/circuit/src/lib.rs b/extensions/sha256/circuit/src/lib.rs deleted file mode 100644 index fe0844f902..0000000000 --- a/extensions/sha256/circuit/src/lib.rs +++ /dev/null @@ -1,5 +0,0 @@ -mod sha256_chip; -pub use sha256_chip::*; - -mod extension; -pub use extension::*; diff --git a/extensions/sha256/circuit/src/sha256_chip/air.rs b/extensions/sha256/circuit/src/sha256_chip/air.rs deleted file mode 100644 index f4f1df34eb..0000000000 --- a/extensions/sha256/circuit/src/sha256_chip/air.rs +++ /dev/null @@ -1,599 +0,0 @@ -use std::{array, borrow::Borrow, cmp::min}; - -use openvm_circuit::{ - arch::ExecutionBridge, - system::memory::{offline_checker::MemoryBridge, MemoryAddress}, -}; -use openvm_circuit_primitives::{ - bitwise_op_lookup::BitwiseOperationLookupBus, encoder::Encoder, utils::not, SubAir, -}; -use openvm_instructions::{ - riscv::{RV32_CELL_BITS, RV32_MEMORY_AS, RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS}, - LocalOpcode, -}; -use openvm_sha256_air::{ - compose, Sha256Air, SHA256_BLOCK_U8S, SHA256_HASH_WORDS, SHA256_ROUNDS_PER_ROW, - SHA256_WORD_U16S, SHA256_WORD_U8S, -}; -use openvm_sha256_transpiler::Rv32Sha256Opcode; -use openvm_stark_backend::{ - interaction::InteractionBuilder, - p3_air::{Air, AirBuilder, BaseAir}, - p3_field::{Field, FieldAlgebra}, - p3_matrix::Matrix, - rap::{BaseAirWithPublicValues, PartitionedBaseAir}, -}; - -use super::{ - Sha256VmDigestCols, Sha256VmRoundCols, SHA256VM_CONTROL_WIDTH, SHA256VM_DIGEST_WIDTH, - SHA256VM_ROUND_WIDTH, SHA256VM_WIDTH, SHA256_READ_SIZE, -}; - -/// Sha256VmAir does all constraints related to message padding and -/// the Sha256Air subair constrains the actual hash -#[derive(Clone, Debug, derive_new::new)] -pub struct Sha256VmAir { - pub execution_bridge: ExecutionBridge, - pub memory_bridge: MemoryBridge, - /// Bus to send byte checks to - pub bitwise_lookup_bus: BitwiseOperationLookupBus, - /// Maximum number of bits allowed for an address pointer - /// Must be at least 24 - pub ptr_max_bits: usize, - pub(super) sha256_subair: Sha256Air, - pub(super) padding_encoder: Encoder, -} - -impl BaseAirWithPublicValues for Sha256VmAir {} -impl PartitionedBaseAir for Sha256VmAir {} -impl BaseAir for Sha256VmAir { - fn width(&self) -> usize { - SHA256VM_WIDTH - } -} - -impl Air for Sha256VmAir { - fn eval(&self, builder: &mut AB) { - self.eval_padding(builder); - self.eval_transitions(builder); - self.eval_reads(builder); - self.eval_last_row(builder); - - self.sha256_subair.eval(builder, SHA256VM_CONTROL_WIDTH); - } -} - -#[allow(dead_code, non_camel_case_types)] -pub(super) enum PaddingFlags { - /// Not considered for padding - W's are not constrained - NotConsidered, - /// Not padding - W's should be equal to the message - NotPadding, - /// FIRST_PADDING_i: it is the first row with padding and there are i cells of non-padding - FirstPadding0, - FirstPadding1, - FirstPadding2, - FirstPadding3, - FirstPadding4, - FirstPadding5, - FirstPadding6, - FirstPadding7, - FirstPadding8, - FirstPadding9, - FirstPadding10, - FirstPadding11, - FirstPadding12, - FirstPadding13, - FirstPadding14, - FirstPadding15, - /// FIRST_PADDING_i_LastRow: it is the first row with padding and there are i cells of - /// non-padding AND it is the last reading row of the message - /// NOTE: if the Last row has padding it has to be at least 9 cells since the last 8 cells are - /// padded with the message length - FirstPadding0_LastRow, - FirstPadding1_LastRow, - FirstPadding2_LastRow, - FirstPadding3_LastRow, - FirstPadding4_LastRow, - FirstPadding5_LastRow, - FirstPadding6_LastRow, - FirstPadding7_LastRow, - /// The entire row is padding AND it is not the first row with padding - /// AND it is the 4th row of the last block of the message - EntirePaddingLastRow, - /// The entire row is padding AND it is not the first row with padding - EntirePadding, -} - -impl PaddingFlags { - /// The number of padding flags (including NotConsidered) - pub const COUNT: usize = EntirePadding as usize + 1; -} - -use PaddingFlags::*; -impl Sha256VmAir { - /// Implement all necessary constraints for the padding - fn eval_padding(&self, builder: &mut AB) { - let main = builder.main(); - let (local, next) = (main.row_slice(0), main.row_slice(1)); - let local_cols: &Sha256VmRoundCols = local[..SHA256VM_ROUND_WIDTH].borrow(); - let next_cols: &Sha256VmRoundCols = next[..SHA256VM_ROUND_WIDTH].borrow(); - - // Constrain the sanity of the padding flags - self.padding_encoder - .eval(builder, &local_cols.control.pad_flags); - - builder.assert_one(self.padding_encoder.contains_flag_range::( - &local_cols.control.pad_flags, - NotConsidered as usize..=EntirePadding as usize, - )); - - Self::eval_padding_transitions(self, builder, local_cols, next_cols); - Self::eval_padding_row(self, builder, local_cols); - } - - fn eval_padding_transitions( - &self, - builder: &mut AB, - local: &Sha256VmRoundCols, - next: &Sha256VmRoundCols, - ) { - let next_is_last_row = next.inner.flags.is_digest_row * next.inner.flags.is_last_block; - - // Constrain that `padding_occured` is 1 on a suffix of rows in each message, excluding the - // last digest row, and 0 everywhere else. Furthermore, the suffix starts in the - // first 4 rows of some block. - - builder.assert_bool(local.control.padding_occurred); - // Last round row in the last block has padding_occurred = 1 - // This is the end of the suffix - builder - .when(next_is_last_row.clone()) - .assert_one(local.control.padding_occurred); - - // Digest row in the last block has padding_occurred = 0 - builder - .when(next_is_last_row.clone()) - .assert_zero(next.control.padding_occurred); - - // If padding_occurred = 1 in the current row, then padding_occurred = 1 in the next row, - // unless next is the last digest row - builder - .when(local.control.padding_occurred - next_is_last_row.clone()) - .assert_one(next.control.padding_occurred); - - // If next row is not first 4 rows of a block, then next.padding_occurred = - // local.padding_occurred. So padding_occurred only changes in the first 4 rows of a - // block. - builder - .when_transition() - .when(not(next.inner.flags.is_first_4_rows) - next_is_last_row) - .assert_eq( - next.control.padding_occurred, - local.control.padding_occurred, - ); - - // Constrain the that the start of the padding is correct - let next_is_first_padding_row = - next.control.padding_occurred - local.control.padding_occurred; - // Row index if its between 0..4, else 0 - let next_row_idx = self.sha256_subair.row_idx_encoder.flag_with_val::( - &next.inner.flags.row_idx, - &(0..4).map(|x| (x, x)).collect::>(), - ); - // How many non-padding cells there are in the next row. - // Will be 0 on non-padding rows. - let next_padding_offset = self.padding_encoder.flag_with_val::( - &next.control.pad_flags, - &(0..16) - .map(|i| (FirstPadding0 as usize + i, i)) - .collect::>(), - ) + self.padding_encoder.flag_with_val::( - &next.control.pad_flags, - &(0..8) - .map(|i| (FirstPadding0_LastRow as usize + i, i)) - .collect::>(), - ); - - // Will be 0 on last digest row since: - // - padding_occurred = 0 is constrained above - // - next_row_idx = 0 since row_idx is not in 0..4 - // - and next_padding_offset = 0 since `pad_flags = NotConsidered` - let expected_len = next.inner.flags.local_block_idx - * next.control.padding_occurred - * AB::Expr::from_canonical_usize(SHA256_BLOCK_U8S) - + next_row_idx * AB::Expr::from_canonical_usize(SHA256_READ_SIZE) - + next_padding_offset; - - // Note: `next_is_first_padding_row` is either -1,0,1 - // If 1, then this constrains the length of message - // If -1, then `next` must be the last digest row and so this constraint will be 0 == 0 - builder.when(next_is_first_padding_row).assert_eq( - expected_len, - next.control.len * next.control.padding_occurred, - ); - - // Constrain the padding flags are of correct type (eg is not padding or first padding) - let is_next_first_padding = self.padding_encoder.contains_flag_range::( - &next.control.pad_flags, - FirstPadding0 as usize..=FirstPadding7_LastRow as usize, - ); - - let is_next_last_padding = self.padding_encoder.contains_flag_range::( - &next.control.pad_flags, - FirstPadding0_LastRow as usize..=EntirePaddingLastRow as usize, - ); - - let is_next_entire_padding = self.padding_encoder.contains_flag_range::( - &next.control.pad_flags, - EntirePaddingLastRow as usize..=EntirePadding as usize, - ); - - let is_next_not_considered = self - .padding_encoder - .contains_flag::(&next.control.pad_flags, &[NotConsidered as usize]); - - let is_next_not_padding = self - .padding_encoder - .contains_flag::(&next.control.pad_flags, &[NotPadding as usize]); - - let is_next_4th_row = self - .sha256_subair - .row_idx_encoder - .contains_flag::(&next.inner.flags.row_idx, &[3]); - - // `pad_flags` is `NotConsidered` on all rows except the first 4 rows of a block - builder.assert_eq( - not(next.inner.flags.is_first_4_rows), - is_next_not_considered, - ); - - // `pad_flags` is `EntirePadding` if the previous row is padding - builder.when(next.inner.flags.is_first_4_rows).assert_eq( - local.control.padding_occurred * next.control.padding_occurred, - is_next_entire_padding, - ); - - // `pad_flags` is `FirstPadding*` if current row is padding and the previous row is not - // padding - builder.when(next.inner.flags.is_first_4_rows).assert_eq( - not(local.control.padding_occurred) * next.control.padding_occurred, - is_next_first_padding, - ); - - // `pad_flags` is `NotPadding` if current row is not padding - builder - .when(next.inner.flags.is_first_4_rows) - .assert_eq(not(next.control.padding_occurred), is_next_not_padding); - - // `pad_flags` is `*LastRow` on the row that contains the last four words of the message - builder - .when(next.inner.flags.is_last_block) - .assert_eq(is_next_4th_row, is_next_last_padding); - } - - fn eval_padding_row( - &self, - builder: &mut AB, - local: &Sha256VmRoundCols, - ) { - let message: [AB::Var; SHA256_READ_SIZE] = array::from_fn(|i| { - local.inner.message_schedule.carry_or_buffer[i / (SHA256_WORD_U8S)] - [i % (SHA256_WORD_U8S)] - }); - - let get_ith_byte = |i: usize| { - let word_idx = i / SHA256_ROUNDS_PER_ROW; - let word = local.inner.message_schedule.w[word_idx].map(|x| x.into()); - // Need to reverse the byte order to match the endianness of the memory - let byte_idx = 4 - i % 4 - 1; - compose::(&word[byte_idx * 8..(byte_idx + 1) * 8], 1) - }; - - let is_not_padding = self - .padding_encoder - .contains_flag::(&local.control.pad_flags, &[NotPadding as usize]); - - // Check the `w`s on case by case basis - for (i, message_byte) in message.iter().enumerate() { - let w = get_ith_byte(i); - let should_be_message = is_not_padding.clone() - + if i < 15 { - self.padding_encoder.contains_flag_range::( - &local.control.pad_flags, - FirstPadding0 as usize + i + 1..=FirstPadding15 as usize, - ) - } else { - AB::Expr::ZERO - } - + if i < 7 { - self.padding_encoder.contains_flag_range::( - &local.control.pad_flags, - FirstPadding0_LastRow as usize + i + 1..=FirstPadding7_LastRow as usize, - ) - } else { - AB::Expr::ZERO - }; - builder - .when(should_be_message) - .assert_eq(w.clone(), *message_byte); - - let should_be_zero = self - .padding_encoder - .contains_flag::(&local.control.pad_flags, &[EntirePadding as usize]) - + if i < 12 { - self.padding_encoder.contains_flag::( - &local.control.pad_flags, - &[EntirePaddingLastRow as usize], - ) + if i > 0 { - self.padding_encoder.contains_flag_range::( - &local.control.pad_flags, - FirstPadding0_LastRow as usize - ..=min( - FirstPadding0_LastRow as usize + i - 1, - FirstPadding7_LastRow as usize, - ), - ) - } else { - AB::Expr::ZERO - } - } else { - AB::Expr::ZERO - } - + if i > 0 { - self.padding_encoder.contains_flag_range::( - &local.control.pad_flags, - FirstPadding0 as usize..=FirstPadding0 as usize + i - 1, - ) - } else { - AB::Expr::ZERO - }; - builder.when(should_be_zero).assert_zero(w.clone()); - - // Assumes bit-length of message is a multiple of 8 (message is bytes) - // This is true because the message is given as &[u8] - let should_be_128 = self - .padding_encoder - .contains_flag::(&local.control.pad_flags, &[FirstPadding0 as usize + i]) - + if i < 8 { - self.padding_encoder.contains_flag::( - &local.control.pad_flags, - &[FirstPadding0_LastRow as usize + i], - ) - } else { - AB::Expr::ZERO - }; - - builder - .when(should_be_128) - .assert_eq(AB::Expr::from_canonical_u32(1 << 7), w); - - // should be len is handled outside of the loop - } - let appended_len = compose::( - &[ - get_ith_byte(15), - get_ith_byte(14), - get_ith_byte(13), - get_ith_byte(12), - ], - RV32_CELL_BITS, - ); - - let actual_len = local.control.len; - - let is_last_padding_row = self.padding_encoder.contains_flag_range::( - &local.control.pad_flags, - FirstPadding0_LastRow as usize..=EntirePaddingLastRow as usize, - ); - - builder.when(is_last_padding_row.clone()).assert_eq( - appended_len * AB::F::from_canonical_usize(RV32_CELL_BITS).inverse(), // bit to byte conversion - actual_len, - ); - - // We constrain that the appended length is in bytes - builder.when(is_last_padding_row.clone()).assert_zero( - local.inner.message_schedule.w[3][0] - + local.inner.message_schedule.w[3][1] - + local.inner.message_schedule.w[3][2], - ); - - // We can't support messages longer than 2^30 bytes because the length has to fit in a field - // element. So, constrain that the first 4 bytes of the length are 0. - // Thus, the bit-length is < 2^32 so the message is < 2^29 bytes. - for i in 8..12 { - builder - .when(is_last_padding_row.clone()) - .assert_zero(get_ith_byte(i)); - } - } - /// Implement constraints on `len`, `read_ptr` and `cur_timestamp` - fn eval_transitions(&self, builder: &mut AB) { - let main = builder.main(); - let (local, next) = (main.row_slice(0), main.row_slice(1)); - let local_cols: &Sha256VmRoundCols = local[..SHA256VM_ROUND_WIDTH].borrow(); - let next_cols: &Sha256VmRoundCols = next[..SHA256VM_ROUND_WIDTH].borrow(); - - let is_last_row = - local_cols.inner.flags.is_last_block * local_cols.inner.flags.is_digest_row; - - // Len should be the same for the entire message - builder - .when_transition() - .when(not::(is_last_row.clone())) - .assert_eq(next_cols.control.len, local_cols.control.len); - - // Read ptr should increment by [SHA256_READ_SIZE] for the first 4 rows and stay the same - // otherwise - let read_ptr_delta = local_cols.inner.flags.is_first_4_rows - * AB::Expr::from_canonical_usize(SHA256_READ_SIZE); - builder - .when_transition() - .when(not::(is_last_row.clone())) - .assert_eq( - next_cols.control.read_ptr, - local_cols.control.read_ptr + read_ptr_delta, - ); - - // Timestamp should increment by 1 for the first 4 rows and stay the same otherwise - let timestamp_delta = local_cols.inner.flags.is_first_4_rows * AB::Expr::ONE; - builder - .when_transition() - .when(not::(is_last_row.clone())) - .assert_eq( - next_cols.control.cur_timestamp, - local_cols.control.cur_timestamp + timestamp_delta, - ); - } - - /// Implement the reads for the first 4 rows of a block - fn eval_reads(&self, builder: &mut AB) { - let main = builder.main(); - let local = main.row_slice(0); - let local_cols: &Sha256VmRoundCols = local[..SHA256VM_ROUND_WIDTH].borrow(); - - let message: [AB::Var; SHA256_READ_SIZE] = array::from_fn(|i| { - local_cols.inner.message_schedule.carry_or_buffer[i / (SHA256_WORD_U16S * 2)] - [i % (SHA256_WORD_U16S * 2)] - }); - - self.memory_bridge - .read( - MemoryAddress::new( - AB::Expr::from_canonical_u32(RV32_MEMORY_AS), - local_cols.control.read_ptr, - ), - message, - local_cols.control.cur_timestamp, - &local_cols.read_aux, - ) - .eval(builder, local_cols.inner.flags.is_first_4_rows); - } - /// Implement the constraints for the last row of a message - fn eval_last_row(&self, builder: &mut AB) { - let main = builder.main(); - let local = main.row_slice(0); - let local_cols: &Sha256VmDigestCols = local[..SHA256VM_DIGEST_WIDTH].borrow(); - - let timestamp: AB::Var = local_cols.from_state.timestamp; - let mut timestamp_delta: usize = 0; - let mut timestamp_pp = || { - timestamp_delta += 1; - timestamp + AB::Expr::from_canonical_usize(timestamp_delta - 1) - }; - - let is_last_row = - local_cols.inner.flags.is_last_block * local_cols.inner.flags.is_digest_row; - - self.memory_bridge - .read( - MemoryAddress::new( - AB::Expr::from_canonical_u32(RV32_REGISTER_AS), - local_cols.rd_ptr, - ), - local_cols.dst_ptr, - timestamp_pp(), - &local_cols.register_reads_aux[0], - ) - .eval(builder, is_last_row.clone()); - - self.memory_bridge - .read( - MemoryAddress::new( - AB::Expr::from_canonical_u32(RV32_REGISTER_AS), - local_cols.rs1_ptr, - ), - local_cols.src_ptr, - timestamp_pp(), - &local_cols.register_reads_aux[1], - ) - .eval(builder, is_last_row.clone()); - - self.memory_bridge - .read( - MemoryAddress::new( - AB::Expr::from_canonical_u32(RV32_REGISTER_AS), - local_cols.rs2_ptr, - ), - local_cols.len_data, - timestamp_pp(), - &local_cols.register_reads_aux[2], - ) - .eval(builder, is_last_row.clone()); - - // range check that the memory pointers don't overflow - // Note: no need to range check the length since we read from memory step by step and - // the memory bus will catch any memory accesses beyond ptr_max_bits - let shift = AB::Expr::from_canonical_usize( - 1 << (RV32_REGISTER_NUM_LIMBS * RV32_CELL_BITS - self.ptr_max_bits), - ); - // This only works if self.ptr_max_bits >= 24 which is typically the case - self.bitwise_lookup_bus - .send_range( - // It is fine to shift like this since we already know that dst_ptr and src_ptr - // have [RV32_CELL_BITS] bits - local_cols.dst_ptr[RV32_REGISTER_NUM_LIMBS - 1] * shift.clone(), - local_cols.src_ptr[RV32_REGISTER_NUM_LIMBS - 1] * shift.clone(), - ) - .eval(builder, is_last_row.clone()); - - // the number of reads that happened to read the entire message: we do 4 reads per block - let time_delta = (local_cols.inner.flags.local_block_idx + AB::Expr::ONE) - * AB::Expr::from_canonical_usize(4); - // Every time we read the message we increment the read pointer by SHA256_READ_SIZE - let read_ptr_delta = time_delta.clone() * AB::Expr::from_canonical_usize(SHA256_READ_SIZE); - - let result: [AB::Var; SHA256_WORD_U8S * SHA256_HASH_WORDS] = array::from_fn(|i| { - // The limbs are written in big endian order to the memory so need to be reversed - local_cols.inner.final_hash[i / SHA256_WORD_U8S] - [SHA256_WORD_U8S - i % SHA256_WORD_U8S - 1] - }); - - let dst_ptr_val = - compose::(&local_cols.dst_ptr.map(|x| x.into()), RV32_CELL_BITS); - - // Note: revisit in the future to do 2 block writes of 16 cells instead of 1 block write of - // 32 cells This could be beneficial as the output is often an input for - // another hash - self.memory_bridge - .write( - MemoryAddress::new(AB::Expr::from_canonical_u32(RV32_MEMORY_AS), dst_ptr_val), - result, - timestamp_pp() + time_delta.clone(), - &local_cols.writes_aux, - ) - .eval(builder, is_last_row.clone()); - - self.execution_bridge - .execute_and_increment_pc( - AB::Expr::from_canonical_usize(Rv32Sha256Opcode::SHA256.global_opcode().as_usize()), - [ - local_cols.rd_ptr.into(), - local_cols.rs1_ptr.into(), - local_cols.rs2_ptr.into(), - AB::Expr::from_canonical_u32(RV32_REGISTER_AS), - AB::Expr::from_canonical_u32(RV32_MEMORY_AS), - ], - local_cols.from_state, - AB::Expr::from_canonical_usize(timestamp_delta) + time_delta.clone(), - ) - .eval(builder, is_last_row.clone()); - - // Assert that we read the correct length of the message - let len_val = compose::(&local_cols.len_data.map(|x| x.into()), RV32_CELL_BITS); - builder - .when(is_last_row.clone()) - .assert_eq(local_cols.control.len, len_val); - // Assert that we started reading from the correct pointer initially - let src_val = compose::(&local_cols.src_ptr.map(|x| x.into()), RV32_CELL_BITS); - builder - .when(is_last_row.clone()) - .assert_eq(local_cols.control.read_ptr, src_val + read_ptr_delta); - // Assert that we started reading from the correct timestamp - builder.when(is_last_row.clone()).assert_eq( - local_cols.control.cur_timestamp, - local_cols.from_state.timestamp + AB::Expr::from_canonical_u32(3) + time_delta, - ); - } -} diff --git a/extensions/sha256/circuit/src/sha256_chip/columns.rs b/extensions/sha256/circuit/src/sha256_chip/columns.rs deleted file mode 100644 index 38c13a0f73..0000000000 --- a/extensions/sha256/circuit/src/sha256_chip/columns.rs +++ /dev/null @@ -1,70 +0,0 @@ -//! WARNING: the order of fields in the structs is important, do not change it - -use openvm_circuit::{ - arch::ExecutionState, - system::memory::offline_checker::{MemoryReadAuxCols, MemoryWriteAuxCols}, -}; -use openvm_circuit_primitives::AlignedBorrow; -use openvm_instructions::riscv::RV32_REGISTER_NUM_LIMBS; -use openvm_sha256_air::{Sha256DigestCols, Sha256RoundCols}; - -use super::{SHA256_REGISTER_READS, SHA256_WRITE_SIZE}; - -/// the first 16 rows of every SHA256 block will be of type Sha256VmRoundCols and the last row will -/// be of type Sha256VmDigestCols -#[repr(C)] -#[derive(Clone, Copy, Debug, AlignedBorrow)] -pub struct Sha256VmRoundCols { - pub control: Sha256VmControlCols, - pub inner: Sha256RoundCols, - pub read_aux: MemoryReadAuxCols, -} - -#[repr(C)] -#[derive(Clone, Copy, Debug, AlignedBorrow)] -pub struct Sha256VmDigestCols { - pub control: Sha256VmControlCols, - pub inner: Sha256DigestCols, - - pub from_state: ExecutionState, - /// It is counter intuitive, but we will constrain the register reads on the very last row of - /// every message - pub rd_ptr: T, - pub rs1_ptr: T, - pub rs2_ptr: T, - pub dst_ptr: [T; RV32_REGISTER_NUM_LIMBS], - pub src_ptr: [T; RV32_REGISTER_NUM_LIMBS], - pub len_data: [T; RV32_REGISTER_NUM_LIMBS], - pub register_reads_aux: [MemoryReadAuxCols; SHA256_REGISTER_READS], - pub writes_aux: MemoryWriteAuxCols, -} - -/// These are the columns that are used on both round and digest rows -#[repr(C)] -#[derive(Clone, Copy, Debug, AlignedBorrow)] -pub struct Sha256VmControlCols { - /// Note: We will use the buffer in `inner.message_schedule` as the message data - /// This is the length of the entire message in bytes - pub len: T, - /// Need to keep timestamp and read_ptr since block reads don't have the necessary information - pub cur_timestamp: T, - pub read_ptr: T, - /// Padding flags which will be used to encode the the number of non-padding cells in the - /// current row - pub pad_flags: [T; 6], - /// A boolean flag that indicates whether a padding already occurred - pub padding_occurred: T, -} - -/// Width of the Sha256VmControlCols -pub const SHA256VM_CONTROL_WIDTH: usize = Sha256VmControlCols::::width(); -/// Width of the Sha256VmRoundCols -pub const SHA256VM_ROUND_WIDTH: usize = Sha256VmRoundCols::::width(); -/// Width of the Sha256VmDigestCols -pub const SHA256VM_DIGEST_WIDTH: usize = Sha256VmDigestCols::::width(); -/// Width of the Sha256Cols -pub const SHA256VM_WIDTH: usize = if SHA256VM_ROUND_WIDTH > SHA256VM_DIGEST_WIDTH { - SHA256VM_ROUND_WIDTH -} else { - SHA256VM_DIGEST_WIDTH -}; diff --git a/extensions/sha256/circuit/src/sha256_chip/mod.rs b/extensions/sha256/circuit/src/sha256_chip/mod.rs deleted file mode 100644 index 4c40eca5d8..0000000000 --- a/extensions/sha256/circuit/src/sha256_chip/mod.rs +++ /dev/null @@ -1,206 +0,0 @@ -//! Sha256 hasher. Handles full sha256 hashing with padding. -//! variable length inputs read from VM memory. -use std::{ - array, - cmp::{max, min}, - sync::{Arc, Mutex}, -}; - -use openvm_circuit::arch::{ - ExecutionBridge, ExecutionError, ExecutionState, InstructionExecutor, SystemPort, -}; -use openvm_circuit_primitives::{ - bitwise_op_lookup::SharedBitwiseOperationLookupChip, encoder::Encoder, -}; -use openvm_instructions::{ - instruction::Instruction, - program::DEFAULT_PC_STEP, - riscv::{RV32_CELL_BITS, RV32_MEMORY_AS, RV32_REGISTER_AS}, - LocalOpcode, -}; -use openvm_rv32im_circuit::adapters::read_rv32_register; -use openvm_sha256_air::{Sha256Air, SHA256_BLOCK_BITS}; -use openvm_sha256_transpiler::Rv32Sha256Opcode; -use openvm_stark_backend::{interaction::BusIndex, p3_field::PrimeField32}; -use serde::{Deserialize, Serialize}; -use sha2::{Digest, Sha256}; - -mod air; -mod columns; -mod trace; - -pub use air::*; -pub use columns::*; -use openvm_circuit::system::memory::{MemoryController, OfflineMemory, RecordId}; - -#[cfg(test)] -mod tests; - -// ==== Constants for register/memory adapter ==== -/// Register reads to get dst, src, len -const SHA256_REGISTER_READS: usize = 3; -/// Number of cells to read in a single memory access -const SHA256_READ_SIZE: usize = 16; -/// Number of cells to write in a single memory access -const SHA256_WRITE_SIZE: usize = 32; -/// Number of rv32 cells read in a SHA256 block -pub const SHA256_BLOCK_CELLS: usize = SHA256_BLOCK_BITS / RV32_CELL_BITS; -/// Number of rows we will do a read on for each SHA256 block -pub const SHA256_NUM_READ_ROWS: usize = SHA256_BLOCK_CELLS / SHA256_READ_SIZE; -pub struct Sha256VmChip { - pub air: Sha256VmAir, - /// IO and memory data necessary for each opcode call - pub records: Vec>, - pub offline_memory: Arc>>, - pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<8>, - - offset: usize, -} - -#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] -pub struct Sha256Record { - pub from_state: ExecutionState, - pub dst_read: RecordId, - pub src_read: RecordId, - pub len_read: RecordId, - pub input_records: Vec<[RecordId; SHA256_NUM_READ_ROWS]>, - pub input_message: Vec<[[u8; SHA256_READ_SIZE]; SHA256_NUM_READ_ROWS]>, - pub digest_write: RecordId, -} - -impl Sha256VmChip { - pub fn new( - SystemPort { - execution_bus, - program_bus, - memory_bridge, - }: SystemPort, - address_bits: usize, - bitwise_lookup_chip: SharedBitwiseOperationLookupChip<8>, - self_bus_idx: BusIndex, - offset: usize, - offline_memory: Arc>>, - ) -> Self { - Self { - air: Sha256VmAir::new( - ExecutionBridge::new(execution_bus, program_bus), - memory_bridge, - bitwise_lookup_chip.bus(), - address_bits, - Sha256Air::new(bitwise_lookup_chip.bus(), self_bus_idx), - Encoder::new(PaddingFlags::COUNT, 2, false), - ), - bitwise_lookup_chip, - records: Vec::new(), - offset, - offline_memory, - } - } -} - -impl InstructionExecutor for Sha256VmChip { - fn execute( - &mut self, - memory: &mut MemoryController, - instruction: &Instruction, - from_state: ExecutionState, - ) -> Result, ExecutionError> { - let &Instruction { - opcode, - a, - b, - c, - d, - e, - .. - } = instruction; - let local_opcode = opcode.local_opcode_idx(self.offset); - debug_assert_eq!(local_opcode, Rv32Sha256Opcode::SHA256.local_usize()); - debug_assert_eq!(d, F::from_canonical_u32(RV32_REGISTER_AS)); - debug_assert_eq!(e, F::from_canonical_u32(RV32_MEMORY_AS)); - - debug_assert_eq!(from_state.timestamp, memory.timestamp()); - - let (dst_read, dst) = read_rv32_register(memory, d, a); - let (src_read, src) = read_rv32_register(memory, d, b); - let (len_read, len) = read_rv32_register(memory, d, c); - - #[cfg(debug_assertions)] - { - assert!(dst < (1 << self.air.ptr_max_bits)); - assert!(src < (1 << self.air.ptr_max_bits)); - assert!(len < (1 << self.air.ptr_max_bits)); - } - - // need to pad with one 1 bit, 64 bits for the message length and then pad until the length - // is divisible by [SHA256_BLOCK_BITS] - let num_blocks = ((len << 3) as usize + 1 + 64).div_ceil(SHA256_BLOCK_BITS); - - // we will read [num_blocks] * [SHA256_BLOCK_CELLS] cells but only [len] cells will be used - debug_assert!( - src as usize + num_blocks * SHA256_BLOCK_CELLS <= (1 << self.air.ptr_max_bits) - ); - let mut hasher = Sha256::new(); - let mut input_records = Vec::with_capacity(num_blocks * SHA256_NUM_READ_ROWS); - let mut input_message = Vec::with_capacity(num_blocks * SHA256_NUM_READ_ROWS); - let mut read_ptr = src; - for _ in 0..num_blocks { - let block_reads_records = array::from_fn(|i| { - memory.read( - e, - F::from_canonical_u32(read_ptr + (i * SHA256_READ_SIZE) as u32), - ) - }); - let block_reads_bytes = array::from_fn(|i| { - // we add to the hasher only the bytes that are part of the message - let num_reads = min( - SHA256_READ_SIZE, - (max(read_ptr, src + len) - read_ptr) as usize, - ); - let row_input = block_reads_records[i] - .1 - .map(|x| x.as_canonical_u32().try_into().unwrap()); - hasher.update(&row_input[..num_reads]); - read_ptr += SHA256_READ_SIZE as u32; - row_input - }); - input_records.push(block_reads_records.map(|x| x.0)); - input_message.push(block_reads_bytes); - } - - let mut digest = [0u8; SHA256_WRITE_SIZE]; - digest.copy_from_slice(hasher.finalize().as_ref()); - let (digest_write, _) = memory.write( - e, - F::from_canonical_u32(dst), - digest.map(|b| F::from_canonical_u8(b)), - ); - - self.records.push(Sha256Record { - from_state: from_state.map(F::from_canonical_u32), - dst_read, - src_read, - len_read, - input_records, - input_message, - digest_write, - }); - - Ok(ExecutionState { - pc: from_state.pc + DEFAULT_PC_STEP, - timestamp: memory.timestamp(), - }) - } - - fn get_opcode_name(&self, _: usize) -> String { - "SHA256".to_string() - } -} - -pub fn sha256_solve(input_message: &[u8]) -> [u8; SHA256_WRITE_SIZE] { - let mut hasher = Sha256::new(); - hasher.update(input_message); - let mut output = [0u8; SHA256_WRITE_SIZE]; - output.copy_from_slice(hasher.finalize().as_ref()); - output -} diff --git a/extensions/sha256/circuit/src/sha256_chip/tests.rs b/extensions/sha256/circuit/src/sha256_chip/tests.rs deleted file mode 100644 index 55bc076e2c..0000000000 --- a/extensions/sha256/circuit/src/sha256_chip/tests.rs +++ /dev/null @@ -1,149 +0,0 @@ -use openvm_circuit::arch::{ - testing::{memory::gen_pointer, VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}, - SystemPort, -}; -use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, -}; -use openvm_instructions::{instruction::Instruction, riscv::RV32_CELL_BITS, LocalOpcode}; -use openvm_sha256_air::get_random_message; -use openvm_sha256_transpiler::Rv32Sha256Opcode::{self, *}; -use openvm_stark_backend::{interaction::BusIndex, p3_field::FieldAlgebra}; -use openvm_stark_sdk::{config::setup_tracing, p3_baby_bear::BabyBear, utils::create_seeded_rng}; -use rand::{rngs::StdRng, Rng}; - -use super::Sha256VmChip; -use crate::{sha256_solve, Sha256VmDigestCols, Sha256VmRoundCols}; - -type F = BabyBear; -const BUS_IDX: BusIndex = 28; -fn set_and_execute( - tester: &mut VmChipTestBuilder, - chip: &mut Sha256VmChip, - rng: &mut StdRng, - opcode: Rv32Sha256Opcode, - message: Option<&[u8]>, - len: Option, -) { - let len = len.unwrap_or(rng.gen_range(1..100000)); - let tmp = get_random_message(rng, len); - let message: &[u8] = message.unwrap_or(&tmp); - let len = message.len(); - - let rd = gen_pointer(rng, 4); - let rs1 = gen_pointer(rng, 4); - let rs2 = gen_pointer(rng, 4); - - let max_mem_ptr: u32 = 1 - << tester - .memory_controller() - .borrow() - .mem_config() - .pointer_max_bits; - let dst_ptr = rng.gen_range(0..max_mem_ptr); - let dst_ptr = dst_ptr ^ (dst_ptr & 3); - tester.write(1, rd, dst_ptr.to_le_bytes().map(F::from_canonical_u8)); - let src_ptr = rng.gen_range(0..(max_mem_ptr - len as u32)); - let src_ptr = src_ptr ^ (src_ptr & 3); - tester.write(1, rs1, src_ptr.to_le_bytes().map(F::from_canonical_u8)); - tester.write(1, rs2, len.to_le_bytes().map(F::from_canonical_u8)); - - for (i, &byte) in message.iter().enumerate() { - tester.write(2, src_ptr as usize + i, [F::from_canonical_u8(byte)]); - } - - tester.execute( - chip, - &Instruction::from_usize(opcode.global_opcode(), [rd, rs1, rs2, 1, 2]), - ); - - let output = sha256_solve(message); - assert_eq!( - output.map(F::from_canonical_u8), - tester.read::<32>(2, dst_ptr as usize) - ); -} - -/////////////////////////////////////////////////////////////////////////////////////// -/// POSITIVE TESTS -/// -/// Randomly generate computations and execute, ensuring that the generated trace -/// passes all constraints. -/////////////////////////////////////////////////////////////////////////////////////// -#[test] -fn rand_sha256_test() { - setup_tracing(); - let mut rng = create_seeded_rng(); - let mut tester = VmChipTestBuilder::default(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let mut chip = Sha256VmChip::new( - SystemPort { - execution_bus: tester.execution_bus(), - program_bus: tester.program_bus(), - memory_bridge: tester.memory_bridge(), - }, - tester.address_bits(), - bitwise_chip.clone(), - BUS_IDX, - Rv32Sha256Opcode::CLASS_OFFSET, - tester.offline_memory_mutex_arc(), - ); - - let num_tests: usize = 3; - for _ in 0..num_tests { - set_and_execute(&mut tester, &mut chip, &mut rng, SHA256, None, None); - } - - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); - tester.simple_test().expect("Verification failed"); -} - -/////////////////////////////////////////////////////////////////////////////////////// -/// SANITY TESTS -/// -/// Ensure that solve functions produce the correct results. -/////////////////////////////////////////////////////////////////////////////////////// -#[test] -fn execute_roundtrip_sanity_test() { - let mut rng = create_seeded_rng(); - let mut tester = VmChipTestBuilder::default(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let mut chip = Sha256VmChip::new( - SystemPort { - execution_bus: tester.execution_bus(), - program_bus: tester.program_bus(), - memory_bridge: tester.memory_bridge(), - }, - tester.address_bits(), - bitwise_chip.clone(), - BUS_IDX, - Rv32Sha256Opcode::CLASS_OFFSET, - tester.offline_memory_mutex_arc(), - ); - - println!( - "Sha256VmDigestCols::width(): {}", - Sha256VmDigestCols::::width() - ); - println!( - "Sha256VmRoundCols::width(): {}", - Sha256VmRoundCols::::width() - ); - let num_tests: usize = 1; - for _ in 0..num_tests { - set_and_execute(&mut tester, &mut chip, &mut rng, SHA256, None, None); - } -} - -#[test] -fn sha256_solve_sanity_check() { - let input = b"Axiom is the best! Axiom is the best! Axiom is the best! Axiom is the best!"; - let output = sha256_solve(input); - let expected: [u8; 32] = [ - 99, 196, 61, 185, 226, 212, 131, 80, 154, 248, 97, 108, 157, 55, 200, 226, 160, 73, 207, - 46, 245, 169, 94, 255, 42, 136, 193, 15, 40, 133, 173, 22, - ]; - assert_eq!(output, expected); -} diff --git a/extensions/sha256/circuit/src/sha256_chip/trace.rs b/extensions/sha256/circuit/src/sha256_chip/trace.rs deleted file mode 100644 index c02cd00dd8..0000000000 --- a/extensions/sha256/circuit/src/sha256_chip/trace.rs +++ /dev/null @@ -1,351 +0,0 @@ -use std::{array, borrow::BorrowMut, sync::Arc}; - -use openvm_circuit_primitives::utils::next_power_of_two_or_zero; -use openvm_instructions::riscv::{RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}; -use openvm_rv32im_circuit::adapters::compose; -use openvm_sha256_air::{ - get_flag_pt_array, limbs_into_u32, Sha256Air, SHA256_BLOCK_WORDS, SHA256_BUFFER_SIZE, SHA256_H, - SHA256_HASH_WORDS, SHA256_ROWS_PER_BLOCK, SHA256_WORD_U8S, -}; -use openvm_stark_backend::{ - config::{StarkGenericConfig, Val}, - p3_air::BaseAir, - p3_field::{FieldAlgebra, PrimeField32}, - p3_matrix::dense::RowMajorMatrix, - p3_maybe_rayon::prelude::*, - prover::types::AirProofInput, - rap::get_air_name, - AirRef, Chip, ChipUsageGetter, -}; - -use super::{ - Sha256VmChip, Sha256VmDigestCols, Sha256VmRoundCols, SHA256VM_CONTROL_WIDTH, - SHA256VM_DIGEST_WIDTH, SHA256VM_ROUND_WIDTH, -}; -use crate::{ - sha256_chip::{PaddingFlags, SHA256_READ_SIZE}, - SHA256_BLOCK_CELLS, -}; - -impl Chip for Sha256VmChip> -where - Val: PrimeField32, -{ - fn air(&self) -> AirRef { - Arc::new(self.air.clone()) - } - - fn generate_air_proof_input(self) -> AirProofInput { - let non_padded_height = self.current_trace_height(); - let height = next_power_of_two_or_zero(non_padded_height); - let width = self.trace_width(); - let mut values = Val::::zero_vec(height * width); - if height == 0 { - return AirProofInput::simple_no_pis(RowMajorMatrix::new(values, width)); - } - let records = self.records; - let offline_memory = self.offline_memory.lock().unwrap(); - let memory_aux_cols_factory = offline_memory.aux_cols_factory(); - - let mem_ptr_shift: u32 = - 1 << (RV32_REGISTER_NUM_LIMBS * RV32_CELL_BITS - self.air.ptr_max_bits); - - let mut states = Vec::with_capacity(height.div_ceil(SHA256_ROWS_PER_BLOCK)); - let mut global_block_idx = 0; - for (record_idx, record) in records.iter().enumerate() { - let dst_read = offline_memory.record_by_id(record.dst_read); - let src_read = offline_memory.record_by_id(record.src_read); - let len_read = offline_memory.record_by_id(record.len_read); - - self.bitwise_lookup_chip.request_range( - dst_read - .data_at(RV32_REGISTER_NUM_LIMBS - 1) - .as_canonical_u32() - * mem_ptr_shift, - src_read - .data_at(RV32_REGISTER_NUM_LIMBS - 1) - .as_canonical_u32() - * mem_ptr_shift, - ); - let len = compose(len_read.data_slice().try_into().unwrap()); - let mut state = &None; - for (i, input_message) in record.input_message.iter().enumerate() { - let input_message = input_message - .iter() - .flatten() - .copied() - .collect::>() - .try_into() - .unwrap(); - states.push(Some(Self::generate_state( - state, - input_message, - record_idx, - len, - i == record.input_records.len() - 1, - ))); - state = &states[global_block_idx]; - global_block_idx += 1; - } - } - states.extend(std::iter::repeat_n( - None, - (height - non_padded_height).div_ceil(SHA256_ROWS_PER_BLOCK), - )); - - // During the first pass we will fill out most of the matrix - // But there are some cells that can't be generated by the first pass so we will do a second - // pass over the matrix - values - .par_chunks_mut(width * SHA256_ROWS_PER_BLOCK) - .zip(states.into_par_iter().enumerate()) - .for_each(|(block, (global_block_idx, state))| { - // Fill in a valid block - if let Some(state) = state { - let mut has_padding_occurred = - state.local_block_idx * SHA256_BLOCK_CELLS > state.message_len as usize; - let message_left = if has_padding_occurred { - 0 - } else { - state.message_len as usize - state.local_block_idx * SHA256_BLOCK_CELLS - }; - let is_last_block = state.is_last_block; - let buffer: [[Val; SHA256_BUFFER_SIZE]; 4] = array::from_fn(|j| { - array::from_fn(|k| { - Val::::from_canonical_u8( - state.block_input_message[j * SHA256_BUFFER_SIZE + k], - ) - }) - }); - - let padded_message: [u32; SHA256_BLOCK_WORDS] = array::from_fn(|j| { - limbs_into_u32::(array::from_fn(|k| { - state.block_padded_message[(j + 1) * SHA256_WORD_U8S - k - 1] as u32 - })) - }); - - self.air.sha256_subair.generate_block_trace::>( - block, - width, - SHA256VM_CONTROL_WIDTH, - &padded_message, - self.bitwise_lookup_chip.clone(), - &state.hash, - is_last_block, - global_block_idx as u32 + 1, - state.local_block_idx as u32, - &buffer, - ); - - let block_reads = records[state.message_idx].input_records - [state.local_block_idx] - .map(|record_id| offline_memory.record_by_id(record_id)); - - let mut read_ptr = block_reads[0].pointer; - let mut cur_timestamp = Val::::from_canonical_u32(block_reads[0].timestamp); - - let read_size = Val::::from_canonical_usize(SHA256_READ_SIZE); - for row in 0..SHA256_ROWS_PER_BLOCK { - let row_slice = &mut block[row * width..(row + 1) * width]; - if row < 16 { - let cols: &mut Sha256VmRoundCols> = - row_slice[..SHA256VM_ROUND_WIDTH].borrow_mut(); - cols.control.len = Val::::from_canonical_u32(state.message_len); - cols.control.read_ptr = read_ptr; - cols.control.cur_timestamp = cur_timestamp; - if row < 4 { - read_ptr += read_size; - cur_timestamp += Val::::ONE; - memory_aux_cols_factory - .generate_read_aux(block_reads[row], &mut cols.read_aux); - - if (row + 1) * SHA256_READ_SIZE <= message_left { - cols.control.pad_flags = get_flag_pt_array( - &self.air.padding_encoder, - PaddingFlags::NotPadding as usize, - ) - .map(Val::::from_canonical_u32); - } else if !has_padding_occurred { - has_padding_occurred = true; - let len = message_left - row * SHA256_READ_SIZE; - cols.control.pad_flags = get_flag_pt_array( - &self.air.padding_encoder, - if row == 3 && is_last_block { - PaddingFlags::FirstPadding0_LastRow - } else { - PaddingFlags::FirstPadding0 - } as usize - + len, - ) - .map(Val::::from_canonical_u32); - } else { - cols.control.pad_flags = get_flag_pt_array( - &self.air.padding_encoder, - if row == 3 && is_last_block { - PaddingFlags::EntirePaddingLastRow - } else { - PaddingFlags::EntirePadding - } as usize, - ) - .map(Val::::from_canonical_u32); - } - } else { - cols.control.pad_flags = get_flag_pt_array( - &self.air.padding_encoder, - PaddingFlags::NotConsidered as usize, - ) - .map(Val::::from_canonical_u32); - } - cols.control.padding_occurred = - Val::::from_bool(has_padding_occurred); - } else { - if is_last_block { - has_padding_occurred = false; - } - let cols: &mut Sha256VmDigestCols> = - row_slice[..SHA256VM_DIGEST_WIDTH].borrow_mut(); - cols.control.len = Val::::from_canonical_u32(state.message_len); - cols.control.read_ptr = read_ptr; - cols.control.cur_timestamp = cur_timestamp; - cols.control.pad_flags = get_flag_pt_array( - &self.air.padding_encoder, - PaddingFlags::NotConsidered as usize, - ) - .map(Val::::from_canonical_u32); - if is_last_block { - let record = &records[state.message_idx]; - let dst_read = offline_memory.record_by_id(record.dst_read); - let src_read = offline_memory.record_by_id(record.src_read); - let len_read = offline_memory.record_by_id(record.len_read); - let digest_write = offline_memory.record_by_id(record.digest_write); - cols.from_state = record.from_state; - cols.rd_ptr = dst_read.pointer; - cols.rs1_ptr = src_read.pointer; - cols.rs2_ptr = len_read.pointer; - cols.dst_ptr.copy_from_slice(dst_read.data_slice()); - cols.src_ptr.copy_from_slice(src_read.data_slice()); - cols.len_data.copy_from_slice(len_read.data_slice()); - memory_aux_cols_factory - .generate_read_aux(dst_read, &mut cols.register_reads_aux[0]); - memory_aux_cols_factory - .generate_read_aux(src_read, &mut cols.register_reads_aux[1]); - memory_aux_cols_factory - .generate_read_aux(len_read, &mut cols.register_reads_aux[2]); - memory_aux_cols_factory - .generate_write_aux(digest_write, &mut cols.writes_aux); - } - cols.control.padding_occurred = - Val::::from_bool(has_padding_occurred); - } - } - } - // Fill in the invalid rows - else { - block.par_chunks_mut(width).for_each(|row| { - let cols: &mut Sha256VmRoundCols> = row.borrow_mut(); - self.air.sha256_subair.generate_default_row(&mut cols.inner); - }) - } - }); - - // Do a second pass over the trace to fill in the missing values - // Note, we need to skip the very first row - values[width..] - .par_chunks_mut(width * SHA256_ROWS_PER_BLOCK) - .take(non_padded_height / SHA256_ROWS_PER_BLOCK) - .for_each(|chunk| { - self.air - .sha256_subair - .generate_missing_cells(chunk, width, SHA256VM_CONTROL_WIDTH); - }); - - AirProofInput::simple_no_pis(RowMajorMatrix::new(values, width)) - } -} - -impl ChipUsageGetter for Sha256VmChip { - fn air_name(&self) -> String { - get_air_name(&self.air) - } - fn current_trace_height(&self) -> usize { - self.records.iter().fold(0, |acc, record| { - acc + record.input_records.len() * SHA256_ROWS_PER_BLOCK - }) - } - - fn trace_width(&self) -> usize { - BaseAir::::width(&self.air) - } -} - -/// This is the state information that a block will use to generate its trace -#[derive(Debug, Clone)] -struct Sha256State { - hash: [u32; SHA256_HASH_WORDS], - local_block_idx: usize, - message_len: u32, - block_input_message: [u8; SHA256_BLOCK_CELLS], - block_padded_message: [u8; SHA256_BLOCK_CELLS], - message_idx: usize, - is_last_block: bool, -} - -impl Sha256VmChip { - fn generate_state( - prev_state: &Option, - block_input_message: [u8; SHA256_BLOCK_CELLS], - message_idx: usize, - message_len: u32, - is_last_block: bool, - ) -> Sha256State { - let local_block_idx = if let Some(prev_state) = prev_state { - prev_state.local_block_idx + 1 - } else { - 0 - }; - let has_padding_occurred = local_block_idx * SHA256_BLOCK_CELLS > message_len as usize; - let message_left = if has_padding_occurred { - 0 - } else { - message_len as usize - local_block_idx * SHA256_BLOCK_CELLS - }; - - let padded_message_bytes: [u8; SHA256_BLOCK_CELLS] = array::from_fn(|j| { - if j < message_left { - block_input_message[j] - } else if j == message_left && !has_padding_occurred { - 1 << (RV32_CELL_BITS - 1) - } else if !is_last_block || j < SHA256_BLOCK_CELLS - 4 { - 0u8 - } else { - let shift_amount = (SHA256_BLOCK_CELLS - j - 1) * RV32_CELL_BITS; - ((message_len * RV32_CELL_BITS as u32) - .checked_shr(shift_amount as u32) - .unwrap_or(0) - & ((1 << RV32_CELL_BITS) - 1)) as u8 - } - }); - - if let Some(prev_state) = prev_state { - Sha256State { - hash: Sha256Air::get_block_hash(&prev_state.hash, prev_state.block_padded_message), - local_block_idx, - message_len, - block_input_message, - block_padded_message: padded_message_bytes, - message_idx, - is_last_block, - } - } else { - Sha256State { - hash: SHA256_H, - local_block_idx: 0, - message_len, - block_input_message, - block_padded_message: padded_message_bytes, - message_idx, - is_last_block, - } - } - } -} diff --git a/extensions/sha256/guest/src/lib.rs b/extensions/sha256/guest/src/lib.rs deleted file mode 100644 index 1c51a272fd..0000000000 --- a/extensions/sha256/guest/src/lib.rs +++ /dev/null @@ -1,22 +0,0 @@ -#![no_std] - -/// This is custom-0 defined in RISC-V spec document -pub const OPCODE: u8 = 0x0b; -pub const SHA256_FUNCT3: u8 = 0b100; -pub const SHA256_FUNCT7: u8 = 0x1; - -/// zkvm native implementation of sha256 -/// # Safety -/// -/// The VM accepts the preimage by pointer and length, and writes the -/// 32-byte hash. -/// - `bytes` must point to an input buffer at least `len` long. -/// - `output` must point to a buffer that is at least 32-bytes long. -/// -/// [`sha2-256`]: https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.180-4.pdf -#[cfg(target_os = "zkvm")] -#[inline(always)] -#[no_mangle] -pub extern "C" fn zkvm_sha256_impl(bytes: *const u8, len: usize, output: *mut u8) { - openvm_platform::custom_insn_r!(opcode = OPCODE, funct3 = SHA256_FUNCT3, funct7 = SHA256_FUNCT7, rd = In output, rs1 = In bytes, rs2 = In len); -} diff --git a/extensions/sha256/transpiler/src/lib.rs b/extensions/sha256/transpiler/src/lib.rs deleted file mode 100644 index 6b13efe055..0000000000 --- a/extensions/sha256/transpiler/src/lib.rs +++ /dev/null @@ -1,46 +0,0 @@ -use openvm_instructions::{riscv::RV32_MEMORY_AS, LocalOpcode}; -use openvm_instructions_derive::LocalOpcode; -use openvm_sha256_guest::{OPCODE, SHA256_FUNCT3, SHA256_FUNCT7}; -use openvm_stark_backend::p3_field::PrimeField32; -use openvm_transpiler::{util::from_r_type, TranspilerExtension, TranspilerOutput}; -use rrs_lib::instruction_formats::RType; -use strum::{EnumCount, EnumIter, FromRepr}; - -#[derive( - Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, EnumCount, EnumIter, FromRepr, LocalOpcode, -)] -#[opcode_offset = 0x320] -#[repr(usize)] -pub enum Rv32Sha256Opcode { - SHA256, -} - -#[derive(Default)] -pub struct Sha256TranspilerExtension; - -impl TranspilerExtension for Sha256TranspilerExtension { - fn process_custom(&self, instruction_stream: &[u32]) -> Option> { - if instruction_stream.is_empty() { - return None; - } - let instruction_u32 = instruction_stream[0]; - let opcode = (instruction_u32 & 0x7f) as u8; - let funct3 = ((instruction_u32 >> 12) & 0b111) as u8; - - if (opcode, funct3) != (OPCODE, SHA256_FUNCT3) { - return None; - } - let dec_insn = RType::new(instruction_u32); - - if dec_insn.funct7 != SHA256_FUNCT7 as u32 { - return None; - } - let instruction = from_r_type( - Rv32Sha256Opcode::SHA256.global_opcode().as_usize(), - RV32_MEMORY_AS as usize, - &dec_insn, - true, - ); - Some(TranspilerOutput::one_to_one(instruction)) - } -} diff --git a/guest-libs/ff_derive/tests/lib.rs b/guest-libs/ff_derive/tests/lib.rs index 6df9a1d675..180cadc27a 100644 --- a/guest-libs/ff_derive/tests/lib.rs +++ b/guest-libs/ff_derive/tests/lib.rs @@ -6,7 +6,7 @@ mod tests { use num_bigint::BigUint; use openvm_algebra_circuit::Rv32ModularConfig; use openvm_algebra_transpiler::ModularTranspilerExtension; - use openvm_circuit::utils::air_test; + use openvm_circuit::utils::{air_test, test_system_config_with_continuations}; use openvm_instructions::exe::VmExe; use openvm_rv32im_transpiler::{ Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, @@ -20,11 +20,18 @@ mod tests { type F = BabyBear; + #[cfg(test)] + fn test_rv32modular_config(moduli: Vec) -> Rv32ModularConfig { + let mut config = Rv32ModularConfig::new(moduli); + config.system = test_system_config_with_continuations(); + config + } + #[test] fn test_full_limbs() -> Result<()> { let moduli = ["39402006196394479212279040100143613805079739270465446667948293404245721771496870329047266088258938001861606973112319"] .map(|s| BigUint::from_str(s).unwrap()); - let config = Rv32ModularConfig::new(moduli.to_vec()); + let config = test_rv32modular_config(moduli.to_vec()); let elf = build_example_program_at_path( get_programs_dir!("tests/programs"), "full_limbs", @@ -46,7 +53,7 @@ mod tests { #[test] fn test_fermat() -> Result<()> { let moduli = ["65537"].map(|s| BigUint::from_str(s).unwrap()); - let config = Rv32ModularConfig::new(moduli.to_vec()); + let config = test_rv32modular_config(moduli.to_vec()); let elf = build_example_program_at_path(get_programs_dir!("tests/programs"), "fermat", &config)?; let openvm_exe = VmExe::from_elf( @@ -65,7 +72,7 @@ mod tests { #[test] fn test_sqrt() -> Result<()> { let moduli = ["357686312646216567629137"].map(|s| BigUint::from_str(s).unwrap()); - let config = Rv32ModularConfig::new(moduli.to_vec()); + let config = test_rv32modular_config(moduli.to_vec()); let elf = build_example_program_at_path(get_programs_dir!("tests/programs"), "sqrt", &config)?; let openvm_exe = VmExe::from_elf( @@ -86,7 +93,7 @@ mod tests { let moduli = ["52435875175126190479447740508185965837690552500527637822603658699938581184513"] .map(|s| BigUint::from_str(s).unwrap()); - let config = Rv32ModularConfig::new(moduli.to_vec()); + let config = test_rv32modular_config(moduli.to_vec()); let elf = build_example_program_at_path( get_programs_dir!("tests/programs"), "constants", @@ -110,7 +117,7 @@ mod tests { let moduli = ["52435875175126190479447740508185965837690552500527637822603658699938581184513"] .map(|s| BigUint::from_str(s).unwrap()); - let config = Rv32ModularConfig::new(moduli.to_vec()); + let config = test_rv32modular_config(moduli.to_vec()); let elf = build_example_program_at_path( get_programs_dir!("tests/programs"), "from_u128", @@ -134,7 +141,7 @@ mod tests { let moduli = ["52435875175126190479447740508185965837690552500527637822603658699938581184513"] .map(|s| BigUint::from_str(s).unwrap()); - let config = Rv32ModularConfig::new(moduli.to_vec()); + let config = test_rv32modular_config(moduli.to_vec()); let elf = build_example_program_at_path_with_features( get_programs_dir!("tests/programs"), "batch_inversion", @@ -159,7 +166,7 @@ mod tests { let moduli = ["52435875175126190479447740508185965837690552500527637822603658699938581184513"] .map(|s| BigUint::from_str(s).unwrap()); - let config = Rv32ModularConfig::new(moduli.to_vec()); + let config = test_rv32modular_config(moduli.to_vec()); let elf = build_example_program_at_path( get_programs_dir!("tests/programs"), "operations", diff --git a/guest-libs/k256/Cargo.toml b/guest-libs/k256/Cargo.toml index 362df43b6f..19f15b8be4 100644 --- a/guest-libs/k256/Cargo.toml +++ b/guest-libs/k256/Cargo.toml @@ -36,8 +36,8 @@ openvm-algebra-circuit.workspace = true openvm-algebra-transpiler.workspace = true openvm-ecc-transpiler.workspace = true openvm-ecc-circuit.workspace = true -openvm-sha256-circuit.workspace = true -openvm-sha256-transpiler.workspace = true +openvm-sha2-circuit.workspace = true +openvm-sha2-transpiler.workspace = true openvm-rv32im-circuit.workspace = true openvm-rv32im-transpiler.workspace = true openvm-toolchain-tests.workspace = true @@ -45,6 +45,7 @@ openvm-toolchain-tests.workspace = true openvm-stark-backend.workspace = true openvm-stark-sdk.workspace = true +rand = { workspace = true } serde.workspace = true eyre.workspace = true derive_more = { workspace = true, features = ["from"] } @@ -84,4 +85,5 @@ ignored = [ "derive_more", "signature", "once_cell", + "rand", ] diff --git a/guest-libs/k256/tests/lib.rs b/guest-libs/k256/tests/lib.rs index e38675aa09..16605bd144 100644 --- a/guest-libs/k256/tests/lib.rs +++ b/guest-libs/k256/tests/lib.rs @@ -2,22 +2,32 @@ mod guest_tests { use ecdsa_config::EcdsaConfig; use eyre::Result; use openvm_algebra_transpiler::ModularTranspilerExtension; - use openvm_circuit::{arch::instructions::exe::VmExe, utils::air_test}; - use openvm_ecc_circuit::{Rv32WeierstrassConfig, SECP256K1_CONFIG}; + use openvm_circuit::{ + arch::instructions::exe::VmExe, + utils::{air_test, test_system_config_with_continuations}, + }; + use openvm_ecc_circuit::{CurveConfig, Rv32WeierstrassConfig, SECP256K1_CONFIG}; use openvm_ecc_transpiler::EccTranspilerExtension; use openvm_rv32im_transpiler::{ Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, }; - use openvm_sha256_transpiler::Sha256TranspilerExtension; + use openvm_sha2_transpiler::Sha2TranspilerExtension; use openvm_stark_sdk::p3_baby_bear::BabyBear; use openvm_toolchain_tests::{build_example_program_at_path, get_programs_dir}; use openvm_transpiler::{transpiler::Transpiler, FromElf}; type F = BabyBear; + #[cfg(test)] + fn test_rv32weierstrass_config(curves: Vec) -> Rv32WeierstrassConfig { + let mut config = Rv32WeierstrassConfig::new(curves); + config.system = test_system_config_with_continuations(); + config + } + #[test] fn test_add() -> Result<()> { - let config = Rv32WeierstrassConfig::new(vec![SECP256K1_CONFIG.clone()]); + let config = test_rv32weierstrass_config(vec![SECP256K1_CONFIG.clone()]); let elf = build_example_program_at_path(get_programs_dir!("tests/programs"), "add", &config)?; let openvm_exe = VmExe::from_elf( @@ -35,7 +45,7 @@ mod guest_tests { #[test] fn test_mul() -> Result<()> { - let config = Rv32WeierstrassConfig::new(vec![SECP256K1_CONFIG.clone()]); + let config = test_rv32weierstrass_config(vec![SECP256K1_CONFIG.clone()]); let elf = build_example_program_at_path(get_programs_dir!("tests/programs"), "mul", &config)?; let openvm_exe = VmExe::from_elf( @@ -53,7 +63,7 @@ mod guest_tests { #[test] fn test_linear_combination() -> Result<()> { - let config = Rv32WeierstrassConfig::new(vec![SECP256K1_CONFIG.clone()]); + let config = test_rv32weierstrass_config(vec![SECP256K1_CONFIG.clone()]); let elf = build_example_program_at_path( get_programs_dir!("tests/programs"), "linear_combination", @@ -80,6 +90,7 @@ mod guest_tests { use openvm_circuit::{ arch::{InitFileGenerator, SystemConfig}, derive::VmConfig, + utils::test_system_config_with_continuations, }; use openvm_ecc_circuit::{ CurveConfig, WeierstrassExtension, WeierstrassExtensionExecutor, @@ -89,7 +100,7 @@ mod guest_tests { Rv32I, Rv32IExecutor, Rv32IPeriphery, Rv32Io, Rv32IoExecutor, Rv32IoPeriphery, Rv32M, Rv32MExecutor, Rv32MPeriphery, }; - use openvm_sha256_circuit::{Sha256, Sha256Executor, Sha256Periphery}; + use openvm_sha2_circuit::{Sha2, Sha2Executor, Sha2Periphery}; use openvm_stark_backend::p3_field::PrimeField32; use serde::{Deserialize, Serialize}; @@ -108,7 +119,7 @@ mod guest_tests { #[extension] pub weierstrass: WeierstrassExtension, #[extension] - pub sha256: Sha256, + pub sha2: Sha2, } impl EcdsaConfig { @@ -118,13 +129,13 @@ mod guest_tests { .flat_map(|c| [c.modulus.clone(), c.scalar.clone()]) .collect(); Self { - system: SystemConfig::default().with_continuations(), + system: test_system_config_with_continuations(), base: Default::default(), mul: Default::default(), io: Default::default(), modular: ModularExtension::new(primes), weierstrass: WeierstrassExtension::new(curves), - sha256: Default::default(), + sha2: Default::default(), } } } @@ -154,7 +165,7 @@ mod guest_tests { .with_extension(Rv32IoTranspilerExtension) .with_extension(EccTranspilerExtension) .with_extension(ModularTranspilerExtension) - .with_extension(Sha256TranspilerExtension), + .with_extension(Sha2TranspilerExtension), )?; air_test(config, openvm_exe); Ok(()) @@ -162,7 +173,7 @@ mod guest_tests { #[test] fn test_scalar_sqrt() -> Result<()> { - let config = Rv32WeierstrassConfig::new(vec![SECP256K1_CONFIG.clone()]); + let config = test_rv32weierstrass_config(vec![SECP256K1_CONFIG.clone()]); let elf = build_example_program_at_path( get_programs_dir!("tests/programs"), "scalar_sqrt", diff --git a/guest-libs/p256/Cargo.toml b/guest-libs/p256/Cargo.toml index e54a7d22d6..dad42296e3 100644 --- a/guest-libs/p256/Cargo.toml +++ b/guest-libs/p256/Cargo.toml @@ -33,8 +33,8 @@ openvm-algebra-circuit.workspace = true openvm-algebra-transpiler.workspace = true openvm-ecc-transpiler.workspace = true openvm-ecc-circuit.workspace = true -openvm-sha256-circuit.workspace = true -openvm-sha256-transpiler.workspace = true +openvm-sha2-circuit.workspace = true +openvm-sha2-transpiler.workspace = true openvm-rv32im-circuit.workspace = true openvm-rv32im-transpiler.workspace = true openvm-toolchain-tests.workspace = true @@ -42,6 +42,7 @@ openvm-toolchain-tests.workspace = true openvm-stark-backend.workspace = true openvm-stark-sdk.workspace = true +rand = { workspace = true } serde.workspace = true eyre.workspace = true derive_more = { workspace = true, features = ["from"] } @@ -70,4 +71,4 @@ voprf = ["elliptic-curve/voprf"] num-bigint = { workspace = true } [package.metadata.cargo-shear] -ignored = ["openvm", "serde", "num-bigint", "derive_more"] +ignored = ["openvm", "serde", "num-bigint", "derive_more", "rand"] diff --git a/guest-libs/p256/tests/lib.rs b/guest-libs/p256/tests/lib.rs index f11cb63325..a0bbeb046f 100644 --- a/guest-libs/p256/tests/lib.rs +++ b/guest-libs/p256/tests/lib.rs @@ -2,22 +2,32 @@ mod guest_tests { use ecdsa_config::EcdsaConfig; use eyre::Result; use openvm_algebra_transpiler::ModularTranspilerExtension; - use openvm_circuit::{arch::instructions::exe::VmExe, utils::air_test}; - use openvm_ecc_circuit::{Rv32WeierstrassConfig, P256_CONFIG}; + use openvm_circuit::{ + arch::instructions::exe::VmExe, + utils::{air_test, test_system_config_with_continuations}, + }; + use openvm_ecc_circuit::{CurveConfig, Rv32WeierstrassConfig, P256_CONFIG}; use openvm_ecc_transpiler::EccTranspilerExtension; use openvm_rv32im_transpiler::{ Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, }; - use openvm_sha256_transpiler::Sha256TranspilerExtension; + use openvm_sha2_transpiler::Sha2TranspilerExtension; use openvm_stark_sdk::p3_baby_bear::BabyBear; use openvm_toolchain_tests::{build_example_program_at_path, get_programs_dir}; use openvm_transpiler::{transpiler::Transpiler, FromElf}; type F = BabyBear; + #[cfg(test)] + fn test_rv32weierstrass_config(curves: Vec) -> Rv32WeierstrassConfig { + let mut config = Rv32WeierstrassConfig::new(curves); + config.system = test_system_config_with_continuations(); + config + } + #[test] fn test_add() -> Result<()> { - let config = Rv32WeierstrassConfig::new(vec![P256_CONFIG.clone()]); + let config = test_rv32weierstrass_config(vec![P256_CONFIG.clone()]); let elf = build_example_program_at_path(get_programs_dir!("tests/programs"), "add", &config)?; let openvm_exe = VmExe::from_elf( @@ -35,7 +45,7 @@ mod guest_tests { #[test] fn test_mul() -> Result<()> { - let config = Rv32WeierstrassConfig::new(vec![P256_CONFIG.clone()]); + let config = test_rv32weierstrass_config(vec![P256_CONFIG.clone()]); let elf = build_example_program_at_path(get_programs_dir!("tests/programs"), "mul", &config)?; let openvm_exe = VmExe::from_elf( @@ -53,7 +63,7 @@ mod guest_tests { #[test] fn test_linear_combination() -> Result<()> { - let config = Rv32WeierstrassConfig::new(vec![P256_CONFIG.clone()]); + let config = test_rv32weierstrass_config(vec![P256_CONFIG.clone()]); let elf = build_example_program_at_path( get_programs_dir!("tests/programs"), "linear_combination", @@ -80,6 +90,7 @@ mod guest_tests { use openvm_circuit::{ arch::{InitFileGenerator, SystemConfig}, derive::VmConfig, + utils::test_system_config_with_continuations, }; use openvm_ecc_circuit::{ CurveConfig, WeierstrassExtension, WeierstrassExtensionExecutor, @@ -89,7 +100,7 @@ mod guest_tests { Rv32I, Rv32IExecutor, Rv32IPeriphery, Rv32Io, Rv32IoExecutor, Rv32IoPeriphery, Rv32M, Rv32MExecutor, Rv32MPeriphery, }; - use openvm_sha256_circuit::{Sha256, Sha256Executor, Sha256Periphery}; + use openvm_sha2_circuit::{Sha2, Sha2Executor, Sha2Periphery}; use openvm_stark_backend::p3_field::PrimeField32; use serde::{Deserialize, Serialize}; @@ -108,7 +119,7 @@ mod guest_tests { #[extension] pub weierstrass: WeierstrassExtension, #[extension] - pub sha256: Sha256, + pub sha2: Sha2, } impl EcdsaConfig { @@ -118,13 +129,13 @@ mod guest_tests { .flat_map(|c| [c.modulus.clone(), c.scalar.clone()]) .collect(); Self { - system: SystemConfig::default().with_continuations(), + system: test_system_config_with_continuations(), base: Default::default(), mul: Default::default(), io: Default::default(), modular: ModularExtension::new(primes), weierstrass: WeierstrassExtension::new(curves), - sha256: Default::default(), + sha2: Default::default(), } } } @@ -154,7 +165,7 @@ mod guest_tests { .with_extension(Rv32IoTranspilerExtension) .with_extension(EccTranspilerExtension) .with_extension(ModularTranspilerExtension) - .with_extension(Sha256TranspilerExtension), + .with_extension(Sha2TranspilerExtension), )?; air_test(config, openvm_exe); Ok(()) @@ -162,7 +173,7 @@ mod guest_tests { #[test] fn test_scalar_sqrt() -> Result<()> { - let config = Rv32WeierstrassConfig::new(vec![P256_CONFIG.clone()]); + let config = test_rv32weierstrass_config(vec![P256_CONFIG.clone()]); let elf = build_example_program_at_path( get_programs_dir!("tests/programs"), "scalar_sqrt", diff --git a/guest-libs/pairing/tests/lib.rs b/guest-libs/pairing/tests/lib.rs index 6e55834b77..1d738e8701 100644 --- a/guest-libs/pairing/tests/lib.rs +++ b/guest-libs/pairing/tests/lib.rs @@ -11,11 +11,10 @@ mod bn254 { }; use openvm_algebra_circuit::{Fp2Extension, ModularExtension}; use openvm_algebra_transpiler::{Fp2TranspilerExtension, ModularTranspilerExtension}; - use openvm_circuit::{ - arch::SystemConfig, - utils::{air_test, air_test_impl, air_test_with_min_segments}, + use openvm_circuit::utils::{ + air_test, air_test_impl, air_test_with_min_segments, test_system_config_with_continuations, }; - use openvm_ecc_circuit::{Rv32WeierstrassConfig, WeierstrassExtension}; + use openvm_ecc_circuit::{CurveConfig, Rv32WeierstrassConfig, WeierstrassExtension}; use openvm_ecc_guest::{ algebra::{field::FieldExtension, IntMod}, AffinePoint, @@ -48,7 +47,7 @@ mod bn254 { .zip(primes.clone()) .collect::>(); Rv32PairingConfig { - system: SystemConfig::default().with_continuations(), + system: test_system_config_with_continuations(), base: Default::default(), mul: Default::default(), io: Default::default(), @@ -59,10 +58,17 @@ mod bn254 { } } + #[cfg(test)] + fn test_rv32weierstrass_config(curves: Vec) -> Rv32WeierstrassConfig { + let mut config = Rv32WeierstrassConfig::new(curves); + config.system = test_system_config_with_continuations(); + config + } + #[test] fn test_bn_ec() -> Result<()> { let curve = PairingCurve::Bn254.curve_config(); - let config = Rv32WeierstrassConfig::new(vec![curve]); + let config = test_rv32weierstrass_config(vec![curve]); let elf = build_example_program_at_path_with_features( get_programs_dir!("tests/programs"), "bn_ec", @@ -459,8 +465,11 @@ mod bls12_381 { use openvm_algebra_circuit::{Fp2Extension, ModularExtension}; use openvm_algebra_transpiler::{Fp2TranspilerExtension, ModularTranspilerExtension}; use openvm_circuit::{ - arch::{instructions::exe::VmExe, SystemConfig}, - utils::{air_test, air_test_impl, air_test_with_min_segments}, + arch::instructions::exe::VmExe, + utils::{ + air_test, air_test_impl, air_test_with_min_segments, + test_system_config_with_continuations, + }, }; use openvm_ecc_circuit::{CurveConfig, Rv32WeierstrassConfig, WeierstrassExtension}; use openvm_ecc_guest::{ @@ -497,7 +506,7 @@ mod bls12_381 { .zip(primes.clone()) .collect::>(); Rv32PairingConfig { - system: SystemConfig::default().with_continuations(), + system: test_system_config_with_continuations(), base: Default::default(), mul: Default::default(), io: Default::default(), @@ -508,6 +517,13 @@ mod bls12_381 { } } + #[cfg(test)] + fn test_rv32weierstrass_config(curves: Vec) -> Rv32WeierstrassConfig { + let mut config = Rv32WeierstrassConfig::new(curves); + config.system = test_system_config_with_continuations(); + config + } + #[test] fn test_bls_ec() -> Result<()> { let curve = CurveConfig { @@ -517,7 +533,7 @@ mod bls12_381 { a: BigUint::ZERO, b: BigUint::from_u8(4).unwrap(), }; - let config = Rv32WeierstrassConfig::new(vec![curve]); + let config = test_rv32weierstrass_config(vec![curve]); let elf = build_example_program_at_path_with_features( get_programs_dir!("tests/programs"), "bls_ec", diff --git a/guest-libs/ruint/tests/programs/examples/matrix_power.rs b/guest-libs/ruint/tests/programs/examples/matrix_power.rs index 95826d32de..6a874bc35e 100644 --- a/guest-libs/ruint/tests/programs/examples/matrix_power.rs +++ b/guest-libs/ruint/tests/programs/examples/matrix_power.rs @@ -123,6 +123,11 @@ pub fn main() { panic!(); } + if U256::from_limbs([u64::MAX; 4]) + one != zero { + print("FAIL: U256::MAX == 0 test failed"); + panic!(); + } + if two_to_200 != two_to_200 { print("FAIL: 2^200 clone test failed"); panic!(); diff --git a/guest-libs/sha2/Cargo.toml b/guest-libs/sha2/Cargo.toml index f8bf7b545e..9e13e85ce8 100644 --- a/guest-libs/sha2/Cargo.toml +++ b/guest-libs/sha2/Cargo.toml @@ -10,15 +10,15 @@ repository.workspace = true license.workspace = true [dependencies] -openvm-sha256-guest = { workspace = true } +openvm-sha2-guest = { workspace = true } [dev-dependencies] openvm-instructions = { workspace = true } openvm-stark-sdk = { workspace = true } openvm-circuit = { workspace = true, features = ["test-utils", "parallel"] } openvm-transpiler = { workspace = true } -openvm-sha256-transpiler = { workspace = true } -openvm-sha256-circuit = { workspace = true } +openvm-sha2-transpiler = { workspace = true } +openvm-sha2-circuit = { workspace = true } openvm-rv32im-transpiler = { workspace = true } openvm-toolchain-tests = { workspace = true } eyre = { workspace = true } diff --git a/guest-libs/sha2/src/lib.rs b/guest-libs/sha2/src/lib.rs index 43d90ba822..dfeddf70a1 100644 --- a/guest-libs/sha2/src/lib.rs +++ b/guest-libs/sha2/src/lib.rs @@ -8,6 +8,22 @@ pub fn sha256(input: &[u8]) -> [u8; 32] { output } +/// The sha512 cryptographic hash function. +#[inline(always)] +pub fn sha512(input: &[u8]) -> [u8; 64] { + let mut output = [0u8; 64]; + set_sha512(input, &mut output); + output +} + +/// The sha384 cryptographic hash function. +#[inline(always)] +pub fn sha384(input: &[u8]) -> [u8; 48] { + let mut output = [0u8; 48]; + set_sha384(input, &mut output); + output +} + /// Sets `output` to the sha256 hash of `input`. pub fn set_sha256(input: &[u8], output: &mut [u8; 32]) { #[cfg(not(target_os = "zkvm"))] @@ -19,10 +35,51 @@ pub fn set_sha256(input: &[u8], output: &mut [u8; 32]) { } #[cfg(target_os = "zkvm")] { - openvm_sha256_guest::zkvm_sha256_impl( + openvm_sha2_guest::zkvm_sha256_impl( input.as_ptr(), input.len(), output.as_mut_ptr() as *mut u8, ); } } + +/// Sets `output` to the sha512 hash of `input`. +pub fn set_sha512(input: &[u8], output: &mut [u8; 64]) { + #[cfg(not(target_os = "zkvm"))] + { + use sha2::{Digest, Sha512}; + let mut hasher = Sha512::new(); + hasher.update(input); + output.copy_from_slice(hasher.finalize().as_ref()); + } + #[cfg(target_os = "zkvm")] + { + openvm_sha2_guest::zkvm_sha512_impl( + input.as_ptr(), + input.len(), + output.as_mut_ptr() as *mut u8, + ); + } +} + +/// Sets the first 48 bytes of `output` to the sha384 hash of `input`. +/// Sets the last 16 bytes to zeros. +pub fn set_sha384(input: &[u8], output: &mut [u8; 48]) { + #[cfg(not(target_os = "zkvm"))] + { + use sha2::{Digest, Sha384}; + let mut hasher = Sha384::new(); + hasher.update(input); + output.copy_from_slice(hasher.finalize().as_ref()); + } + #[cfg(target_os = "zkvm")] + { + let mut output_64: [u8; 64] = [0; 64]; + openvm_sha2_guest::zkvm_sha384_impl( + input.as_ptr(), + input.len(), + output_64.as_mut_ptr() as *mut u8, + ); + output.copy_from_slice(&output_64[..48]); + } +} diff --git a/guest-libs/sha2/tests/lib.rs b/guest-libs/sha2/tests/lib.rs index 9ebab5ac02..3a8c92194c 100644 --- a/guest-libs/sha2/tests/lib.rs +++ b/guest-libs/sha2/tests/lib.rs @@ -6,8 +6,8 @@ mod tests { use openvm_rv32im_transpiler::{ Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, }; - use openvm_sha256_circuit::Sha256Rv32Config; - use openvm_sha256_transpiler::Sha256TranspilerExtension; + use openvm_sha2_circuit::Sha2Rv32Config; + use openvm_sha2_transpiler::Sha2TranspilerExtension; use openvm_stark_sdk::p3_baby_bear::BabyBear; use openvm_toolchain_tests::{build_example_program_at_path, get_programs_dir}; use openvm_transpiler::{transpiler::Transpiler, FromElf}; @@ -15,17 +15,17 @@ mod tests { type F = BabyBear; #[test] - fn test_sha256() -> Result<()> { - let config = Sha256Rv32Config::default(); + fn test_sha2() -> Result<()> { + let config = Sha2Rv32Config::default(); let elf = - build_example_program_at_path(get_programs_dir!("tests/programs"), "sha", &config)?; + build_example_program_at_path(get_programs_dir!("tests/programs"), "sha2", &config)?; let openvm_exe = VmExe::from_elf( elf, Transpiler::::default() .with_extension(Rv32ITranspilerExtension) .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension) - .with_extension(Sha256TranspilerExtension), + .with_extension(Sha2TranspilerExtension), )?; air_test(config, openvm_exe); Ok(()) diff --git a/guest-libs/sha2/tests/programs/Cargo.toml b/guest-libs/sha2/tests/programs/Cargo.toml index df13f8dfc7..c197564ec0 100644 --- a/guest-libs/sha2/tests/programs/Cargo.toml +++ b/guest-libs/sha2/tests/programs/Cargo.toml @@ -8,12 +8,12 @@ edition = "2021" openvm = { path = "../../../../crates/toolchain/openvm" } openvm-platform = { path = "../../../../crates/toolchain/platform" } openvm-sha2 = { path = "../../" } - hex = { version = "0.4.3", default-features = false, features = ["alloc"] } serde = { version = "1.0", default-features = false, features = [ "alloc", "derive", ] } +hex-literal = { version = "1.0.0" } [features] default = [] diff --git a/guest-libs/sha2/tests/programs/examples/sha.rs b/guest-libs/sha2/tests/programs/examples/sha.rs deleted file mode 100644 index ebfd50cbee..0000000000 --- a/guest-libs/sha2/tests/programs/examples/sha.rs +++ /dev/null @@ -1,29 +0,0 @@ -#![cfg_attr(not(feature = "std"), no_main)] -#![cfg_attr(not(feature = "std"), no_std)] - -extern crate alloc; - -use alloc::vec::Vec; -use core::hint::black_box; - -use hex::FromHex; -use openvm_sha2::sha256; - -openvm::entry!(main); - -pub fn main() { - let test_vectors = [ - ("", "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"), - ("98c1c0bdb7d5fea9a88859f06c6c439f", "b6b2c9c9b6f30e5c66c977f1bd7ad97071bee739524aecf793384890619f2b05"), - ("5b58f4163e248467cc1cd3eecafe749e8e2baaf82c0f63af06df0526347d7a11327463c115210a46b6740244eddf370be89c", "ac0e25049870b91d78ef6807bb87fce4603c81abd3c097fba2403fd18b6ce0b7"), - ("9ad198539e3160194f38ac076a782bd5210a007560d1fce9ef78f8a4a5e4d78c6b96c250cff3520009036e9c6087d5dab587394edda862862013de49a12072485a6c01165ec0f28ffddf1873fbd53e47fcd02fb6a5ccc9622d5588a92429c663ce298cb71b50022fc2ec4ba9f5bbd250974e1a607b165fee16e8f3f2be20d7348b91a2f518ce928491900d56d9f86970611580350cee08daea7717fe28a73b8dcfdea22a65ed9f5a09198de38e4e4f2cc05b0ba3dd787a5363ab6c9f39dcb66c1a29209b1d6b1152769395df8150b4316658ea6ab19af94903d643fcb0ae4d598035ebe73c8b1b687df1ab16504f633c929569c6d0e5fae6eea43838fbc8ce2c2b43161d0addc8ccf945a9c4e06294e56a67df0000f561f61b630b1983ba403e775aaeefa8d339f669d1e09ead7eae979383eda983321e1743e5404b4b328da656de79ff52d179833a6bd5129f49432d74d001996c37c68d9ab49fcff8061d193576f396c20e1f0d9ee83a51290ba60efa9c3cb2e15b756321a7ca668cdbf63f95ec33b1c450aa100101be059dc00077245b25a6a66698dee81953ed4a606944076e2858b1420de0095a7f60b08194d6d9a997009d345c71f63a7034b976e409af8a9a040ac7113664609a7adedb76b2fadf04b0348392a1650526eb2a4d6ed5e4bbcda8aabc8488b38f4f5d9a398103536bb8250ed82a9b9825f7703c263f9e", "080ad71239852124fc26758982090611b9b19abf22d22db3a57f67a06e984a23") - ]; - for (input, expected_output) in test_vectors.iter() { - let input = Vec::from_hex(input).unwrap(); - let expected_output = Vec::from_hex(expected_output).unwrap(); - let output = sha256(&black_box(input)); - if output != *expected_output { - panic!(); - } - } -} diff --git a/guest-libs/sha2/tests/programs/examples/sha2.rs b/guest-libs/sha2/tests/programs/examples/sha2.rs new file mode 100644 index 0000000000..7f28152b42 --- /dev/null +++ b/guest-libs/sha2/tests/programs/examples/sha2.rs @@ -0,0 +1,85 @@ +#![cfg_attr(not(feature = "std"), no_main)] +#![cfg_attr(not(feature = "std"), no_std)] + +extern crate alloc; + +use alloc::vec::Vec; +use core::hint::black_box; + +use hex::FromHex; +use openvm_sha2::{sha256, sha384, sha512}; + +openvm::entry!(main); + +struct ShaTestVector { + input: &'static str, + expected_output_sha256: &'static str, + expected_output_sha512: &'static str, + expected_output_sha384: &'static str, +} + +pub fn main() { + let test_vectors = [ + ShaTestVector { + input: "", + expected_output_sha256: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", + expected_output_sha512: "cf83e1357eefb8bdf1542850d66d8007d620e4050b5715dc83f4a921d36ce9ce47d0d13c5d85f2b0ff8318d2877eec2f63b931bd47417a81a538327af927da3e", + expected_output_sha384: "38b060a751ac96384cd9327eb1b1e36a21fdb71114be07434c0cc7bf63f6e1da274edebfe76f65fbd51ad2f14898b95b", + }, + ShaTestVector { + input: "98c1c0bdb7d5fea9a88859f06c6c439f", + expected_output_sha256: "b6b2c9c9b6f30e5c66c977f1bd7ad97071bee739524aecf793384890619f2b05", + expected_output_sha512: "eb576959c531f116842c0cc915a29c8f71d7a285c894c349b83469002ef093d51f9f14ce4248488bff143025e47ed27c12badb9cd43779cb147408eea062d583", + expected_output_sha384: "63e3061aab01f335ea3a4e617b9d14af9b63a5240229164ee962f6d5335ff25f0f0bf8e46723e83c41b9d17413b6a3c7", + }, + ShaTestVector { + input: "5b58f4163e248467cc1cd3eecafe749e8e2baaf82c0f63af06df0526347d7a11327463c115210a46b6740244eddf370be89c", + expected_output_sha256: "ac0e25049870b91d78ef6807bb87fce4603c81abd3c097fba2403fd18b6ce0b7", + expected_output_sha512: "a20d5fb14814d045a7d2861e80d2b688f1cd1daaba69e6bb1cc5233f514141ea4623b3373af702e78e3ec5dc8c1b716a37a9a2f5fbc9493b9df7043f5e99a8da", + expected_output_sha384: "eac4b72b0540486bc088834860873338e31e9e4062532bf509191ef63b9298c67db5654a28fe6f07e4cc6ff466d1be24", + }, + ShaTestVector { + input: "9ad198539e3160194f38ac076a782bd5210a007560d1fce9ef78f8a4a5e4d78c6b96c250cff3520009036e9c6087d5dab587394edda862862013de49a12072485a6c01165ec0f28ffddf1873fbd53e47fcd02fb6a5ccc9622d5588a92429c663ce298cb71b50022fc2ec4ba9f5bbd250974e1a607b165fee16e8f3f2be20d7348b91a2f518ce928491900d56d9f86970611580350cee08daea7717fe28a73b8dcfdea22a65ed9f5a09198de38e4e4f2cc05b0ba3dd787a5363ab6c9f39dcb66c1a29209b1d6b1152769395df8150b4316658ea6ab19af94903d643fcb0ae4d598035ebe73c8b1b687df1ab16504f633c929569c6d0e5fae6eea43838fbc8ce2c2b43161d0addc8ccf945a9c4e06294e56a67df0000f561f61b630b1983ba403e775aaeefa8d339f669d1e09ead7eae979383eda983321e1743e5404b4b328da656de79ff52d179833a6bd5129f49432d74d001996c37c68d9ab49fcff8061d193576f396c20e1f0d9ee83a51290ba60efa9c3cb2e15b756321a7ca668cdbf63f95ec33b1c450aa100101be059dc00077245b25a6a66698dee81953ed4a606944076e2858b1420de0095a7f60b08194d6d9a997009d345c71f63a7034b976e409af8a9a040ac7113664609a7adedb76b2fadf04b0348392a1650526eb2a4d6ed5e4bbcda8aabc8488b38f4f5d9a398103536bb8250ed82a9b9825f7703c263f9e", + expected_output_sha256: "080ad71239852124fc26758982090611b9b19abf22d22db3a57f67a06e984a23", + expected_output_sha512: "8d215ee6dc26757c210db0dd00c1c6ed16cc34dbd4bb0fa10c1edb6b62d5ab16aea88c881001b173d270676daf2d6381b5eab8711fa2f5589c477c1d4b84774f", + expected_output_sha384: "904a90010d772a904a35572fdd4bdf1dd253742e47872c8a18e2255f66fa889e44781e65487a043f435daa53c496a53e", + } + ]; + + for ( + i, + ShaTestVector { + input, + expected_output_sha256, + expected_output_sha512, + expected_output_sha384, + }, + ) in test_vectors.iter().enumerate() + { + let input = Vec::from_hex(input).unwrap(); + let expected_output_sha256 = Vec::from_hex(expected_output_sha256).unwrap(); + let output = sha256(black_box(&input)); + if output != *expected_output_sha256 { + panic!( + "sha256 test {i} failed on input: {:?}.\nexpected: {:?},\ngot: {:?}", + input, expected_output_sha256, output + ); + } + let expected_output_sha512 = Vec::from_hex(expected_output_sha512).unwrap(); + let output = sha512(black_box(&input)); + if output != *expected_output_sha512 { + panic!( + "sha512 test {i} failed on input: {:?}.\nexpected: {:?},\ngot: {:?}", + input, expected_output_sha512, output + ); + } + let expected_output_sha384 = Vec::from_hex(expected_output_sha384).unwrap(); + let output = sha384(black_box(&input)); + if output != *expected_output_sha384 { + panic!( + "sha384 test {i} failed on input: {:?}.\nexpected: {:?},\ngot: {:?}", + input, expected_output_sha384, output + ); + } + } +} diff --git a/guest-libs/verify_stark/Cargo.toml b/guest-libs/verify_stark/Cargo.toml index 070083edad..66f13731d2 100644 --- a/guest-libs/verify_stark/Cargo.toml +++ b/guest-libs/verify_stark/Cargo.toml @@ -21,4 +21,4 @@ openvm-circuit = { workspace = true, features = ["parallel"] } openvm-stark-sdk = { workspace = true } openvm-native-compiler.workspace = true openvm-verify-stark.workspace = true -eyre.workspace = true \ No newline at end of file +eyre.workspace = true diff --git a/rust-toolchain.toml b/rust-toolchain.toml index 35e9b966ed..3d10b2a450 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,3 +1,3 @@ [toolchain] channel = "1.85.1" -components = ["clippy", "rustfmt"] +components = ["clippy", "rustfmt"] \ No newline at end of file