Skip to content

Commit 2aa0fe7

Browse files
committed
key commands
1 parent 7f5de93 commit 2aa0fe7

File tree

12 files changed

+419
-92
lines changed

12 files changed

+419
-92
lines changed

eng/StackExchange.Redis.Build/RespCommandGenerator.cs

Lines changed: 70 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,11 @@ private readonly record struct MethodTuple(
7070
string Context,
7171
string? Formatter,
7272
string? Parser,
73-
string DebugNotes);
73+
MethodFlags Flags,
74+
string DebugNotes)
75+
{
76+
public bool IsRespOperation => (Flags & MethodFlags.RespOperation) != 0;
77+
}
7478

7579
private static string GetFullName(ITypeSymbol type) =>
7680
type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
@@ -82,6 +86,7 @@ private enum RESPite
8286
RespKeyAttribute,
8387
RespPrefixAttribute,
8488
RespSuffixAttribute,
89+
RespOperation,
8590
}
8691

8792
private static bool IsRESPite(ITypeSymbol? symbol, RESPite type)
@@ -93,6 +98,7 @@ private static bool IsRESPite(ITypeSymbol? symbol, RESPite type)
9398
RESPite.RespKeyAttribute => nameof(RESPite.RespKeyAttribute),
9499
RESPite.RespPrefixAttribute => nameof(RESPite.RespPrefixAttribute),
95100
RESPite.RespSuffixAttribute => nameof(RESPite.RespSuffixAttribute),
101+
RESPite.RespOperation => nameof(RESPite.RespOperation),
96102
_ => type.ToString(),
97103
};
98104

@@ -187,6 +193,7 @@ private MethodTuple Transform(
187193
if (ctx.SemanticModel.GetDeclaredSymbol(ctx.Node) is not IMethodSymbol method) return default;
188194
if (!(method is { IsPartialDefinition: true, PartialImplementationPart: null })) return default;
189195

196+
MethodFlags methodFlags = 0;
190197
string returnType, debugNote = "";
191198
if (method.ReturnsVoid)
192199
{
@@ -199,7 +206,12 @@ private MethodTuple Transform(
199206
}
200207
else
201208
{
202-
returnType = GetFullName(method.ReturnType);
209+
ITypeSymbol? rt = method.ReturnType;
210+
if (IsRespOperation(ref rt))
211+
{
212+
methodFlags |= MethodFlags.RespOperation;
213+
}
214+
returnType = rt is null ? "" : GetFullName(rt);
203215
}
204216

205217
string ns = "", parentType = "";
@@ -380,8 +392,16 @@ static bool IsIndirectRespContext(ITypeSymbol type, out string memberName)
380392
if (IsSERedis(param.Type, SERedis.CommandFlags))
381393
{
382394
flags |= ParameterFlags.CommandFlags;
383-
// magic pattern; we *demand* a method called Context that takes the flags
384-
context = $"Context({param.Name})";
395+
// magic pattern; we *demand* a method called Context that takes the flags; if this is an extension
396+
// method, assume it is on the first parameter
397+
if ((methodFlags & MethodFlags.ExtensionMethod) != 0)
398+
{
399+
context = $"{method.Parameters[0].Name}.Context({param.Name})";
400+
}
401+
else
402+
{
403+
context = $"Context({param.Name})";
404+
}
385405
}
386406
else if (IsRESPite(param.Type, RESPite.RespContext))
387407
{
@@ -407,6 +427,7 @@ static bool IsIndirectRespContext(ITypeSymbol type, out string memberName)
407427

408428
if (param.Ordinal == 0 && method.IsExtensionMethod)
409429
{
430+
methodFlags |= MethodFlags.ExtensionMethod;
410431
modifiers = "this " + modifiers;
411432
}
412433

@@ -463,6 +484,7 @@ void AddLiteral(string token, LiteralFlags literalFlags)
463484
context ?? "",
464485
formatter,
465486
parser,
487+
methodFlags,
466488
debugNote);
467489

468490
static string TypeModifiers(ITypeSymbol type)
@@ -486,6 +508,24 @@ static string TypeModifiers(ITypeSymbol type)
486508
}
487509
}
488510

