diff --git a/packages/open-next/src/core/routing/middleware.ts b/packages/open-next/src/core/routing/middleware.ts index 42387522c..8245b61b8 100644 --- a/packages/open-next/src/core/routing/middleware.ts +++ b/packages/open-next/src/core/routing/middleware.ts @@ -54,13 +54,18 @@ export async function handleMiddleware( // We bypass the middleware if the request is internal if (internalEvent.headers["x-isr"]) return internalEvent; + // Retrieve the protocol: + // - In lambda, the url only contains the rawPath and the query - default to https + // - In cloudflare, the protocol is usually http in dev and https in production + const protocol = internalEvent.url.startsWith("http://") ? "http:" : "https:"; + const host = internalEvent.headers.host - ? `https://${internalEvent.headers.host}` + ? `${protocol}//${internalEvent.headers.host}` : "http://localhost:3000"; + const initialUrl = new URL(normalizedPath, host); initialUrl.search = convertToQueryString(query); const url = initialUrl.toString(); - // console.log("url", url, normalizedPath); const middleware = await middlewareLoader(); @@ -125,7 +130,7 @@ export async function handleMiddleware( .get("location") ?.replace( "http://localhost:3000", - `https://${internalEvent.headers.host}`, + `${protocol}//${internalEvent.headers.host}`, ) ?? resHeaders.location; // res.setHeader("Location", location); return { diff --git a/packages/tests-unit/tests/core/routing/middleware.test.ts b/packages/tests-unit/tests/core/routing/middleware.test.ts index 998098d28..be074b541 100644 --- a/packages/tests-unit/tests/core/routing/middleware.test.ts +++ b/packages/tests-unit/tests/core/routing/middleware.test.ts @@ -1,5 +1,8 @@ import { handleMiddleware } from "@opennextjs/aws/core/routing/middleware.js"; -import { convertFromQueryString } from "@opennextjs/aws/core/routing/util.js"; +import { + convertFromQueryString, + isExternal, +} from "@opennextjs/aws/core/routing/util.js"; import type { InternalEvent } from "@opennextjs/aws/types/open-next.js"; import { toReadableStream } from "@opennextjs/aws/utils/stream.js"; import { vi } from "vitest"; @@ -48,7 +51,17 @@ type PartialEvent = Partial< > & { body?: string }; function createEvent(event: PartialEvent): InternalEvent { - const [rawPath, qs] = (event.url ?? "/").split("?", 2); + let rawPath: string; + let qs: string; + if (isExternal(event.url)) { + const url = new URL(event.url!); + rawPath = url.pathname; + qs = url.search; + } else { + const parts = (event.url ?? "/").split("?", 2); + rawPath = parts[0]; + qs = parts[1] ?? ""; + } return { type: "core", method: event.method ?? "GET", @@ -56,7 +69,7 @@ function createEvent(event: PartialEvent): InternalEvent { url: event.url ?? "/", body: Buffer.from(event.body ?? ""), headers: event.headers ?? {}, - query: convertFromQueryString(qs ?? ""), + query: convertFromQueryString(qs), cookies: event.cookies ?? {}, remoteAddress: event.remoteAddress ?? "::1", }; @@ -70,7 +83,7 @@ beforeEach(() => { * Ideally these tests would be broken up and tests smaller parts of the middleware rather than the entire function. */ describe("handleMiddleware", () => { - it("should bypass middlware for internal requests", async () => { + it("should bypass middleware for internal requests", async () => { const event = createEvent({ headers: { "x-isr": "1", @@ -78,11 +91,11 @@ describe("handleMiddleware", () => { }); const result = await handleMiddleware(event, middlewareLoader); - expect(middlewareLoader).not.toBeCalled(); + expect(middlewareLoader).not.toHaveBeenCalled(); expect(result).toEqual(event); }); - it("should invoke middlware with redirect", async () => { + it("should invoke middleware with redirect", async () => { const event = createEvent({}); middleware.mockResolvedValue({ status: 302, @@ -92,12 +105,12 @@ describe("handleMiddleware", () => { }); const result = await handleMiddleware(event, middlewareLoader); - expect(middlewareLoader).toBeCalled(); + expect(middlewareLoader).toHaveBeenCalled(); expect(result.statusCode).toEqual(302); expect(result.headers.location).toEqual("/redirect"); }); - it("should invoke middlware with external redirect", async () => { + it("should invoke middleware with external redirect", async () => { const event = createEvent({}); middleware.mockResolvedValue({ status: 302, @@ -107,12 +120,12 @@ describe("handleMiddleware", () => { }); const result = await handleMiddleware(event, middlewareLoader); - expect(middlewareLoader).toBeCalled(); + expect(middlewareLoader).toHaveBeenCalled(); expect(result.statusCode).toEqual(302); expect(result.headers.location).toEqual("http://external/redirect"); }); - it("should invoke middlware with rewrite", async () => { + it("should invoke middleware with rewrite", async () => { const event = createEvent({ headers: { host: "localhost", @@ -125,7 +138,7 @@ describe("handleMiddleware", () => { }); const result = await handleMiddleware(event, middlewareLoader); - expect(middlewareLoader).toBeCalled(); + expect(middlewareLoader).toHaveBeenCalled(); expect(result).toEqual({ ...event, rawPath: "/rewrite", @@ -137,7 +150,7 @@ describe("handleMiddleware", () => { }); }); - it("should invoke middlware with rewrite with __nextDataReq", async () => { + it("should invoke middleware with rewrite with __nextDataReq", async () => { const event = createEvent({ url: "/rewrite?__nextDataReq=1&key=value", headers: { @@ -151,7 +164,7 @@ describe("handleMiddleware", () => { }); const result = await handleMiddleware(event, middlewareLoader); - expect(middlewareLoader).toBeCalled(); + expect(middlewareLoader).toHaveBeenCalled(); expect(result).toEqual({ ...event, rawPath: "/rewrite", @@ -167,7 +180,7 @@ describe("handleMiddleware", () => { }); }); - it("should invoke middlware with external rewrite", async () => { + it("should invoke middleware with external rewrite", async () => { const event = createEvent({ headers: { host: "localhost", @@ -180,7 +193,7 @@ describe("handleMiddleware", () => { }); const result = await handleMiddleware(event, middlewareLoader); - expect(middlewareLoader).toBeCalled(); + expect(middlewareLoader).toHaveBeenCalled(); expect(result).toEqual({ ...event, rawPath: "http://external/rewrite", @@ -201,7 +214,7 @@ describe("handleMiddleware", () => { }); const result = await handleMiddleware(event, middlewareLoader); - expect(middlewareLoader).toBeCalled(); + expect(middlewareLoader).toHaveBeenCalled(); expect(result).toEqual({ ...event, headers: { @@ -223,7 +236,7 @@ describe("handleMiddleware", () => { }); const result = await handleMiddleware(event, middlewareLoader); - expect(middlewareLoader).toBeCalled(); + expect(middlewareLoader).toHaveBeenCalled(); expect(result).toEqual({ type: "core", statusCode: 200, @@ -246,7 +259,7 @@ describe("handleMiddleware", () => { }); const result = await handleMiddleware(event, middlewareLoader); - expect(middlewareLoader).toBeCalled(); + expect(middlewareLoader).toHaveBeenCalled(); expect(result).toEqual({ type: "core", statusCode: 200, @@ -257,4 +270,49 @@ describe("handleMiddleware", () => { isBase64Encoded: false, }); }); + + it("should use the http event protocol when specified", async () => { + const event = createEvent({ + url: "http://test.me/path", + headers: { + host: "test.me", + }, + }); + await handleMiddleware(event, middlewareLoader); + expect(middleware).toHaveBeenCalledWith( + expect.objectContaining({ + url: "http://test.me/path", + }), + ); + }); + + it("should use the https event protocol when specified", async () => { + const event = createEvent({ + url: "https://test.me/path", + headers: { + host: "test.me/path", + }, + }); + await handleMiddleware(event, middlewareLoader); + expect(middleware).toHaveBeenCalledWith( + expect.objectContaining({ + url: "https://test.me/path", + }), + ); + }); + + it("should default to https protocol", async () => { + const event = createEvent({ + url: "/path", + headers: { + host: "test.me", + }, + }); + await handleMiddleware(event, middlewareLoader); + expect(middleware).toHaveBeenCalledWith( + expect.objectContaining({ + url: "https://test.me/path", + }), + ); + }); }); diff --git a/packages/tests-unit/tests/core/routing/util.test.ts b/packages/tests-unit/tests/core/routing/util.test.ts index 3e9feb4c4..b0f6fbb56 100644 --- a/packages/tests-unit/tests/core/routing/util.test.ts +++ b/packages/tests-unit/tests/core/routing/util.test.ts @@ -146,11 +146,11 @@ describe("getUrlParts", () => { describe("external", () => { it("throws for empty url", () => { - expect(() => getUrlParts("", true)).toThrowError(); + expect(() => getUrlParts("", true)).toThrow(); }); it("throws for invalid url", () => { - expect(() => getUrlParts("/relative", true)).toThrowError(); + expect(() => getUrlParts("/relative", true)).toThrow(); }); it("returns url parts for /", () => { @@ -581,7 +581,7 @@ describe("revalidateIfRequired", () => { const headers: Record = {}; await revalidateIfRequired("localhost", "/path", headers); - expect(sendMock).not.toBeCalled(); + expect(sendMock).not.toHaveBeenCalled(); }); it("should send to queue when x-nextjs-cache is STALE", async () => { @@ -590,7 +590,7 @@ describe("revalidateIfRequired", () => { }; await revalidateIfRequired("localhost", "/path", headers); - expect(sendMock).toBeCalledWith({ + expect(sendMock).toHaveBeenCalledWith({ MessageBody: { host: "localhost", url: "/path" }, MessageDeduplicationId: expect.any(String), MessageGroupId: expect.any(String), @@ -604,7 +604,7 @@ describe("revalidateIfRequired", () => { sendMock.mockRejectedValueOnce(new Error("Failed to send")); await revalidateIfRequired("localhost", "/path", headers); - expect(sendMock).toBeCalledWith({ + expect(sendMock).toHaveBeenCalledWith({ MessageBody: { host: "localhost", url: "/path" }, MessageDeduplicationId: expect.any(String), MessageGroupId: expect.any(String),