Skip to content

Commit f1880d9

Browse files
authored
feat: add get method to store (#195)
1 parent 102bbc2 commit f1880d9

File tree

4 files changed

+171
-83
lines changed

4 files changed

+171
-83
lines changed

source/lib.ts

Lines changed: 101 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,43 @@
44
import type {
55
Store,
66
IncrementResponse,
7+
ClientRateLimitInfo,
78
Options as RateLimitConfiguration,
89
} from 'express-rate-limit'
9-
import { type Options, type SendCommandFn } from './types.js'
10+
import scripts from './scripts.js'
11+
import type { Options, SendCommandFn, RedisReply } from './types.js'
12+
13+
/**
14+
* Converts a string/number to a number.
15+
*
16+
* @param input {string | number | undefined} - The input to convert to a number.
17+
*
18+
* @return {number} - The parsed integer.
19+
* @throws {Error} - Thrown if the string does not contain a valid number.
20+
*/
21+
const toInt = (input: string | number | boolean | undefined): number => {
22+
if (typeof input === 'number') return input
23+
return Number.parseInt((input ?? '').toString(), 10)
24+
}
25+
26+
/**
27+
* Parses the response from the script.
28+
*
29+
* Note that the responses returned by the `get` and `increment` scripts are
30+
* the same, so this function can be used with both.
31+
*/
32+
const parseScriptResponse = (results: RedisReply): ClientRateLimitInfo => {
33+
if (!Array.isArray(results))
34+
throw new TypeError('Expected result to be array of values')
35+
if (results.length !== 2)
36+
throw new Error(`Expected 2 replies, got ${results.length}`)
37+
38+
const totalHits = toInt(results[0])
39+
const timeToExpire = toInt(results[1])
40+
41+
const resetTime = new Date(Date.now() + timeToExpire)
42+
return { totalHits, resetTime }
43+
}
1044

1145
/**
1246
* A `Store` for the `express-rate-limit` package that stores hit counts in
@@ -30,9 +64,11 @@ class RedisStore implements Store {
3064
resetExpiryOnChange: boolean
3165

3266
/**
33-
* Stores the loaded SHA1 of the LUA script for executing the increment operations.
67+
* Stores the loaded SHA1s of the LUA scripts used for executing the increment
68+
* and get key operations.
3469
*/
35-
loadedScriptSha1: Promise<string>
70+
incrementScriptSha: Promise<string>
71+
getScriptSha: Promise<string>
3672

3773
/**
3874
* The number of milliseconds to remember that user's requests.
@@ -51,32 +87,18 @@ class RedisStore implements Store {
5187

5288
// So that the script loading can occur non-blocking, this will send
5389
// the script to be loaded, and will capture the value within the
54-
// promise return. This way, if increments start being called before
90+
// promise return. This way, if increment/get start being called before
5591
// the script has finished loading, it will wait until it is loaded
5692
// before it continues.
57-
this.loadedScriptSha1 = this.loadScript()
93+
this.incrementScriptSha = this.loadIncrementScript()
94+
this.getScriptSha = this.loadGetScript()
5895
}
5996

60-
async loadScript(): Promise<string> {
61-
const result = await this.sendCommand(
62-
'SCRIPT',
63-
'LOAD',
64-
`
65-
local totalHits = redis.call("INCR", KEYS[1])
66-
local timeToExpire = redis.call("PTTL", KEYS[1])
67-
if timeToExpire <= 0 or ARGV[1] == "1"
68-
then
69-
redis.call("PEXPIRE", KEYS[1], tonumber(ARGV[2]))
70-
timeToExpire = tonumber(ARGV[2])
71-
end
72-
73-
return { totalHits, timeToExpire }
74-
`
75-
// Ensure that code changes that affect whitespace do not affect
76-
// the script contents.
77-
.replaceAll(/^\s+/gm, '')
78-
.trim(),
79-
)
97+
/**
98+
* Loads the script used to increment a client's hit count.
99+
*/
100+
async loadIncrementScript(): Promise<string> {
101+
const result = await this.sendCommand('SCRIPT', 'LOAD', scripts.increment)
80102

81103
if (typeof result !== 'string') {
82104
throw new TypeError('unexpected reply from redis client')
@@ -86,30 +108,26 @@ class RedisStore implements Store {
86108
}
87109

88110
/**
89-
* Method to prefix the keys with the given text.
90-
*
91-
* @param key {string} - The key.
92-
*
93-
* @returns {string} - The text + the key.
111+
* Loads the script used to fetch a client's hit count and expiry time.
94112
*/
95-
prefixKey(key: string): string {
96-
return `${this.prefix}${key}`
113+
async loadGetScript(): Promise<string> {
114+
const result = await this.sendCommand('SCRIPT', 'LOAD', scripts.get)
115+
116+
if (typeof result !== 'string') {
117+
throw new TypeError('unexpected reply from redis client')
118+
}
119+
120+
return result
97121
}
98122

99123
/**
100-
* Method that actually initializes the store.
101-
*
102-
* @param options {RateLimitConfiguration} - The options used to setup the middleware.
124+
* Runs the increment command, and retries it if the script is not loaded.
103125
*/
104-
init(options: RateLimitConfiguration) {
105-
this.windowMs = options.windowMs
106-
}
107-
108-
async runCommandWithRetry(key: string) {
126+
async retryableIncrement(key: string): Promise<RedisReply> {
109127
const evalCommand = async () =>
110128
this.sendCommand(
111129
'EVALSHA',
112-
await this.loadedScriptSha1,
130+
await this.incrementScriptSha,
113131
'1',
114132
this.prefixKey(key),
115133
this.resetExpiryOnChange ? '1' : '0',
@@ -121,44 +139,59 @@ class RedisStore implements Store {
121139
return result
122140
} catch {
123141
// TODO: distinguish different error types
124-
this.loadedScriptSha1 = this.loadScript()
142+
this.incrementScriptSha = this.loadIncrementScript()
125143
return evalCommand()
126144
}
127145
}
128146

129147
/**
130-
* Method to increment a client's hit counter.
148+
* Method to prefix the keys with the given text.
131149
*
132-
* @param key {string} - The identifier for a client
150+
* @param key {string} - The key.
133151
*
134-
* @returns {IncrementResponse} - The number of hits and reset time for that client
152+
* @returns {string} - The text + the key.
135153
*/
136-
async increment(key: string): Promise<IncrementResponse> {
137-
const results = await this.runCommandWithRetry(key)
138-
139-
if (!Array.isArray(results)) {
140-
throw new TypeError('Expected result to be array of values')
141-
}
154+
prefixKey(key: string): string {
155+
return `${this.prefix}${key}`
156+
}
142157

143-
if (results.length !== 2) {
144-
throw new Error(`Expected 2 replies, got ${results.length}`)
145-
}
158+
/**
159+
* Method that actually initializes the store.
160+
*
161+
* @param options {RateLimitConfiguration} - The options used to setup the middleware.
162+
*/
163+
init(options: RateLimitConfiguration) {
164+
this.windowMs = options.windowMs
165+
}
146166

147-
const totalHits = results[0]
148-
if (typeof totalHits !== 'number') {
149-
throw new TypeError('Expected value to be a number')
150-
}
167+
/**
168+
* Method to fetch a client's hit count and reset time.
169+
*
170+
* @param key {string} - The identifier for a client.
171+
*
172+
* @returns {ClientRateLimitInfo | undefined} - The number of hits and reset time for that client.
173+
*/
174+
async get(key: string): Promise<ClientRateLimitInfo | undefined> {
175+
const results = await this.sendCommand(
176+
'EVALSHA',
177+
await this.getScriptSha,
178+
'1',
179+
this.prefixKey(key),
180+
)
151181

152-
const timeToExpire = results[1]
153-
if (typeof timeToExpire !== 'number') {
154-
throw new TypeError('Expected value to be a number')
155-
}
182+
return parseScriptResponse(results)
183+
}
156184

157-
const resetTime = new Date(Date.now() + timeToExpire)
158-
return {
159-
totalHits,
160-
resetTime,
161-
}
185+
/**
186+
* Method to increment a client's hit counter.
187+
*
188+
* @param key {string} - The identifier for a client
189+
*
190+
* @returns {IncrementResponse} - The number of hits and reset time for that client
191+
*/
192+
async increment(key: string): Promise<IncrementResponse> {
193+
const results = await this.retryableIncrement(key)
194+
return parseScriptResponse(results)
162195
}
163196

164197
/**
@@ -180,4 +213,5 @@ class RedisStore implements Store {
180213
}
181214
}
182215

216+
// Export it to the world!
183217
export default RedisStore

source/scripts.ts

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
// /source/scripts.ts
2+
// The lua scripts for the increment and get operations.
3+
4+
/**
5+
* The lua scripts, used to make consecutive queries on the same key and avoid
6+
* race conditions by doing all the work on the redis server.
7+
*/
8+
const scripts = {
9+
increment: `
10+
local totalHits = redis.call("INCR", KEYS[1])
11+
local timeToExpire = redis.call("PTTL", KEYS[1])
12+
if timeToExpire <= 0 or ARGV[1] == "1"
13+
then
14+
redis.call("PEXPIRE", KEYS[1], tonumber(ARGV[2]))
15+
timeToExpire = tonumber(ARGV[2])
16+
end
17+
18+
return { totalHits, timeToExpire }
19+
`
20+
// Ensure that code changes that affect whitespace do not affect
21+
// the script contents.
22+
.replaceAll(/^\s+/gm, '')
23+
.trim(),
24+
get: `
25+
local totalHits = redis.call("GET", KEYS[1])
26+
local timeToExpire = redis.call("PTTL", KEYS[1])
27+
28+
return { totalHits, timeToExpire }
29+
`
30+
.replaceAll(/^\s+/gm, '')
31+
.trim(),
32+
}
33+
34+
// Export them so we can use them in the `lib.ts` file.
35+
export default scripts

source/types.ts

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,14 @@
44
/**
55
* The type of data Redis might return to us.
66
*/
7-
export type RedisReply = number | string
7+
type Data = boolean | number | string
8+
export type RedisReply = Data | Data[]
89

910
/**
1011
* The library sends Redis raw commands, so all we need to know are the
1112
* 'raw-command-sending' functions for each redis client.
1213
*/
13-
export type SendCommandFn = (
14-
...args: string[]
15-
) => Promise<RedisReply | RedisReply[]>
14+
export type SendCommandFn = (...args: string[]) => Promise<RedisReply>
1615

1716
/**
1817
* The configuration options for the store.

test/store-test.ts

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,11 @@
22
// The tests for the store.
33

44
import { createHash } from 'node:crypto'
5-
import { jest } from '@jest/globals'
5+
import { expect, jest } from '@jest/globals'
66
import { type Options } from 'express-rate-limit'
77
import MockRedisClient from 'ioredis-mock'
88
import RedisStore, { type RedisReply } from '../source/index.js'
99

10-
// The SHA of the script to evaluate
11-
let scriptSha: string | undefined
1210
// The mock redis client to use.
1311
const client = new MockRedisClient()
1412

@@ -18,29 +16,33 @@ const client = new MockRedisClient()
1816
*
1917
* @param {string[]} ...args - The raw command to send.
2018
*
21-
* @return {RedisReply | RedisReply[]} The reply returned by Redis.
19+
* @return {RedisReply} The reply returned by Redis.
2220
*/
23-
const sendCommand = async (
24-
...args: string[]
25-
): Promise<RedisReply | RedisReply[]> => {
21+
const sendCommand = async (...args: string[]): Promise<RedisReply> => {
2622
// `SCRIPT LOAD`, called when the store is initialized. This loads the lua script
2723
// for incrementing a client's hit counter.
2824
if (args[0] === 'SCRIPT') {
2925
// `ioredis-mock` doesn't have a `SCRIPT LOAD` function, so we have to compute
3026
// the SHA manually and `EVAL` the script to get it saved.
3127
const shasum = createHash('sha1')
3228
shasum.update(args[2])
33-
scriptSha = shasum.digest('hex')
34-
await client.eval(args[2], 1, '__test', '0', '100')
29+
const sha = shasum.digest('hex')
30+
31+
const testArgs = args[2].includes('INCR')
32+
? ['__test_incr', '0', '10']
33+
: ['__test_get']
34+
await client.eval(args[2], 1, ...testArgs)
3535

3636
// Return the SHA to the store.
37-
return scriptSha
37+
return sha
3838
}
3939

4040
// `EVALSHA` executes the script that was loaded already with the given arguments
41-
if (args[0] === 'EVALSHA')
41+
if (args[0] === 'EVALSHA') {
4242
// @ts-expect-error Wrong types :/
43-
return client.evalsha(scriptSha!, ...args.slice(2)) as number[]
43+
return client.evalsha(...args.slice(1)) as number[]
44+
}
45+
4446
// `DECR` decrements the count for a client.
4547
if (args[0] === 'DECR') return client.decr(args[1])
4648
// `DEL` resets the count for a client by deleting the key.
@@ -128,6 +130,7 @@ describe('redis store test', () => {
128130
const key = 'test-store'
129131

130132
await store.increment(key) // => 1
133+
await store.increment(key) // => 2
131134
await store.resetKey(key) // => undefined
132135

133136
const { totalHits } = await store.increment(key) // => 1
@@ -139,6 +142,23 @@ describe('redis store test', () => {
139142
expect(Number(await client.pttl('rl:test-store'))).toEqual(10)
140143
})
141144

145+
it('fetches the count for a key in the store when `getKey` is called', async () => {
146+
const store = new RedisStore({ sendCommand })
147+
store.init({ windowMs: 10 } as Options)
148+
149+
const key = 'test-store'
150+
151+
await store.increment(key) // => 1
152+
await store.increment(key) // => 2
153+
const info = await store.get(key)
154+
155+
// Ensure the hit count is 1, and that `resetTime` is a date.
156+
expect(info).toMatchObject({
157+
totalHits: 2,
158+
resetTime: expect.any(Date),
159+
})
160+
})
161+
142162
it('resets expiry time on change if `resetExpiryOnChange` is set to `true`', async () => {
143163
const store = new RedisStore({ sendCommand, resetExpiryOnChange: true })
144164
store.init({ windowMs: 60 } as Options)

0 commit comments

Comments
 (0)