Skip to content

Commit a5c15ad

Browse files
authored
Merge pull request #1932 from o1-labs/feature/conditional-recursive-proving
Conditional recursion from within ZkProgram
2 parents f7d522e + dec0558 commit a5c15ad

File tree

4 files changed

+153
-22
lines changed

4 files changed

+153
-22
lines changed

CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm
2222
### Added
2323

2424
- `ZkProgram` to support non-pure provable types as inputs and outputs https://github.com/o1-labs/o1js/pull/1828
25-
- API for recursively proving a ZkProgram method from within another https://github.com/o1-labs/o1js/pull/1931
25+
- APIs for recursively proving a ZkProgram method from within another https://github.com/o1-labs/o1js/pull/1931 https://github.com/o1-labs/o1js/pull/1932
2626
- `let recursive = Experimental.Recursive(program);`
2727
- `recursive.<methodName>(...args): Promise<PublicOutput>`
28+
- `recursive.<methodName>.if(condition, ...args): Promise<PublicOutput>`
2829
- This also works within the same program, as long as the return value is type-annotated
2930
- Add `enforceTransactionLimits` parameter on Network https://github.com/o1-labs/o1js/issues/1910
3031
- Method for optional types to assert none https://github.com/o1-labs/o1js/pull/1922
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
/**
2+
* This shows how to prove an arbitrarily long chain of hashes using ZkProgram, i.e.
3+
* `hash^n(x) = y`.
4+
*
5+
* We implement this as a self-recursive ZkProgram, using `proveRecursivelyIf()`
6+
*/
7+
import {
8+
assert,
9+
Bool,
10+
Experimental,
11+
Field,
12+
Poseidon,
13+
Provable,
14+
Struct,
15+
ZkProgram,
16+
} from 'o1js';
17+
18+
const HASHES_PER_PROOF = 30;
19+
20+
class HashChainSpec extends Struct({ x: Field, n: Field }) {}
21+
22+
const hashChain = ZkProgram({
23+
name: 'hash-chain',
24+
publicInput: HashChainSpec,
25+
publicOutput: Field,
26+
27+
methods: {
28+
chain: {
29+
privateInputs: [],
30+
31+
async method({ x, n }: HashChainSpec) {
32+
Provable.log('hashChain (start method)', n);
33+
let y = x;
34+
let k = Field(0);
35+
let reachedN = Bool(false);
36+
37+
for (let i = 0; i < HASHES_PER_PROOF; i++) {
38+
reachedN = k.equals(n);
39+
y = Provable.if(reachedN, y, Poseidon.hash([y]));
40+
k = Provable.if(reachedN, n, k.add(1));
41+
}
42+
43+
// we have y = hash^k(x)
44+
// now do z = hash^(n-k)(y) = hash^n(x) by calling this method recursively
45+
// except if we have k = n, then ignore the output and use y
46+
let z: Field = await hashChainRecursive.chain.if(reachedN.not(), {
47+
x: y,
48+
n: n.sub(k),
49+
});
50+
z = Provable.if(reachedN, y, z);
51+
Provable.log('hashChain (start proving)', n);
52+
return { publicOutput: z };
53+
},
54+
},
55+
},
56+
});
57+
let hashChainRecursive = Experimental.Recursive(hashChain);
58+
59+
await hashChain.compile();
60+
61+
let n = 100;
62+
let x = Field.random();
63+
64+
let { proof } = await hashChain.chain({ x, n: Field(n) });
65+
66+
assert(await hashChain.verify(proof), 'Proof invalid');
67+
68+
// check that the output is correct
69+
let z = Array.from({ length: n }, () => 0).reduce((y) => Poseidon.hash([y]), x);
70+
proof.publicOutput.assertEquals(z, 'Output is incorrect');
71+
72+
console.log('Finished hash chain proof');

src/lib/proof-system/recursive.ts

Lines changed: 78 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import { Tuple } from '../util/types.js';
55
import { Proof } from './proof.js';
66
import { mapObject, mapToObject, zip } from '../util/arrays.js';
77
import { Undefined, Void } from './zkprogram.js';
8+
import { Bool } from '../provable/bool.js';
89

910
export { Recursive };
1011

