Skip to content

Commit c20cc2f

Browse files
authored
fix(middleware): respect host protocol (#611)
1 parent 811bdc0 commit c20cc2f

File tree

3 files changed

+89
-26
lines changed

3 files changed

+89
-26
lines changed

packages/open-next/src/core/routing/middleware.ts

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,18 @@ export async function handleMiddleware(
5454
// We bypass the middleware if the request is internal
5555
if (internalEvent.headers["x-isr"]) return internalEvent;
5656

57+
// Retrieve the protocol:
58+
// - In lambda, the url only contains the rawPath and the query - default to https
59+
// - In cloudflare, the protocol is usually http in dev and https in production
60+
const protocol = internalEvent.url.startsWith("http://") ? "http:" : "https:";
61+
5762
const host = internalEvent.headers.host
58-
? `https://${internalEvent.headers.host}`
63+
? `${protocol}//${internalEvent.headers.host}`
5964
: "http://localhost:3000";
65+
6066
const initialUrl = new URL(normalizedPath, host);
6167
initialUrl.search = convertToQueryString(query);
6268
const url = initialUrl.toString();
63-
// console.log("url", url, normalizedPath);
6469

6570
const middleware = await middlewareLoader();
6671

@@ -125,7 +130,7 @@ export async function handleMiddleware(
125130
.get("location")
126131
?.replace(
127132
"http://localhost:3000",
128-
`https://${internalEvent.headers.host}`,
133+
`${protocol}//${internalEvent.headers.host}`,
129134
) ?? resHeaders.location;
130135
// res.setHeader("Location", location);
131136
return {

packages/tests-unit/tests/core/routing/middleware.test.ts

Lines changed: 76 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import { handleMiddleware } from "@opennextjs/aws/core/routing/middleware.js";
2-
import { convertFromQueryString } from "@opennextjs/aws/core/routing/util.js";
2+
import {
3+
convertFromQueryString,
4+
isExternal,
5+
} from "@opennextjs/aws/core/routing/util.js";
36
import type { InternalEvent } from "@opennextjs/aws/types/open-next.js";
47
import { toReadableStream } from "@opennextjs/aws/utils/stream.js";
58
import { vi } from "vitest";
@@ -48,15 +51,25 @@ type PartialEvent = Partial<
4851
> & { body?: string };
4952

5053
function createEvent(event: PartialEvent): InternalEvent {
51-
const [rawPath, qs] = (event.url ?? "/").split("?", 2);
54+
let rawPath: string;
55+
let qs: string;
56+
if (isExternal(event.url)) {
57+
const url = new URL(event.url!);
58+
rawPath = url.pathname;
59+
qs = url.search;
60+
} else {
61+
const parts = (event.url ?? "/").split("?", 2);
62+
rawPath = parts[0];
63+
qs = parts[1] ?? "";
64+
}
5265
return {
5366
type: "core",
5467
method: event.method ?? "GET",
5568
rawPath,
5669
url: event.url ?? "/",
5770
body: Buffer.from(event.body ?? ""),
5871
headers: event.headers ?? {},
59-
query: convertFromQueryString(qs ?? ""),
72+
query: convertFromQueryString(qs),
6073
cookies: event.cookies ?? {},
6174
remoteAddress: event.remoteAddress ?? "::1",
6275
};
@@ -70,19 +83,19 @@ beforeEach(() => {
7083
* Ideally these tests would be broken up and tests smaller parts of the middleware rather than the entire function.
7184
*/
7285
describe("handleMiddleware", () => {
73-
it("should bypass middlware for internal requests", async () => {
86+
it("should bypass middleware for internal requests", async () => {
7487
const event = createEvent({
7588
headers: {
7689
"x-isr": "1",
7790
},
7891
});
7992
const result = await handleMiddleware(event, middlewareLoader);
8093

81-
expect(middlewareLoader).not.toBeCalled();
94+
expect(middlewareLoader).not.toHaveBeenCalled();
8295
expect(result).toEqual(event);
8396
});
8497

85-
it("should invoke middlware with redirect", async () => {
98+
it("should invoke middleware with redirect", async () => {
8699
const event = createEvent({});
87100
middleware.mockResolvedValue({
88101
status: 302,
@@ -92,12 +105,12 @@ describe("handleMiddleware", () => {
92105
});
93106
const result = await handleMiddleware(event, middlewareLoader);
94107

95-
expect(middlewareLoader).toBeCalled();
108+
expect(middlewareLoader).toHaveBeenCalled();
96109
expect(result.statusCode).toEqual(302);
97110
expect(result.headers.location).toEqual("/redirect");
98111
});
99112

100-
it("should invoke middlware with external redirect", async () => {
113+
it("should invoke middleware with external redirect", async () => {
101114
const event = createEvent({});
102115
middleware.mockResolvedValue({
103116
status: 302,
@@ -107,12 +120,12 @@ describe("handleMiddleware", () => {
107120
});
108121
const result = await handleMiddleware(event, middlewareLoader);
109122

110-
expect(middlewareLoader).toBeCalled();
123+
expect(middlewareLoader).toHaveBeenCalled();
111124
expect(result.statusCode).toEqual(302);
112125
expect(result.headers.location).toEqual("http://external/redirect");
113126
});
114127

115-
it("should invoke middlware with rewrite", async () => {
128+
it("should invoke middleware with rewrite", async () => {
116129
const event = createEvent({
117130
headers: {
118131
host: "localhost",
@@ -125,7 +138,7 @@ describe("handleMiddleware", () => {
125138
});
126139
const result = await handleMiddleware(event, middlewareLoader);
127140

128-
expect(middlewareLoader).toBeCalled();
141+
expect(middlewareLoader).toHaveBeenCalled();
129142
expect(result).toEqual({
130143
...event,
131144
rawPath: "/rewrite",
@@ -137,7 +150,7 @@ describe("handleMiddleware", () => {
137150
});
138151
});
139152

140-
it("should invoke middlware with rewrite with __nextDataReq", async () => {
153+
it("should invoke middleware with rewrite with __nextDataReq", async () => {
141154
const event = createEvent({
142155
url: "/rewrite?__nextDataReq=1&key=value",
143156
headers: {
@@ -151,7 +164,7 @@ describe("handleMiddleware", () => {
151164
});
152165
const result = await handleMiddleware(event, middlewareLoader);
153166

154-
expect(middlewareLoader).toBeCalled();
167+
expect(middlewareLoader).toHaveBeenCalled();
155168
expect(result).toEqual({
156169
...event,
157170
rawPath: "/rewrite",
@@ -167,7 +180,7 @@ describe("handleMiddleware", () => {
167180
});
168181
});
169182

170-
it("should invoke middlware with external rewrite", async () => {
183+
it("should invoke middleware with external rewrite", async () => {
171184
const event = createEvent({
172185
headers: {
173186
host: "localhost",
@@ -180,7 +193,7 @@ describe("handleMiddleware", () => {
180193
});
181194
const result = await handleMiddleware(event, middlewareLoader);
182195

183-
expect(middlewareLoader).toBeCalled();
196+
expect(middlewareLoader).toHaveBeenCalled();
184197
expect(result).toEqual({
185198
...event,
186199
rawPath: "http://external/rewrite",
@@ -201,7 +214,7 @@ describe("handleMiddleware", () => {
201214
});
202215
const result = await handleMiddleware(event, middlewareLoader);
203216

204-
expect(middlewareLoader).toBeCalled();
217+
expect(middlewareLoader).toHaveBeenCalled();
205218
expect(result).toEqual({
206219
...event,
207220
headers: {
@@ -223,7 +236,7 @@ describe("handleMiddleware", () => {
223236
});
224237
const result = await handleMiddleware(event, middlewareLoader);
225238

226-
expect(middlewareLoader).toBeCalled();
239+
expect(middlewareLoader).toHaveBeenCalled();
227240
expect(result).toEqual({
228241
type: "core",
229242
statusCode: 200,
@@ -246,7 +259,7 @@ describe("handleMiddleware", () => {
246259
});
247260
const result = await handleMiddleware(event, middlewareLoader);
248261

249-
expect(middlewareLoader).toBeCalled();
262+
expect(middlewareLoader).toHaveBeenCalled();
250263
expect(result).toEqual({
251264
type: "core",
252265
statusCode: 200,
@@ -257,4 +270,49 @@ describe("handleMiddleware", () => {
257270
isBase64Encoded: false,
258271
});
259272
});
273+
274+
it("should use the http event protocol when specified", async () => {
275+
const event = createEvent({
276+
url: "http://test.me/path",
277+
headers: {
278+
host: "test.me",
279+
},
280+
});
281+
await handleMiddleware(event, middlewareLoader);
282+
expect(middleware).toHaveBeenCalledWith(
283+
expect.objectContaining({
284+
url: "http://test.me/path",
285+
}),
286+
);
287+
});
288+
289+
it("should use the https event protocol when specified", async () => {
290+
const event = createEvent({
291+
url: "https://test.me/path",
292+
headers: {
293+
host: "test.me/path",
294+
},
295+
});
296+
await handleMiddleware(event, middlewareLoader);
297+
expect(middleware).toHaveBeenCalledWith(
298+
expect.objectContaining({
299+
url: "https://test.me/path",
300+
}),
301+
);
302+
});
303+
304+
it("should default to https protocol", async () => {
305+
const event = createEvent({
306+
url: "/path",
307+
headers: {
308+
host: "test.me",
309+
},
310+
});
311+
await handleMiddleware(event, middlewareLoader);
312+
expect(middleware).toHaveBeenCalledWith(
313+
expect.objectContaining({
314+
url: "https://test.me/path",
315+
}),
316+
);
317+
});
260318
});

packages/tests-unit/tests/core/routing/util.test.ts

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -146,11 +146,11 @@ describe("getUrlParts", () => {
146146

147147
describe("external", () => {
148148
it("throws for empty url", () => {
149-
expect(() => getUrlParts("", true)).toThrowError();
149+
expect(() => getUrlParts("", true)).toThrow();
150150
});
151151

152152
it("throws for invalid url", () => {
153-
expect(() => getUrlParts("/relative", true)).toThrowError();
153+
expect(() => getUrlParts("/relative", true)).toThrow();
154154
});
155155

156156
it("returns url parts for /", () => {
@@ -581,7 +581,7 @@ describe("revalidateIfRequired", () => {
581581
const headers: Record<string, string> = {};
582582
await revalidateIfRequired("localhost", "/path", headers);
583583

584-
expect(sendMock).not.toBeCalled();
584+
expect(sendMock).not.toHaveBeenCalled();
585585
});
586586

587587
it("should send to queue when x-nextjs-cache is STALE", async () => {
@@ -590,7 +590,7 @@ describe("revalidateIfRequired", () => {
590590
};
591591
await revalidateIfRequired("localhost", "/path", headers);
592592

593-
expect(sendMock).toBeCalledWith({
593+
expect(sendMock).toHaveBeenCalledWith({
594594
MessageBody: { host: "localhost", url: "/path" },
595595
MessageDeduplicationId: expect.any(String),
596596
MessageGroupId: expect.any(String),
@@ -604,7 +604,7 @@ describe("revalidateIfRequired", () => {
604604
sendMock.mockRejectedValueOnce(new Error("Failed to send"));
605605
await revalidateIfRequired("localhost", "/path", headers);
606606

607-
expect(sendMock).toBeCalledWith({
607+
expect(sendMock).toHaveBeenCalledWith({
608608
MessageBody: { host: "localhost", url: "/path" },
609609
MessageDeduplicationId: expect.any(String),
610610
MessageGroupId: expect.any(String),

0 commit comments

Comments
 (0)