Skip to content

Commit 5d6030b

Browse files
authored
fix(server): return correct context from middleware next (#245)
Recent changes not restrict middleware out context match with in context, so we need reflect it in runtime. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Summary by CodeRabbit - **Tests** - Enhanced middleware test cases with additional assertions to verify that each middleware component produces the expected output and context. - **Refactor** - Improved internal context management during middleware execution for more consistent state propagation and simplified structure. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent ba53a01 commit 5d6030b

File tree

2 files changed

+27
-16
lines changed

2 files changed

+27
-16
lines changed

packages/server/src/procedure-client.test.ts

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -178,25 +178,21 @@ describe.each(procedureCases)('createProcedureClient - case %s', async (_, proce
178178
})
179179

180180
preMid2.mockImplementationOnce(({ next }) => {
181-
return next({
182-
context: {
183-
extra2: '__extra2__',
184-
},
185-
})
181+
return next()
186182
})
187183

188184
postMid1.mockImplementationOnce(({ next }) => {
189185
return next({
190186
context: {
191-
extra3: '__extra3__',
187+
extra2: '__extra2__',
192188
},
193189
})
194190
})
195191

196192
postMid2.mockImplementationOnce(({ next }) => {
197193
return next({
198194
context: {
199-
extra4: '__extra4__',
195+
extra3: '__extra3__',
200196
},
201197
})
202198
})
@@ -205,36 +201,38 @@ describe.each(procedureCases)('createProcedureClient - case %s', async (_, proce
205201

206202
expect(preMid1).toBeCalledTimes(1)
207203
expect(preMid1).toHaveBeenCalledWith(expect.objectContaining({ context: {} }), expect.any(Object), expect.any(Function))
204+
expect(await preMid1.mock.results[0]!.value).toEqual({ output: { val: 123 }, context: { extra1: '__extra1__' } })
208205

209206
expect(preMid2).toBeCalledTimes(1)
210207
expect(preMid2).toHaveBeenCalledWith(expect.objectContaining({
211208
context: { extra1: '__extra1__' },
212209
}), expect.any(Object), expect.any(Function))
210+
expect(await preMid2.mock.results[0]!.value).toEqual({ output: { val: 123 }, context: { } })
213211

214212
expect(postMid1).toBeCalledTimes(1)
215213
expect(postMid1).toHaveBeenCalledWith(expect.objectContaining({
216214
context: {
217215
extra1: '__extra1__',
218-
extra2: '__extra2__',
219216
},
220217
}), expect.any(Object), expect.any(Function))
218+
expect(await postMid1.mock.results[0]!.value).toEqual({ output: { val: '123' }, context: { extra2: '__extra2__' } })
221219

222220
expect(postMid2).toBeCalledTimes(1)
223221
expect(postMid2).toHaveBeenCalledWith(expect.objectContaining({
224222
context: {
225223
extra1: '__extra1__',
226224
extra2: '__extra2__',
227-
extra3: '__extra3__',
228225
},
229226
}), expect.any(Object), expect.any(Function))
227+
expect(await postMid2.mock.results[0]!.value).toEqual({ output: { val: '123' }, context: { extra3: '__extra3__' } })
230228

231229
expect(handler).toBeCalledTimes(1)
232230
expect(handler).toHaveBeenCalledWith(expect.objectContaining({ context: {
233231
extra1: '__extra1__',
234232
extra2: '__extra2__',
235233
extra3: '__extra3__',
236-
extra4: '__extra4__',
237234
} }))
235+
expect(await handler.mock.results[0]!.value).toEqual({ val: '123' })
238236
})
239237

240238
it('middleware can override context', async () => {
@@ -280,24 +278,29 @@ describe.each(procedureCases)('createProcedureClient - case %s', async (_, proce
280278
expect(preMid1).toHaveBeenCalledWith(expect.objectContaining({
281279
context: expect.objectContaining({ userId: '123' }),
282280
}), expect.any(Object), expect.any(Function))
281+
expect(await preMid1.mock.results[0]!.value).toEqual({ output: { val: 123 }, context: { userId: '__override1__' } })
283282

284283
expect(preMid2).toBeCalledTimes(1)
285284
expect(preMid2).toHaveBeenCalledWith(expect.objectContaining({
286285
context: expect.objectContaining({ userId: '__override1__' }),
287286
}), expect.any(Object), expect.any(Function))
287+
expect(await preMid2.mock.results[0]!.value).toEqual({ output: { val: 123 }, context: { userId: '__override2__' } })
288288

289289
expect(postMid1).toBeCalledTimes(1)
290290
expect(postMid1).toHaveBeenCalledWith(expect.objectContaining({
291291
context: expect.objectContaining({ userId: '__override2__' }),
292292
}), expect.any(Object), expect.any(Function))
293+
expect(await postMid1.mock.results[0]!.value).toEqual({ output: { val: '123' }, context: { userId: '__override3__' } })
293294

294295
expect(postMid2).toBeCalledTimes(1)
295296
expect(postMid2).toHaveBeenCalledWith(expect.objectContaining({
296297
context: expect.objectContaining({ userId: '__override3__' }),
297298
}), expect.any(Object), expect.any(Function))
299+
expect(await postMid2.mock.results[0]!.value).toEqual({ output: { val: '123' }, context: { userId: '__override4__' } })
298300

299301
expect(handler).toBeCalledTimes(1)
300302
expect(handler).toHaveBeenCalledWith(expect.objectContaining({ context: expect.objectContaining({ userId: '__override4__' }) }))
303+
expect(await handler.mock.results[0]!.value).toEqual({ val: '123' })
301304
})
302305

303306
const contextCases = [

packages/server/src/procedure-client.ts

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import type { Client, ClientContext } from '@orpc/client'
22
import type { Interceptor, MaybeOptionalOptions, Value } from '@orpc/shared'
3-
import type { Context } from './context'
43
import type { Lazyable } from './lazy'
54
import type { MiddlewareNextFn } from './middleware'
65
import type { AnyProcedure, Procedure, ProcedureHandlerOptions } from './procedure'
76
import { ORPCError } from '@orpc/client'
87
import { type AnySchema, type ErrorFromErrorMap, type ErrorMap, type InferSchemaInput, type InferSchemaOutput, type Meta, ValidationError } from '@orpc/contract'
98
import { intercept, toError, value } from '@orpc/shared'
9+
import { type Context, mergeCurrentContext } from './context'
1010
import { createORPCErrorConstructorMap, type ORPCErrorConstructorMap, validateORPCError } from './error'
1111
import { unlazy } from './lazy'
1212
import { middlewareOutputFn } from './middleware'
@@ -171,8 +171,10 @@ async function executeProcedureInternal(procedure: AnyProcedure, options: Proced
171171

172172
const next: MiddlewareNextFn<any> = async (...[nextOptions]) => {
173173
const index = currentIndex
174+
const midContext = nextOptions?.context ?? {} as any
175+
174176
currentIndex += 1
175-
currentContext = { ...currentContext, ...nextOptions?.context }
177+
currentContext = mergeCurrentContext(currentContext, midContext)
176178

177179
if (index === inputValidationIndex) {
178180
currentInput = await validateInput(procedure, currentInput)
@@ -181,20 +183,26 @@ async function executeProcedureInternal(procedure: AnyProcedure, options: Proced
181183
const mid = middlewares[index]
182184

183185
const result = mid
184-
? await mid({ ...options, context: currentContext, next }, currentInput, middlewareOutputFn)
185-
: { output: await procedure['~orpc'].handler({ ...options, context: currentContext, input: currentInput }), context: currentContext }
186+
? {
187+
context: midContext,
188+
output: (await mid({ ...options, context: currentContext, next }, currentInput, middlewareOutputFn)).output,
189+
}
190+
: {
191+
context: midContext,
192+
output: await procedure['~orpc'].handler({ ...options, context: currentContext, input: currentInput }),
193+
}
186194

187195
if (index === outputValidationIndex) {
188196
const validatedOutput = await validateOutput(procedure, result.output)
189197

190198
return {
191-
...result,
199+
context: result.context,
192200
output: validatedOutput,
193201
}
194202
}
195203

196204
return result
197205
}
198206

199-
return (await next({})).output
207+
return (await next()).output
200208
}

0 commit comments

Comments
 (0)