diff --git a/.changeset/funny-eggs-divide.md b/.changeset/funny-eggs-divide.md new file mode 100644 index 000000000000..6401b4269a64 --- /dev/null +++ b/.changeset/funny-eggs-divide.md @@ -0,0 +1,5 @@ +--- +"miniflare": minor +--- + +Add Miniflare Workers KV bulk get support diff --git a/packages/miniflare/src/workers/kv/constants.ts b/packages/miniflare/src/workers/kv/constants.ts index c211def10763..8529b1ff8808 100644 --- a/packages/miniflare/src/workers/kv/constants.ts +++ b/packages/miniflare/src/workers/kv/constants.ts @@ -34,6 +34,7 @@ export const SiteBindings = { // This ensures edge caching of Workers Sites files is disabled, and the latest // local version is always served. export const SITES_NO_CACHE_PREFIX = "$__MINIFLARE_SITES__$/"; +export const MAX_BULK_GET_KEYS = 100; export function encodeSitesKey(key: string): string { // `encodeURIComponent()` ensures `ETag`s used by `@cloudflare/kv-asset-handler` diff --git a/packages/miniflare/src/workers/kv/namespace.worker.ts b/packages/miniflare/src/workers/kv/namespace.worker.ts index 1359cd37acf6..f29362a6b26f 100644 --- a/packages/miniflare/src/workers/kv/namespace.worker.ts +++ b/packages/miniflare/src/workers/kv/namespace.worker.ts @@ -4,13 +4,15 @@ import { DELETE, GET, HttpError, + KeyValueEntry, KeyValueStorage, maybeApply, MiniflareDurableObject, + POST, PUT, RouteHandler, } from "miniflare:shared"; -import { KVHeaders, KVLimits, KVParams } from "./constants"; +import { KVHeaders, KVLimits, KVParams, MAX_BULK_GET_KEYS } from "./constants"; import { decodeKey, decodeListOptions, @@ -73,6 +75,42 @@ function secondsToMillis(seconds: number): number { return seconds * 1000; } +async function processKeyValue( + obj: KeyValueEntry | null, + type: "text" | "json" = "text", + withMetadata = false +) { + const decoder = new TextDecoder(); + let decodedValue = ""; + if (obj?.value) { + for await (const chunk of obj?.value) { + decodedValue += decoder.decode(chunk, { stream: true }); + } + decodedValue += decoder.decode(); + } + + let val = null; + try { + val = !obj?.value + ? null + : type === "json" + ? JSON.parse(decodedValue) + : decodedValue; + } catch (err: any) { + throw new HttpError( + 400, + "At least one of the requested keys corresponds to a non-JSON value" + ); + } + if (val && withMetadata) { + return { + value: val, + metadata: obj?.metadata ?? null, + }; + } + return val; +} + export class KVNamespaceObject extends MiniflareDurableObject { #storage?: KeyValueStorage; get storage() { @@ -81,13 +119,46 @@ export class KVNamespaceObject extends MiniflareDurableObject { } @GET("/:key") + @POST("/bulk/get") get: RouteHandler = async (req, params, url) => { + if (req.method === "POST" && req.body != null) { + let decodedBody = ""; + const decoder = new TextDecoder(); + for await (const chunk of req.body) { + decodedBody += decoder.decode(chunk, { stream: true }); + } + decodedBody += decoder.decode(); + const parsedBody = JSON.parse(decodedBody); + const keys: string[] = parsedBody.keys; + const type = parsedBody?.type; + if (type && type !== "text" && type !== "json") { + return new Response(`Type ${type} is invalid`, { status: 400 }); + } + const obj: { [key: string]: any } = {}; + if (keys.length > MAX_BULK_GET_KEYS) { + return new Response(`Accepting a max of 100 keys, got ${keys.length}`, { + status: 400, + }); + } + for (const key of keys) { + validateGetOptions(key, { cacheTtl: parsedBody?.cacheTtl }); + const entry = await this.storage.get(key); + const value = await processKeyValue( + entry, + parsedBody?.type, + parsedBody?.withMetadata + ); + obj[key] = value; + } + + return new Response(JSON.stringify(obj)); + } + // Decode URL parameters const key = decodeKey(params, url.searchParams); const cacheTtlParam = url.searchParams.get(KVParams.CACHE_TTL); const cacheTtl = cacheTtlParam === null ? undefined : parseInt(cacheTtlParam); - // Get value from storage validateGetOptions(key, { cacheTtl }); const entry = await this.storage.get(key); @@ -114,7 +185,6 @@ export class KVNamespaceObject extends MiniflareDurableObject { const rawExpiration = url.searchParams.get(KVParams.EXPIRATION); const rawExpirationTtl = url.searchParams.get(KVParams.EXPIRATION_TTL); const rawMetadata = req.headers.get(KVHeaders.METADATA); - // Validate key, expiration and metadata const now = millisToSeconds(this.timers.now()); const { expiration, metadata } = validatePutOptions(key, { diff --git a/packages/miniflare/test/plugins/kv/index.spec.ts b/packages/miniflare/test/plugins/kv/index.spec.ts index 6c7efcbc1fea..3560192c5a09 100644 --- a/packages/miniflare/test/plugins/kv/index.spec.ts +++ b/packages/miniflare/test/plugins/kv/index.spec.ts @@ -6,6 +6,7 @@ import consumers from "stream/consumers"; import { Macro, ThrowsExpectation } from "ava"; import { KV_PLUGIN_NAME, + MAX_BULK_GET_KEYS, Miniflare, MiniflareOptions, ReplaceWorkersTypes, @@ -122,6 +123,107 @@ test("get: returns value", async (t) => { const result = await kv.get("key"); t.is(result, "value"); }); + +test("bulk get: returns value", async (t) => { + const { kv } = t.context; + await kv.put("key1", "value1"); + const result: any = await kv.get(["key1", "key2"]); + const expectedResult = new Map([ + ["key1", "value1"], + ["key2", null], + ]); + + t.deepEqual(result, expectedResult); +}); + +test("bulk get: check max keys", async (t) => { + const { kv } = t.context; + await kv.put("key1", "value1"); + const keyArray = []; + for (let i = 0; i <= MAX_BULK_GET_KEYS; i++) { + keyArray.push(`key${i}`); + } + try { + await kv.get(keyArray); + } catch (error: any) { + t.is(error.message, "KV GET_BULK failed: 400 Bad Request"); + } +}); + +test("bulk get: request json type", async (t) => { + const { kv } = t.context; + await kv.put("key1", '{"example": "ex"}'); + await kv.put("key2", "example"); + let result: any = await kv.get(["key1"]); + let expectedResult: any = new Map([["key1", '{"example": "ex"}']]); + expectedResult = new Map([["key1", '{"example": "ex"}']]); + t.deepEqual(result, expectedResult); + + result = await kv.get(["key1"], "json"); + expectedResult = new Map([["key1", { example: "ex" }]]); + t.deepEqual(result, expectedResult); + + try { + await kv.get(["key1", "key2"], "json"); + } catch (error: any) { + t.is( + error.message, + "KV GET_BULK failed: 400 At least one of the requested keys corresponds to a non-JSON value" + ); + } +}); + +test("bulk get: check metadata", async (t) => { + const { kv } = t.context; + await kv.put("key1", "value1", { + expiration: TIME_FUTURE, + metadata: { testing: true }, + }); + + await kv.put("key2", "value2"); + const result: any = await kv.getWithMetadata(["key1", "key2"]); + const expectedResult: any = new Map([ + ["key1", { value: "value1", metadata: { testing: true } }], + ["key2", { value: "value2", metadata: null }], + ]); + t.deepEqual(result, expectedResult); +}); + +test("bulk get: check metadata with int", async (t) => { + const { kv } = t.context; + await kv.put("key1", "value1", { + expiration: TIME_FUTURE, + metadata: 123, + }); + + const result: any = await kv.getWithMetadata(["key1"]); + const expectedResult: any = new Map([ + ["key1", { value: "value1", metadata: 123 }], + ]); + t.deepEqual(result, expectedResult); +}); + +test("bulk get: check metadata as string", async (t) => { + const { kv } = t.context; + await kv.put("key1", "value1", { + expiration: TIME_FUTURE, + metadata: "example", + }); + const result: any = await kv.getWithMetadata(["key1"]); + const expectedResult: any = new Map([ + ["key1", { value: "value1", metadata: "example" }], + ]); + t.deepEqual(result, expectedResult); +}); + +test("bulk get: get with metadata for 404", async (t) => { + const { kv } = t.context; + + const result: any = await kv.getWithMetadata(["key1"]); + const expectedResult: any = new Map([["key1", null]]); + t.deepEqual(result, expectedResult); +}); + test("get: returns null for non-existent keys", async (t) => { const { kv } = t.context; t.is(await kv.get("key"), null); diff --git a/packages/miniflare/test/test-shared/miniflare.ts b/packages/miniflare/test/test-shared/miniflare.ts index c88bf57e13d7..3c0c6ec330ca 100644 --- a/packages/miniflare/test/test-shared/miniflare.ts +++ b/packages/miniflare/test/test-shared/miniflare.ts @@ -38,7 +38,25 @@ export function namespace(ns: string, binding: T): Namespaced { return (keys: unknown, ...args: unknown[]) => { if (typeof keys === "string") keys = ns + keys; if (Array.isArray(keys)) keys = keys.map((key) => ns + key); - return (value as (...args: unknown[]) => unknown)(keys, ...args); + const result = (value as (...args: unknown[]) => unknown)( + keys, + ...args + ); + if (result instanceof Promise) { + return result.then((res) => { + // KV.get([a,b,c]) would be prefixed with ns, so we strip this prefix from response. + // Map keys => [{ns}{a}, {ns}{b}, {ns}{b}] -> [a,b,c] + if (res instanceof Map) { + const newResult = new Map(); + for (const [key, value] of res) { + newResult.set(key.slice(ns.length), value); + } + return newResult; + } + return res; + }); + } + return result; }; } return value; @@ -83,7 +101,7 @@ export function miniflareTest< status: 500, headers: { "MF-Experimental-Error-Stack": "true" }, }); - } + } } } `;