diff --git a/packages/library/src/hooks/TransactionFeeHook.ts b/packages/library/src/hooks/TransactionFeeHook.ts index d22b2fc9d..871d2d0a3 100644 --- a/packages/library/src/hooks/TransactionFeeHook.ts +++ b/packages/library/src/hooks/TransactionFeeHook.ts @@ -51,7 +51,8 @@ const errors = { export class TransactionFeeHook extends ProvableTransactionHook { public constructor( // dependency on runtime, since balances are part of runtime logic - @inject("Runtime") public runtime: Runtime + @inject("Runtime") public runtime: Runtime, + @inject("Balances") public balances: Balances ) { super(); } @@ -91,10 +92,6 @@ export class TransactionFeeHook extends ProvableTransactionHook("Balances"); - } - public get feeAnalyzer() { if (this.persistedFeeAnalyzer === undefined) { throw new Error("TransactionFeeHook.start not called by protocol"); diff --git a/packages/protocol/src/protocol/Protocol.ts b/packages/protocol/src/protocol/Protocol.ts index aa62328d1..2542a3b56 100644 --- a/packages/protocol/src/protocol/Protocol.ts +++ b/packages/protocol/src/protocol/Protocol.ts @@ -187,6 +187,28 @@ export class Protocol< ); } }); + + // Cross-register all runtime modules to the protocol container for easier + // access of runtime modules inside protocol hooks + if (this.container.isRegistered("Runtime", true)) { + const runtimeContainer: ModuleContainer = + this.container.resolve("Runtime"); + + runtimeContainer.moduleNames.forEach((runtimeModuleName) => { + this.container.register(runtimeModuleName, { + useFactory: (dependencyContainer) => { + // Prevents creation of closure + const runtime: ModuleContainer = + dependencyContainer.resolve("Runtime"); + return runtime.resolve(runtimeModuleName); + }, + }); + }); + } else { + log.warn( + "Couldn't resolve Runtime reference in Protocol, resolving RuntimeModules in hooks won't be available" + ); + } } public async start() { diff --git a/packages/protocol/test/TestingProtocol.ts b/packages/protocol/test/TestingProtocol.ts index 09e82cbd1..568f26c95 100644 --- a/packages/protocol/test/TestingProtocol.ts +++ b/packages/protocol/test/TestingProtocol.ts @@ -1,21 +1,17 @@ -import { WithZkProgrammable, ZkProgrammable } from "@proto-kit/common"; import { container } from "tsyringe"; +import { Runtime } from "@proto-kit/module"; +import { Balance } from "@proto-kit/sequencer/test/integration/mocks/Balance"; +import { NoopRuntime } from "@proto-kit/sequencer/test/integration/mocks/NoopRuntime"; import { AccountStateHook, BlockHeightHook, BlockProver, LastStateRootBlockHook, - MethodPublicOutput, Protocol, StateTransitionProver, } from "../src"; -class RuntimeMock implements WithZkProgrammable { - zkProgrammable: ZkProgrammable = - undefined as unknown as ZkProgrammable; -} - export function createAndInitTestingProtocol() { const ProtocolClass = Protocol.from({ modules: { @@ -35,11 +31,22 @@ export function createAndInitTestingProtocol() { StateTransitionProver: {}, LastStateRoot: {}, }); - protocol.create(() => container.createChildContainer()); - protocol.registerValue({ - Runtime: new RuntimeMock(), + const appChain = container.createChildContainer(); + + appChain.register("Runtime", { + useClass: Runtime.from({ + modules: { + Balance, + NoopRuntime, + }, + config: { + Balance: {}, + NoopRuntime: {}, + }, + }), }); + protocol.create(() => appChain.createChildContainer()); return protocol; } diff --git a/packages/sdk/test/modularization.test.ts b/packages/sdk/test/modularization.test.ts index 118842f5f..981a1de91 100644 --- a/packages/sdk/test/modularization.test.ts +++ b/packages/sdk/test/modularization.test.ts @@ -2,7 +2,10 @@ import "reflect-metadata"; import { MethodIdResolver, Runtime, RuntimeModule } from "@proto-kit/module"; import { ChildContainerProvider } from "@proto-kit/common"; import { Protocol, ProtocolModule } from "@proto-kit/protocol"; -import { VanillaProtocolModules } from "@proto-kit/library"; +import { + VanillaProtocolModules, + VanillaRuntimeModules, +} from "@proto-kit/library"; import { Sequencer, SequencerModule } from "@proto-kit/sequencer"; import { PrivateKey } from "o1js"; @@ -13,7 +16,6 @@ class TestRuntimeModule extends RuntimeModule { public create(childContainerProvider: ChildContainerProvider) { super.create(childContainerProvider); - // Just to test if it doesn't throw childContainerProvider(); @@ -47,9 +49,9 @@ describe("modularization", () => { it("should initialize all modules correctly", async () => { const appChain = AppChain.from({ Runtime: Runtime.from({ - modules: { + modules: VanillaRuntimeModules.with({ TestRuntimeModule, - }, + }), }), Protocol: Protocol.from({ modules: VanillaProtocolModules.with({ @@ -66,6 +68,7 @@ describe("modularization", () => { appChain.configurePartial({ Runtime: { + Balances: {}, TestRuntimeModule: {}, }, Protocol: {