@@ -25,6 +26,7 @@ function Recursive<
2526
...args: any
2627
) => Promise<{ publicOutput: InferProvable<PublicOutputType> }>;
2728
};
29+
maxProofsVerified: () => Promise<0 | 1 | 2>;
2830
} & {
2931
[Key in keyof PrivateInputs]: (...args: any) => Promise<{
3032
proof: Proof<
@@ -38,7 +40,13 @@ function Recursive<
3840
InferProvable<PublicInputType>,
3941
InferProvable<PublicOutputType>,
4042
PrivateInputs[Key]
41-
>;
43+
> & {
44+
if: ConditionalRecursiveProver<
45+
InferProvable<PublicInputType>,
46+
InferProvable<PublicOutputType>,
47+
PrivateInputs[Key]
48+
>;
49+
};
4250
} {
4351
type PublicInput = InferProvable<PublicInputType>;
4452
type PublicOutput = InferProvable<PublicOutputType>;
@@ -64,9 +72,15 @@ function Recursive<
6472

6573
let regularRecursiveProvers = mapToObject(methodKeys, (key) => {
6674
return async function proveRecursively_(
75+
conditionAndConfig: Bool | { condition: Bool; domainLog2?: number },
6776
publicInput: PublicInput,
6877
...args: TupleToInstances<PrivateInputs[MethodKey]>
69-
) {
78+
): Promise<PublicOutput> {
79+
let condition =
80+
conditionAndConfig instanceof Bool
81+
? conditionAndConfig
82+
: conditionAndConfig.condition;
83+
7084
// create the base proof in a witness block
7185
let proof = await Provable.witnessAsync(SelfProof, async () => {
7286
// move method args to constants
@@ -78,6 +92,20 @@ function Recursive<
7892
Provable.toConstant(type, arg)
7993
);
8094

95+
if (!condition.toBoolean()) {
96+
let publicOutput: PublicOutput =
97+
ProvableType.synthesize(publicOutputType);
98+
let maxProofsVerified = await zkprogram.maxProofsVerified();
99+
return SelfProof.dummy(
100+
publicInput,
101+
publicOutput,
102+
maxProofsVerified,
103+
conditionAndConfig instanceof Bool
104+
? undefined
105+
: conditionAndConfig.domainLog2
106+
);
107+
}
108+
81109
let prover = zkprogram[key];
82110

83111
if (hasPublicInput) {
@@ -96,32 +124,48 @@ function Recursive<
96124

97125
// declare and verify the proof, and return its public output
98126
proof.declare();
99-
proof.verify();
127+
proof.verifyIf(condition);
100128
return proof.publicOutput;
101129
};
102130
});
103131

104-
type RecursiveProver_<K extends MethodKey> = RecursiveProver<
105-
PublicInput,
106-
PublicOutput,
107-
PrivateInputs[K]
108-
>;
109-
type RecursiveProvers = {
110-
[K in MethodKey]: RecursiveProver_<K>;
111-
};
112-
let proveRecursively: RecursiveProvers = mapToObject(
113-
methodKeys,
114-
(key: MethodKey) => {
132+
return mapObject(
133+
regularRecursiveProvers,
134+
(
135+
prover
136+
): RecursiveProver<PublicInput, PublicOutput, PrivateInputs[MethodKey]> & {
137+
if: ConditionalRecursiveProver<
138+
PublicInput,
139+
PublicOutput,
140+
PrivateInputs[MethodKey]
141+
>;
142+
} => {
115143
if (!hasPublicInput) {
116-
return ((...args: any) =>
117-
regularRecursiveProvers[key](undefined as any, ...args)) as any;
144+
return Object.assign(
145+
((...args: any) =>
146+
prover(new Bool(true), undefined as any, ...args)) as any,
147+
{
148+
if: (
149+
condition: Bool | { condition: Bool; domainLog2?: number },
150+
...args: any
151+
) => prover(condition, undefined as any, ...args),
152+
}
153+
);
118154
} else {
119-
return regularRecursiveProvers[key] as any;
155+
return Object.assign(
156+
((pi: PublicInput, ...args: any) =>
157+
prover(new Bool(true), pi, ...args)) as any,
158+
{
159+
if: (
160+
condition: Bool | { condition: Bool; domainLog2?: number },
161+
pi: PublicInput,
162+
...args: any
163+
) => prover(condition, pi, ...args),
164+
}
165+
);
120166
}
121167
}
122168
);
123-
124-
return proveRecursively;
125169
}
126170

127171
type RecursiveProver<
@@ -135,6 +179,21 @@ type RecursiveProver<
135179
...args: TupleToInstances<Args>
136180
) => Promise<PublicOutput>;
137181

182+
type ConditionalRecursiveProver<
183+
PublicInput,
184+
PublicOutput,
185+
Args extends Tuple<ProvableType>
186+
> = PublicInput extends undefined
187+
? (
188+
condition: Bool | { condition: Bool; domainLog2?: number },
189+
...args: TupleToInstances<Args>
190+
) => Promise<PublicOutput>
191+
: (
192+
condition: Bool | { condition: Bool; domainLog2?: number },
193+
publicInput: PublicInput,
194+
...args: TupleToInstances<Args>
195+
) => Promise<PublicOutput>;
196+
138197
type TupleToInstances<T> = {
139198
[I in keyof T]: InferProvable<T[I]>;
140199
};

src/lib/proof-system/zkprogram.ts

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ import {
3030
unsetSrsCache,
3131
} from '../../bindings/crypto/bindings/srs.js';
3232
import {
33-
ProvablePure,
3433
ProvableType,
3534
ProvableTypePure,
3635
ToProvable,
@@ -55,7 +54,7 @@ import {
5554
import { emptyWitness } from '../provable/types/util.js';
5655
import { InferValue } from '../../bindings/lib/provable-generic.js';
5756
import { DeclaredProof, ZkProgramContext } from './zkprogram-context.js';
58-
import { mapObject, mapToObject, zip } from '../util/arrays.js';
57+
import { mapObject, mapToObject } from '../util/arrays.js';
5958

6059
// public API
6160
export {

0 commit comments

Comments
 (0)