Skip to content

Commit de53b2b

Browse files
authored
Handle PubSub commands routing (#176)
--------- Signed-off-by: shohame <[email protected]>
1 parent b8c921d commit de53b2b

File tree

3 files changed

+125
-12
lines changed

3 files changed

+125
-12
lines changed

redis/src/cluster.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -661,6 +661,13 @@ where
661661
_ => crate::cluster_routing::combine_array_results(results),
662662
}
663663
}
664+
Some(ResponsePolicy::CombineMaps) => {
665+
let results = results
666+
.into_iter()
667+
.map(|res| res.map(|(_, val)| val))
668+
.collect::<RedisResult<Vec<_>>>()?;
669+
crate::cluster_routing::combine_map_results(results)
670+
}
664671
Some(ResponsePolicy::Special) | None => {
665672
// This is our assumption - if there's no coherent way to aggregate the responses, we just map each response to the sender, and pass it to the user.
666673
// TODO - once Value::Error is merged, we can use join_all and report separate errors and also pass successes.

redis/src/cluster_async/mod.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1290,6 +1290,11 @@ where
12901290
_ => crate::cluster_routing::combine_array_results(results),
12911291
})
12921292
}
1293+
Some(ResponsePolicy::CombineMaps) => {
1294+
future::try_join_all(receivers.into_iter().map(get_receiver))
1295+
.await
1296+
.and_then(crate::cluster_routing::combine_map_results)
1297+
}
12931298
Some(ResponsePolicy::Special) | None => {
12941299
// This is our assumption - if there's no coherent way to aggregate the responses, we just map each response to the sender, and pass it to the user.
12951300
// TODO - once Value::Error is merged, we can use join_all and report separate errors and also pass successes.

redis/src/cluster_routing.rs

Lines changed: 113 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ pub enum ResponsePolicy {
4848
CombineArrays,
4949
/// Handling is not defined by the Redis standard. Will receive a special case
5050
Special,
51+
/// Combines multiple map responses into a single map.
52+
CombineMaps,
5153
}
5254

5355
/// Defines whether a request should be routed to a single node, or multiple ones.
@@ -187,8 +189,42 @@ pub fn logical_aggregate(values: Vec<Value>, op: LogicalAggregateOp) -> RedisRes
187189
.collect(),
188190
))
189191
}
192+
/// Aggregate array responses into a single map.
193+
pub fn combine_map_results(values: Vec<Value>) -> RedisResult<Value> {
194+
let mut map: HashMap<Vec<u8>, i64> = HashMap::new();
190195

191-
/// Aggreagte arrau responses into a single array.
196+
for value in values {
197+
match value {
198+
Value::Array(elements) => {
199+
let mut iter = elements.into_iter();
200+
201+
while let Some(key) = iter.next() {
202+
if let Value::BulkString(key_bytes) = key {
203+
if let Some(Value::Int(value)) = iter.next() {
204+
*map.entry(key_bytes).or_insert(0) += value;
205+
} else {
206+
return Err((ErrorKind::TypeError, "expected integer value").into());
207+
}
208+
} else {
209+
return Err((ErrorKind::TypeError, "expected string key").into());
210+
}
211+
}
212+
}
213+
_ => {
214+
return Err((ErrorKind::TypeError, "expected array of values as response").into());
215+
}
216+
}
217+
}
218+
219+
let result_vec: Vec<(Value, Value)> = map
220+
.into_iter()
221+
.map(|(k, v)| (Value::BulkString(k), Value::Int(v)))
222+
.collect();
223+
224+
Ok(Value::Map(result_vec))
225+
}
226+
227+
/// Aggregate array responses into a single array.
192228
pub fn combine_array_results(values: Vec<Value>) -> RedisResult<Value> {
193229
let mut results = Vec::new();
194230

@@ -302,7 +338,9 @@ impl ResponsePolicy {
302338
b"SCRIPT EXISTS" => Some(ResponsePolicy::AggregateLogical(LogicalAggregateOp::And)),
303339

304340
b"DBSIZE" | b"DEL" | b"EXISTS" | b"SLOWLOG LEN" | b"TOUCH" | b"UNLINK"
305-
| b"LATENCY RESET" => Some(ResponsePolicy::Aggregate(AggregateOp::Sum)),
341+
| b"LATENCY RESET" | b"PUBSUB NUMPAT" => {
342+
Some(ResponsePolicy::Aggregate(AggregateOp::Sum))
343+
}
306344

307345
b"WAIT" => Some(ResponsePolicy::Aggregate(AggregateOp::Min)),
308346

@@ -314,7 +352,10 @@ impl ResponsePolicy {
314352
Some(ResponsePolicy::AllSucceeded)
315353
}
316354

317-
b"KEYS" | b"MGET" | b"SLOWLOG GET" => Some(ResponsePolicy::CombineArrays),
355+
b"KEYS" | b"MGET" | b"SLOWLOG GET" | b"PUBSUB CHANNELS" | b"PUBSUB SHARDCHANNELS" => {
356+
Some(ResponsePolicy::CombineArrays)
357+
}
358+
b"PUBSUB NUMSUB" | b"PUBSUB SHARDNUMSUB" => Some(ResponsePolicy::CombineMaps),
318359

319360
b"FUNCTION KILL" | b"SCRIPT KILL" => Some(ResponsePolicy::OneSucceeded),
320361

@@ -354,11 +395,30 @@ enum RouteBy {
354395

355396
fn base_routing(cmd: &[u8]) -> RouteBy {
356397
match cmd {
357-
b"ACL SETUSER" | b"ACL DELUSER" | b"ACL SAVE" | b"CLIENT SETNAME" | b"CLIENT SETINFO"
358-
| b"SLOWLOG GET" | b"SLOWLOG LEN" | b"SLOWLOG RESET" | b"CONFIG SET"
359-
| b"CONFIG RESETSTAT" | b"CONFIG REWRITE" | b"SCRIPT FLUSH" | b"SCRIPT LOAD"
360-
| b"LATENCY RESET" | b"LATENCY GRAPH" | b"LATENCY HISTOGRAM" | b"LATENCY HISTORY"
361-
| b"LATENCY DOCTOR" | b"LATENCY LATEST" => RouteBy::AllNodes,
398+
b"ACL SETUSER"
399+
| b"ACL DELUSER"
400+
| b"ACL SAVE"
401+
| b"CLIENT SETNAME"
402+
| b"CLIENT SETINFO"
403+
| b"SLOWLOG GET"
404+
| b"SLOWLOG LEN"
405+
| b"SLOWLOG RESET"
406+
| b"CONFIG SET"
407+
| b"CONFIG RESETSTAT"
408+
| b"CONFIG REWRITE"
409+
| b"SCRIPT FLUSH"
410+
| b"SCRIPT LOAD"
411+
| b"LATENCY RESET"
412+
| b"LATENCY GRAPH"
413+
| b"LATENCY HISTOGRAM"
414+
| b"LATENCY HISTORY"
415+
| b"LATENCY DOCTOR"
416+
| b"LATENCY LATEST"
417+
| b"PUBSUB NUMPAT"
418+
| b"PUBSUB CHANNELS"
419+
| b"PUBSUB NUMSUB"
420+
| b"PUBSUB SHARDCHANNELS"
421+
| b"PUBSUB SHARDNUMSUB" => RouteBy::AllNodes,
362422

363423
b"DBSIZE"
364424
| b"FLUSHALL"
@@ -463,10 +523,6 @@ fn base_routing(cmd: &[u8]) -> RouteBy {
463523
| b"MODULE LOAD"
464524
| b"MODULE LOADEX"
465525
| b"MODULE UNLOAD"
466-
| b"PUBSUB CHANNELS"
467-
| b"PUBSUB NUMPAT"
468-
| b"PUBSUB NUMSUB"
469-
| b"PUBSUB SHARDCHANNELS"
470526
| b"READONLY"
471527
| b"READWRITE"
472528
| b"SAVE"
@@ -1233,4 +1289,49 @@ mod tests {
12331289
])
12341290
);
12351291
}
1292+
1293+
#[test]
1294+
fn test_combine_map_results() {
1295+
let input = vec![];
1296+
let result = super::combine_map_results(input).unwrap();
1297+
assert_eq!(result, Value::Map(vec![]));
1298+
1299+
let input = vec![
1300+
Value::Array(vec![
1301+
Value::BulkString(b"key1".to_vec()),
1302+
Value::Int(5),
1303+
Value::BulkString(b"key2".to_vec()),
1304+
Value::Int(10),
1305+
]),
1306+
Value::Array(vec![
1307+
Value::BulkString(b"key1".to_vec()),
1308+
Value::Int(3),
1309+
Value::BulkString(b"key3".to_vec()),
1310+
Value::Int(15),
1311+
]),
1312+
];
1313+
let result = super::combine_map_results(input).unwrap();
1314+
let mut expected = vec![
1315+
(Value::BulkString(b"key1".to_vec()), Value::Int(8)),
1316+
(Value::BulkString(b"key2".to_vec()), Value::Int(10)),
1317+
(Value::BulkString(b"key3".to_vec()), Value::Int(15)),
1318+
];
1319+
expected.sort_unstable_by(|a, b| match (&a.0, &b.0) {
1320+
(Value::BulkString(a_bytes), Value::BulkString(b_bytes)) => a_bytes.cmp(b_bytes),
1321+
_ => std::cmp::Ordering::Equal,
1322+
});
1323+
let mut result_vec = match result {
1324+
Value::Map(v) => v,
1325+
_ => panic!("Expected Map"),
1326+
};
1327+
result_vec.sort_unstable_by(|a, b| match (&a.0, &b.0) {
1328+
(Value::BulkString(a_bytes), Value::BulkString(b_bytes)) => a_bytes.cmp(b_bytes),
1329+
_ => std::cmp::Ordering::Equal,
1330+
});
1331+
assert_eq!(result_vec, expected);
1332+
1333+
let input = vec![Value::Int(5)];
1334+
let result = super::combine_map_results(input);
1335+
assert!(result.is_err());
1336+
}
12361337
}

0 commit comments

Comments
 (0)