@@ -2,23 +2,31 @@ import { z } from 'zod'
22
33import { getClientErrorMessage } from './errors'
44import { registerBridgeEvents } from './events/bridge'
5+ import { sortedMiddlewares , isServiceWorker , sendResponseMessage } from './utils'
56
67import { generateUUIDv4 } from '../uuid'
78
89import type { RegisterBridgeEventsOptions } from './events/bridge'
910import type {
10- FrameId ,
11- FrameType ,
12- EventType ,
11+ ControllerState ,
12+ DataStorage ,
1313 EventName ,
1414 EventParams ,
15- ParamsValidator ,
16- HandlerResult ,
15+ EventType ,
16+ FrameId ,
17+ RegisteredFrame ,
18+ MessageSource ,
19+ FrameType ,
1720 Handler ,
1821 HandlerMethods ,
22+ HandlerResult ,
1923 HandlerScope ,
20- DataStorage ,
21- ControllerState ,
24+ Middleware ,
25+ MiddlewareFn ,
26+ MiddlewareNextFn ,
27+ MiddlewareId ,
28+ ParamsValidator ,
29+ RegisteredMiddleware ,
2230} from './types'
2331
2432// NOTE: taken from https://semver.org/#is-there-a-suggested-regular-expression-regex-to-check-a-semver-string
@@ -35,8 +43,10 @@ const MESSAGE_SCHEMA = z.object({
3543} ) . strict ( )
3644
3745export class PostMessageController extends EventTarget {
38- #registeredFrames: Record < FrameId , FrameType > = { }
46+ #registeredFrames: Record < FrameId , RegisteredFrame > = { }
3947 #registeredHandlers: Record < HandlerScope , Record < EventType , Record < EventName , HandlerMethods < EventParams , HandlerResult > > > > = { }
48+ #registeredMiddlewares: Record < HandlerScope , Array < RegisteredMiddleware < EventParams , HandlerResult > > > = { }
49+ #middlewaresIdsMap: Record < MiddlewareId , HandlerScope > = { }
4050 #storage: DataStorage = { }
4151 state : ControllerState = { isBridgeReady : false }
4252
@@ -47,30 +57,109 @@ export class PostMessageController extends EventTarget {
4757 this . addHandler = this . addHandler . bind ( this )
4858 this . eventListener = this . eventListener . bind ( this )
4959 this . registerBridgeEvents = this . registerBridgeEvents . bind ( this )
60+ this . addMiddleware = this . addMiddleware . bind ( this )
61+ this . removeMiddleware = this . removeMiddleware . bind ( this )
5062 }
5163
64+ // ---- PRIVATE UTILITIES METHODS ----
65+
5266 private updateState ( state : Partial < ControllerState > ) {
5367 this . state = { ...this . state , ...state }
5468 this . dispatchEvent ( new CustomEvent ( 'statechange' , { detail : this . state } ) )
5569 }
5670
71+ private getWrappedHandler ( eventType : EventType , eventName : EventName , scope : HandlerScope , handler : Handler < EventParams , HandlerResult > ) : Handler < EventParams , HandlerResult > {
72+ const globalMiddlewares = ( this . #registeredMiddlewares[ '*' ] ?? [ ] )
73+ . filter ( mw =>
74+ ( ! mw . eventType || mw . eventType === eventType ) &&
75+ ( ! mw . eventName || mw . eventName === eventName ) &&
76+ ( mw . scope === '*' )
77+ )
78+ const scopedMiddlewares = ( this . #registeredMiddlewares[ scope ] ?? [ ] )
79+ . filter ( mw =>
80+ ( ! mw . eventType || mw . eventType === eventType ) &&
81+ ( ! mw . eventName || mw . eventName === eventName ) &&
82+ ( mw . scope === scope )
83+ )
84+
85+ // NOTE: sort middlewares by execution order
86+ const middlewares = sortedMiddlewares ( [ ...globalMiddlewares , ...scopedMiddlewares ] )
87+
88+ return middlewares . reduceRight < Handler < EventParams , HandlerResult > > (
89+ ( nextHandler , mw ) => {
90+ return async ( args ) => {
91+ const nextMiddlewareFn : MiddlewareNextFn < EventParams , HandlerResult > = ( nextArgs ) => {
92+ return nextHandler ( {
93+ ...args ,
94+ params : nextArgs ?. params ?? args . params ,
95+ } )
96+ }
97+
98+ return mw . fn ( {
99+ ...args ,
100+ next : nextMiddlewareFn ,
101+ } )
102+ }
103+ } ,
104+ handler
105+ )
106+ }
107+
108+ private getMessageSource ( source : Window | ServiceWorker ) : MessageSource | null {
109+ if ( isServiceWorker ( source ) ) {
110+ if ( source !== navigator . serviceWorker . controller ) return null
111+ return {
112+ ref : source ,
113+ id : 'worker' ,
114+ type : 'worker' ,
115+ }
116+ }
117+ if ( source === window ) {
118+ return {
119+ ref : source ,
120+ id : 'parent' ,
121+ type : 'window' ,
122+ }
123+ }
124+
125+ const registeredFrame = Object . entries ( this . #registeredFrames)
126+ . find ( ( [ , frame ] ) => frame . ref . contentWindow === source )
127+
128+ if ( ! registeredFrame ) return null
129+
130+ return {
131+ ref : registeredFrame [ 1 ] . ref ,
132+ id : registeredFrame [ 0 ] ,
133+ type : 'frame' ,
134+ metadata : registeredFrame [ 1 ] . metadata ,
135+ }
136+ }
137+
138+ // ---- SOURCE REGISTRATION METHODS ----
139+
57140 addFrame ( frame : FrameType ) : FrameId {
58141 const registeredFrame = Object . entries ( this . #registeredFrames)
59- . find ( ( [ , ref ] ) => ref === frame )
142+ . find ( ( [ , existingFrame ] ) => existingFrame . ref === frame )
60143 if ( registeredFrame ) {
61144 return registeredFrame [ 0 ]
62145 }
63146
64147 const frameId = generateUUIDv4 ( )
65- this . #registeredFrames[ frameId ] = frame
148+ this . #registeredFrames[ frameId ] = { ref : frame }
66149 return frameId
67150 }
68151
69152 removeFrame ( frameId : FrameId ) {
153+ ( this . #registeredMiddlewares[ frameId ] ?? [ ] ) . forEach ( mw => {
154+ delete this . #middlewaresIdsMap[ mw . id ]
155+ } )
70156 delete this . #registeredFrames[ frameId ]
71157 delete this . #registeredHandlers[ frameId ]
158+ delete this . #registeredMiddlewares[ frameId ]
72159 }
73160
161+ // ---- HANDLER REGISTRATION METHODS ----
162+
74163 addHandler < Params extends EventParams , Result extends HandlerResult > (
75164 eventType : EventType ,
76165 eventName : EventName ,
@@ -89,78 +178,113 @@ export class PostMessageController extends EventTarget {
89178 eventHandlers [ eventName ] = { validator, handler } as HandlerMethods < EventParams , HandlerResult >
90179 }
91180
181+ addMiddleware < Params extends EventParams , Result extends HandlerResult > ( mw : Middleware < Params , Result > ) : MiddlewareId {
182+ const { scope } = mw
183+ if ( ! this . #registeredMiddlewares[ scope ] ) {
184+ this . #registeredMiddlewares[ scope ] = [ ]
185+ }
186+ const id = generateUUIDv4 ( )
187+ this . #registeredMiddlewares[ scope ] . push ( {
188+ ...mw ,
189+ fn : mw . fn as MiddlewareFn < EventParams , HandlerResult > ,
190+ id,
191+ } )
192+ this . #middlewaresIdsMap[ id ] = scope
193+
194+ return id
195+ }
196+
197+ removeMiddleware ( id : MiddlewareId ) : void {
198+ const scope = this . #middlewaresIdsMap[ id ]
199+ if ( scope && this . #registeredMiddlewares[ scope ] ) {
200+ this . #registeredMiddlewares[ scope ] = this . #registeredMiddlewares[ scope ] . filter ( mw => mw . id !== id )
201+ }
202+ delete this . #middlewaresIdsMap[ id ]
203+ }
204+
205+ // ---- EVENT LISTENERS ----
206+
92207 async eventListener ( event : MessageEvent ) {
93208 if ( typeof window === 'undefined' ) return
94- if ( ! event . isTrusted || ! event . source || ! ( 'self' in event . source ) ) return
209+ if ( ! event . isTrusted || ! event . source || ( ! ( 'self' in event . source ) && ! isServiceWorker ( event . source ) ) ) return
95210
96211 const { success : isValidMessage , data : message } = MESSAGE_SCHEMA . safeParse ( event . data )
97212 if ( ! isValidMessage ) return
98213
99214 const { handler : eventName , params : { requestId, ...handlerParams } , type : eventType } = message
100215
101- const sourceWindow = event . source
102-
103- let frame : FrameType | undefined = undefined
104- let frameId = 'parent'
105-
106- if ( sourceWindow !== window ) {
107- const registeredFrame = Object . entries ( this . #registeredFrames)
108- . find ( ( [ , ref ] ) => ref . contentWindow === sourceWindow )
109-
110- if ( ! registeredFrame ) {
111- return sourceWindow . postMessage (
112- getClientErrorMessage ( 'ACCESS_DENIED' , 0 , 'Message was received from unregistered origin / iframe' , requestId , eventName ) ,
113- event . origin ,
114- )
115- }
216+ const messageSource = this . getMessageSource ( event . source )
116217
117- frameId = registeredFrame [ 0 ]
118- frame = registeredFrame [ 1 ]
218+ if ( ! messageSource ) {
219+ return sendResponseMessage ( {
220+ data : getClientErrorMessage ( 'ACCESS_DENIED' , 0 , 'Message was received from unregistered origin / iframe' , requestId , eventName ) ,
221+ target : event . source ,
222+ origin : event . origin ,
223+ } )
119224 }
120225
121226 const handlerMethods = (
122- this . #registeredHandlers[ frameId ] ?. [ eventType ] ?. [ eventName ]
227+ this . #registeredHandlers[ messageSource . id ] ?. [ eventType ] ?. [ eventName ]
123228 ?? this . #registeredHandlers[ '*' ] ?. [ eventType ] ?. [ eventName ]
124229 ?? { }
125230 )
231+
126232 const { handler, validator } = handlerMethods
127233 if ( ! handler || ! validator ) {
128- return sourceWindow . postMessage (
129- getClientErrorMessage ( 'UNKNOWN_METHOD' , 2 , 'Unknown method was provided. Make sure your runtime environment supports it.' , requestId ) ,
130- event . origin ,
131- )
234+ return sendResponseMessage ( {
235+ data : getClientErrorMessage ( 'UNKNOWN_METHOD' , 2 , 'Unknown method was provided. Make sure your runtime environment supports it.' , requestId ) ,
236+ origin : event . origin ,
237+ target : event . source ,
238+ } )
132239 }
133240
134241 const validationResult = validator ( handlerParams )
135242 if ( ! validationResult . success ) {
136- return sourceWindow . postMessage (
137- getClientErrorMessage ( 'INVALID_PARAMETERS' , 3 , validationResult . error , requestId , eventName ) ,
138- event . origin ,
139- )
243+ return sendResponseMessage ( {
244+ data : getClientErrorMessage ( 'INVALID_PARAMETERS' , 3 , validationResult . error , requestId , eventName ) ,
245+ origin : event . origin ,
246+ target : event . source ,
247+ } )
140248 }
141249
142250 const validatedParams = validationResult . data
143251 this . #storage[ eventType ] ??= new Map ( )
144- const storage = this . #storage[ eventType ]
252+ const eventsStorage = this . #storage[ eventType ]
253+ const wrappedHandler = this . getWrappedHandler ( eventType , eventName , messageSource . id , handler )
145254
146255 try {
147- const result = await handler ( validatedParams , storage , frame )
148- return sourceWindow . postMessage ( {
149- type : `${ eventName } Result` ,
256+ const result = await wrappedHandler ( {
257+ eventType,
258+ eventName,
259+ params : validatedParams ,
260+ storage : { events : eventsStorage } ,
261+ source : messageSource ,
262+ } )
263+
264+ return sendResponseMessage ( {
150265 data : {
151- ...result ,
152- requestId,
266+ type : `${ eventName } Result` ,
267+ data : {
268+ ...result ,
269+ requestId,
270+ } ,
153271 } ,
154- } , event . origin )
272+ target : event . source ,
273+ origin : event . origin ,
274+ } )
155275 } catch ( err ) {
156276 const errorMessage = err instanceof Error ? err . message : String ( err )
157- return sourceWindow . postMessage (
158- getClientErrorMessage ( 'HANDLER_ERROR' , 4 , errorMessage , requestId , eventName ) ,
159- event . origin
160- )
277+
278+ return sendResponseMessage ( {
279+ data : getClientErrorMessage ( 'HANDLER_ERROR' , 4 , errorMessage , requestId , eventName ) ,
280+ target : event . source ,
281+ origin : event . origin ,
282+ } )
161283 }
162284 }
163285
286+ // ---- COMMON HANDLERS METHODS ----
287+
164288 registerBridgeEvents ( options : Omit < RegisterBridgeEventsOptions , 'addHandler' > ) {
165289 registerBridgeEvents ( {
166290 ...options ,
0 commit comments