diff --git a/packages/module/src/method/MethodParameterEncoder.ts b/packages/module/src/method/MethodParameterEncoder.ts index c9132655d..1d47dca5c 100644 --- a/packages/module/src/method/MethodParameterEncoder.ts +++ b/packages/module/src/method/MethodParameterEncoder.ts @@ -62,7 +62,7 @@ function getAllPropertyNamesOfPrototypeChain(type: unknown): string[] { ); } -function isFlexibleProvablePure( +export function isFlexibleProvablePure( type: unknown ): type is FlexibleProvablePure { // The required properties are defined on the prototype for Structs and CircuitValues @@ -73,35 +73,43 @@ function isFlexibleProvablePure( return mandatory.every((prop) => props.includes(prop)); } -export class MethodParameterEncoder { - public static fromMethod(target: RuntimeModule, methodName: string) { - // eslint-disable-next-line @typescript-eslint/no-unsafe-assignment - const paramtypes: ArgTypeArray = Reflect.getMetadata( - "design:paramtypes", - target, - methodName +export function checkArgsProvable( + target: RuntimeModule, + methodName: string +) { + // eslint-disable-next-line @typescript-eslint/no-unsafe-assignment + const paramtypes: ArgTypeArray = Reflect.getMetadata( + "design:paramtypes", + target, + methodName + ); + + if (paramtypes === undefined) { + throw new Error( + `Method with name ${methodName} doesn't exist on this module` ); + } - if (paramtypes === undefined) { - throw new Error( - `Method with name ${methodName} doesn't exist on this module` - ); - } + const indizes = paramtypes + .map((type, index) => { + if (isProofBaseType(type) || isFlexibleProvablePure(type)) { + return undefined; + } + return `${index}`; + }) + .filter(filterNonUndefined); + if (indizes.length > 0) { + const indexString = indizes.reduce((a, b) => `${a}, ${b}`); + throw new Error( + `Not all arguments of method '${target.name}.${methodName}' are provable types or proofs (indizes: [${indexString}])` + ); + } + return paramtypes; +} - const indizes = paramtypes - .map((type, index) => { - if (isProofBaseType(type) || isFlexibleProvablePure(type)) { - return undefined; - } - return `${index}`; - }) - .filter(filterNonUndefined); - if (indizes.length > 0) { - const indexString = indizes.reduce((a, b) => `${a}, ${b}`); - throw new Error( - `Not all arguments of method '${target.name}.${methodName}' are provable types or proofs (indizes: [${indexString}])` - ); - } +export class MethodParameterEncoder { + public static fromMethod(target: RuntimeModule, methodName: string) { + const paramtypes = checkArgsProvable(target, methodName); return new MethodParameterEncoder(paramtypes); } diff --git a/packages/module/src/method/runtimeMethod.ts b/packages/module/src/method/runtimeMethod.ts index c793ef3dc..4688ea670 100644 --- a/packages/module/src/method/runtimeMethod.ts +++ b/packages/module/src/method/runtimeMethod.ts @@ -17,7 +17,10 @@ import { import type { RuntimeModule } from "../runtime/RuntimeModule.js"; -import { MethodParameterEncoder } from "./MethodParameterEncoder"; +import { + MethodParameterEncoder, + checkArgsProvable, +} from "./MethodParameterEncoder"; const errors = { runtimeNotProvided: (name: string) => @@ -196,11 +199,9 @@ function runtimeMethodInternal(options: { return ( target: RuntimeModule, methodName: string, - descriptor: TypedPropertyDescriptor< - // TODO Limit possible parameter types - (...args: any[]) => Promise - > + descriptor: TypedPropertyDescriptor<(...args: any[]) => Promise> ) => { + checkArgsProvable(target, methodName); const executionContext = container.resolve( RuntimeMethodExecutionContext ); diff --git a/packages/module/test/method/MethodParameterEncoder.test.ts b/packages/module/test/method/MethodParameterEncoder.test.ts index a356b087f..9a4c86a28 100644 --- a/packages/module/test/method/MethodParameterEncoder.test.ts +++ b/packages/module/test/method/MethodParameterEncoder.test.ts @@ -7,14 +7,9 @@ import { ZkProgram, Proof, } from "o1js"; -import { NonMethods, noop } from "@proto-kit/common"; +import { NonMethods } from "@proto-kit/common"; -import { - MethodParameterEncoder, - RuntimeModule, - runtimeModule, - runtimeMethod, -} from "../../src"; +import { MethodParameterEncoder } from "../../src"; class TestStruct extends Struct({ a: Field, @@ -124,29 +119,3 @@ describe("MethodParameterEncoder", () => { ); }, 30000); }); - -class TieredStruct extends TestStruct {} - -@runtimeModule() -class TestModule extends RuntimeModule { - @runtimeMethod() - public async foo( - a: TieredStruct, - b: PublicKey, - c: Field, - d: TestProof, - e: string - ) { - noop(); - } -} - -describe("MethodParameterEncoder construction", () => { - it("should throw on non-provable method signature", () => { - const module = new TestModule(); - module.name = "testModule"; - expect(() => MethodParameterEncoder.fromMethod(module, "foo")).toThrowError( - "'testModule.foo' are provable types or proofs (indizes: [4])" - ); - }); -}); diff --git a/packages/module/test/method/runtimeMethod-fail.test.ts b/packages/module/test/method/runtimeMethod-fail.test.ts new file mode 100644 index 000000000..60ec39658 --- /dev/null +++ b/packages/module/test/method/runtimeMethod-fail.test.ts @@ -0,0 +1,50 @@ +import { Bool, Field, PublicKey, Struct, ZkProgram } from "o1js"; +import { noop } from "@proto-kit/common"; + +import { runtimeMethod, RuntimeModule, runtimeModule } from "../../src"; + +class TestStruct extends Struct({ + a: Field, + b: Bool, +}) {} +class TieredStruct extends TestStruct {} +const TestProgram = ZkProgram({ + name: "TestProgram", + publicInput: PublicKey, + publicOutput: TestStruct, + methods: { + foo: { + privateInputs: [], + method: async (input: PublicKey) => { + return { + a: Field(input.x), + b: Bool(input.isOdd), + }; + }, + }, + }, +}); +class TestProof extends ZkProgram.Proof(TestProgram) {} + +describe("Creating module with non-provable method argument", () => { + it("should throw on non-provable method signature", () => { + expect(() => { + @runtimeModule() + // eslint-disable-next-line @typescript-eslint/no-unused-vars + class TestModule extends RuntimeModule { + @runtimeMethod() + public async foo( + a: TieredStruct, + b: PublicKey, + c: Field, + d: TestProof, + e: string + ) { + noop(); + } + } + }).toThrow( + "Not all arguments of method 'undefined.foo' are provable types or proofs (indizes: [4])" + ); + }); +});