Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 62 additions & 11 deletions src/bindings/crypto/bindings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,15 @@ import { verifierIndexConversion } from './bindings/conversion-verifier-index.js
import { oraclesConversion } from './bindings/conversion-oracles.js';
import { jsEnvironment } from './bindings/env.js';
import { srs } from './bindings/srs.js';
// native
import { conversionCore as conversionCoreNative } from './native/conversion-core.js';
import { fieldsFromRustFlat as fieldsFromRustFlatNative, fieldsToRustFlat as fieldsToRustFlatNative } from './native/conversion-base.js';
import { proofConversion as proofConversionNative } from './native/conversion-proof.js';
import { verifierIndexConversion as verifierIndexConversionNative } from './native/conversion-verifier-index.js';
import { oraclesConversion as oraclesConversionNative } from './native/conversion-oracles.js';
import { srs as srsNative } from './native/srs.js';

export { getRustConversion, RustConversion, Wasm };
export { getRustConversion, type RustConversion, type NativeConversion, type Wasm };

const tsBindings = {
jsEnvironment,
Expand All @@ -31,19 +38,28 @@ const tsBindings = {
...FpVectorBindings,
...FqVectorBindings,
rustConversion: createRustConversion,
srs: (wasm: Wasm) => srs(wasm, getRustConversion(wasm)),
srs: (wasm: Wasm) => {
const bundle = getConversionBundle(wasm);
return bundle.srsFactory(wasm, bundle.conversion);
},
};

// this is put in a global variable so that mina/src/lib/crypto/kimchi_bindings/js/bindings.js finds it
(globalThis as any).__snarkyTsBindings = tsBindings;

type Wasm = typeof wasmNamespace;

function createRustConversion(wasm: Wasm) {
let core = conversionCore(wasm);
let verifierIndex = verifierIndexConversion(wasm, core);
let oracles = oraclesConversion(wasm);
let proof = proofConversion(wasm, core);
function createRustConversion(wasm: Wasm): RustConversion {
return shouldUseNativeConversion(wasm)
? createNativeConversion(wasm)
: createWasmConversion(wasm);
}

function createWasmConversion(wasm: Wasm) {
const core = conversionCore(wasm);
const verifierIndex = verifierIndexConversion(wasm, core);
const oracles = oraclesConversion(wasm);
const proof = proofConversion(wasm, core);

return {
fp: { ...core.fp, ...verifierIndex.fp, ...oracles.fp, ...proof.fp },
Expand All @@ -55,10 +71,45 @@ function createRustConversion(wasm: Wasm) {
};
}

type RustConversion = ReturnType<typeof createRustConversion>;
type WasmConversion = ReturnType<typeof createWasmConversion>;
type NativeConversion = ReturnType<typeof createNativeConversion>;
type RustConversion = WasmConversion | NativeConversion;

function getRustConversion(wasm: Wasm): RustConversion {
return createRustConversion(wasm);
}

function createNativeConversion(wasm: Wasm) {
const core = conversionCoreNative(wasm);
const verifierIndex = verifierIndexConversionNative(wasm, core);
const oracles = oraclesConversionNative(wasm);
const proof = proofConversionNative(wasm, core);

return {
fp: { ...core.fp, ...verifierIndex.fp, ...oracles.fp, ...proof.fp },
fq: { ...core.fq, ...verifierIndex.fq, ...oracles.fq, ...proof.fq },
fieldsToRustFlatNative,
fieldsFromRustFlatNative,
wireToRust: core.wireToRust,
mapMlArrayToRustVector: core.mapMlArrayToRustVector,
};
}

function shouldUseNativeConversion(wasm: Wasm): boolean {
const marker = (wasm as any).__kimchi_use_native;
const globalMarker =
typeof globalThis !== 'undefined' &&
(globalThis as any).__kimchi_use_native;
return Boolean(marker || globalMarker);
}

let rustConversion: RustConversion | undefined;
type ConversionBundle =
| { conversion: WasmConversion; srsFactory: typeof srs }
| { conversion: NativeConversion; srsFactory: typeof srsNative };

function getRustConversion(wasm: Wasm) {
return rustConversion ?? (rustConversion = createRustConversion(wasm));
function getConversionBundle(wasm: Wasm): ConversionBundle {
if (shouldUseNativeConversion(wasm)) {
return { conversion: createNativeConversion(wasm), srsFactory: srsNative };
}
return { conversion: createWasmConversion(wasm), srsFactory: srs };
}
150 changes: 150 additions & 0 deletions src/bindings/crypto/native/conversion-core.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import { Buffer } from 'buffer';
import type * as napiNamespace from '../../compiled/node_bindings/plonk_wasm.cjs';
import { MlArray } from '../../../lib/ml/base.js';
import { OrInfinity, Gate, PolyComm, Wire } from './kimchi-types.js';
import {
WasmAffine as NapiAffine,
affineFromRust,
affineToRust,
fieldsFromRustFlat,
fieldsToRustFlat,
} from './conversion-base.js';
import { mapTuple } from './util.js';


type Napi = typeof napiNamespace;

type NapiPolyComm = napiNamespace.WasmFpPolyComm | napiNamespace.WasmFqPolyComm;

type NapiClasses = {
makeAffine: () => NapiAffine;
PolyComm: typeof napiNamespace.WasmFpPolyComm | typeof napiNamespace.WasmFqPolyComm;
};

export function conversionCore(napi: Napi) {
const fp = conversionCorePerField({
makeAffine: napi.caml_vesta_affine_one,
PolyComm: napi.WasmFpPolyComm,
});
const fq = conversionCorePerField({
makeAffine: napi.caml_pallas_affine_one,
PolyComm: napi.WasmFqPolyComm,
});

return {
fp,
fq,
wireToRust: fp.wireToRust, // doesn't depend on the field
mapMlArrayToRustVector<TMl, TRust extends {}>(
[, ...array]: MlArray<TMl>,
map: (x: TMl) => TRust
): TRust[] {
return array.map(map);
},
};
}

function conversionCorePerField({ makeAffine, PolyComm: PolyCommClass }: NapiClasses) {
const self = {
wireToRust([, row, col]: Wire) {
return { row, col };
},

vectorToRust: fieldsToRustFlat,
vectorFromRust: fieldsFromRustFlat,

gateToRust(gate: Gate): any {
const [, typ, [, ...wires], coeffs] = gate;
const mapped = mapTuple(wires, self.wireToRust);
const nativeWires = {
w0: mapped[0],
w1: mapped[1],
w2: mapped[2],
w3: mapped[3],
w4: mapped[4],
w5: mapped[5],
w6: mapped[6],
} as const;
return {
typ,
wires: nativeWires,
coeffs: toBuffer(fieldsToRustFlat(coeffs)),
};
},
gateFromRust(nativeGate: any): Gate {
const { typ, wires, coeffs } = nativeGate;
const mlWires: Gate[2] = [
0,
[0, wires.w0.row, wires.w0.col],
[0, wires.w1.row, wires.w1.col],
[0, wires.w2.row, wires.w2.col],
[0, wires.w3.row, wires.w3.col],
[0, wires.w4.row, wires.w4.col],
[0, wires.w5.row, wires.w5.col],
[0, wires.w6.row, wires.w6.col],
];
const mlCoeffs = fieldsFromRustFlat(toUint8Array(coeffs));
return [0, typ, mlWires, mlCoeffs];
},

pointToRust(point: OrInfinity) {
return affineToRust(point, makeAffine);
},
pointFromRust(point: any): OrInfinity {
return affineFromRust(point);
},

pointsToRust([, ...points]: MlArray<OrInfinity>): NapiAffine[] {
return points.map(self.pointToRust);
},
pointsFromRust(points: unknown): MlArray<OrInfinity> {
const list = asArray<any>(points, 'pointsFromRust');
return [0, ...list.map(self.pointFromRust)];
},

polyCommToRust(polyComm: PolyComm): NapiPolyComm {
const [, camlElems] = polyComm;
const rustUnshifted = self.pointsToRust(camlElems);
return new PolyCommClass(rustUnshifted as any, undefined as any);
},
polyCommFromRust(polyComm: NapiPolyComm): PolyComm {
const rustUnshifted = asArray<any>((polyComm as any).unshifted, 'polyComm.unshifted');
const mlUnshifted = rustUnshifted.map(self.pointFromRust);
return [0, [0, ...mlUnshifted]];
},

polyCommsToRust([, ...comms]: MlArray<PolyComm>): NapiPolyComm[] {
return comms.map(self.polyCommToRust);
},
polyCommsFromRust(rustComms: unknown): MlArray<PolyComm> {
const list = asArray<NapiPolyComm>(rustComms, 'polyCommsFromRust');
return [0, ...list.map((comm) => self.polyCommFromRust(comm))];
},
};

return self;
}

function asArray<T>(value: unknown, context: string): T[] {
if (value == null) return [];
if (Array.isArray(value)) return value as T[];
throw Error(`${context}: expected array of native values`);
}

function toUint8Array(value: any): Uint8Array {
if (value instanceof Uint8Array) return value;
if (Array.isArray(value)) return Uint8Array.from(value);
if (value && typeof value === 'object') {
if (ArrayBuffer.isView(value)) {
const view = value as ArrayBufferView;
return new Uint8Array(view.buffer, view.byteOffset, view.byteLength);
}
if (value instanceof ArrayBuffer) return new Uint8Array(value);
}
throw Error('Expected byte array');
}

function toBuffer(bytes: Uint8Array): Buffer {
if (Buffer.isBuffer(bytes)) return bytes;
return Buffer.from(bytes.buffer, bytes.byteOffset, bytes.byteLength);
}
17 changes: 17 additions & 0 deletions src/bindings/scripts/build-o1js-node-artifacts.sh
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ run_cmd cp "${MINA_PATH}"/src/config.mlh "src"
run_cmd cp -r "${MINA_PATH}"/src/config "src/config"
ok "Mina config files copied"

info "Building Kimchi native bindings for Node.js..."
run_cmd dune b "${KIMCHI_BINDINGS}"/js/native
ok "Kimchi native bindings built"

info "Building Kimchi bindings for Node.js..."
run_cmd dune b "${KIMCHI_BINDINGS}"/js/node_js
ok "Kimchi bindings built"
Expand Down Expand Up @@ -96,6 +100,14 @@ run_cmd mkdir -p "${BINDINGS_PATH}"
run_cmd chmod -R 777 "${BINDINGS_PATH}"
ok "Output directory prepared"

info "Preparing native bindings directory..."
run_cmd mkdir -p src/bindings/compiled/native
ok "Native bindings directory prepared"

info "Copying N-API bindings..."
run_cmd cp _build/default/"${KIMCHI_BINDINGS}"/js/native/plonk_napi* "${BINDINGS_PATH}"
ok "N-API bindings copied"

info "Copying WASM bindings..."
run_cmd cp _build/default/"${KIMCHI_BINDINGS}"/js/node_js/plonk_wasm* "${BINDINGS_PATH}"
run_cmd mv -f "${BINDINGS_PATH}"/plonk_wasm.js "${BINDINGS_PATH}"/plonk_wasm.cjs
Expand All @@ -114,6 +126,11 @@ fi
run_cmd mv -f "${BINDINGS_PATH}"/o1js_node.bc.js "${BINDINGS_PATH}"/o1js_node.bc.cjs
ok "Node.js bindings copied"

info "Copying native bindings..."
run_cmd cp _build/default/"${KIMCHI_BINDINGS}"/js/native/plonk_napi.node src/bindings/compiled/native/
run_cmd chmod 777 src/bindings/compiled/native/plonk_napi.node
ok "Native bindings copied"

info "Updating WASM references in bindings..."
run_cmd sed -i 's/plonk_wasm.js/plonk_wasm.cjs/' "${BINDINGS_PATH}"/o1js_node.bc.cjs
ok "WASM references updated"
Expand Down
1 change: 1 addition & 0 deletions src/build/copy-to-dist.js
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ await copyFromTo(
[
'src/bindings.d.ts',
'src/bindings/compiled/_node_bindings',
'src/bindings/compiled/native',
'src/bindings/compiled/node_bindings/plonk_wasm.d.cts',
],
'src/',
Expand Down
Loading