511+
private bool IsRespOperation(ref ITypeSymbol? type) // identify RespOperation[<T>]
512+
{
513+
if (type is INamedTypeSymbol named && IsRESPite(type, RESPite.RespOperation))
514+
{
515+
if (named.IsGenericType)
516+
{
517+
if (named.TypeArguments.Length != 1) return false; // unexpected
518+
type = named.TypeArguments[0];
519+
}
520+
else
521+
{
522+
type = null;
523+
}
524+
return true;
525+
}
526+
return false;
527+
}
528+
489529
private static ParameterFlags GetTypeFlags(ref ITypeSymbol paramType)
490530
{
491531
var flags = ParameterFlags.None;
@@ -710,7 +750,10 @@ private void Generate(
710750
var csValue = CodeLiteral(method.Command);
711751

712752
WriteMethod(false);
713-
WriteMethod(true);
753+
if ((method.Flags & MethodFlags.RespOperation) == 0)
754+
{
755+
WriteMethod(true); // also write async half
756+
}
714757

715758
void WriteMethod(bool asAsync)
716759
{
@@ -724,6 +767,14 @@ void WriteMethod(bool asAsync)
724767
sb.Append('<').Append(method.ReturnType).Append('>');
725768
}
726769
}
770+
else if (method.IsRespOperation)
771+
{
772+
sb.Append("global::RESPite.RespOperation");
773+
if (!string.IsNullOrWhiteSpace(method.ReturnType))
774+
{
775+
sb.Append('<').Append(method.ReturnType).Append('>');
776+
}
777+
}
727778
else
728779
{
729780
sb.Append(string.IsNullOrEmpty(method.ReturnType) ? "void" : method.ReturnType);
@@ -768,8 +819,7 @@ void WriteMethod(bool asAsync)
768819
sb.Append(", ").Append(formatter);
769820
}
770821
}
771-
772-
sb.Append(asAsync ? ").Send" : ").Wait");
822+
sb.Append(asAsync | method.IsRespOperation ? ").Send" : ").Wait");
773823
if (!string.IsNullOrWhiteSpace(method.ReturnType))
774824
{
775825
sb.Append('<').Append(method.ReturnType).Append('>');
@@ -806,6 +856,10 @@ void WriteMethod(bool asAsync)
806856
? ".AsTask()"
807857
: ".AsValueTask()");
808858
}
859+
else if (method.IsRespOperation)
860+
{
861+
// nothing to do
862+
}
809863
else
810864
{
811865
sb.Append(".Wait(");
@@ -1244,6 +1298,7 @@ private static int DataParameterCount(
12441298
"double" => RespFormattersPrefix + "Double",
12451299
"" => RespFormattersPrefix + "Empty",
12461300
"global::StackExchange.Redis.RedisKey" => "global::RESPite.StackExchange.Redis.RespFormatters.RedisKey",
1301+
"global::StackExchange.Redis.RedisKey[]" => "global::RESPite.StackExchange.Redis.RespFormatters.RedisKeyArray",
12471302
"global::StackExchange.Redis.RedisValue" => "global::RESPite.StackExchange.Redis.RespFormatters.RedisValue",
12481303
_ => null,
12491304
};
@@ -1289,6 +1344,14 @@ private static string RemovePartial(string modifiers)
12891344
return modifiers.Replace(" partial ", " ");
12901345
}
12911346

