4
4
import type {
5
5
Store ,
6
6
IncrementResponse ,
7
+ ClientRateLimitInfo ,
7
8
Options as RateLimitConfiguration ,
8
9
} from 'express-rate-limit'
9
- import { type Options , type SendCommandFn } from './types.js'
10
+ import scripts from './scripts.js'
11
+ import type { Options , SendCommandFn , RedisReply } from './types.js'
12
+
13
+ /**
14
+ * Converts a string/number to a number.
15
+ *
16
+ * @param input {string | number | undefined} - The input to convert to a number.
17
+ *
18
+ * @return {number } - The parsed integer.
19
+ * @throws {Error } - Thrown if the string does not contain a valid number.
20
+ */
21
+ const toInt = ( input : string | number | boolean | undefined ) : number => {
22
+ if ( typeof input === 'number' ) return input
23
+ return Number . parseInt ( ( input ?? '' ) . toString ( ) , 10 )
24
+ }
25
+
26
+ /**
27
+ * Parses the response from the script.
28
+ *
29
+ * Note that the responses returned by the `get` and `increment` scripts are
30
+ * the same, so this function can be used with both.
31
+ */
32
+ const parseScriptResponse = ( results : RedisReply ) : ClientRateLimitInfo => {
33
+ if ( ! Array . isArray ( results ) )
34
+ throw new TypeError ( 'Expected result to be array of values' )
35
+ if ( results . length !== 2 )
36
+ throw new Error ( `Expected 2 replies, got ${ results . length } ` )
37
+
38
+ const totalHits = toInt ( results [ 0 ] )
39
+ const timeToExpire = toInt ( results [ 1 ] )
40
+
41
+ const resetTime = new Date ( Date . now ( ) + timeToExpire )
42
+ return { totalHits, resetTime }
43
+ }
10
44
11
45
/**
12
46
* A `Store` for the `express-rate-limit` package that stores hit counts in
@@ -30,9 +64,11 @@ class RedisStore implements Store {
30
64
resetExpiryOnChange : boolean
31
65
32
66
/**
33
- * Stores the loaded SHA1 of the LUA script for executing the increment operations.
67
+ * Stores the loaded SHA1s of the LUA scripts used for executing the increment
68
+ * and get key operations.
34
69
*/
35
- loadedScriptSha1 : Promise < string >
70
+ incrementScriptSha : Promise < string >
71
+ getScriptSha : Promise < string >
36
72
37
73
/**
38
74
* The number of milliseconds to remember that user's requests.
@@ -51,32 +87,18 @@ class RedisStore implements Store {
51
87
52
88
// So that the script loading can occur non-blocking, this will send
53
89
// the script to be loaded, and will capture the value within the
54
- // promise return. This way, if increments start being called before
90
+ // promise return. This way, if increment/get start being called before
55
91
// the script has finished loading, it will wait until it is loaded
56
92
// before it continues.
57
- this . loadedScriptSha1 = this . loadScript ( )
93
+ this . incrementScriptSha = this . loadIncrementScript ( )
94
+ this . getScriptSha = this . loadGetScript ( )
58
95
}
59
96
60
- async loadScript ( ) : Promise < string > {
61
- const result = await this . sendCommand (
62
- 'SCRIPT' ,
63
- 'LOAD' ,
64
- `
65
- local totalHits = redis.call("INCR", KEYS[1])
66
- local timeToExpire = redis.call("PTTL", KEYS[1])
67
- if timeToExpire <= 0 or ARGV[1] == "1"
68
- then
69
- redis.call("PEXPIRE", KEYS[1], tonumber(ARGV[2]))
70
- timeToExpire = tonumber(ARGV[2])
71
- end
72
-
73
- return { totalHits, timeToExpire }
74
- `
75
- // Ensure that code changes that affect whitespace do not affect
76
- // the script contents.
77
- . replaceAll ( / ^ \s + / gm, '' )
78
- . trim ( ) ,
79
- )
97
+ /**
98
+ * Loads the script used to increment a client's hit count.
99
+ */
100
+ async loadIncrementScript ( ) : Promise < string > {
101
+ const result = await this . sendCommand ( 'SCRIPT' , 'LOAD' , scripts . increment )
80
102
81
103
if ( typeof result !== 'string' ) {
82
104
throw new TypeError ( 'unexpected reply from redis client' )
@@ -86,30 +108,26 @@ class RedisStore implements Store {
86
108
}
87
109
88
110
/**
89
- * Method to prefix the keys with the given text.
90
- *
91
- * @param key {string} - The key.
92
- *
93
- * @returns {string } - The text + the key.
111
+ * Loads the script used to fetch a client's hit count and expiry time.
94
112
*/
95
- prefixKey ( key : string ) : string {
96
- return `${ this . prefix } ${ key } `
113
+ async loadGetScript ( ) : Promise < string > {
114
+ const result = await this . sendCommand ( 'SCRIPT' , 'LOAD' , scripts . get )
115
+
116
+ if ( typeof result !== 'string' ) {
117
+ throw new TypeError ( 'unexpected reply from redis client' )
118
+ }
119
+
120
+ return result
97
121
}
98
122
99
123
/**
100
- * Method that actually initializes the store.
101
- *
102
- * @param options {RateLimitConfiguration} - The options used to setup the middleware.
124
+ * Runs the increment command, and retries it if the script is not loaded.
103
125
*/
104
- init ( options : RateLimitConfiguration ) {
105
- this . windowMs = options . windowMs
106
- }
107
-
108
- async runCommandWithRetry ( key : string ) {
126
+ async retryableIncrement ( key : string ) : Promise < RedisReply > {
109
127
const evalCommand = async ( ) =>
110
128
this . sendCommand (
111
129
'EVALSHA' ,
112
- await this . loadedScriptSha1 ,
130
+ await this . incrementScriptSha ,
113
131
'1' ,
114
132
this . prefixKey ( key ) ,
115
133
this . resetExpiryOnChange ? '1' : '0' ,
@@ -121,44 +139,59 @@ class RedisStore implements Store {
121
139
return result
122
140
} catch {
123
141
// TODO: distinguish different error types
124
- this . loadedScriptSha1 = this . loadScript ( )
142
+ this . incrementScriptSha = this . loadIncrementScript ( )
125
143
return evalCommand ( )
126
144
}
127
145
}
128
146
129
147
/**
130
- * Method to increment a client's hit counter .
148
+ * Method to prefix the keys with the given text .
131
149
*
132
- * @param key {string} - The identifier for a client
150
+ * @param key {string} - The key.
133
151
*
134
- * @returns {IncrementResponse } - The number of hits and reset time for that client
152
+ * @returns {string } - The text + the key.
135
153
*/
136
- async increment ( key : string ) : Promise < IncrementResponse > {
137
- const results = await this . runCommandWithRetry ( key )
138
-
139
- if ( ! Array . isArray ( results ) ) {
140
- throw new TypeError ( 'Expected result to be array of values' )
141
- }
154
+ prefixKey ( key : string ) : string {
155
+ return `${ this . prefix } ${ key } `
156
+ }
142
157
143
- if ( results . length !== 2 ) {
144
- throw new Error ( `Expected 2 replies, got ${ results . length } ` )
145
- }
158
+ /**
159
+ * Method that actually initializes the store.
160
+ *
161
+ * @param options {RateLimitConfiguration} - The options used to setup the middleware.
162
+ */
163
+ init ( options : RateLimitConfiguration ) {
164
+ this . windowMs = options . windowMs
165
+ }
146
166
147
- const totalHits = results [ 0 ]
148
- if ( typeof totalHits !== 'number' ) {
149
- throw new TypeError ( 'Expected value to be a number' )
150
- }
167
+ /**
168
+ * Method to fetch a client's hit count and reset time.
169
+ *
170
+ * @param key {string} - The identifier for a client.
171
+ *
172
+ * @returns {ClientRateLimitInfo | undefined } - The number of hits and reset time for that client.
173
+ */
174
+ async get ( key : string ) : Promise < ClientRateLimitInfo | undefined > {
175
+ const results = await this . sendCommand (
176
+ 'EVALSHA' ,
177
+ await this . getScriptSha ,
178
+ '1' ,
179
+ this . prefixKey ( key ) ,
180
+ )
151
181
152
- const timeToExpire = results [ 1 ]
153
- if ( typeof timeToExpire !== 'number' ) {
154
- throw new TypeError ( 'Expected value to be a number' )
155
- }
182
+ return parseScriptResponse ( results )
183
+ }
156
184
157
- const resetTime = new Date ( Date . now ( ) + timeToExpire )
158
- return {
159
- totalHits,
160
- resetTime,
161
- }
185
+ /**
186
+ * Method to increment a client's hit counter.
187
+ *
188
+ * @param key {string} - The identifier for a client
189
+ *
190
+ * @returns {IncrementResponse } - The number of hits and reset time for that client
191
+ */
192
+ async increment ( key : string ) : Promise < IncrementResponse > {
193
+ const results = await this . retryableIncrement ( key )
194
+ return parseScriptResponse ( results )
162
195
}
163
196
164
197
/**
@@ -180,4 +213,5 @@ class RedisStore implements Store {
180
213
}
181
214
}
182
215
216
+ // Export it to the world!
183
217
export default RedisStore
0 commit comments