diff --git a/src/bindings/crypto/bindings.ts b/src/bindings/crypto/bindings.ts index 23b133cdc7..5f45a4039c 100644 --- a/src/bindings/crypto/bindings.ts +++ b/src/bindings/crypto/bindings.ts @@ -3,6 +3,7 @@ * It is exposed to JSOO by populating a global variable with an object. * It gets imported as the first thing in ../../bindings.js so that the global variable is ready by the time JSOO code gets executed. */ +import type * as napiNamespace from '../compiled/node_bindings/plonk_wasm.cjs'; import type * as wasmNamespace from '../compiled/node_bindings/plonk_wasm.cjs'; import { prefixHashes, prefixHashesLegacy } from '../crypto/constants.js'; import { Bigint256Bindings } from './bindings/bigint256.js'; @@ -14,12 +15,15 @@ import { verifierIndexConversion } from './bindings/conversion-verifier-index.js import { PallasBindings, VestaBindings } from './bindings/curve.js'; import { jsEnvironment } from './bindings/env.js'; import { FpBindings, FqBindings } from './bindings/field.js'; -import { srs } from './bindings/srs.js'; import { FpVectorBindings, FqVectorBindings } from './bindings/vector.js'; +import { srs } from './bindings/srs.js'; +import { srs as napiSrs } from './napi-srs.js'; import { napiConversionCore } from './napi-conversion-core.js'; import { napiProofConversion } from './napi-conversion-proof.js'; +import { napiVerifierIndexConversion } from './napi-conversion-verifier-index.js'; +import { napiOraclesConversion } from './bindings/napi-conversion-oracles.js'; -export { RustConversion, Wasm, createNativeRustConversion, getRustConversion }; +export { Napi, RustConversion, Wasm, createNativeRustConversion, getRustConversion }; /* TODO: Uncomment in phase 2 of conversion layer import { conversionCore as conversionCoreNative } from './native/conversion-core.js'; @@ -49,13 +53,14 @@ const tsBindings = { return bundle.srsFactory(wasm, bundle.conversion); },*/ srs: (wasm: Wasm) => srs(wasm, getRustConversion(wasm)), - srsNative: (napi: Wasm) => srs(napi, createNativeRustConversion(napi) as any), + srsNative: (napi: Napi) => napiSrs(napi, createNativeRustConversion(napi) as any), }; // this is put in a global variable so that mina/src/lib/crypto/kimchi_bindings/js/bindings.js finds it (globalThis as any).__snarkyTsBindings = tsBindings; type Wasm = typeof wasmNamespace; +type Napi = typeof napiNamespace; type RustConversion = ReturnType; @@ -86,9 +91,11 @@ function buildWasmConversion(wasm: Wasm) { function createNativeRustConversion(napi: any) { let core = napiConversionCore(napi); let proof = napiProofConversion(napi, core); + let verif = napiVerifierIndexConversion(napi, core); + let oracles = napiOraclesConversion(napi); return { - fp: { ...core.fp, ...proof.fp }, - fq: { ...core.fq, ...proof.fq }, + fp: { ...core.fp, ...proof.fp, ...verif.fp, ...oracles.fp }, + fq: { ...core.fq, ...proof.fq, ...verif.fq, ...oracles.fq }, }; } diff --git a/src/bindings/crypto/bindings/conversion-base.ts b/src/bindings/crypto/bindings/conversion-base.ts index 9f67829df3..5b8cdd601d 100644 --- a/src/bindings/crypto/bindings/conversion-base.ts +++ b/src/bindings/crypto/bindings/conversion-base.ts @@ -1,24 +1,24 @@ -import { Field } from './field.js'; -import { bigintToBytes32, bytesToBigint32 } from '../bigint-helpers.js'; +import type { MlArray } from '../../../lib/ml/base.js'; import type { WasmGPallas, WasmGVesta, WasmPallasGProjective, WasmVestaGProjective, } from '../../compiled/node_bindings/plonk_wasm.cjs'; -import type { MlArray } from '../../../lib/ml/base.js'; -import { OrInfinity, Infinity } from './curve.js'; +import { bigintToBytes32, bytesToBigint32 } from '../bigint-helpers.js'; +import { Infinity, OrInfinity } from './curve.js'; +import { Field } from './field.js'; export { - fieldToRust, + WasmAffine, + WasmProjective, + affineFromRust, + affineToRust, fieldFromRust, - fieldsToRustFlat, + fieldToRust, fieldsFromRustFlat, + fieldsToRustFlat, maybeFieldToRust, - affineToRust, - affineFromRust, - WasmAffine, - WasmProjective, }; // TODO: Hardcoding this is a little brittle diff --git a/src/bindings/crypto/bindings/conversion-core.ts b/src/bindings/crypto/bindings/conversion-core.ts index bae7c79d62..ea5002e6df 100644 --- a/src/bindings/crypto/bindings/conversion-core.ts +++ b/src/bindings/crypto/bindings/conversion-core.ts @@ -116,9 +116,9 @@ function conversionCorePerField( return new PolyComm(rustUnshifted, rustShifted); }, polyCommFromRust(polyComm: WasmPolyComm): PolyComm { - console.log('polyComm', polyComm); + console.log('polyComm old', polyComm); let rustUnshifted = polyComm.unshifted; - console.log('polyCommFromRust', rustUnshifted); + console.log('rustUnshifted', rustUnshifted); let mlUnshifted = mapFromUintArray(rustUnshifted, (ptr) => { return affineFromRust(wrap(ptr, CommitmentCurve)); }); diff --git a/src/bindings/crypto/bindings/conversion-proof.ts b/src/bindings/crypto/bindings/conversion-proof.ts index ee106fb940..5fdc45e6ce 100644 --- a/src/bindings/crypto/bindings/conversion-proof.ts +++ b/src/bindings/crypto/bindings/conversion-proof.ts @@ -1,46 +1,46 @@ +import { MlArray, MlOption, MlTuple } from '../../../lib/ml/base.js'; +import type * as wasmNamespace from '../../compiled/node_bindings/plonk_wasm.cjs'; import type { WasmFpLookupCommitments, - WasmPastaFpLookupTable, WasmFpOpeningProof, WasmFpProverCommitments, WasmFpProverProof, WasmFpRuntimeTable, - WasmPastaFpRuntimeTableCfg, WasmFqLookupCommitments, WasmFqOpeningProof, WasmFqProverCommitments, - WasmPastaFqLookupTable, WasmFqProverProof, WasmFqRuntimeTable, + WasmPastaFpLookupTable, + WasmPastaFpRuntimeTableCfg, + WasmPastaFqLookupTable, WasmPastaFqRuntimeTableCfg, WasmVecVecFp, WasmVecVecFq, } from '../../compiled/node_bindings/plonk_wasm.cjs'; -import type * as wasmNamespace from '../../compiled/node_bindings/plonk_wasm.cjs'; +import { + fieldFromRust, + fieldToRust, + fieldsFromRustFlat, + fieldsToRustFlat, +} from './conversion-base.js'; +import { ConversionCore, ConversionCores, mapToUint32Array, unwrap } from './conversion-core.js'; import type { + Field, + LookupCommitments, + LookupTable, + OpeningProof, OrInfinity, PointEvaluations, PolyComm, - ProverProof, - ProofWithPublic, ProofEvaluations, + ProofWithPublic, ProverCommitments, - OpeningProof, + ProverProof, RecursionChallenge, - LookupCommitments, RuntimeTable, RuntimeTableCfg, - LookupTable, - Field, } from './kimchi-types.js'; -import { MlArray, MlOption, MlTuple } from '../../../lib/ml/base.js'; -import { - fieldToRust, - fieldFromRust, - fieldsToRustFlat, - fieldsFromRustFlat, -} from './conversion-base.js'; -import { ConversionCore, ConversionCores, mapToUint32Array, unwrap } from './conversion-core.js'; export { proofConversion }; @@ -178,6 +178,7 @@ function proofConversionPerField( } function runtimeTableToRust([, id, data]: RuntimeTable): WasmRuntimeTable { + console.log('old runtime table to rust!'); return new RuntimeTable(id, core.vectorToRust(data)); } diff --git a/src/bindings/crypto/bindings/conversion-verifier-index.ts b/src/bindings/crypto/bindings/conversion-verifier-index.ts index 8aae4771f4..a5353a1316 100644 --- a/src/bindings/crypto/bindings/conversion-verifier-index.ts +++ b/src/bindings/crypto/bindings/conversion-verifier-index.ts @@ -117,8 +117,6 @@ function verifierIndexConversionPerField( ); } function verificationEvalsFromRust(evals: WasmVerificationEvals): VerificationEvals { - console.log('evals', evals.coefficients_comm); - let mlEvals: VerificationEvals = [ 0, core.polyCommsFromRust(evals.sigma_comm), diff --git a/src/bindings/crypto/bindings/gate-vector-napi.unit-test.ts b/src/bindings/crypto/bindings/gate-vector-napi.unit-test.ts index 8f6939db1f..9b85604c4e 100644 --- a/src/bindings/crypto/bindings/gate-vector-napi.unit-test.ts +++ b/src/bindings/crypto/bindings/gate-vector-napi.unit-test.ts @@ -6,9 +6,11 @@ import type { Field, Gate, Wire } from './kimchi-types.js'; const require = createRequire(import.meta.url); function loadNative() { + const slug = `${process.platform}-${process.arch}`; const candidates = [ - '../../compiled/_node_bindings/plonk_napi.node', + `../../../../../native/${slug}/plonk_napi.node`, '../../compiled/node_bindings/plonk_napi.node', + '../../compiled/_node_bindings/plonk_napi.node', ]; for (const path of candidates) { try { @@ -23,6 +25,21 @@ function loadNative() { const native: any = loadNative(); +const gateVectorCreate = + native.caml_pasta_fp_plonk_gate_vector_create ?? native.camlPastaFpPlonkGateVectorCreate; +const gateVectorLen = + native.caml_pasta_fp_plonk_gate_vector_len ?? native.camlPastaFpPlonkGateVectorLen; +const gateVectorAdd = + native.caml_pasta_fp_plonk_gate_vector_add ?? native.camlPastaFpPlonkGateVectorAdd; +const gateVectorGet = + native.caml_pasta_fp_plonk_gate_vector_get ?? native.camlPastaFpPlonkGateVectorGet; +const gateVectorWrap = + native.caml_pasta_fp_plonk_gate_vector_wrap ?? native.camlPastaFpPlonkGateVectorWrap; +const gateVectorDigest = + native.caml_pasta_fp_plonk_gate_vector_digest ?? native.camlPastaFpPlonkGateVectorDigest; +const circuitSerialize = + native.caml_pasta_fp_plonk_circuit_serialize ?? native.camlPastaFpPlonkCircuitSerialize; + const { fp } = napiConversionCore(native); const zeroField: Field = [0, 0n]; @@ -44,24 +61,24 @@ const sampleGate: Gate = [ [0, zeroField, zeroField, zeroField, zeroField, zeroField, zeroField, zeroField], ]; -const vector = native.camlPastaFpPlonkGateVectorCreate(); -expect(native.camlPastaFpPlonkGateVectorLen(vector)).toBe(0); +const vector = gateVectorCreate(); +expect(gateVectorLen(vector)).toBe(0); -native.camlPastaFpPlonkGateVectorAdd(vector, fp.gateToRust(sampleGate)); -expect(native.camlPastaFpPlonkGateVectorLen(vector)).toBe(1); +gateVectorAdd(vector, fp.gateToRust(sampleGate)); +expect(gateVectorLen(vector)).toBe(1); -const gate0 = native.camlPastaFpPlonkGateVectorGet(vector, 0); +const gate0 = gateVectorGet(vector, 0); expect(gate0.typ).toBe(sampleGate[1]); const rustTarget = fp.wireToRust(mlWire(0, 0)); const rustHead = fp.wireToRust(mlWire(1, 2)); -native.camlPastaFpPlonkGateVectorWrap(vector, rustTarget, rustHead); -const wrapped = native.camlPastaFpPlonkGateVectorGet(vector, 0); +gateVectorWrap(vector, rustTarget, rustHead); +const wrapped = gateVectorGet(vector, 0); expect(wrapped.wires.w0).toEqual({ row: 1, col: 2 }); -native.camlPastaFpPlonkGateVectorDigest(0, vector); -native.camlPastaFpPlonkCircuitSerialize(0, vector); +gateVectorDigest(0, vector); +circuitSerialize(0, vector); -console.log('{}', native.camlPastaFpPlonkGateVectorDigest(0, vector)); +console.log('{}', gateVectorDigest(0, vector)); console.log('gate vector napi bindings (fp) are working ✔️'); diff --git a/src/bindings/crypto/bindings/napi-conversion-oracles.ts b/src/bindings/crypto/bindings/napi-conversion-oracles.ts new file mode 100644 index 0000000000..13bcb42f96 --- /dev/null +++ b/src/bindings/crypto/bindings/napi-conversion-oracles.ts @@ -0,0 +1,123 @@ +import { MlOption } from '../../../lib/ml/base.js'; +import type * as napiNamespace from '../../compiled/node_bindings/plonk_wasm.cjs'; +import type { + WasmFpOracles, + WasmFpRandomOracles, + WasmFqOracles, + WasmFqRandomOracles, +} from '../../compiled/node_bindings/plonk_wasm.cjs'; +import { + fieldFromRust, + fieldToRust, + fieldsFromRustFlat, + fieldsToRustFlat, + maybeFieldToRust, +} from './conversion-base.js'; +import { Field, Oracles, RandomOracles, ScalarChallenge } from './kimchi-types.js'; + +export { napiOraclesConversion }; + +type napi = typeof napiNamespace; + +type NapiRandomOracles = WasmFpRandomOracles | WasmFqRandomOracles; +type NapiOracles = WasmFpOracles | WasmFqOracles; + +type NapiClasses = { + RandomOracles: typeof WasmFpRandomOracles | typeof WasmFqRandomOracles; + Oracles: typeof WasmFpOracles | typeof WasmFqOracles; +}; + +function napiOraclesConversion(napi: napi) { + return { + fp: oraclesConversionPerField({ + RandomOracles: napi.WasmFpRandomOracles, + Oracles: napi.WasmFpOracles, + }), + fq: oraclesConversionPerField({ + RandomOracles: napi.WasmFqRandomOracles, + Oracles: napi.WasmFqOracles, + }), + }; +} + +function oraclesConversionPerField({ RandomOracles, Oracles }: NapiClasses) { + function randomOraclesToRust(ro: RandomOracles): NapiRandomOracles { + let jointCombinerMl = MlOption.from(ro[1]); + let jointCombinerChal = maybeFieldToRust(jointCombinerMl?.[1][1]); + let jointCombiner = maybeFieldToRust(jointCombinerMl?.[2]); + let beta = fieldToRust(ro[2]); + let gamma = fieldToRust(ro[3]); + let alphaChal = fieldToRust(ro[4][1]); + let alpha = fieldToRust(ro[5]); + let zeta = fieldToRust(ro[6]); + let v = fieldToRust(ro[7]); + let u = fieldToRust(ro[8]); + let zetaChal = fieldToRust(ro[9][1]); + let vChal = fieldToRust(ro[10][1]); + let uChal = fieldToRust(ro[11][1]); + return new RandomOracles( + jointCombinerChal, + jointCombiner, + beta, + gamma, + alphaChal, + alpha, + zeta, + v, + u, + zetaChal, + vChal, + uChal + ); + } + function randomOraclesFromRust(ro: NapiRandomOracles): RandomOracles { + let jointCombinerChal = ro.joint_combiner_chal; + let jointCombiner = ro.joint_combiner; + let jointCombinerOption = MlOption<[0, ScalarChallenge, Field]>( + jointCombinerChal && + jointCombiner && [0, [0, fieldFromRust(jointCombinerChal)], fieldFromRust(jointCombiner)] + ); + let mlRo: RandomOracles = [ + 0, + jointCombinerOption, + fieldFromRust(ro.beta), + fieldFromRust(ro.gamma), + [0, fieldFromRust(ro.alpha_chal)], + fieldFromRust(ro.alpha), + fieldFromRust(ro.zeta), + fieldFromRust(ro.v), + fieldFromRust(ro.u), + [0, fieldFromRust(ro.zeta_chal)], + [0, fieldFromRust(ro.v_chal)], + [0, fieldFromRust(ro.u_chal)], + ]; + // TODO: do we not want to free? + // ro.free(); + return mlRo; + } + + return { + oraclesToRust(oracles: Oracles): NapiOracles { + let [, o, pEval, openingPrechallenges, digestBeforeEvaluations] = oracles; + return new Oracles( + randomOraclesToRust(o), + fieldToRust(pEval[1]), + fieldToRust(pEval[2]), + fieldsToRustFlat(openingPrechallenges), + fieldToRust(digestBeforeEvaluations) + ); + }, + oraclesFromRust(oracles: NapiOracles): Oracles { + let mlOracles: Oracles = [ + 0, + randomOraclesFromRust(oracles.o), + [0, fieldFromRust(oracles.p_eval0), fieldFromRust(oracles.p_eval1)], + fieldsFromRustFlat(oracles.opening_prechallenges), + fieldFromRust(oracles.digest_before_evaluations), + ]; + // TODO: do we not want to free? + // oracles.free(); + return mlOracles; + }, + }; +} diff --git a/src/bindings/crypto/bindings/srs.ts b/src/bindings/crypto/bindings/srs.ts index cc3d440bac..3c8e9fef4b 100644 --- a/src/bindings/crypto/bindings/srs.ts +++ b/src/bindings/crypto/bindings/srs.ts @@ -104,7 +104,7 @@ function srsPerField(f: 'fp' | 'fq', wasm: Wasm, conversion: RustConversion) { let maybeLagrangeCommitment = (srs: WasmSrs, domain_size: number, i: number) => { try { console.log(3); - console.log('srs', srs); + console.log('srs wasm', srs); let bytes = (wasm as any)[`caml_${f}_srs_to_bytes_external`](srs); console.log('bytes', bytes); let wasmSrs = undefined; diff --git a/src/bindings/crypto/napi-conversion-core.ts b/src/bindings/crypto/napi-conversion-core.ts index 7409d16939..8e5c1e3a55 100644 --- a/src/bindings/crypto/napi-conversion-core.ts +++ b/src/bindings/crypto/napi-conversion-core.ts @@ -6,7 +6,6 @@ import { fieldsFromRustFlat, fieldsToRustFlat, } from './bindings/conversion-base.js'; -import { mapFromUintArray } from './bindings/conversion-core.js'; import { Field, Gate, LookupTable, OrInfinity, PolyComm, Wire } from './bindings/kimchi-types.js'; import { mapTuple } from './bindings/util.js'; @@ -70,13 +69,11 @@ function napiConversionCore(napi: any) { }; return { - fp: { ...fpCore }, + fp: { + ...fpCore, + }, fq: { ...fqCore, - shiftsFromRust: (s: any) => { - let shifts = [s.s0, s.s1, s.s2, s.s3, s.s4, s.s5, s.s6]; - return [0, ...shifts.map(fieldFromRust)]; - }, }, ...shared, }; @@ -156,9 +153,9 @@ function conversionCorePerField({ makeAffine, PolyComm }: NapiClasses) { }; const affineFromRust = (pt: NapiAffine): OrInfinity => { if (pt.infinity) return 0; - console.log('pt', pt); - console.log('pt.x', pt.x); - console.log('pt.y', pt.y); + // console.log('pt', pt); + // console.log('pt.x', pt.x); + // console.log('pt.y', pt.y); const xField = fieldFromRust(pt.x); const yField = fieldFromRust(pt.y); @@ -182,22 +179,13 @@ function conversionCorePerField({ makeAffine, PolyComm }: NapiClasses) { return new PolyCommClass(unshifted as unknown, undefined); }; - /* const polyCommFromRust = (polyComm: NapiPolyComm): PolyComm => { - console.log('polyComm', polyComm); + const polyCommFromRust = (polyComm: any): any => { + if (polyComm == null) return undefined; + // console.log('polyComm', polyComm); const rustUnshifted = asArrayLike(polyComm.unshifted, 'polyComm.unshifted'); - console.log('rustUnshifted', rustUnshifted); + // console.log('rustUnshifted', rustUnshifted); const mlUnshifted = rustUnshifted.map(affineFromRust); return [0, [0, ...mlUnshifted]]; - }; */ - const polyCommFromRust = (polyComm: any): any => { - let rustUnshifted = polyComm.unshifted; - console.log('rustUnshifted', rustUnshifted); - let mlUnshifted = mapFromUintArray(rustUnshifted, (ptr) => { - console.log('ptr', ptr); - /* return affineFromRust(wrap(ptr, CommitmentCurve)); - */ - }); - return [0, [0, ...mlUnshifted]]; }; const polyCommsToRust = ([, ...comms]: MlArray): NapiPolyComm[] => diff --git a/src/bindings/crypto/napi-conversion-proof.ts b/src/bindings/crypto/napi-conversion-proof.ts index 830ad61f8b..789aafb6ed 100644 --- a/src/bindings/crypto/napi-conversion-proof.ts +++ b/src/bindings/crypto/napi-conversion-proof.ts @@ -1,35 +1,77 @@ +import { MlArray, MlOption, MlTuple } from '../../lib/ml/base.js'; +import type * as napiNamespace from '../compiled/node_bindings/plonk_wasm.cjs'; import type { - WasmPastaFpLookupTable, + WasmFpLookupCommitments, + WasmFpOpeningProof, + WasmFpProverCommitments, + WasmFpProverProof, WasmFpRuntimeTable, + WasmFqLookupCommitments, + WasmFqOpeningProof, + WasmFqProverCommitments, + WasmFqProverProof, + WasmFqRuntimeTable, + WasmPastaFpLookupTable, WasmPastaFpRuntimeTableCfg, WasmPastaFqLookupTable, - WasmFqRuntimeTable, WasmPastaFqRuntimeTableCfg, WasmVecVecFp, WasmVecVecFq, } from '../compiled/node_bindings/plonk_wasm.cjs'; -import type * as napiNamespace from '../compiled/node_bindings/plonk_wasm.cjs'; +import { + fieldFromRust, + fieldToRust, + fieldsFromRustFlat, + fieldsToRustFlat, +} from './bindings/conversion-base.js'; +import type { Field } from './bindings/field.js'; import type { + LookupCommitments, + LookupTable, + OpeningProof, + OrInfinity, + PointEvaluations, + PolyComm, + ProofEvaluations, + ProofWithPublic, + ProverCommitments, + ProverProof, + RecursionChallenge, RuntimeTable, RuntimeTableCfg, - LookupTable, } from './bindings/kimchi-types.js'; -import { MlArray} from '../../lib/ml/base.js'; -import { - fieldsToRustFlat, -} from './bindings/conversion-base.js'; import { ConversionCore, ConversionCores } from './napi-conversion-core.js'; export { napiProofConversion }; +const fieldToRust_ = (x: Field) => fieldToRust(x); +const proofEvaluationsToRust = mapProofEvaluations(fieldToRust_); +const proofEvaluationsFromRust = mapProofEvaluations(fieldFromRust); +const pointEvalsOptionToRust = mapPointEvalsOption(fieldToRust_); +const pointEvalsOptionFromRust = mapPointEvalsOption(fieldFromRust); + +type NapiProofEvaluations = [ + 0, + MlOption>, + ...RemoveLeadingZero>, +]; + type napi = typeof napiNamespace; +type NapiProverCommitments = WasmFpProverCommitments | WasmFqProverCommitments; +type NapiOpeningProof = WasmFpOpeningProof | WasmFqOpeningProof; +type NapiProverProof = WasmFpProverProof | WasmFqProverProof; +type NapiLookupCommitments = WasmFpLookupCommitments | WasmFqLookupCommitments; type NapiRuntimeTable = WasmFpRuntimeTable | WasmFqRuntimeTable; type NapiRuntimeTableCfg = WasmPastaFpRuntimeTableCfg | WasmPastaFqRuntimeTableCfg; type NapiLookupTable = WasmPastaFpLookupTable | WasmPastaFqLookupTable; type NapiClasses = { + ProverCommitments: typeof WasmFpProverCommitments | typeof WasmFqProverCommitments; + OpeningProof: typeof WasmFpOpeningProof | typeof WasmFqOpeningProof; VecVec: typeof WasmVecVecFp | typeof WasmVecVecFq; + ProverProof: typeof WasmFpProverProof | typeof WasmFqProverProof; + LookupCommitments: typeof WasmFpLookupCommitments | typeof WasmFqLookupCommitments; RuntimeTable: typeof WasmFpRuntimeTable | typeof WasmFqRuntimeTable; RuntimeTableCfg: typeof WasmPastaFpRuntimeTableCfg | typeof WasmPastaFqRuntimeTableCfg; LookupTable: typeof WasmPastaFpLookupTable | typeof WasmPastaFqLookupTable; @@ -38,13 +80,21 @@ type NapiClasses = { function napiProofConversion(napi: napi, core: ConversionCores) { return { fp: proofConversionPerField(core.fp, { + ProverCommitments: napi.WasmFpProverCommitments, + OpeningProof: napi.WasmFpOpeningProof, VecVec: napi.WasmVecVecFp, + ProverProof: napi.WasmFpProverProof, + LookupCommitments: napi.WasmFpLookupCommitments, RuntimeTable: napi.WasmFpRuntimeTable, RuntimeTableCfg: napi.WasmPastaFpRuntimeTableCfg, LookupTable: napi.WasmPastaFpLookupTable, }), fq: proofConversionPerField(core.fq, { + ProverCommitments: napi.WasmFqProverCommitments, + OpeningProof: napi.WasmFqOpeningProof, VecVec: napi.WasmVecVecFq, + ProverProof: napi.WasmFqProverProof, + LookupCommitments: napi.WasmFqLookupCommitments, RuntimeTable: napi.WasmFqRuntimeTable, RuntimeTableCfg: napi.WasmPastaFqRuntimeTableCfg, LookupTable: napi.WasmPastaFqLookupTable, @@ -55,14 +105,80 @@ function napiProofConversion(napi: napi, core: ConversionCores) { function proofConversionPerField( core: ConversionCore, { + ProverCommitments, + OpeningProof, VecVec, + ProverProof, + LookupCommitments, RuntimeTable, RuntimeTableCfg, LookupTable, }: NapiClasses ) { + function commitmentsToRust(commitments: ProverCommitments): NapiProverCommitments { + let wComm = core.polyCommsToRust(commitments[1]); + let zComm = core.polyCommToRust(commitments[2]); + let tComm = core.polyCommToRust(commitments[3]); + let lookup = MlOption.mapFrom(commitments[4], lookupCommitmentsToRust); + return new ProverCommitments(wComm as any, zComm as any, tComm as any, lookup as any); + } + function commitmentsFromRust(commitments: NapiProverCommitments): ProverCommitments { + let wComm = core.polyCommsFromRust(commitments.w_comm); + let zComm = core.polyCommFromRust(commitments.z_comm); + let tComm = core.polyCommFromRust(commitments.t_comm); + let lookup = MlOption.mapTo(commitments.lookup, lookupCommitmentsFromRust); + commitments.free(); + return [0, wComm as MlTuple, zComm, tComm, lookup]; + } + + function lookupCommitmentsToRust(lookup: LookupCommitments): NapiLookupCommitments { + let sorted = core.polyCommsToRust(lookup[1]); + let aggreg = core.polyCommToRust(lookup[2]); + let runtime = MlOption.mapFrom(lookup[3], core.polyCommToRust); + return new LookupCommitments(sorted as any, aggreg as any, runtime as any); + } + function lookupCommitmentsFromRust(lookup: NapiLookupCommitments): LookupCommitments { + let sorted = core.polyCommsFromRust(lookup.sorted); + let aggreg = core.polyCommFromRust(lookup.aggreg); + let runtime = MlOption.mapTo(lookup.runtime, core.polyCommFromRust); + lookup.free(); + return [0, sorted, aggreg, runtime]; + } + + function openingProofToRust(proof: OpeningProof): NapiOpeningProof { + let [_, [, ...lr], delta, z1, z2, sg] = proof; + // We pass l and r as separate vectors over the FFI + let l: MlArray = [0]; + let r: MlArray = [0]; + for (let [, li, ri] of lr) { + l.push(li); + r.push(ri); + } + return new OpeningProof( + core.pointsToRust(l) as any, + core.pointsToRust(r) as any, + core.pointToRust(delta), + fieldToRust(z1), + fieldToRust(z2), + core.pointToRust(sg) + ); + } + function openingProofFromRust(proof: any): OpeningProof { + let [, ...l] = core.pointsFromRust(proof.lr_0); + let [, ...r] = core.pointsFromRust(proof.lr_1); + let n = l.length; + if (n !== r.length) throw Error('openingProofFromRust: l and r length mismatch.'); + let lr = l.map<[0, OrInfinity, OrInfinity]>((li, i) => [0, li, r[i]]); + let delta = core.pointFromRust(proof.delta); + let z1 = fieldFromRust(proof.z1); + let z2 = fieldFromRust(proof.z2); + let sg = core.pointFromRust(proof.sg); + proof.free(); + return [0, [0, ...lr], delta, z1, z2, sg]; + } function runtimeTableToRust([, id, data]: RuntimeTable): NapiRuntimeTable { + console.log('runtime table'); return new RuntimeTable(id, core.vectorToRust(data)); } @@ -80,6 +196,63 @@ function proofConversionPerField( } return { + proofToRust([, public_evals, proof]: ProofWithPublic): NapiProverProof { + let commitments = commitmentsToRust(proof[1]); + let openingProof = openingProofToRust(proof[2]); + let [, ...evals] = proofEvaluationsToRust(proof[3]); + let publicEvals = pointEvalsOptionToRust(public_evals); + // TODO typed as `any` in wasm-bindgen, this has the correct type + let evalsActual: NapiProofEvaluations = [0, publicEvals, ...evals]; + + let ftEval1 = fieldToRust(proof[4]); + let public_ = fieldsToRustFlat(proof[5]); + let [, ...prevChallenges] = proof[6]; + let n = prevChallenges.length; + let prevChallengeScalars = new VecVec(n); + let prevChallengeCommsMl: MlArray = [0]; + for (let [, scalars, comms] of prevChallenges) { + prevChallengeScalars.push(fieldsToRustFlat(scalars)); + prevChallengeCommsMl.push(comms); + } + let prevChallengeComms = core.polyCommsToRust(prevChallengeCommsMl); + return new ProverProof( + commitments, + openingProof, + evalsActual, + ftEval1, + public_, + prevChallengeScalars, + prevChallengeComms as any + ); + }, + proofFromRust(wasmProof: NapiProverProof): ProofWithPublic { + let commitments = commitmentsFromRust(wasmProof.commitments); + let openingProof = openingProofFromRust(wasmProof.proof); + // TODO typed as `any` in wasm-bindgen, this is the correct type + let [, wasmPublicEvals, ...wasmEvals]: NapiProofEvaluations = wasmProof.evals; + let publicEvals = pointEvalsOptionFromRust(wasmPublicEvals); + let evals = proofEvaluationsFromRust([0, ...wasmEvals]); + + let ftEval1 = fieldFromRust(wasmProof.ft_eval1); + let public_ = fieldsFromRustFlat(wasmProof.public_); + let prevChallengeScalars = wasmProof.prev_challenges_scalars; + let [, ...prevChallengeComms] = core.polyCommsFromRust(wasmProof.prev_challenges_comms); + let prevChallenges = prevChallengeComms.map((comms, i) => { + let scalars = fieldsFromRustFlat(prevChallengeScalars.get(i)); + return [0, scalars, comms]; + }); + let proof: ProverProof = [ + 0, + commitments, + openingProof, + evals, + ftEval1, + public_, + [0, ...prevChallenges], + ]; + return [0, publicEvals, proof]; + }, + runtimeTablesToRust([, ...tables]: MlArray): NapiRuntimeTable[] { return tables.map(runtimeTableToRust); }, @@ -92,4 +265,87 @@ function proofConversionPerField( return tables.map(lookupTableToRust); }, }; -} \ No newline at end of file +} + +function createMapPointEvals(map: (x: Field1) => Field2) { + return (evals: PointEvaluations): PointEvaluations => { + let [, zeta, zeta_omega] = evals; + return [0, MlArray.map(zeta, map), MlArray.map(zeta_omega, map)]; + }; +} + +function mapPointEvalsOption(map: (x: Field1) => Field2) { + return (evals: MlOption>) => + MlOption.map(evals, createMapPointEvals(map)); +} + +function mapProofEvaluations(map: (x: Field1) => Field2) { + const mapPointEvals = createMapPointEvals(map); + + const mapPointEvalsOption = ( + evals: MlOption> + ): MlOption> => MlOption.map(evals, mapPointEvals); + + return function mapProofEvaluations(evals: ProofEvaluations): ProofEvaluations { + let [ + , + w, + z, + s, + coeffs, + genericSelector, + poseidonSelector, + completeAddSelector, + mulSelector, + emulSelector, + endomulScalarSelector, + rangeCheck0Selector, + rangeCheck1Selector, + foreignFieldAddSelector, + foreignFieldMulSelector, + xorSelector, + rotSelector, + lookupAggregation, + lookupTable, + lookupSorted, + runtimeLookupTable, + runtimeLookupTableSelector, + xorLookupSelector, + lookupGateLookupSelector, + rangeCheckLookupSelector, + foreignFieldMulLookupSelector, + ] = evals; + return [ + 0, + MlTuple.map(w, mapPointEvals), + mapPointEvals(z), + MlTuple.map(s, mapPointEvals), + MlTuple.map(coeffs, mapPointEvals), + mapPointEvals(genericSelector), + mapPointEvals(poseidonSelector), + mapPointEvals(completeAddSelector), + mapPointEvals(mulSelector), + mapPointEvals(emulSelector), + mapPointEvals(endomulScalarSelector), + mapPointEvalsOption(rangeCheck0Selector), + mapPointEvalsOption(rangeCheck1Selector), + mapPointEvalsOption(foreignFieldAddSelector), + mapPointEvalsOption(foreignFieldMulSelector), + mapPointEvalsOption(xorSelector), + mapPointEvalsOption(rotSelector), + mapPointEvalsOption(lookupAggregation), + mapPointEvalsOption(lookupTable), + MlArray.map(lookupSorted, mapPointEvalsOption), + mapPointEvalsOption(runtimeLookupTable), + mapPointEvalsOption(runtimeLookupTableSelector), + mapPointEvalsOption(xorLookupSelector), + mapPointEvalsOption(lookupGateLookupSelector), + mapPointEvalsOption(rangeCheckLookupSelector), + mapPointEvalsOption(foreignFieldMulLookupSelector), + ]; + }; +} + +// helper + +type RemoveLeadingZero = T extends [0, ...infer U] ? U : never; diff --git a/src/bindings/crypto/napi-conversion-verifier-index.ts b/src/bindings/crypto/napi-conversion-verifier-index.ts new file mode 100644 index 0000000000..443097305d --- /dev/null +++ b/src/bindings/crypto/napi-conversion-verifier-index.ts @@ -0,0 +1,287 @@ +import { MlArray, MlBool, MlOption } from '../../lib/ml/base.js'; +import type * as napiNamespace from '../compiled/node_bindings/plonk_wasm.cjs'; +import type { + WasmFpDomain, + WasmFpLookupSelectors, + WasmFpLookupVerifierIndex, + WasmFpPlonkVerificationEvals, + WasmFpPlonkVerifierIndex, + WasmFpShifts, + WasmFqDomain, + WasmFqLookupSelectors, + WasmFqLookupVerifierIndex, + WasmFqPlonkVerificationEvals, + WasmFqPlonkVerifierIndex, + WasmFqShifts, + LookupInfo as WasmLookupInfo, +} from '../compiled/node_bindings/plonk_wasm.cjs'; +import { fieldFromRust, fieldToRust } from './bindings/conversion-base.js'; +import { + Domain, + Field, + PolyComm, + VerificationEvals, + VerifierIndex, +} from './bindings/kimchi-types.js'; +import { ConversionCore, ConversionCores } from './napi-conversion-core.js'; +import { Lookup, LookupInfo, LookupSelectors } from './bindings/lookup.js'; + +export { napiVerifierIndexConversion }; + +type napi = typeof napiNamespace; + +type NapiDomain = WasmFpDomain | WasmFqDomain; +type NapiVerificationEvals = WasmFpPlonkVerificationEvals | WasmFqPlonkVerificationEvals; +type NapiShifts = WasmFpShifts | WasmFqShifts; +type NapiVerifierIndex = WasmFpPlonkVerifierIndex | WasmFqPlonkVerifierIndex; +type NapiLookupVerifierIndex = WasmFpLookupVerifierIndex | WasmFqLookupVerifierIndex; +type NapiLookupSelector = WasmFpLookupSelectors | WasmFqLookupSelectors; + +type NapiClasses = { + Domain: typeof WasmFpDomain | typeof WasmFqDomain; + VerificationEvals: typeof WasmFpPlonkVerificationEvals | typeof WasmFqPlonkVerificationEvals; + Shifts: typeof WasmFpShifts | typeof WasmFqShifts; + VerifierIndex: typeof WasmFpPlonkVerifierIndex | typeof WasmFqPlonkVerifierIndex; + LookupVerifierIndex: typeof WasmFpLookupVerifierIndex | typeof WasmFqLookupVerifierIndex; + LookupSelector: typeof WasmFpLookupSelectors | typeof WasmFqLookupSelectors; +}; + +function napiVerifierIndexConversion(napi: any, core: ConversionCores) { + return { + fp: verifierIndexConversionPerField(napi, core.fp, { + Domain: napi.WasmFpDomain, + VerificationEvals: napi.WasmFpPlonkVerificationEvals, + Shifts: napi.WasmFpShifts, + VerifierIndex: napi.WasmFpPlonkVerifierIndex, + LookupVerifierIndex: napi.WasmFpLookupVerifierIndex, + LookupSelector: napi.WasmFpLookupSelectors, + }), + fq: verifierIndexConversionPerField(napi, core.fq, { + Domain: napi.WasmFqDomain, + VerificationEvals: napi.WasmFqPlonkVerificationEvals, + Shifts: napi.WasmFqShifts, + VerifierIndex: napi.WasmFqPlonkVerifierIndex, + LookupVerifierIndex: napi.WasmFqLookupVerifierIndex, + LookupSelector: napi.WasmFqLookupSelectors, + }), + }; +} + +function verifierIndexConversionPerField( + napi: any, + core: ConversionCore, + { + Domain, + VerificationEvals, + Shifts, + VerifierIndex, + LookupVerifierIndex, + LookupSelector, + }: NapiClasses +) { + function domainToRust([, logSizeOfGroup, groupGen]: Domain): NapiDomain { + return new Domain(logSizeOfGroup, fieldToRust(groupGen)); + } + function domainFromRust(domain: NapiDomain): Domain { + return [0, domain.log_size_of_group, fieldFromRust(domain.group_gen)]; + } + + function verificationEvalsToRust(evals: VerificationEvals): NapiVerificationEvals { + let sigmaComm = core.polyCommsToRust(evals[1]); + let coefficientsComm = core.polyCommsToRust(evals[2]); + let genericComm = core.polyCommToRust(evals[3]); + let psmComm = core.polyCommToRust(evals[4]); + let completeAddComm = core.polyCommToRust(evals[5]); + let mulComm = core.polyCommToRust(evals[6]); + let emulComm = core.polyCommToRust(evals[7]); + let endomulScalarComm = core.polyCommToRust(evals[8]); + let xorComm = MlOption.mapFrom(evals[9], core.polyCommToRust); + let rangeCheck0Comm = MlOption.mapFrom(evals[10], core.polyCommToRust); + let rangeCheck1Comm = MlOption.mapFrom(evals[11], core.polyCommToRust); + let foreignFieldAddComm = MlOption.mapFrom(evals[12], core.polyCommToRust); + let foreignFieldMulComm = MlOption.mapFrom(evals[13], core.polyCommToRust); + let rotComm = MlOption.mapFrom(evals[14], core.polyCommToRust); + return new VerificationEvals( + sigmaComm as any, + coefficientsComm as any, + genericComm as any, + psmComm as any, + completeAddComm as any, + mulComm as any, + emulComm as any, + endomulScalarComm as any, + xorComm as any, + rangeCheck0Comm as any, + rangeCheck1Comm as any, + foreignFieldAddComm as any, + foreignFieldMulComm as any, + rotComm as any + ); + } + function verificationEvalsFromRust(evals: NapiVerificationEvals): VerificationEvals { + let mlEvals: VerificationEvals = [ + 0, + core.polyCommsFromRust(evals.sigma_comm), + core.polyCommsFromRust(evals.coefficients_comm), + core.polyCommFromRust(evals.generic_comm), + core.polyCommFromRust(evals.psm_comm), + core.polyCommFromRust(evals.complete_add_comm), + core.polyCommFromRust(evals.mul_comm), + core.polyCommFromRust(evals.emul_comm), + core.polyCommFromRust(evals.endomul_scalar_comm), + MlOption.mapTo(evals.xor_comm, core.polyCommFromRust), + MlOption.mapTo(evals.range_check0_comm, core.polyCommFromRust), + MlOption.mapTo(evals.range_check1_comm, core.polyCommFromRust), + MlOption.mapTo(evals.foreign_field_add_comm, core.polyCommFromRust), + MlOption.mapTo(evals.foreign_field_mul_comm, core.polyCommFromRust), + MlOption.mapTo(evals.rot_comm, core.polyCommFromRust), + ]; + return mlEvals; + } + + function lookupVerifierIndexToRust(lookup: Lookup): NapiLookupVerifierIndex { + let [ + , + joint_lookup_used, + lookup_table, + selectors, + table_ids, + lookup_info, + runtime_tables_selector, + ] = lookup; + return new LookupVerifierIndex( + MlBool.from(joint_lookup_used), + core.polyCommsToRust(lookup_table) as any, + lookupSelectorsToRust(selectors), + MlOption.mapFrom(table_ids, core.polyCommToRust) as any, + lookupInfoToRust(lookup_info), + MlOption.mapFrom(runtime_tables_selector, core.polyCommToRust) as any + ); + } + function lookupVerifierIndexFromRust(lookup: NapiLookupVerifierIndex): Lookup { + console.log('lookup: ', lookup); + let mlLookup: Lookup = [ + 0, + MlBool(lookup.joint_lookup_used), + core.polyCommsFromRust(lookup.lookup_table), + lookupSelectorsFromRust(lookup.lookup_selectors), + MlOption.mapTo(lookup.table_ids, core.polyCommFromRust), + lookupInfoFromRust(lookup.lookup_info), + MlOption.mapTo(lookup.runtime_tables_selector, core.polyCommFromRust), + ]; + return mlLookup; + } + + function lookupSelectorsToRust([ + , + lookup, + xor, + range_check, + ffmul, + ]: LookupSelectors): NapiLookupSelector { + return new LookupSelector( + MlOption.mapFrom(xor, core.polyCommToRust) as any, + MlOption.mapFrom(lookup, core.polyCommToRust) as any, + MlOption.mapFrom(range_check, core.polyCommToRust) as any, + MlOption.mapFrom(ffmul, core.polyCommToRust) as any + ); + } + function lookupSelectorsFromRust(selector: NapiLookupSelector): LookupSelectors { + let lookup = MlOption.mapTo(selector.lookup, core.polyCommFromRust); + let xor = MlOption.mapTo(selector.xor, core.polyCommFromRust); + let range_check = MlOption.mapTo(selector.range_check, core.polyCommFromRust); + let ffmul = MlOption.mapTo(selector.ffmul, core.polyCommFromRust); + return [0, lookup, xor, range_check, ffmul]; + } + + function lookupInfoToRust([, maxPerRow, maxJointSize, features]: LookupInfo): WasmLookupInfo { + let [, patterns, joint_lookup_used, uses_runtime_tables] = features; + let [, xor, lookup, range_check, foreign_field_mul] = patterns; + let napiPatterns = new napi.LookupPatterns( + MlBool.from(xor), + MlBool.from(lookup), + MlBool.from(range_check), + MlBool.from(foreign_field_mul) + ); + let napiFeatures = new napi.LookupFeatures( + napiPatterns, + MlBool.from(joint_lookup_used), + MlBool.from(uses_runtime_tables) + ); + return new napi.LookupInfo(maxPerRow, maxJointSize, napiFeatures); + } + function lookupInfoFromRust(info: WasmLookupInfo): LookupInfo { + let features = info.features; + let patterns = features.patterns; + let mlInfo: LookupInfo = [ + 0, + info.max_per_row, + info.max_joint_size, + [ + 0, + [ + 0, + MlBool(patterns.xor), + MlBool(patterns.lookup), + MlBool(patterns.range_check), + MlBool(patterns.foreign_field_mul), + ], + MlBool(features.joint_lookup_used), + MlBool(features.uses_runtime_tables), + ], + ]; + return mlInfo; + } + + let self = { + shiftsToRust([, ...shifts]: MlArray): NapiShifts { + let s = shifts.map((s) => fieldToRust(s)); + return new Shifts(s[0], s[1], s[2], s[3], s[4], s[5], s[6]); + }, + shiftsFromRust(s: NapiShifts): MlArray { + let shifts = [s.s0, s.s1, s.s2, s.s3, s.s4, s.s5, s.s6]; + return [0, ...shifts.map(fieldFromRust)]; + }, + + verifierIndexToRust(vk: VerifierIndex): NapiVerifierIndex { + let domain = domainToRust(vk[1]); + let maxPolySize = vk[2]; + let nPublic = vk[3]; + let prevChallenges = vk[4]; + let srs = vk[5]; + let evals = verificationEvalsToRust(vk[6]); + let shifts = self.shiftsToRust(vk[7]); + let lookupIndex = MlOption.mapFrom(vk[8], lookupVerifierIndexToRust); + let zkRows = vk[9]; + return new VerifierIndex( + domain, + maxPolySize, + nPublic, + prevChallenges, + srs, + evals, + shifts, + lookupIndex, + zkRows + ); + }, + verifierIndexFromRust(vk: NapiVerifierIndex): VerifierIndex { + console.log('vk lookup index from rust', vk.lookup_index); + let mlVk: VerifierIndex = [ + 0, + domainFromRust(vk.domain), + vk.max_poly_size, + vk.public_, + vk.prev_challenges, + vk.srs, + verificationEvalsFromRust(vk.evals), + self.shiftsFromRust(vk.shifts), + MlOption.mapTo(vk.lookup_index, lookupVerifierIndexFromRust), + vk.zk_rows, + ]; + return mlVk; + }, + }; + + return self; +} diff --git a/src/bindings/crypto/napi-srs.ts b/src/bindings/crypto/napi-srs.ts new file mode 100644 index 0000000000..cbbae328aa --- /dev/null +++ b/src/bindings/crypto/napi-srs.ts @@ -0,0 +1,358 @@ +import { MlArray } from '../../lib/ml/base.js'; +import { + readCache, + withVersion, + writeCache, + type Cache, + type CacheHeader, +} from '../../lib/proof-system/cache.js'; +import { assert } from '../../lib/util/errors.js'; +import { type WasmFpSrs, type WasmFqSrs } from '../compiled/node_bindings/plonk_wasm.cjs'; +import type { Napi, RustConversion } from './bindings.js'; +import { OrInfinity, OrInfinityJson } from './bindings/curve.js'; +import { PolyComm } from './bindings/kimchi-types.js'; + +export { setSrsCache, srs, unsetSrsCache }; + +type NapiSrs = WasmFpSrs | WasmFqSrs; + +type SrsStore = Record; + +function empty(): SrsStore { + return {}; +} + +const srsStore = { fp: empty(), fq: empty() }; + +const CacheReadRegister = new Map(); + +let cache: Cache | undefined; + +function setSrsCache(c: Cache) { + cache = c; +} +function unsetSrsCache() { + cache = undefined; +} + +const srsVersion = 1; + +function cacheHeaderLagrange(f: 'fp' | 'fq', domainSize: number): CacheHeader { + let id = `lagrange-basis-${f}-${domainSize}`; + return withVersion( + { + kind: 'lagrange-basis', + persistentId: id, + uniqueId: id, + dataType: 'string', + }, + srsVersion + ); +} +function cacheHeaderSrs(f: 'fp' | 'fq', domainSize: number): CacheHeader { + let id = `srs-${f}-${domainSize}`; + return withVersion( + { + kind: 'srs', + persistentId: id, + uniqueId: id, + dataType: 'string', + }, + srsVersion + ); +} + +function srs(napi: Napi, conversion: RustConversion) { + return { + fp: srsPerField('fp', napi, conversion), + fq: srsPerField('fq', napi, conversion), + }; +} + +function srsPerField(f: 'fp' | 'fq', napi: Napi, conversion: RustConversion) { + // note: these functions are properly typed, thanks to TS template literal types + let createSrs = (size: number) => { + try { + console.log(0); + return napi[`caml_${f}_srs_create_parallel`](size); + } catch (error) { + console.error(`Error in SRS get for field ${f}`); + throw error; + } + }; + let getSrs = (srs: NapiSrs) => { + try { + console.log(1); + let v = napi[`caml_${f}_srs_get`](srs); + console.log(2); + return v; + } catch (error) { + console.error(`Error in SRS get for field ${f}`); + throw error; + } + }; + let setSrs = (bytes: any) => { + try { + console.log(2); + return napi[`caml_${f}_srs_set`](bytes); + } catch (error) { + console.error(`Error in SRS set for field ${f} args ${bytes}`); + throw error; + } + }; + + let maybeLagrangeCommitment = (srs: NapiSrs, domain_size: number, i: number) => { + try { + console.log(3); + console.log('srs napi', srs); + /*let bytes = (napi as any)[`caml_${f}_srs_to_bytes`](srs); + console.log('bytes', bytes); + let wasmSrs = undefined; + if (f === 'fp') wasmSrs = (napi as any)[`caml_${f}_srs_from_bytes`](bytes); + else wasmSrs = (napi as any)[`caml_fq_srs_from_bytes`](bytes); + */ + let s = napi[`caml_${f}_srs_maybe_lagrange_commitment`](srs, domain_size, i); + console.log('S', s); + return s; + } catch (error) { + console.error(`Error in SRS maybe lagrange commitment for field ${f}`); + throw error; + } + }; + let lagrangeCommitment = (srs: NapiSrs, domain_size: number, i: number) => { + try { + return napi[`caml_${f}_srs_lagrange_commitment`](srs, domain_size, i); + } catch (error) { + console.error(`Error in SRS lagrange commitment for field ${f}`); + throw error; + } + }; + let setLagrangeBasis = (srs: NapiSrs, domain_size: number, input: any) => { + try { + console.log(6); + return napi[`caml_${f}_srs_set_lagrange_basis`](srs, domain_size, input); + } catch (error) { + console.error(`Error in SRS set lagrange basis for field ${f}`); + throw error; + } + }; + let getLagrangeBasis = (srs: NapiSrs, n: number) => { + try { + return napi[`caml_${f}_srs_get_lagrange_basis`](srs, n); + } catch (error) { + console.error(`Error in SRS get lagrange basis for field ${f}`); + throw error; + } + }; + return { + /** + * returns existing stored SRS or falls back to creating a new one + */ + create(size: number): NapiSrs { + let srs = srsStore[f][size] satisfies NapiSrs as NapiSrs | undefined; + + if (srs === undefined) { + if (cache === undefined) { + // if there is no cache, create SRS in memory + console.log('Creating SRS without cache'); + srs = createSrs(size); + console.log('SRS created without cache:', srs); + } else { + let header = cacheHeaderSrs(f, size); + + // try to read SRS from cache / recompute and write if not found + console.log('Reading SRS from cache'); + srs = readCache(cache, header, (bytes) => { + // TODO: this takes a bit too long, about 300ms for 2^16 + // `pointsToRust` is the clear bottleneck + let jsonSrs: OrInfinityJson[] = JSON.parse(new TextDecoder().decode(bytes)); + let mlSrs = MlArray.mapTo(jsonSrs, OrInfinity.fromJSON); + let wasmSrs = conversion[f].pointsToRust(mlSrs); + return setSrs(wasmSrs); + }); + console.log('SRS read from cache:', srs); + if (srs === undefined) { + // not in cache + console.log(1); + srs = createSrs(size); + console.log('Writing SRS to cache', srs); + + if (cache.canWrite) { + console.log(2); + let wasmSrs = getSrs(srs); + console.log(3); + let mlSrs = conversion[f].pointsFromRust(wasmSrs); + let jsonSrs = MlArray.mapFrom(mlSrs, OrInfinity.toJSON); + let bytes = new TextEncoder().encode(JSON.stringify(jsonSrs)); + + writeCache(cache, header, bytes); + } + } + } + console.log('Storing SRS in memory'); + srsStore[f][size] = srs; + console.log('SRS stored in memory:', srs); + } + + // TODO should we call freeOnFinalize() and expose a function to clean the SRS cache? + console.trace('Returning SRS:', srs); + return srsStore[f][size]; + }, + + /** + * returns ith Lagrange basis commitment for a given domain size + */ + lagrangeCommitment(srs: NapiSrs, domainSize: number, i: number): PolyComm { + console.log('lagrangeCommitment'); + // happy, fast case: if basis is already stored on the srs, return the ith commitment + let commitment = maybeLagrangeCommitment(srs, domainSize, i); + + if (commitment === undefined || commitment === null) { + console.log('comm was undefined'); + if (cache === undefined) { + // if there is no cache, recompute and store basis in memory + commitment = lagrangeCommitment(srs, domainSize, i); + } else { + // try to read lagrange basis from cache / recompute and write if not found + let header = cacheHeaderLagrange(f, domainSize); + let didRead = readCacheLazy( + cache, + header, + conversion, + f, + srs, + domainSize, + setLagrangeBasis + ); + if (didRead !== true) { + // not in cache + if (cache.canWrite) { + // TODO: this code path will throw on the web since `caml_${f}_srs_get_lagrange_basis` is not properly implemented + // using a writable cache in the browser seems to be fairly uncommon though, so it's at least an 80/20 solution + let napiComms = getLagrangeBasis(srs, domainSize); + console.log('napiComms', napiComms); + let mlComms = conversion[f].polyCommsFromRust(napiComms); + console.log('mlComms', mlComms); + let comms = polyCommsToJSON(mlComms); + let bytes = new TextEncoder().encode(JSON.stringify(comms)); + writeCache(cache, header, bytes); + } else { + lagrangeCommitment(srs, domainSize, i); + } + } + // here, basis is definitely stored on the srs + let c = maybeLagrangeCommitment(srs, domainSize, i); + assert(c !== undefined, 'commitment exists after setting'); + commitment = c; + } + } + console.log('commitment was not undefined'); + + // edge case for when we have a writeable cache and the basis was already stored on the srs + // but we didn't store it in the cache separately yet + if (commitment && cache && cache.canWrite) { + let header = cacheHeaderLagrange(f, domainSize); + let didRead = readCacheLazy( + cache, + header, + conversion, + f, + srs, + domainSize, + setLagrangeBasis + ); + // only proceed for entries we haven't written to the cache yet + if (didRead !== true) { + // same code as above - write the lagrange basis to the cache if it wasn't there already + // currently we re-generate the basis via `getLagrangeBasis` - we could derive this from the + // already existing `commitment` instead, but this is simpler and the performance impact is negligible + let napiComms = getLagrangeBasis(srs, domainSize); + let mlComms = conversion[f].polyCommsFromRust(napiComms); + let comms = polyCommsToJSON(mlComms); + let bytes = new TextEncoder().encode(JSON.stringify(comms)); + + writeCache(cache, header, bytes); + } + } + return conversion[f].polyCommFromRust(commitment); + }, + + /** + * Returns the Lagrange basis commitments for the whole domain + */ + lagrangeCommitmentsWholeDomain(srs: NapiSrs, domainSize: number) { + console.log('lagrangeCommitmentsWholeDomain'); + try { + let napiComms = napi[`caml_${f}_srs_lagrange_commitments_whole_domain_ptr`]( + srs, + domainSize + ); + let mlComms = conversion[f].polyCommsFromRust(napiComms as any); + return mlComms; + } catch (error) { + console.error(`Error in SRS lagrange commitments whole domain ptr for field ${f}`); + throw error; + } + }, + + /** + * adds Lagrange basis for a given domain size + */ + addLagrangeBasis(srs: NapiSrs, logSize: number) { + console.log('addLagrangeBasis'); + // this ensures that basis is stored on the srs, no need to duplicate caching logic + this.lagrangeCommitment(srs, 1 << logSize, 0); + }, + }; +} + +type PolyCommJson = { + shifted: OrInfinityJson[]; + unshifted: OrInfinityJson | undefined; +}; + +function polyCommsToJSON(comms: MlArray): PolyCommJson[] { + return MlArray.mapFrom(comms, ([, elems]) => { + return { + shifted: MlArray.mapFrom(elems, OrInfinity.toJSON), + unshifted: undefined, + }; + }); +} + +function polyCommsFromJSON(json: PolyCommJson[]): MlArray { + return MlArray.mapTo(json, ({ shifted, unshifted }) => { + return [0, MlArray.mapTo(shifted, OrInfinity.fromJSON)]; + }); +} + +function readCacheLazy( + cache: Cache, + header: CacheHeader, + conversion: RustConversion, + f: 'fp' | 'fq', + srs: NapiSrs, + domainSize: number, + setLagrangeBasis: (srs: NapiSrs, domainSize: number, comms: Uint32Array) => void +) { + if (CacheReadRegister.get(header.uniqueId) === true) return true; + return readCache(cache, header, (bytes) => { + let comms: PolyCommJson[] = JSON.parse(new TextDecoder().decode(bytes)); + let mlComms = polyCommsFromJSON(comms); + let napiComms = conversion[f].polyCommsToRust(mlComms); + + setLagrangeBasis(srs, domainSize, napiComms); + CacheReadRegister.set(header.uniqueId, true); + return true; + }); +} +function runInTryCatch any>(fn: T): T { + return function (...args: Parameters): ReturnType { + try { + return fn(...args); + } catch (e) { + console.error(`Error in SRS function ${fn.name} with args:`, args); + throw e; + } + } as T; +} diff --git a/src/lib/proof-system/prover-keys.ts b/src/lib/proof-system/prover-keys.ts index e9fd8de7d2..122d8bf27b 100644 --- a/src/lib/proof-system/prover-keys.ts +++ b/src/lib/proof-system/prover-keys.ts @@ -10,7 +10,8 @@ import { WasmPastaFpPlonkIndex, WasmPastaFqPlonkIndex, } from '../../bindings/compiled/node_bindings/plonk_wasm.cjs'; -import { getRustConversion } from '../../bindings/crypto/bindings.js'; +// TODO: include conversion bundle to decide between wasm and napi conversion +import { createNativeRustConversion } from '../../bindings/crypto/bindings.js'; import { VerifierIndex } from '../../bindings/crypto/bindings/kimchi-types.js'; import { MlString } from '../ml/base.js'; import { CacheHeader, cacheHeaderVersion } from './cache.js'; @@ -95,13 +96,13 @@ function encodeProverKey(value: SnarkKey): Uint8Array { case KeyType.StepProvingKey: { let index = value[1][1]; let encoded = wasm.caml_pasta_fp_plonk_index_encode( - (wasm as any).prover_index_fp_from_bytes(index.serialize()) + (wasm as any).prover_index_fp_deserialize((wasm as any).prover_index_fp_serialize(index)) ); return encoded; } case KeyType.StepVerificationKey: { let vkMl = value[1]; - const rustConversion = getRustConversion(wasm); + const rustConversion = createNativeRustConversion(wasm); let vkWasm = rustConversion.fp.verifierIndexToRust(vkMl); let string = wasm.caml_pasta_fp_plonk_verifier_index_serialize(vkWasm); return new TextEncoder().encode(string); @@ -109,7 +110,7 @@ function encodeProverKey(value: SnarkKey): Uint8Array { case KeyType.WrapProvingKey: { let index = value[1][1]; let encoded = wasm.caml_pasta_fq_plonk_index_encode( - (wasm as any).prover_index_fq_from_bytes(index.serialize()) + (wasm as any).prover_index_fq_deserialize((wasm as any).prover_index_fq_serialize(index)) ); return encoded; } @@ -139,7 +140,7 @@ function decodeProverKey(header: SnarkKeyHeader, bytes: Uint8Array): SnarkKey { let srs = Pickles.loadSrsFp(); let string = new TextDecoder().decode(bytes); let vkWasm = wasm.caml_pasta_fp_plonk_verifier_index_deserialize(srs, string); - const rustConversion = getRustConversion(wasm); + const rustConversion = createNativeRustConversion(wasm); let vkMl = rustConversion.fp.verifierIndexFromRust(vkWasm); return [KeyType.StepVerificationKey, vkMl]; } diff --git a/src/mina b/src/mina index 857e2a8402..ebfe5d874f 160000 --- a/src/mina +++ b/src/mina @@ -1 +1 @@ -Subproject commit 857e2a8402e67d4a3931bdd60043123890046df7 +Subproject commit ebfe5d874f09b5562d40e9eed6c647310cf9d2f1 diff --git a/tests/native/native.ts b/tests/native/native.ts index 9216ddbb00..1b6da18c17 100644 --- a/tests/native/native.ts +++ b/tests/native/native.ts @@ -1,6 +1,8 @@ import assert from 'node:assert'; import native from '../../src/native/native'; +// run with `./run tests/native/native.ts --bundle` + console.log(native); assert(native.getNativeCalls() == 0n, 'native module starts with no calls');