1347+
[Flags]
1348+
private enum MethodFlags
1349+
{
1350+
None = 0,
1351+
RespOperation = 1 << 0,
1352+
ExtensionMethod = 1 << 1,
1353+
}
1354+
12921355
[Flags]
12931356
private enum ParameterFlags
12941357
{
Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
using System.Runtime.CompilerServices;
2+
using System.Runtime.InteropServices.ComTypes;
3+
using RESPite.Messages;
4+
using StackExchange.Redis;
5+
6+
namespace RESPite.StackExchange.Redis;
7+
8+
internal static partial class RedisCommands
9+
{
10+
// this is just a "type pun" - it should be an invisible/magic pointer cast to the JIT
11+
public static ref readonly KeyCommands Keys(this in RespContext context)
12+
=> ref Unsafe.As<RespContext, KeyCommands>(ref Unsafe.AsRef(in context));
13+
}
14+
15+
public readonly struct KeyCommands(in RespContext context)
16+
{
17+
public readonly RespContext Context = context; // important: this is the only field
18+
}
19+
20+
internal static partial class KeyCommandsExtensions
21+
{
22+
[RespCommand]
23+
public static partial RespOperation<bool> Del(this in KeyCommands context, RedisKey key);
24+
25+
[RespCommand]
26+
public static partial RespOperation<long> Del(this in KeyCommands context, [RespKey] RedisKey[] keys);
27+
28+
[RespCommand]
29+
public static partial RespOperation<byte[]?> Dump(this in KeyCommands context, RedisKey key);
30+
31+
[RespCommand("object")]
32+
public static partial RespOperation<string?> ObjectEncoding(this in KeyCommands context, [RespPrefix("ENCODING")] RedisKey key);
33+
34+
[RespCommand("object", Parser = "RespParsers.TimeSpanFromSeconds")]
35+
public static partial RespOperation<TimeSpan?> ObjectIdleTime(this in KeyCommands context, [RespPrefix("IDLETIME")] RedisKey key);
36+
37+
[RespCommand("object")]
38+
public static partial RespOperation<long?> ObjectRefCount(this in KeyCommands context, [RespPrefix("REFCOUNT")] RedisKey key);
39+
40+
[RespCommand("object")]
41+
public static partial RespOperation<long?> ObjectFreq(this in KeyCommands context, [RespPrefix("FREQ")] RedisKey key);
42+
43+
[RespCommand(Parser = "RespParsers.TimeSpanFromSeconds")]
44+
public static partial RespOperation<TimeSpan?> Ttl(this in KeyCommands context, RedisKey key);
45+
46+
[RespCommand(Parser = "RespParsers.TimeSpanFromMilliseconds")]
47+
public static partial RespOperation<TimeSpan?> Pttl(this in KeyCommands context, RedisKey key);
48+
49+
[RespCommand(Parser = "RespParsers.DateTimeFromSeconds")]
50+
public static partial RespOperation<DateTime?> ExpireTime(this in KeyCommands context, RedisKey key);
51+
52+
[RespCommand(Parser = "RespParsers.DateTimeFromMilliseconds")]
53+
public static partial RespOperation<DateTime?> PExpireTime(this in KeyCommands context, RedisKey key);
54+
55+
[RespCommand]
56+
public static partial RespOperation<bool> Exists(this in KeyCommands context, RedisKey key);
57+
58+
[RespCommand]
59+
public static partial RespOperation<bool> Move(this in KeyCommands context, RedisKey key, int db);
60+
61+
[RespCommand]
62+
public static partial RespOperation<long> Exists(this in KeyCommands context, [RespKey] RedisKey[] keys);
63+
64+
public static RespOperation<bool> Expire(this in KeyCommands context, RedisKey key, TimeSpan? expiry, ExpireWhen when = ExpireWhen.Always)
65+
{
66+
if (expiry is null || expiry == TimeSpan.MaxValue)
67+
{
68+
if (when != ExpireWhen.Always) Throw(when);
69+
return Persist(context, key);
70+
static void Throw(ExpireWhen when) => throw new ArgumentException($"PERSIST cannot be used with {when}.");
71+
}
72+
var millis = (long)expiry.GetValueOrDefault().TotalMilliseconds;
73+
if (millis % 1000 == 0) // use seconds
74+
{
75+
return Expire(context, key, millis / 1000, when);
76+
}
77+
return PExpire(context, key, millis, when);
78+
}
79+
80+
public static RespOperation<bool> ExpireAt(this in KeyCommands context, RedisKey key, DateTime? expiry, ExpireWhen when = ExpireWhen.Always)
81+
{
82+
if (expiry is null || expiry == DateTime.MaxValue)
83+
{
84+
if (when != ExpireWhen.Always) Throw(when);
85+
return Persist(context, key);
86+
static void Throw(ExpireWhen when) => throw new ArgumentException($"PERSIST cannot be used with {when}.");
87+
}
88+
var millis = RedisDatabase.GetUnixTimeMilliseconds(expiry.GetValueOrDefault());
89+
if (millis % 1000 == 0) // use seconds
90+
{
91+
return ExpireAt(context, key, millis / 1000, when);
92+
}
93+
return PExpireAt(context, key, millis, when);
94+
}
95+
96+
[RespCommand]
97+
public static partial RespOperation<bool> Persist(this in KeyCommands context, RedisKey key);
98+
99+
[RespCommand]
100+
public static partial RespOperation<bool> Touch(this in KeyCommands context, RedisKey key);
101+
102+
[RespCommand]
103+
public static partial RespOperation<long> Touch(this in KeyCommands context, [RespKey] RedisKey[] keys);
104+
105+
[RespCommand(Parser = "RedisTypeParser.Instance")]
106+
public static partial RespOperation<RedisType> Type(this in KeyCommands context, RedisKey key);
107+
108+
private sealed class RedisTypeParser : IRespParser<RedisType>
109+
{
110+
public static readonly RedisTypeParser Instance = new();
111+
private RedisTypeParser() { }
112+
113+
public RedisType Parse(ref RespReader reader)
114+
{
115+
if (reader.IsNull) return RedisType.None;
116+
if (reader.Is("zset"u8)) return RedisType.SortedSet;
117+
return reader.ReadEnum(RedisType.Unknown);
118+
}
119+
}
120+
121+
[RespCommand]
122+
public static partial RespOperation<bool> Rename(this in KeyCommands context, RedisKey key, RedisKey newKey);
123+
124+
[RespCommand(Formatter = "RestoreFormatter.Instance")]
125+
public static partial RespOperation Restore(this in KeyCommands context, RedisKey key, TimeSpan? ttl, byte[] serializedValue);
126+
127+
private sealed class RestoreFormatter : IRespFormatter<(RedisKey Key, TimeSpan? Ttl, byte[] SerializedValue)>
128+
{
129+
public static readonly RestoreFormatter Instance = new();
130+
private RestoreFormatter() { }
131+
132+
public void Format(
133+
scoped ReadOnlySpan<byte> command,
134+
ref RespWriter writer,
135+
in (RedisKey Key, TimeSpan? Ttl, byte[] SerializedValue) request)
136+
{
137+
writer.WriteCommand(command, 3);
138+
writer.Write(request.Key);
139+
if (request.Ttl.HasValue)
140+
{
141+
writer.WriteBulkString((long)request.Ttl.Value.TotalMilliseconds);
142+
}
143+
else
144+
{
145+
writer.WriteRaw("$1\r\n0\r\n"u8);
146+
}
147+
writer.WriteBulkString(request.SerializedValue);
148+
}
149+
}
150+
151+
[RespCommand]
152+
public static partial RespOperation<RedisKey> RandomKey(this in KeyCommands context);
153+
154+
[RespCommand(Formatter = "ExpireFormatter.Instance")]
155+
public static partial RespOperation<bool> Expire(this in KeyCommands context, RedisKey key, long seconds, ExpireWhen when = ExpireWhen.Always);
156+
157+
[RespCommand(Formatter = "ExpireFormatter.Instance")]
158+
public static partial RespOperation<bool> PExpire(this in KeyCommands context, RedisKey key, long milliseconds, ExpireWhen when = ExpireWhen.Always);
159+
160+
[RespCommand(Formatter = "ExpireFormatter.Instance")]
161+
public static partial RespOperation<bool> ExpireAt(this in KeyCommands context, RedisKey key, long seconds, ExpireWhen when = ExpireWhen.Always);
162+
163+
[RespCommand(Formatter = "ExpireFormatter.Instance")]
164+
public static partial RespOperation<bool> PExpireAt(this in KeyCommands context, RedisKey key, long milliseconds, ExpireWhen when = ExpireWhen.Always);
165+
166+
private sealed class ExpireFormatter : IRespFormatter<(RedisKey Key, long Value, ExpireWhen When)>
167+
{
168+
public static readonly ExpireFormatter Instance = new();
169+
private ExpireFormatter() { }
170+
171+
public void Format(
172+
scoped ReadOnlySpan<byte> command,
173+
ref RespWriter writer,
174+
in (RedisKey Key, long Value, ExpireWhen When) request)
175+
{
176+
writer.WriteCommand(command, request.When == ExpireWhen.Always ? 2 : 3);
177+
writer.Write(request.Key);
178+
writer.Write(request.Value);
179+
switch (request.When)
180+
{
181+
case ExpireWhen.Always:
182+
break;
183+
case ExpireWhen.HasExpiry:
184+
writer.WriteRaw("$2\r\nXX\r\n"u8);
185+
break;
186+
case ExpireWhen.HasNoExpiry:
187+
writer.WriteRaw("$2\r\nNX\r\n"u8);
188+
break;
189+
case ExpireWhen.GreaterThanCurrentExpiry:
190+
writer.WriteRaw("$2\r\nGT\r\n"u8);
191+
break;
192+
case ExpireWhen.LessThanCurrentExpiry:
193+
writer.WriteRaw("$2\r\nLT\r\n"u8);
194+
break;
195+
default:
196+
Throw();
197+
static void Throw() => throw new ArgumentOutOfRangeException(nameof(request.When));
198+
break;
199+
}
200+
}
201+
}
202+
}

0 commit comments

Comments
 (0)