diff --git a/src/bindings/crypto/native/bigint256.ts b/src/bindings/crypto/native/bigint256.ts new file mode 100644 index 000000000..e266a06b9 --- /dev/null +++ b/src/bindings/crypto/native/bigint256.ts @@ -0,0 +1,177 @@ +import { MlBool } from '../../../lib/ml/base.js'; +import { withPrefix } from './util.js'; + +/** + * TS implementation of Pasta_bindings.BigInt256 + */ +export { + Bigint256Bindings, + Bigint256, + toMlStringAscii, + fromMlString, + MlBytes, + mlBytesFromUint8Array, + mlBytesToUint8Array, +}; + +type Bigint256 = [0, bigint]; + +const Bigint256Bindings = withPrefix('caml_bigint_256', { + // TODO + of_numeral(s: MlBytes, i: number, j: number): Bigint256 { + throw Error('caml_bigint_256_of_numeral not implemented'); + }, + of_decimal_string(s: MlBytes): Bigint256 { + return [0, BigInt(fromMlString(s))]; + }, + num_limbs(): number { + return 4; + }, + bytes_per_limb(): number { + return 8; + }, + div([, x]: Bigint256, [, y]: Bigint256): Bigint256 { + return [0, x / y]; + }, + compare([, x]: Bigint256, [, y]: Bigint256): number { + if (x < y) return -1; + if (x === y) return 0; + return 1; + }, + print([, x]: Bigint256): void { + console.log(x.toString()); + }, + to_string(x: Bigint256) { + return toMlStringAscii(x[1].toString()); + }, + // TODO performance critical + test_bit(b: Bigint256, i: number): MlBool { + return MlBool(!!(b[1] & (1n << BigInt(i)))); + }, + to_bytes([, x]: Bigint256) { + let ocamlBytes = caml_create_bytes(32); + for (let i = 0; i < 32; i++) { + let byte = Number(x & 0xffn); + caml_bytes_unsafe_set(ocamlBytes, i, byte); + x >>= 8n; + } + if (x !== 0n) throw Error("bigint256 doesn't fit into 32 bytes."); + return ocamlBytes; + }, + of_bytes(ocamlBytes: MlBytes): Bigint256 { + let length = ocamlBytes.l; + if (length > 32) throw Error(length + " bytes don't fit into bigint256"); + let x = 0n; + let bitPosition = 0n; + for (let i = 0; i < length; i++) { + let byte = caml_bytes_unsafe_get(ocamlBytes, i); + x |= BigInt(byte) << bitPosition; + bitPosition += 8n; + } + return [0, x]; + }, + deep_copy([, x]: Bigint256): Bigint256 { + return [0, x]; + }, +}); + +// TODO clean up all this / make type-safe and match JSOO in all relevant cases + +function fromMlString(s: MlBytes) { + // TODO doesn't handle all cases + return s.c; +} +function toMlStringAscii(s: string) { + return new MlBytes(9, s, s.length); +} + +function caml_bytes_unsafe_get(s: MlBytes, i: number): number { + switch (s.t & 6) { + default: /* PARTIAL */ + if (i >= s.c.length) return 0; + case 0 /* BYTES */: + return s.c.charCodeAt(i); + case 4 /* ARRAY */: + return s.c[i] as any as number; + } +} + +function caml_bytes_unsafe_set(s: MlBytes, i: number, c: number) { + // The OCaml compiler uses Char.unsafe_chr on integers larger than 255! + c &= 0xff; + if (s.t != 4 /* ARRAY */) { + if (i == s.c.length) { + s.c += String.fromCharCode(c); + if (i + 1 == s.l) s.t = 0; /*BYTES | UNKNOWN*/ + return 0; + } + caml_convert_bytes_to_array(s); + } + // TODO + (s.c as any)[i] = c; + return 0; +} + +function caml_create_bytes(len: number) { + return new MlBytes(2, '', len); +} + +function caml_convert_bytes_to_array(s: MlBytes) { + /* Assumes not ARRAY */ + let a = new Uint8Array(s.l); + let b = s.c, + l = b.length, + i = 0; + for (; i < l; i++) a[i] = b.charCodeAt(i); + for (l = s.l; i < l; i++) a[i] = 0; + (s as any).c = a; + // TODO + s.t = 4; /* ARRAY */ + return a; +} + +function mlBytesFromUint8Array(uint8array: Uint8Array | number[]) { + let length = uint8array.length; + let ocaml_bytes = caml_create_bytes(length); + for (let i = 0; i < length; i++) { + // No need to convert here: OCaml Char.t is just an int under the hood. + caml_bytes_unsafe_set(ocaml_bytes, i, uint8array[i]); + } + return ocaml_bytes; +} + +function mlBytesToUint8Array(ocaml_bytes: MlBytes) { + let length = ocaml_bytes.l; + let bytes = new Uint8Array(length); + for (let i = 0; i < length; i++) { + // No need to convert here: OCaml Char.t is just an int under the hood. + bytes[i] = caml_bytes_unsafe_get(ocaml_bytes, i); + } + return bytes; +} + +class MlBytes { + t: number; + c: string; + l: number; + + constructor(tag: number, content: string, length: number) { + this.t = tag; + this.c = content; + this.l = length; + } + + toString() { + if (this.t === 9) return this.c; + throw Error('todo'); + } + + toUtf16() { + return this.toString(); + } + + slice() { + let content = this.t == 4 ? this.c.slice() : this.c; + return new MlBytes(this.t, content, this.l); + } +} diff --git a/src/bindings/crypto/native/conversion-base.ts b/src/bindings/crypto/native/conversion-base.ts new file mode 100644 index 000000000..9f67829df --- /dev/null +++ b/src/bindings/crypto/native/conversion-base.ts @@ -0,0 +1,97 @@ +import { Field } from './field.js'; +import { bigintToBytes32, bytesToBigint32 } from '../bigint-helpers.js'; +import type { + WasmGPallas, + WasmGVesta, + WasmPallasGProjective, + WasmVestaGProjective, +} from '../../compiled/node_bindings/plonk_wasm.cjs'; +import type { MlArray } from '../../../lib/ml/base.js'; +import { OrInfinity, Infinity } from './curve.js'; + +export { + fieldToRust, + fieldFromRust, + fieldsToRustFlat, + fieldsFromRustFlat, + maybeFieldToRust, + affineToRust, + affineFromRust, + WasmAffine, + WasmProjective, +}; + +// TODO: Hardcoding this is a little brittle +// TODO read from field +const fieldSizeBytes = 32; + +// field, field vectors + +function fieldToRust([, x]: Field, dest = new Uint8Array(32)): Uint8Array { + return bigintToBytes32(x, dest); +} +function fieldFromRust(x: Uint8Array): Field { + return [0, bytesToBigint32(x)]; +} + +function fieldsToRustFlat([, ...fields]: MlArray): Uint8Array { + let n = fields.length; + let flatBytes = new Uint8Array(n * fieldSizeBytes); + for (let i = 0, offset = 0; i < n; i++, offset += fieldSizeBytes) { + fieldToRust(fields[i], flatBytes.subarray(offset, offset + fieldSizeBytes)); + } + return flatBytes; +} + +function fieldsFromRustFlat(fieldBytes: Uint8Array): MlArray { + let n = fieldBytes.length / fieldSizeBytes; + if (!Number.isInteger(n)) { + throw Error('fieldsFromRustFlat: invalid bytes'); + } + let fields: Field[] = Array(n); + for (let i = 0, offset = 0; i < n; i++, offset += fieldSizeBytes) { + let fieldView = new Uint8Array(fieldBytes.buffer, offset, fieldSizeBytes); + fields[i] = fieldFromRust(fieldView); + } + return [0, ...fields]; +} + +function maybeFieldToRust(x?: Field): Uint8Array | undefined { + return x && fieldToRust(x); +} + +// affine + +type WasmAffine = WasmGVesta | WasmGPallas; + +function affineFromRust(pt: A): OrInfinity { + if (pt.infinity) { + pt.free(); + return 0; + } else { + let x = fieldFromRust(pt.x); + let y = fieldFromRust(pt.y); + pt.free(); + return [0, [0, x, y]]; + } +} + +const tmpBytes = new Uint8Array(32); + +function affineToRust(pt: OrInfinity, makeAffine: () => A) { + let res = makeAffine(); + if (pt === Infinity) { + res.infinity = true; + } else { + let [, [, x, y]] = pt; + // we can use the same bytes here every time, + // because x and y setters copy the bytes into wasm memory + res.x = fieldToRust(x, tmpBytes); + res.y = fieldToRust(y, tmpBytes); + } + return res; +} + +// projective + +type WasmProjective = WasmVestaGProjective | WasmPallasGProjective; diff --git a/src/bindings/crypto/native/conversion-oracles.ts b/src/bindings/crypto/native/conversion-oracles.ts new file mode 100644 index 000000000..1094c96c2 --- /dev/null +++ b/src/bindings/crypto/native/conversion-oracles.ts @@ -0,0 +1,123 @@ +import type { + WasmFpOracles, + WasmFpRandomOracles, + WasmFqOracles, + WasmFqRandomOracles, +} from '../../compiled/node_bindings/plonk_wasm.cjs'; +import type * as wasmNamespace from '../../compiled/node_bindings/plonk_wasm.cjs'; +import { MlOption } from '../../../lib/ml/base.js'; +import { Field, Oracles, RandomOracles, ScalarChallenge } from './kimchi-types.js'; +import { + fieldFromRust, + fieldToRust, + fieldsFromRustFlat, + fieldsToRustFlat, + maybeFieldToRust, +} from './conversion-base.js'; + +export { oraclesConversion }; + +type wasm = typeof wasmNamespace; + +type WasmRandomOracles = WasmFpRandomOracles | WasmFqRandomOracles; +type WasmOracles = WasmFpOracles | WasmFqOracles; + +type WasmClasses = { + RandomOracles: typeof WasmFpRandomOracles | typeof WasmFqRandomOracles; + Oracles: typeof WasmFpOracles | typeof WasmFqOracles; +}; + +function oraclesConversion(wasm: wasm) { + return { + fp: oraclesConversionPerField({ + RandomOracles: wasm.WasmFpRandomOracles, + Oracles: wasm.WasmFpOracles, + }), + fq: oraclesConversionPerField({ + RandomOracles: wasm.WasmFqRandomOracles, + Oracles: wasm.WasmFqOracles, + }), + }; +} + +function oraclesConversionPerField({ RandomOracles, Oracles }: WasmClasses) { + function randomOraclesToRust(ro: RandomOracles): WasmRandomOracles { + let jointCombinerMl = MlOption.from(ro[1]); + let jointCombinerChal = maybeFieldToRust(jointCombinerMl?.[1][1]); + let jointCombiner = maybeFieldToRust(jointCombinerMl?.[2]); + let beta = fieldToRust(ro[2]); + let gamma = fieldToRust(ro[3]); + let alphaChal = fieldToRust(ro[4][1]); + let alpha = fieldToRust(ro[5]); + let zeta = fieldToRust(ro[6]); + let v = fieldToRust(ro[7]); + let u = fieldToRust(ro[8]); + let zetaChal = fieldToRust(ro[9][1]); + let vChal = fieldToRust(ro[10][1]); + let uChal = fieldToRust(ro[11][1]); + return new RandomOracles( + jointCombinerChal, + jointCombiner, + beta, + gamma, + alphaChal, + alpha, + zeta, + v, + u, + zetaChal, + vChal, + uChal + ); + } + function randomOraclesFromRust(ro: WasmRandomOracles): RandomOracles { + let jointCombinerChal = ro.joint_combiner_chal; + let jointCombiner = ro.joint_combiner; + let jointCombinerOption = MlOption<[0, ScalarChallenge, Field]>( + jointCombinerChal && + jointCombiner && [0, [0, fieldFromRust(jointCombinerChal)], fieldFromRust(jointCombiner)] + ); + let mlRo: RandomOracles = [ + 0, + jointCombinerOption, + fieldFromRust(ro.beta), + fieldFromRust(ro.gamma), + [0, fieldFromRust(ro.alpha_chal)], + fieldFromRust(ro.alpha), + fieldFromRust(ro.zeta), + fieldFromRust(ro.v), + fieldFromRust(ro.u), + [0, fieldFromRust(ro.zeta_chal)], + [0, fieldFromRust(ro.v_chal)], + [0, fieldFromRust(ro.u_chal)], + ]; + // TODO: do we not want to free? + // ro.free(); + return mlRo; + } + + return { + oraclesToRust(oracles: Oracles): WasmOracles { + let [, o, pEval, openingPrechallenges, digestBeforeEvaluations] = oracles; + return new Oracles( + randomOraclesToRust(o), + fieldToRust(pEval[1]), + fieldToRust(pEval[2]), + fieldsToRustFlat(openingPrechallenges), + fieldToRust(digestBeforeEvaluations) + ); + }, + oraclesFromRust(oracles: WasmOracles): Oracles { + let mlOracles: Oracles = [ + 0, + randomOraclesFromRust(oracles.o), + [0, fieldFromRust(oracles.p_eval0), fieldFromRust(oracles.p_eval1)], + fieldsFromRustFlat(oracles.opening_prechallenges), + fieldFromRust(oracles.digest_before_evaluations), + ]; + // TODO: do we not want to free? + // oracles.free(); + return mlOracles; + }, + }; +} diff --git a/src/bindings/crypto/native/conversion-proof.ts b/src/bindings/crypto/native/conversion-proof.ts new file mode 100644 index 000000000..ee106fb94 --- /dev/null +++ b/src/bindings/crypto/native/conversion-proof.ts @@ -0,0 +1,351 @@ +import type { + WasmFpLookupCommitments, + WasmPastaFpLookupTable, + WasmFpOpeningProof, + WasmFpProverCommitments, + WasmFpProverProof, + WasmFpRuntimeTable, + WasmPastaFpRuntimeTableCfg, + WasmFqLookupCommitments, + WasmFqOpeningProof, + WasmFqProverCommitments, + WasmPastaFqLookupTable, + WasmFqProverProof, + WasmFqRuntimeTable, + WasmPastaFqRuntimeTableCfg, + WasmVecVecFp, + WasmVecVecFq, +} from '../../compiled/node_bindings/plonk_wasm.cjs'; +import type * as wasmNamespace from '../../compiled/node_bindings/plonk_wasm.cjs'; +import type { + OrInfinity, + PointEvaluations, + PolyComm, + ProverProof, + ProofWithPublic, + ProofEvaluations, + ProverCommitments, + OpeningProof, + RecursionChallenge, + LookupCommitments, + RuntimeTable, + RuntimeTableCfg, + LookupTable, + Field, +} from './kimchi-types.js'; +import { MlArray, MlOption, MlTuple } from '../../../lib/ml/base.js'; +import { + fieldToRust, + fieldFromRust, + fieldsToRustFlat, + fieldsFromRustFlat, +} from './conversion-base.js'; +import { ConversionCore, ConversionCores, mapToUint32Array, unwrap } from './conversion-core.js'; + +export { proofConversion }; + +const fieldToRust_ = (x: Field) => fieldToRust(x); +const proofEvaluationsToRust = mapProofEvaluations(fieldToRust_); +const proofEvaluationsFromRust = mapProofEvaluations(fieldFromRust); +const pointEvalsOptionToRust = mapPointEvalsOption(fieldToRust_); +const pointEvalsOptionFromRust = mapPointEvalsOption(fieldFromRust); + +type WasmProofEvaluations = [ + 0, + MlOption>, + ...RemoveLeadingZero>, +]; + +type wasm = typeof wasmNamespace; + +type WasmProverCommitments = WasmFpProverCommitments | WasmFqProverCommitments; +type WasmOpeningProof = WasmFpOpeningProof | WasmFqOpeningProof; +type WasmProverProof = WasmFpProverProof | WasmFqProverProof; +type WasmLookupCommitments = WasmFpLookupCommitments | WasmFqLookupCommitments; +type WasmRuntimeTable = WasmFpRuntimeTable | WasmFqRuntimeTable; +type WasmRuntimeTableCfg = WasmPastaFpRuntimeTableCfg | WasmPastaFqRuntimeTableCfg; +type WasmLookupTable = WasmPastaFpLookupTable | WasmPastaFqLookupTable; + +type WasmClasses = { + ProverCommitments: typeof WasmFpProverCommitments | typeof WasmFqProverCommitments; + OpeningProof: typeof WasmFpOpeningProof | typeof WasmFqOpeningProof; + VecVec: typeof WasmVecVecFp | typeof WasmVecVecFq; + ProverProof: typeof WasmFpProverProof | typeof WasmFqProverProof; + LookupCommitments: typeof WasmFpLookupCommitments | typeof WasmFqLookupCommitments; + RuntimeTable: typeof WasmFpRuntimeTable | typeof WasmFqRuntimeTable; + RuntimeTableCfg: typeof WasmPastaFpRuntimeTableCfg | typeof WasmPastaFqRuntimeTableCfg; + LookupTable: typeof WasmPastaFpLookupTable | typeof WasmPastaFqLookupTable; +}; + +function proofConversion(wasm: wasm, core: ConversionCores) { + return { + fp: proofConversionPerField(core.fp, { + ProverCommitments: wasm.WasmFpProverCommitments, + OpeningProof: wasm.WasmFpOpeningProof, + VecVec: wasm.WasmVecVecFp, + ProverProof: wasm.WasmFpProverProof, + LookupCommitments: wasm.WasmFpLookupCommitments, + RuntimeTable: wasm.WasmFpRuntimeTable, + RuntimeTableCfg: wasm.WasmPastaFpRuntimeTableCfg, + LookupTable: wasm.WasmPastaFpLookupTable, + }), + fq: proofConversionPerField(core.fq, { + ProverCommitments: wasm.WasmFqProverCommitments, + OpeningProof: wasm.WasmFqOpeningProof, + VecVec: wasm.WasmVecVecFq, + ProverProof: wasm.WasmFqProverProof, + LookupCommitments: wasm.WasmFqLookupCommitments, + RuntimeTable: wasm.WasmFqRuntimeTable, + RuntimeTableCfg: wasm.WasmPastaFqRuntimeTableCfg, + LookupTable: wasm.WasmPastaFqLookupTable, + }), + }; +} + +function proofConversionPerField( + core: ConversionCore, + { + ProverCommitments, + OpeningProof, + VecVec, + ProverProof, + LookupCommitments, + RuntimeTable, + RuntimeTableCfg, + LookupTable, + }: WasmClasses +) { + function commitmentsToRust(commitments: ProverCommitments): WasmProverCommitments { + let wComm = core.polyCommsToRust(commitments[1]); + let zComm = core.polyCommToRust(commitments[2]); + let tComm = core.polyCommToRust(commitments[3]); + let lookup = MlOption.mapFrom(commitments[4], lookupCommitmentsToRust); + return new ProverCommitments(wComm, zComm, tComm, lookup); + } + function commitmentsFromRust(commitments: WasmProverCommitments): ProverCommitments { + let wComm = core.polyCommsFromRust(commitments.w_comm); + let zComm = core.polyCommFromRust(commitments.z_comm); + let tComm = core.polyCommFromRust(commitments.t_comm); + let lookup = MlOption.mapTo(commitments.lookup, lookupCommitmentsFromRust); + commitments.free(); + return [0, wComm as MlTuple, zComm, tComm, lookup]; + } + + function lookupCommitmentsToRust(lookup: LookupCommitments): WasmLookupCommitments { + let sorted = core.polyCommsToRust(lookup[1]); + let aggreg = core.polyCommToRust(lookup[2]); + let runtime = MlOption.mapFrom(lookup[3], core.polyCommToRust); + return new LookupCommitments(sorted, aggreg, runtime); + } + function lookupCommitmentsFromRust(lookup: WasmLookupCommitments): LookupCommitments { + let sorted = core.polyCommsFromRust(lookup.sorted); + let aggreg = core.polyCommFromRust(lookup.aggreg); + let runtime = MlOption.mapTo(lookup.runtime, core.polyCommFromRust); + lookup.free(); + return [0, sorted, aggreg, runtime]; + } + + function openingProofToRust(proof: OpeningProof): WasmOpeningProof { + let [_, [, ...lr], delta, z1, z2, sg] = proof; + // We pass l and r as separate vectors over the FFI + let l: MlArray = [0]; + let r: MlArray = [0]; + for (let [, li, ri] of lr) { + l.push(li); + r.push(ri); + } + return new OpeningProof( + core.pointsToRust(l), + core.pointsToRust(r), + core.pointToRust(delta), + fieldToRust(z1), + fieldToRust(z2), + core.pointToRust(sg) + ); + } + function openingProofFromRust(proof: WasmOpeningProof): OpeningProof { + let [, ...l] = core.pointsFromRust(proof.lr_0); + let [, ...r] = core.pointsFromRust(proof.lr_1); + let n = l.length; + if (n !== r.length) throw Error('openingProofFromRust: l and r length mismatch.'); + let lr = l.map<[0, OrInfinity, OrInfinity]>((li, i) => [0, li, r[i]]); + let delta = core.pointFromRust(proof.delta); + let z1 = fieldFromRust(proof.z1); + let z2 = fieldFromRust(proof.z2); + let sg = core.pointFromRust(proof.sg); + proof.free(); + return [0, [0, ...lr], delta, z1, z2, sg]; + } + + function runtimeTableToRust([, id, data]: RuntimeTable): WasmRuntimeTable { + return new RuntimeTable(id, core.vectorToRust(data)); + } + + function runtimeTableCfgToRust([, id, firstColumn]: RuntimeTableCfg): WasmRuntimeTableCfg { + return new RuntimeTableCfg(id, core.vectorToRust(firstColumn)); + } + + function lookupTableToRust([, id, [, ...data]]: LookupTable): WasmLookupTable { + let n = data.length; + let wasmData = new VecVec(n); + for (let i = 0; i < n; i++) { + wasmData.push(fieldsToRustFlat(data[i])); + } + return new LookupTable(id, wasmData); + } + + return { + proofToRust([, public_evals, proof]: ProofWithPublic): WasmProverProof { + let commitments = commitmentsToRust(proof[1]); + let openingProof = openingProofToRust(proof[2]); + let [, ...evals] = proofEvaluationsToRust(proof[3]); + let publicEvals = pointEvalsOptionToRust(public_evals); + // TODO typed as `any` in wasm-bindgen, this has the correct type + let evalsActual: WasmProofEvaluations = [0, publicEvals, ...evals]; + + let ftEval1 = fieldToRust(proof[4]); + let public_ = fieldsToRustFlat(proof[5]); + let [, ...prevChallenges] = proof[6]; + let n = prevChallenges.length; + let prevChallengeScalars = new VecVec(n); + let prevChallengeCommsMl: MlArray = [0]; + for (let [, scalars, comms] of prevChallenges) { + prevChallengeScalars.push(fieldsToRustFlat(scalars)); + prevChallengeCommsMl.push(comms); + } + let prevChallengeComms = core.polyCommsToRust(prevChallengeCommsMl); + return new ProverProof( + commitments, + openingProof, + evalsActual, + ftEval1, + public_, + prevChallengeScalars, + prevChallengeComms + ); + }, + proofFromRust(wasmProof: WasmProverProof): ProofWithPublic { + let commitments = commitmentsFromRust(wasmProof.commitments); + let openingProof = openingProofFromRust(wasmProof.proof); + // TODO typed as `any` in wasm-bindgen, this is the correct type + let [, wasmPublicEvals, ...wasmEvals]: WasmProofEvaluations = wasmProof.evals; + let publicEvals = pointEvalsOptionFromRust(wasmPublicEvals); + let evals = proofEvaluationsFromRust([0, ...wasmEvals]); + + let ftEval1 = fieldFromRust(wasmProof.ft_eval1); + let public_ = fieldsFromRustFlat(wasmProof.public_); + let prevChallengeScalars = wasmProof.prev_challenges_scalars; + let [, ...prevChallengeComms] = core.polyCommsFromRust(wasmProof.prev_challenges_comms); + let prevChallenges = prevChallengeComms.map((comms, i) => { + let scalars = fieldsFromRustFlat(prevChallengeScalars.get(i)); + return [0, scalars, comms]; + }); + wasmProof.free(); + let proof: ProverProof = [ + 0, + commitments, + openingProof, + evals, + ftEval1, + public_, + [0, ...prevChallenges], + ]; + return [0, publicEvals, proof]; + }, + + runtimeTablesToRust([, ...tables]: MlArray): Uint32Array { + return mapToUint32Array(tables, (table) => unwrap(runtimeTableToRust(table))); + }, + + runtimeTableCfgsToRust([, ...tableCfgs]: MlArray): Uint32Array { + return mapToUint32Array(tableCfgs, (tableCfg) => unwrap(runtimeTableCfgToRust(tableCfg))); + }, + + lookupTablesToRust([, ...tables]: MlArray) { + return mapToUint32Array(tables, (table) => unwrap(lookupTableToRust(table))); + }, + }; +} + +function createMapPointEvals(map: (x: Field1) => Field2) { + return (evals: PointEvaluations): PointEvaluations => { + let [, zeta, zeta_omega] = evals; + return [0, MlArray.map(zeta, map), MlArray.map(zeta_omega, map)]; + }; +} + +function mapPointEvalsOption(map: (x: Field1) => Field2) { + return (evals: MlOption>) => + MlOption.map(evals, createMapPointEvals(map)); +} + +function mapProofEvaluations(map: (x: Field1) => Field2) { + const mapPointEvals = createMapPointEvals(map); + + const mapPointEvalsOption = ( + evals: MlOption> + ): MlOption> => MlOption.map(evals, mapPointEvals); + + return function mapProofEvaluations(evals: ProofEvaluations): ProofEvaluations { + let [ + , + w, + z, + s, + coeffs, + genericSelector, + poseidonSelector, + completeAddSelector, + mulSelector, + emulSelector, + endomulScalarSelector, + rangeCheck0Selector, + rangeCheck1Selector, + foreignFieldAddSelector, + foreignFieldMulSelector, + xorSelector, + rotSelector, + lookupAggregation, + lookupTable, + lookupSorted, + runtimeLookupTable, + runtimeLookupTableSelector, + xorLookupSelector, + lookupGateLookupSelector, + rangeCheckLookupSelector, + foreignFieldMulLookupSelector, + ] = evals; + return [ + 0, + MlTuple.map(w, mapPointEvals), + mapPointEvals(z), + MlTuple.map(s, mapPointEvals), + MlTuple.map(coeffs, mapPointEvals), + mapPointEvals(genericSelector), + mapPointEvals(poseidonSelector), + mapPointEvals(completeAddSelector), + mapPointEvals(mulSelector), + mapPointEvals(emulSelector), + mapPointEvals(endomulScalarSelector), + mapPointEvalsOption(rangeCheck0Selector), + mapPointEvalsOption(rangeCheck1Selector), + mapPointEvalsOption(foreignFieldAddSelector), + mapPointEvalsOption(foreignFieldMulSelector), + mapPointEvalsOption(xorSelector), + mapPointEvalsOption(rotSelector), + mapPointEvalsOption(lookupAggregation), + mapPointEvalsOption(lookupTable), + MlArray.map(lookupSorted, mapPointEvalsOption), + mapPointEvalsOption(runtimeLookupTable), + mapPointEvalsOption(runtimeLookupTableSelector), + mapPointEvalsOption(xorLookupSelector), + mapPointEvalsOption(lookupGateLookupSelector), + mapPointEvalsOption(rangeCheckLookupSelector), + mapPointEvalsOption(foreignFieldMulLookupSelector), + ]; + }; +} + +// helper + +type RemoveLeadingZero = T extends [0, ...infer U] ? U : never; diff --git a/src/bindings/crypto/native/conversion-verifier-index.ts b/src/bindings/crypto/native/conversion-verifier-index.ts new file mode 100644 index 000000000..ac4cb7182 --- /dev/null +++ b/src/bindings/crypto/native/conversion-verifier-index.ts @@ -0,0 +1,289 @@ +import type { + WasmFpDomain, + WasmFpLookupSelectors, + WasmFpLookupVerifierIndex, + WasmFpPlonkVerificationEvals, + WasmFpPlonkVerifierIndex, + WasmFpShifts, + WasmFqDomain, + WasmFqLookupSelectors, + WasmFqLookupVerifierIndex, + WasmFqPlonkVerificationEvals, + WasmFqPlonkVerifierIndex, + WasmFqShifts, + LookupInfo as WasmLookupInfo, +} from '../../compiled/node_bindings/plonk_wasm.cjs'; +import type * as wasmNamespace from '../../compiled/node_bindings/plonk_wasm.cjs'; +import { MlBool, MlArray, MlOption } from '../../../lib/ml/base.js'; +import { Field, VerifierIndex, Domain, VerificationEvals, PolyComm } from './kimchi-types.js'; +import { fieldFromRust, fieldToRust } from './conversion-base.js'; +import { ConversionCore, ConversionCores, freeOnFinalize } from './conversion-core.js'; +import { Lookup, LookupInfo, LookupSelectors } from './lookup.js'; + +export { verifierIndexConversion }; + +type wasm = typeof wasmNamespace; + +type WasmDomain = WasmFpDomain | WasmFqDomain; +type WasmVerificationEvals = WasmFpPlonkVerificationEvals | WasmFqPlonkVerificationEvals; +type WasmShifts = WasmFpShifts | WasmFqShifts; +type WasmVerifierIndex = WasmFpPlonkVerifierIndex | WasmFqPlonkVerifierIndex; + +type WasmLookupVerifierIndex = WasmFpLookupVerifierIndex | WasmFqLookupVerifierIndex; +type WasmLookupSelector = WasmFpLookupSelectors | WasmFqLookupSelectors; + +type WasmClasses = { + Domain: typeof WasmFpDomain | typeof WasmFqDomain; + VerificationEvals: typeof WasmFpPlonkVerificationEvals | typeof WasmFqPlonkVerificationEvals; + Shifts: typeof WasmFpShifts | typeof WasmFqShifts; + VerifierIndex: typeof WasmFpPlonkVerifierIndex | typeof WasmFqPlonkVerifierIndex; + LookupVerifierIndex: typeof WasmFpLookupVerifierIndex | typeof WasmFqLookupVerifierIndex; + LookupSelector: typeof WasmFpLookupSelectors | typeof WasmFqLookupSelectors; +}; + +function verifierIndexConversion(wasm: wasm, core: ConversionCores) { + return { + fp: verifierIndexConversionPerField(wasm, core.fp, { + Domain: wasm.WasmFpDomain, + VerificationEvals: wasm.WasmFpPlonkVerificationEvals, + Shifts: wasm.WasmFpShifts, + VerifierIndex: wasm.WasmFpPlonkVerifierIndex, + LookupVerifierIndex: wasm.WasmFpLookupVerifierIndex, + LookupSelector: wasm.WasmFpLookupSelectors, + }), + fq: verifierIndexConversionPerField(wasm, core.fq, { + Domain: wasm.WasmFqDomain, + VerificationEvals: wasm.WasmFqPlonkVerificationEvals, + Shifts: wasm.WasmFqShifts, + VerifierIndex: wasm.WasmFqPlonkVerifierIndex, + LookupVerifierIndex: wasm.WasmFqLookupVerifierIndex, + LookupSelector: wasm.WasmFqLookupSelectors, + }), + }; +} + +function verifierIndexConversionPerField( + wasm: wasm, + core: ConversionCore, + { + Domain, + VerificationEvals, + Shifts, + VerifierIndex, + LookupVerifierIndex, + LookupSelector, + }: WasmClasses +) { + function domainToRust([, logSizeOfGroup, groupGen]: Domain): WasmDomain { + return new Domain(logSizeOfGroup, fieldToRust(groupGen)); + } + function domainFromRust(domain: WasmDomain): Domain { + let logSizeOfGroup = domain.log_size_of_group; + let groupGen = fieldFromRust(domain.group_gen); + domain.free(); + return [0, logSizeOfGroup, groupGen]; + } + + function verificationEvalsToRust(evals: VerificationEvals): WasmVerificationEvals { + let sigmaComm = core.polyCommsToRust(evals[1]); + let coefficientsComm = core.polyCommsToRust(evals[2]); + let genericComm = core.polyCommToRust(evals[3]); + let psmComm = core.polyCommToRust(evals[4]); + let completeAddComm = core.polyCommToRust(evals[5]); + let mulComm = core.polyCommToRust(evals[6]); + let emulComm = core.polyCommToRust(evals[7]); + let endomulScalarComm = core.polyCommToRust(evals[8]); + let xorComm = MlOption.mapFrom(evals[9], core.polyCommToRust); + let rangeCheck0Comm = MlOption.mapFrom(evals[10], core.polyCommToRust); + let rangeCheck1Comm = MlOption.mapFrom(evals[11], core.polyCommToRust); + let foreignFieldAddComm = MlOption.mapFrom(evals[12], core.polyCommToRust); + let foreignFieldMulComm = MlOption.mapFrom(evals[13], core.polyCommToRust); + let rotComm = MlOption.mapFrom(evals[14], core.polyCommToRust); + return new VerificationEvals( + sigmaComm, + coefficientsComm, + genericComm, + psmComm, + completeAddComm, + mulComm, + emulComm, + endomulScalarComm, + xorComm, + rangeCheck0Comm, + rangeCheck1Comm, + foreignFieldAddComm, + foreignFieldMulComm, + rotComm + ); + } + function verificationEvalsFromRust(evals: WasmVerificationEvals): VerificationEvals { + let mlEvals: VerificationEvals = [ + 0, + core.polyCommsFromRust(evals.sigma_comm), + core.polyCommsFromRust(evals.coefficients_comm), + core.polyCommFromRust(evals.generic_comm), + core.polyCommFromRust(evals.psm_comm), + core.polyCommFromRust(evals.complete_add_comm), + core.polyCommFromRust(evals.mul_comm), + core.polyCommFromRust(evals.emul_comm), + core.polyCommFromRust(evals.endomul_scalar_comm), + MlOption.mapTo(evals.xor_comm, core.polyCommFromRust), + MlOption.mapTo(evals.range_check0_comm, core.polyCommFromRust), + MlOption.mapTo(evals.range_check1_comm, core.polyCommFromRust), + MlOption.mapTo(evals.foreign_field_add_comm, core.polyCommFromRust), + MlOption.mapTo(evals.foreign_field_mul_comm, core.polyCommFromRust), + MlOption.mapTo(evals.rot_comm, core.polyCommFromRust), + ]; + evals.free(); + return mlEvals; + } + + function lookupVerifierIndexToRust(lookup: Lookup): WasmLookupVerifierIndex { + let [ + , + joint_lookup_used, + lookup_table, + selectors, + table_ids, + lookup_info, + runtime_tables_selector, + ] = lookup; + return new LookupVerifierIndex( + MlBool.from(joint_lookup_used), + core.polyCommsToRust(lookup_table), + lookupSelectorsToRust(selectors), + MlOption.mapFrom(table_ids, core.polyCommToRust), + lookupInfoToRust(lookup_info), + MlOption.mapFrom(runtime_tables_selector, core.polyCommToRust) + ); + } + function lookupVerifierIndexFromRust(lookup: WasmLookupVerifierIndex): Lookup { + let mlLookup: Lookup = [ + 0, + MlBool(lookup.joint_lookup_used), + core.polyCommsFromRust(lookup.lookup_table), + lookupSelectorsFromRust(lookup.lookup_selectors), + MlOption.mapTo(lookup.table_ids, core.polyCommFromRust), + lookupInfoFromRust(lookup.lookup_info), + MlOption.mapTo(lookup.runtime_tables_selector, core.polyCommFromRust), + ]; + lookup.free(); + return mlLookup; + } + + function lookupSelectorsToRust([ + , + lookup, + xor, + range_check, + ffmul, + ]: LookupSelectors): WasmLookupSelector { + return new LookupSelector( + MlOption.mapFrom(xor, core.polyCommToRust), + MlOption.mapFrom(lookup, core.polyCommToRust), + MlOption.mapFrom(range_check, core.polyCommToRust), + MlOption.mapFrom(ffmul, core.polyCommToRust) + ); + } + function lookupSelectorsFromRust(selector: WasmLookupSelector): LookupSelectors { + let lookup = MlOption.mapTo(selector.lookup, core.polyCommFromRust); + let xor = MlOption.mapTo(selector.xor, core.polyCommFromRust); + let range_check = MlOption.mapTo(selector.range_check, core.polyCommFromRust); + let ffmul = MlOption.mapTo(selector.ffmul, core.polyCommFromRust); + selector.free(); + return [0, lookup, xor, range_check, ffmul]; + } + + function lookupInfoToRust([, maxPerRow, maxJointSize, features]: LookupInfo): WasmLookupInfo { + let [, patterns, joint_lookup_used, uses_runtime_tables] = features; + let [, xor, lookup, range_check, foreign_field_mul] = patterns; + let wasmPatterns = new wasm.LookupPatterns( + MlBool.from(xor), + MlBool.from(lookup), + MlBool.from(range_check), + MlBool.from(foreign_field_mul) + ); + let wasmFeatures = new wasm.LookupFeatures( + wasmPatterns, + MlBool.from(joint_lookup_used), + MlBool.from(uses_runtime_tables) + ); + return new wasm.LookupInfo(maxPerRow, maxJointSize, wasmFeatures); + } + function lookupInfoFromRust(info: WasmLookupInfo): LookupInfo { + let features = info.features; + let patterns = features.patterns; + let mlInfo: LookupInfo = [ + 0, + info.max_per_row, + info.max_joint_size, + [ + 0, + [ + 0, + MlBool(patterns.xor), + MlBool(patterns.lookup), + MlBool(patterns.range_check), + MlBool(patterns.foreign_field_mul), + ], + MlBool(features.joint_lookup_used), + MlBool(features.uses_runtime_tables), + ], + ]; + info.free(); + return mlInfo; + } + + let self = { + shiftsToRust([, ...shifts]: MlArray): WasmShifts { + let s = shifts.map((s) => fieldToRust(s)); + return new Shifts(s[0], s[1], s[2], s[3], s[4], s[5], s[6]); + }, + shiftsFromRust(s: WasmShifts): MlArray { + let shifts = [s.s0, s.s1, s.s2, s.s3, s.s4, s.s5, s.s6]; + s.free(); + return [0, ...shifts.map(fieldFromRust)]; + }, + + verifierIndexToRust(vk: VerifierIndex): WasmVerifierIndex { + let domain = domainToRust(vk[1]); + let maxPolySize = vk[2]; + let nPublic = vk[3]; + let prevChallenges = vk[4]; + let srs = vk[5]; + let evals = verificationEvalsToRust(vk[6]); + let shifts = self.shiftsToRust(vk[7]); + let lookupIndex = MlOption.mapFrom(vk[8], lookupVerifierIndexToRust); + let zkRows = vk[9]; + return new VerifierIndex( + domain, + maxPolySize, + nPublic, + prevChallenges, + srs, + evals, + shifts, + lookupIndex, + zkRows + ); + }, + verifierIndexFromRust(vk: WasmVerifierIndex): VerifierIndex { + let mlVk: VerifierIndex = [ + 0, + domainFromRust(vk.domain), + vk.max_poly_size, + vk.public_, + vk.prev_challenges, + freeOnFinalize(vk.srs), + verificationEvalsFromRust(vk.evals), + self.shiftsFromRust(vk.shifts), + MlOption.mapTo(vk.lookup_index, lookupVerifierIndexFromRust), + vk.zk_rows, + ]; + vk.free(); + return mlVk; + }, + }; + + return self; +} diff --git a/src/bindings/crypto/native/curve.ts b/src/bindings/crypto/native/curve.ts new file mode 100644 index 000000000..3a11c6a2a --- /dev/null +++ b/src/bindings/crypto/native/curve.ts @@ -0,0 +1,91 @@ +/** + * TS implementation of Pasta_bindings.{Pallas, Vesta} + */ +import { MlPair } from '../../../lib/ml/base.js'; +import { Field } from './field.js'; +import { Pallas, Vesta, ProjectiveCurve, GroupProjective, GroupAffine } from '../elliptic-curve.js'; +import { withPrefix } from './util.js'; + +export { + VestaBindings, + PallasBindings, + Infinity, + OrInfinity, + OrInfinityJson, + toMlOrInfinity, + fromMlOrInfinity, +}; + +const VestaBindings = withPrefix('caml_vesta', createCurveBindings(Vesta)); +const PallasBindings = withPrefix('caml_pallas', createCurveBindings(Pallas)); + +function createCurveBindings(Curve: ProjectiveCurve) { + return { + one(): GroupProjective { + return Curve.one; + }, + add: Curve.add, + sub: Curve.sub, + negate: Curve.negate, + double: Curve.double, + scale(g: GroupProjective, [, s]: Field): GroupProjective { + return Curve.scale(g, s); + }, + random(): GroupProjective { + throw Error('random not implemented'); + }, + rng(i: number): GroupProjective { + throw Error('rng not implemented'); + }, + endo_base(): Field { + return [0, Curve.endoBase]; + }, + endo_scalar(): Field { + return [0, Curve.endoScalar]; + }, + to_affine(g: GroupProjective): OrInfinity { + return toMlOrInfinity(Curve.toAffine(g)); + }, + of_affine(g: OrInfinity): GroupProjective { + return Curve.fromAffine(fromMlOrInfinity(g)); + }, + of_affine_coordinates(x: Field, y: Field): GroupProjective { + // allows to create in points not on the curve - matches Rust impl + return { x: x[1], y: y[1], z: 1n }; + }, + affine_deep_copy(g: OrInfinity): OrInfinity { + return toMlOrInfinity(fromMlOrInfinity(g)); + }, + }; +} + +const affineZero = { x: 0n, y: 0n, infinity: true }; + +// Kimchi_types.or_infinity +type Infinity = 0; +const Infinity = 0; +type Finite = [0, T]; +type OrInfinity = Infinity | Finite>; + +function toMlOrInfinity(g: GroupAffine): OrInfinity { + if (g.infinity) return 0; + return [0, [0, [0, g.x], [0, g.y]]]; +} + +function fromMlOrInfinity(g: OrInfinity): GroupAffine { + if (g === 0) return affineZero; + return { x: g[1][1][1], y: g[1][2][1], infinity: false }; +} + +type OrInfinityJson = 'Infinity' | { x: string; y: string }; + +const OrInfinity = { + toJSON(g: OrInfinity): OrInfinityJson { + if (g === 0) return 'Infinity'; + return { x: g[1][1][1].toString(), y: g[1][2][1].toString() }; + }, + fromJSON(g: OrInfinityJson): OrInfinity { + if (g === 'Infinity') return 0; + return [0, [0, [0, BigInt(g.x)], [0, BigInt(g.y)]]]; + }, +}; diff --git a/src/bindings/crypto/native/env.ts b/src/bindings/crypto/native/env.ts new file mode 100644 index 000000000..fa457f6d0 --- /dev/null +++ b/src/bindings/crypto/native/env.ts @@ -0,0 +1 @@ +export const jsEnvironment = 'node'; diff --git a/src/bindings/crypto/native/field.ts b/src/bindings/crypto/native/field.ts new file mode 100644 index 000000000..03994a63a --- /dev/null +++ b/src/bindings/crypto/native/field.ts @@ -0,0 +1,143 @@ +/** + * TS implementation of Pasta_bindings.{Fp, Fq} + */ +import { FiniteField, Fp, Fq, mod } from '../finite-field.js'; +import { + Bigint256Bindings, + Bigint256, + MlBytes, + fromMlString, + toMlStringAscii, +} from './bigint256.js'; +import { MlOption, MlBool } from '../../../lib/ml/base.js'; +import { withPrefix } from './util.js'; + +type Field = [0, bigint]; + +export { FpBindings, FqBindings, Field }; + +const FpBindings = withPrefix('caml_pasta_fp', createFieldBindings(Fp)); +const FqBindings = withPrefix('caml_pasta_fq', createFieldBindings(Fq)); + +function createFieldBindings(Field: FiniteField) { + return { + size_in_bits(): number { + return Field.sizeInBits; + }, + size(): Bigint256 { + return [0, Field.modulus]; + }, + add([, x]: Field, [, y]: Field): Field { + return [0, Field.add(x, y)]; + }, + sub([, x]: Field, [, y]: Field): Field { + return [0, Field.sub(x, y)]; + }, + negate([, x]: Field): Field { + return [0, Field.negate(x)]; + }, + mul([, x]: Field, [, y]: Field): Field { + return [0, Field.mul(x, y)]; + }, + div([, x]: Field, [, y]: Field): Field { + let z = Field.div(x, y); + if (z === undefined) throw Error('division by zero'); + return [0, z]; + }, + inv([, x]: Field): MlOption { + return toMlOption(Field.inverse(x)); + }, + square([, x]: Field): Field { + return [0, Field.square(x)]; + }, + is_square([, x]: Field): MlBool { + return MlBool(Field.isSquare(x)); + }, + sqrt([, x]: Field): MlOption { + return toMlOption(Field.sqrt(x)); + }, + of_int(x: number): Field { + // avoid unnatural behaviour in Rust which treats negative numbers as uint64, + // e.g. -1 becomes 2^64 - 1 + if (x < 0) throw Error('of_int: inputs must be non-negative'); + return [0, Field.fromNumber(x)]; + }, + to_string([, x]: Field): MlBytes { + return toMlStringAscii(x.toString()); + }, + of_string(s: MlBytes): Field { + return [0, Field.fromBigint(BigInt(fromMlString(s)))]; + }, + print(x: Field): void { + console.log(x[0].toString()); + }, + copy(x: Field, [, y]: Field): void { + x[1] = y; + }, + mut_add(x: Field, [, y]: Field): void { + x[1] = Field.add(x[1], y); + }, + mut_sub(x: Field, [, y]: Field): void { + x[1] = Field.sub(x[1], y); + }, + mut_mul(x: Field, [, y]: Field): void { + x[1] = Field.mul(x[1], y); + }, + mut_square(x: Field): void { + x[1] = Field.square(x[1]); + }, + compare(x: Field, y: Field): number { + return Bigint256Bindings.caml_bigint_256_compare(x, y); + }, + equal([, x]: Field, [, y]: Field): MlBool { + return MlBool(x === y); + }, + random(): Field { + return [0, Field.random()]; + }, + rng(i: number): Field { + // not used in js + throw Error('rng: not implemented'); + }, + to_bigint([, x]: Field): Bigint256 { + // copying to a new array to break mutable reference + return [0, x]; + }, + of_bigint([, x]: Bigint256): Field { + if (x >= Field.modulus) throw Error('of_bigint: input exceeds field size'); + // copying to a new array to break mutable reference + return [0, x]; + }, + two_adic_root_of_unity(): Field { + return [0, Field.twoadicRoot]; + }, + domain_generator(i: number): Field { + // this takes an integer i and returns a 2^ith root of unity, i.e. a number `w` with + // w^(2^i) = 1, w^(2^(i-1)) = -1 + // computed by taking the 2^32th root and squaring 32-i times + if (i > 32 || i < 0) + throw Error('log2 size of evaluation domain must be in [0, 32], got ' + i); + if (i === 0) return [0, 1n]; + let generator = Field.twoadicRoot; + for (let j = 32; j > i; j--) { + generator = mod(generator * generator, Field.modulus); + } + return [0, generator]; + }, + to_bytes(x: Field): MlBytes { + return Bigint256Bindings.caml_bigint_256_to_bytes(x); + }, + of_bytes(bytes: MlBytes): Field { + // not used in js + throw Error('of_bytes: not implemented'); + }, + deep_copy([, x]: Field): Field { + return [0, x]; + }, + }; +} + +function toMlOption(x: undefined | T): MlOption<[0, T]> { + if (x === undefined) return 0; // None + return [0, [0, x]]; // Some(x) +} diff --git a/src/bindings/crypto/native/kimchi-types.ts b/src/bindings/crypto/native/kimchi-types.ts new file mode 100644 index 000000000..ba6a7394e --- /dev/null +++ b/src/bindings/crypto/native/kimchi-types.ts @@ -0,0 +1,189 @@ +/** + * This file is a TS representation of kimchi_types.ml + */ +import type { Lookup } from './lookup.js'; +import type { MlArray, MlOption, MlTuple } from '../../../lib/ml/base.js'; +import type { OrInfinity } from './curve.js'; +import type { Field } from './field.js'; +import type { WasmFpSrs, WasmFqSrs } from '../../compiled/node_bindings/plonk_wasm.cjs'; + +export { + Field, + OrInfinity, + Wire, + Gate, + PolyComm, + Domain, + VerificationEvals, + VerifierIndex, + ScalarChallenge, + RandomOracles, + Oracles, + ProverCommitments, + OpeningProof, + PointEvaluations, + ProofEvaluations, + RecursionChallenge, + ProverProof, + ProofWithPublic, + LookupCommitments, + RuntimeTableCfg, + LookupTable, + RuntimeTable, +}; + +// wasm types + +type WasmSrs = WasmFpSrs | WasmFqSrs; + +// ml types from kimchi_types.ml + +type GateType = number; +type Wire = [_: 0, row: number, col: number]; +type Gate = [ + _: 0, + typ: GateType, + wires: [0, Wire, Wire, Wire, Wire, Wire, Wire, Wire], + coeffs: MlArray, +]; + +type PolyComm = [_: 0, elems: MlArray]; + +// verifier index + +type Domain = [_: 0, log_size_of_group: number, group_gen: Field]; + +type VerificationEvals = [ + _: 0, + sigma_comm: MlArray, + coefficients_comm: MlArray, + generic_comm: PolyComm, + psm_comm: PolyComm, + complete_add_comm: PolyComm, + mul_comm: PolyComm, + emul_comm: PolyComm, + endomul_scalar_comm: PolyComm, + xor_comm: MlOption, + range_check0_comm: MlOption, + range_check1_comm: MlOption, + foreign_field_add_comm: MlOption, + foreign_field_mul_comm: MlOption, + rot_comm: MlOption, +]; + +type VerifierIndex = [ + _: 0, + domain: Domain, + max_poly_size: number, + public_: number, + prev_challenges: number, + srs: WasmSrs, + evals: VerificationEvals, + shifts: MlArray, + lookup_index: MlOption>, + zkRows: number, +]; + +// oracles + +type ScalarChallenge = [_: 0, inner: Field]; +type RandomOracles = [ + _: 0, + joint_combiner: MlOption<[0, ScalarChallenge, Field]>, + beta: Field, + gamma: Field, + alpha_chal: ScalarChallenge, + alpha: Field, + zeta: Field, + v: Field, + u: Field, + zeta_chal: ScalarChallenge, + v_chal: ScalarChallenge, + u_chal: ScalarChallenge, +]; +type Oracles = [ + _: 0, + o: RandomOracles, + p_eval: [0, Field, Field], + opening_prechallenges: MlArray, + digest_before_evaluations: Field, +]; + +// proof + +type LookupCommitments = [ + _: 0, + sorted: MlArray, + aggreg: PolyComm, + runtime: MlOption, +]; +type ProverCommitments = [ + _: 0, + w_comm: MlTuple, + z_comm: PolyComm, + t_comm: PolyComm, + lookup: MlOption, +]; +type OpeningProof = [ + _: 0, + lr: MlArray<[0, OrInfinity, OrInfinity]>, + delta: OrInfinity, + z1: Field, + z2: Field, + sg: OrInfinity, +]; +type PointEvaluations = [_: 0, zeta: MlArray, zeta_omega: MlArray]; + +type nColumns = 15; +type nPermutsMinus1 = 6; + +type ProofEvaluations = [ + _: 0, + w: MlTuple, nColumns>, + z: PointEvaluations, + s: MlTuple, nPermutsMinus1>, + coefficients: MlTuple, nColumns>, + generic_selector: PointEvaluations, + poseidon_selector: PointEvaluations, + complete_add_selector: PointEvaluations, + mul_selector: PointEvaluations, + emul_selector: PointEvaluations, + endomul_scalar_selector: PointEvaluations, + range_check0_selector: MlOption>, + range_check1_selector: MlOption>, + foreign_field_add_selector: MlOption>, + foreign_field_mul_selector: MlOption>, + xor_selector: MlOption>, + rot_selector: MlOption>, + lookup_aggregation: MlOption>, + lookup_table: MlOption>, + lookup_sorted: MlArray>>, + runtime_lookup_table: MlOption>, + runtime_lookup_table_selector: MlOption>, + xor_lookup_selector: MlOption>, + lookup_gate_lookup_selector: MlOption>, + range_check_lookup_selector: MlOption>, + foreign_field_mul_lookup_selector: MlOption>, +]; + +type RecursionChallenge = [_: 0, chals: MlArray, comm: PolyComm]; + +type ProverProof = [ + _: 0, + commitments: ProverCommitments, + proof: OpeningProof, + evals: ProofEvaluations, + ft_eval1: Field, + public_: MlArray, + prev_challenges: MlArray, +]; + +type ProofWithPublic = [_: 0, public_evals: MlOption>, proof: ProverProof]; + +// tables + +type RuntimeTableCfg = [_: 0, id: number, first_column: MlArray]; + +type LookupTable = [_: 0, id: number, data: MlArray>]; + +type RuntimeTable = [_: 0, id: number, data: MlArray]; diff --git a/src/bindings/crypto/native/lookup.ts b/src/bindings/crypto/native/lookup.ts new file mode 100644 index 000000000..9180acfcc --- /dev/null +++ b/src/bindings/crypto/native/lookup.ts @@ -0,0 +1,34 @@ +import { MlArray, MlBool, MlOption } from '../../../lib/ml/base.js'; + +export { Lookup, LookupInfo, LookupPatterns, LookupFeatures, LookupSelectors }; + +type LookupPatterns = [ + _: 0, + xor: MlBool, + lookup: MlBool, + range_check: MlBool, + foreign_field_mul: MlBool, +]; +type LookupFeatures = [ + _: 0, + patterns: LookupPatterns, + joint_lookup_used: MlBool, + uses_runtime_tables: MlBool, +]; +type LookupInfo = [_: 0, max_per_row: number, max_joint_size: number, features: LookupFeatures]; +type LookupSelectors = [ + _: 0, + lookup: MlOption, + xor: MlOption, + range_check: MlOption, + ffmul: MlOption, +]; +type Lookup = [ + _: 0, + joint_lookup_used: MlBool, + lookup_table: MlArray, + selectors: LookupSelectors, + table_ids: MlOption, + lookup_info: LookupInfo, + runtime_tables_selector: MlOption, +]; diff --git a/src/bindings/crypto/native/srs.ts b/src/bindings/crypto/native/srs.ts new file mode 100644 index 000000000..b97ef6b0a --- /dev/null +++ b/src/bindings/crypto/native/srs.ts @@ -0,0 +1,272 @@ +import type { Wasm, RustConversion } from '../bindings.js'; +import { type WasmFpSrs, type WasmFqSrs } from '../../compiled/node_bindings/plonk_wasm.cjs'; +import { PolyComm } from './kimchi-types.js'; +import { + type CacheHeader, + type Cache, + withVersion, + writeCache, + readCache, +} from '../../../lib/proof-system/cache.js'; +import { assert } from '../../../lib/util/errors.js'; +import { MlArray } from '../../../lib/ml/base.js'; +import { OrInfinity, OrInfinityJson } from './curve.js'; + +export { srs, setSrsCache, unsetSrsCache }; + +type WasmSrs = WasmFpSrs | WasmFqSrs; + +type SrsStore = Record; + +function empty(): SrsStore { + return {}; +} + +const srsStore = { fp: empty(), fq: empty() }; + +const CacheReadRegister = new Map(); + +let cache: Cache | undefined; + +function setSrsCache(c: Cache) { + cache = c; +} +function unsetSrsCache() { + cache = undefined; +} + +const srsVersion = 1; + +function cacheHeaderLagrange(f: 'fp' | 'fq', domainSize: number): CacheHeader { + let id = `lagrange-basis-${f}-${domainSize}`; + return withVersion( + { + kind: 'lagrange-basis', + persistentId: id, + uniqueId: id, + dataType: 'string', + }, + srsVersion + ); +} +function cacheHeaderSrs(f: 'fp' | 'fq', domainSize: number): CacheHeader { + let id = `srs-${f}-${domainSize}`; + return withVersion( + { + kind: 'srs', + persistentId: id, + uniqueId: id, + dataType: 'string', + }, + srsVersion + ); +} + +function srs(wasm: Wasm, conversion: RustConversion) { + return { + fp: srsPerField('fp', wasm, conversion), + fq: srsPerField('fq', wasm, conversion), + }; +} + +function srsPerField(f: 'fp' | 'fq', wasm: Wasm, conversion: RustConversion) { + // note: these functions are properly typed, thanks to TS template literal types + let createSrs = (s: number) => wasm[`caml_${f}_srs_create_parallel`](s); + let getSrs = wasm[`caml_${f}_srs_get`]; + let setSrs = wasm[`caml_${f}_srs_set`]; + + let maybeLagrangeCommitment = wasm[`caml_${f}_srs_maybe_lagrange_commitment`]; + let lagrangeCommitment = (srs: WasmFpSrs, domain_size: number, i: number) => + wasm[`caml_${f}_srs_lagrange_commitment`](srs, domain_size, i); + let lagrangeCommitmentsWholeDomainPtr = (srs: WasmSrs, domain_size: number) => + wasm[`caml_${f}_srs_lagrange_commitments_whole_domain_ptr`](srs, domain_size); + let setLagrangeBasis = wasm[`caml_${f}_srs_set_lagrange_basis`]; + let getLagrangeBasis = (srs: WasmSrs, n: number) => + wasm[`caml_${f}_srs_get_lagrange_basis`](srs, n); + let getCommitmentsWholeDomainByPtr = + wasm[`caml_${f}_srs_lagrange_commitments_whole_domain_read_from_ptr`]; + return { + /** + * returns existing stored SRS or falls back to creating a new one + */ + create(size: number): WasmSrs { + let srs = srsStore[f][size] satisfies WasmSrs as WasmSrs | undefined; + + if (srs === undefined) { + if (cache === undefined) { + // if there is no cache, create SRS in memory + srs = createSrs(size); + } else { + let header = cacheHeaderSrs(f, size); + + // try to read SRS from cache / recompute and write if not found + srs = readCache(cache, header, (bytes) => { + // TODO: this takes a bit too long, about 300ms for 2^16 + // `pointsToRust` is the clear bottleneck + let jsonSrs: OrInfinityJson[] = JSON.parse(new TextDecoder().decode(bytes)); + let mlSrs = MlArray.mapTo(jsonSrs, OrInfinity.fromJSON); + let wasmSrs = conversion[f].pointsToRust(mlSrs); + return setSrs(wasmSrs); + }); + + if (srs === undefined) { + // not in cache + srs = createSrs(size); + + if (cache.canWrite) { + let wasmSrs = getSrs(srs); + let mlSrs = conversion[f].pointsFromRust(wasmSrs); + let jsonSrs = MlArray.mapFrom(mlSrs, OrInfinity.toJSON); + let bytes = new TextEncoder().encode(JSON.stringify(jsonSrs)); + + writeCache(cache, header, bytes); + } + } + } + + srsStore[f][size] = srs; + } + + // TODO should we call freeOnFinalize() and expose a function to clean the SRS cache? + return srsStore[f][size]; + }, + + /** + * returns ith Lagrange basis commitment for a given domain size + */ + lagrangeCommitment(srs: WasmSrs, domainSize: number, i: number): PolyComm { + // happy, fast case: if basis is already stored on the srs, return the ith commitment + let commitment = maybeLagrangeCommitment(srs, domainSize, i); + + if (commitment === undefined) { + if (cache === undefined) { + // if there is no cache, recompute and store basis in memory + commitment = lagrangeCommitment(srs, domainSize, i); + } else { + // try to read lagrange basis from cache / recompute and write if not found + let header = cacheHeaderLagrange(f, domainSize); + let didRead = readCacheLazy( + cache, + header, + conversion, + f, + srs, + domainSize, + setLagrangeBasis + ); + if (didRead !== true) { + // not in cache + if (cache.canWrite) { + // TODO: this code path will throw on the web since `caml_${f}_srs_get_lagrange_basis` is not properly implemented + // using a writable cache in the browser seems to be fairly uncommon though, so it's at least an 80/20 solution + let wasmComms = getLagrangeBasis(srs, domainSize); + let mlComms = conversion[f].polyCommsFromRust(wasmComms); + let comms = polyCommsToJSON(mlComms); + let bytes = new TextEncoder().encode(JSON.stringify(comms)); + writeCache(cache, header, bytes); + } else { + lagrangeCommitment(srs, domainSize, i); + } + } + // here, basis is definitely stored on the srs + let c = maybeLagrangeCommitment(srs, domainSize, i); + assert(c !== undefined, 'commitment exists after setting'); + commitment = c; + } + } + + // edge case for when we have a writeable cache and the basis was already stored on the srs + // but we didn't store it in the cache seperately yet + if (commitment && cache && cache.canWrite) { + let header = cacheHeaderLagrange(f, domainSize); + let didRead = readCacheLazy( + cache, + header, + conversion, + f, + srs, + domainSize, + setLagrangeBasis + ); + // only proceed for entries we haven't written to the cache yet + if (didRead !== true) { + // same code as above - write the lagrange basis to the cache if it wasn't there already + // currently we re-generate the basis via `getLagrangeBasis` - we could derive this from the + // already existing `commitment` instead, but this is simpler and the performance impact is negligible + let wasmComms = getLagrangeBasis(srs, domainSize); + let mlComms = conversion[f].polyCommsFromRust(wasmComms); + let comms = polyCommsToJSON(mlComms); + let bytes = new TextEncoder().encode(JSON.stringify(comms)); + + writeCache(cache, header, bytes); + } + } + return conversion[f].polyCommFromRust(commitment); + }, + + /** + * Returns the Lagrange basis commitments for the whole domain + */ + lagrangeCommitmentsWholeDomain(srs: WasmSrs, domainSize: number) { + // instead of getting the entire commitment directly (which works for nodejs/servers), we get a pointer to the commitment + // and then read the commitment from the pointer + // this is because the web worker implementation currently does not support returning UintXArray's directly + // hence we return a pointer from wasm, funnel it through the web worker + // and then read the commitment from the pointer in the main thread (where UintXArray's are supported) + // see https://github.com/o1-labs/o1js-bindings/blob/09e17b45e0c2ca2b51cd9ed756106e17ca1cf36d/js/web/worker-spec.js#L110-L115 + let ptr = lagrangeCommitmentsWholeDomainPtr(srs, domainSize); + let wasmComms = getCommitmentsWholeDomainByPtr(ptr); + let mlComms = conversion[f].polyCommsFromRust(wasmComms); + return mlComms; + }, + + /** + * adds Lagrange basis for a given domain size + */ + addLagrangeBasis(srs: WasmSrs, logSize: number) { + // this ensures that basis is stored on the srs, no need to duplicate caching logic + this.lagrangeCommitment(srs, 1 << logSize, 0); + }, + }; +} + +type PolyCommJson = { + shifted: OrInfinityJson[]; + unshifted: OrInfinityJson | undefined; +}; + +function polyCommsToJSON(comms: MlArray): PolyCommJson[] { + return MlArray.mapFrom(comms, ([, elems]) => { + return { + shifted: MlArray.mapFrom(elems, OrInfinity.toJSON), + unshifted: undefined, + }; + }); +} + +function polyCommsFromJSON(json: PolyCommJson[]): MlArray { + return MlArray.mapTo(json, ({ shifted, unshifted }) => { + return [0, MlArray.mapTo(shifted, OrInfinity.fromJSON)]; + }); +} + +function readCacheLazy( + cache: Cache, + header: CacheHeader, + conversion: RustConversion, + f: 'fp' | 'fq', + srs: WasmSrs, + domainSize: number, + setLagrangeBasis: (srs: WasmSrs, domainSize: number, comms: Uint32Array) => void +) { + if (CacheReadRegister.get(header.uniqueId) === true) return true; + return readCache(cache, header, (bytes) => { + let comms: PolyCommJson[] = JSON.parse(new TextDecoder().decode(bytes)); + let mlComms = polyCommsFromJSON(comms); + let wasmComms = conversion[f].polyCommsToRust(mlComms); + + setLagrangeBasis(srs, domainSize, wasmComms); + CacheReadRegister.set(header.uniqueId, true); + return true; + }); +} diff --git a/src/bindings/crypto/native/util.ts b/src/bindings/crypto/native/util.ts new file mode 100644 index 000000000..af3695627 --- /dev/null +++ b/src/bindings/crypto/native/util.ts @@ -0,0 +1,20 @@ +export { withPrefix, mapTuple }; + +function withPrefix>(prefix: prefix, obj: T) { + return Object.fromEntries( + Object.entries(obj).map(([k, v]) => { + return [`${prefix}_${k}`, v]; + }) + ) as { + [k in keyof T & string as `${prefix}_${k}`]: T[k]; + }; +} + +type Tuple = [T, ...T[]] | []; + +function mapTuple, B>( + tuple: T, + f: (a: T[number]) => B +): { [i in keyof T]: B } { + return tuple.map(f) as any; +} diff --git a/src/bindings/crypto/native/vector.ts b/src/bindings/crypto/native/vector.ts new file mode 100644 index 000000000..c7edb98e0 --- /dev/null +++ b/src/bindings/crypto/native/vector.ts @@ -0,0 +1,38 @@ +/** + * TS implementation of Kimchi_bindings.FieldVectors + */ +import { MlArray } from '../../../lib/ml/base.js'; +import { Field } from './field.js'; +import { withPrefix } from './util.js'; + +export { FpVectorBindings, FqVectorBindings }; +export { FieldVector }; + +type FieldVector = MlArray; + +const FieldVectorBindings = { + create(): FieldVector { + // OCaml tag for arrays, so that we can use the same utility fns on both + return [0]; + }, + length(v: FieldVector): number { + return v.length - 1; + }, + emplace_back(v: FieldVector, x: Field): void { + v.push(x); + }, + get(v: FieldVector, i: number): Field { + let value = v[i + 1] as Field | undefined; + if (value === undefined) { + throw Error(`FieldVector.get(): Index out of bounds, got ${i}/${v.length - 1}`); + } + // copying to a new array to break mutable reference + return [...value]; + }, + set(v: FieldVector, i: number, x: Field): void { + v[i + 1] = x; + }, +}; + +const FpVectorBindings = withPrefix('caml_fp_vector', FieldVectorBindings); +const FqVectorBindings = withPrefix('caml_fq_vector', FieldVectorBindings);