diff --git a/benchmark/BDN.benchmark/Lua/LuaParams.cs b/benchmark/BDN.benchmark/Lua/LuaParams.cs
index bd9058f2680..40497074c19 100644
--- a/benchmark/BDN.benchmark/Lua/LuaParams.cs
+++ b/benchmark/BDN.benchmark/Lua/LuaParams.cs
@@ -29,7 +29,7 @@ public LuaParams(LuaMemoryManagementMode mode, bool memoryLimit, TimeSpan? timeo
/// Get the equivalent .
///
public LuaOptions CreateOptions()
- => new(Mode, MemoryLimit ? "2m" : "", Timeout ?? System.Threading.Timeout.InfiniteTimeSpan, LuaLoggingMode.Enable);
+ => new(Mode, MemoryLimit ? "2m" : "", Timeout ?? System.Threading.Timeout.InfiniteTimeSpan, LuaLoggingMode.Enable, []);
///
/// String representation
diff --git a/benchmark/BDN.benchmark/Operations/OperationsBase.cs b/benchmark/BDN.benchmark/Operations/OperationsBase.cs
index f9858260b9b..f57caaa032b 100644
--- a/benchmark/BDN.benchmark/Operations/OperationsBase.cs
+++ b/benchmark/BDN.benchmark/Operations/OperationsBase.cs
@@ -56,7 +56,7 @@ public virtual void GlobalSetup()
QuietMode = true,
EnableLua = true,
DisablePubSub = true,
- LuaOptions = new(LuaMemoryManagementMode.Native, "", Timeout.InfiniteTimeSpan, LuaLoggingMode.Enable),
+ LuaOptions = new(LuaMemoryManagementMode.Native, "", Timeout.InfiniteTimeSpan, LuaLoggingMode.Enable, []),
};
if (Params.useAof)
diff --git a/libs/host/Configuration/Options.cs b/libs/host/Configuration/Options.cs
index ba85044b825..f002f31f30a 100644
--- a/libs/host/Configuration/Options.cs
+++ b/libs/host/Configuration/Options.cs
@@ -560,6 +560,27 @@ internal sealed class Options
[Option("lua-logging-mode", Required = false, HelpText = "Behavior of redis.log(...) when called from Lua scripts. Defaults to Enable.")]
public LuaLoggingMode LuaLoggingMode { get; set; }
+ // Parsing is a tad tricky here as JSON wants to set to empty at certain points
+ //
+ // A bespoke union-on-set gets the desired semantics.
+ private readonly HashSet luaAllowedFunctions = [];
+
+ [OptionValidation]
+ [Option("lua-allowed-functions", Separator = ',', Required = false, HelpText = "If set, restricts the functions available in Lua scripts to given list.")]
+ public IEnumerable LuaAllowedFunctions
+ {
+ get => luaAllowedFunctions;
+ set
+ {
+ if (value == null)
+ {
+ return;
+ }
+
+ luaAllowedFunctions.UnionWith(value);
+ }
+ }
+
[FilePathValidation(false, true, false)]
[Option("unixsocket", Required = false, HelpText = "Unix socket address path to bind server to")]
public string UnixSocketPath { get; set; }
@@ -811,7 +832,7 @@ public GarnetServerOptions GetServerOptions(ILogger logger = null)
LoadModuleCS = LoadModuleCS,
FailOnRecoveryError = FailOnRecoveryError.GetValueOrDefault(),
SkipRDBRestoreChecksumValidation = SkipRDBRestoreChecksumValidation.GetValueOrDefault(),
- LuaOptions = EnableLua.GetValueOrDefault() ? new LuaOptions(LuaMemoryManagementMode, LuaScriptMemoryLimit, LuaScriptTimeoutMs == 0 ? Timeout.InfiniteTimeSpan : TimeSpan.FromMilliseconds(LuaScriptTimeoutMs), LuaLoggingMode, logger) : null,
+ LuaOptions = EnableLua.GetValueOrDefault() ? new LuaOptions(LuaMemoryManagementMode, LuaScriptMemoryLimit, LuaScriptTimeoutMs == 0 ? Timeout.InfiniteTimeSpan : TimeSpan.FromMilliseconds(LuaScriptTimeoutMs), LuaLoggingMode, LuaAllowedFunctions, logger) : null,
UnixSocketPath = UnixSocketPath,
UnixSocketPermission = unixSocketPermissions
};
diff --git a/libs/host/defaults.conf b/libs/host/defaults.conf
index 4741f65e63e..cd1e1f4c95a 100644
--- a/libs/host/defaults.conf
+++ b/libs/host/defaults.conf
@@ -380,6 +380,9 @@
/* Allow redis.log(...) to write to the Garnet logs */
"LuaLoggingMode": "Enable",
+ /* Allow all built in and redis.* functions by default */
+ "LuaAllowedFunctions": null,
+
/* Unix socket address path to bind the server to */
"UnixSocketPath": null,
diff --git a/libs/server/ArgSlice/ScratchBufferManager.cs b/libs/server/ArgSlice/ScratchBufferManager.cs
index 1a892820b9b..3e64a42a6f9 100644
--- a/libs/server/ArgSlice/ScratchBufferManager.cs
+++ b/libs/server/ArgSlice/ScratchBufferManager.cs
@@ -114,6 +114,27 @@ public ArgSlice CreateArgSlice(string str)
return retVal;
}
+ public ReadOnlySpan UTF8EncodeString(string str)
+ {
+ // We'll always need AT LEAST this many bytes
+ ExpandScratchBufferIfNeeded(str.Length);
+
+ var space = FullBuffer()[scratchBufferOffset..];
+
+ // Attempt to fit in the existing buffer first
+ if (!Encoding.UTF8.TryGetBytes(str, space, out var written))
+ {
+ // If that fails, figure out exactly how much space we need
+ var neededBytes = Encoding.UTF8.GetByteCount(str);
+ ExpandScratchBufferIfNeeded(neededBytes);
+
+ space = FullBuffer()[scratchBufferOffset..];
+ written = Encoding.UTF8.GetBytes(str, space);
+ }
+
+ return space[..written];
+ }
+
///
/// Create an ArgSlice that includes a header of specified size, followed by RESP Bulk-String formatted versions of the specified ArgSlice values (arg1 and arg2)
///
diff --git a/libs/server/Lua/LuaOptions.cs b/libs/server/Lua/LuaOptions.cs
index 24e99e51159..f78e77b65c8 100644
--- a/libs/server/Lua/LuaOptions.cs
+++ b/libs/server/Lua/LuaOptions.cs
@@ -2,6 +2,7 @@
// Licensed under the MIT license.
using System;
+using System.Collections.Generic;
using System.Runtime.InteropServices;
using Microsoft.Extensions.Logging;
@@ -18,6 +19,7 @@ public sealed class LuaOptions
public string MemoryLimit = "";
public TimeSpan Timeout = System.Threading.Timeout.InfiniteTimeSpan;
public LuaLoggingMode LogMode = LuaLoggingMode.Silent;
+ public HashSet AllowedFunctions = [];
///
/// Construct options with default options.
@@ -30,12 +32,13 @@ public LuaOptions(ILogger logger = null)
///
/// Construct options with specific settings.
///
- public LuaOptions(LuaMemoryManagementMode memoryMode, string memoryLimit, TimeSpan timeout, LuaLoggingMode logMode, ILogger logger = null) : this(logger)
+ public LuaOptions(LuaMemoryManagementMode memoryMode, string memoryLimit, TimeSpan timeout, LuaLoggingMode logMode, IEnumerable allowedFunctions, ILogger logger = null) : this(logger)
{
MemoryManagementMode = memoryMode;
MemoryLimit = memoryLimit;
Timeout = timeout;
LogMode = logMode;
+ AllowedFunctions = new HashSet(allowedFunctions, StringComparer.OrdinalIgnoreCase);
}
///
diff --git a/libs/server/Lua/LuaRunner.cs b/libs/server/Lua/LuaRunner.cs
index 95375225759..a5d59ca47f9 100644
--- a/libs/server/Lua/LuaRunner.cs
+++ b/libs/server/Lua/LuaRunner.cs
@@ -3,12 +3,18 @@
using System;
using System.Buffers;
+using System.Buffers.Binary;
+using System.Collections.Generic;
+using System.Data;
using System.Diagnostics;
using System.Globalization;
using System.Linq;
+using System.Numerics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Text;
+using System.Text.Json;
+using System.Text.Json.Nodes;
using Garnet.common;
using KeraLua;
using Microsoft.Extensions.Logging;
@@ -235,44 +241,197 @@ private int ConstantStringToRegistry(ref LuaStateWrapper state, ReadOnlySpan
+ /// Simple cache of allowed functions to loader block.
+ ///
+ private sealed record LoaderBlockCache(HashSet AllowedFunctions, ReadOnlyMemory LoaderBlockBytes);
+
+ private const string LoaderBlock = @"
+-- globals to fill in on each invocation
KEYS = {}
ARGV = {}
+
+-- disable for sandboxing purposes
+import = function () end
+
+-- cutdown os for sandboxing purposes
+local osClockRef = os.clock
+os = {
+ clock = osClockRef
+}
+
+-- define cjson for (optional) inclusion into sandbox_env
+local cjson = {
+ encode = garnet_cjson_encode;
+ decode = garnet_cjson_decode;
+}
+
+-- define bit for (optional) inclusion into sandbox_env
+local bit = {
+ tobit = garnet_bit_tobit;
+ tohex = garnet_bit_tohex;
+ bnot = function(...) return garnet_bitop(0, ...); end;
+ bor = function(...) return garnet_bitop(1, ...); end;
+ band = function(...) return garnet_bitop(2, ...); end;
+ bxor = function(...) return garnet_bitop(3, ...); end;
+ lshift = function(...) return garnet_bitop(4, ...); end;
+ rshift = function(...) return garnet_bitop(5, ...); end;
+ arshift = function(...) return garnet_bitop(6, ...); end;
+ rol = function(...) return garnet_bitop(7, ...); end;
+ ror = function(...) return garnet_bitop(8, ...); end;
+ bswap = garnet_bit_bswap;
+}
+
+-- define cmsgpack for (optional) inclusion into sandbox_env
+local cmsgpack = {
+ pack = garnet_cmsgpack_pack;
+ unpack = garnet_cmsgpack_unpack;
+}
+
+-- define struct for (optional) inclusion into sandbox_env
+local struct = {
+ pack = string.pack;
+ unpack = string.unpack;
+ size = string.packsize;
+}
+
+-- define redis for (optional, but almost always) inclusion into sandbox_env
+local garnetCallRef = garnet_call
+local pCallRef = pcall
+local redis = {
+ status_reply = function(text)
+ return text
+ end,
+
+ error_reply = function(text)
+ return { err = 'ERR ' .. text }
+ end,
+
+ call = garnetCallRef,
+
+ pcall = function(...)
+ local success, errOrRes = pCallRef(garnetCallRef, ...)
+ if success then
+ return errOrRes
+ end
+
+ return { err = errOrRes }
+ end,
+
+ sha1hex = garnet_sha1hex,
+
+ LOG_DEBUG = 0,
+ LOG_VERBOSE = 1,
+ LOG_NOTICE = 2,
+ LOG_WARNING = 3,
+
+ log = garnet_log,
+
+ REPL_ALL = 3,
+ REPL_AOF = 1,
+ REPL_REPLICA = 2,
+ REPL_SLAVE = 2,
+ REPL_NONE = 0,
+
+ set_repl = function(...)
+ -- this is a giant footgun, straight up not implementing it
+ error('ERR redis.set_repl is not supported in Garnet', 0)
+ end,
+
+ replicate_commands = function(...)
+ return true
+ end,
+
+ breakpoint = function(...)
+ -- this is giant and weird, not implementing
+ error('ERR redis.breakpoint is not supported in Garnet', 0)
+ end,
+
+ debug = function(...)
+ -- this is giant and weird, not implementing
+ error('ERR redis.debug is not supported in Garnet', 0)
+ end,
+
+ acl_check_cmd = garnet_acl_check_cmd,
+ setresp = garnet_setresp,
+
+ REDIS_VERSION = garnet_REDIS_VERSION,
+ REDIS_VERSION_NUM = garnet_REDIS_VERSION_NUM
+}
+
+-- unpack moved after Lua 5.1, this provides Redis compat
+local unpack = table.unpack
+
+-- added after Lua 5.1, removing to maintain Redis compat
+string.pack = nil
+string.unpack = nil
+string.packsize = nil
+math.maxinteger = nil
+math.type = nil
+math.mininteger = nil
+math.tointeger = nil
+math.ult = nil
+table.pack = nil
+table.unpack = nil
+table.move = nil
+
+-- in Lua 5.1 but not 5.4, so implemented on the .NET side
+local loadstring = garnet_loadstring
+math.atan2 = garnet_atan2
+math.cosh = garnet_cosh
+math.frexp = garnet_frexp
+math.ldexp = garnet_ldexp
+math.log10 = garnet_log10
+math.pow = garnet_pow
+math.sinh = garnet_sinh
+math.tanh = garnet_tanh
+table.maxn = garnet_maxn
+
+local collectgarbageRef = collectgarbage
+local setMetatableRef = setmetatable
+local rawsetRef = rawset
+
+-- prevent modification to metatables for readonly tables
+-- Redis accomplishes this by patching Lua, we'd rather ship
+-- vanilla Lua and do it in code
+local setmetatable = function(table, metatable)
+ if table and table.__readonly then
+ error('Attempt to modify a readonly table', 0)
+ end
+
+ return setMetatableRef(table, metatable)
+end
+
+-- prevent bypassing metatables to update readonly tables
+-- as above, Redis prevents this with a patch to Lua
+local rawset = function(table, key, value)
+ if table and table.__readonly then
+ error('Attempt to modify a readonly table', 0)
+ end
+
+ return rawsetRef(table, key, value)
+end
+
+-- technically deprecated in 5.1, but available in Redis
+-- this is only 'sort of' correct as 5.4 doesn't expose the same
+-- gc primitives
+local gcinfo = function()
+ return collectgarbageRef('count'), 0
+end
+
+-- global object used for the sandbox environment
+--
+-- replacements are performed before VM initialization
+-- to allow configuring available functions
sandbox_env = {
_VERSION = _VERSION;
- assert = assert;
- collectgarbage = collectgarbage;
- coroutine = coroutine;
- error = error;
- gcinfo = gcinfo;
- -- explicitly not allowing getfenv
- getmetatable = getmetatable;
- ipairs = ipairs;
- load = load;
- loadstring = loadstring;
- math = math;
- next = next;
- pairs = pairs;
- pcall = pcall;
- rawequal = rawequal;
- rawget = rawget;
- -- rawset is proxied to implement readonly tables
- select = select;
- -- explicitly not allowing setfenv
- -- setmetatable is proxied to implement readonly tables
- string = string;
- table = table;
- tonumber = tonumber;
- tostring = tostring;
- type = type;
- unpack = table.unpack;
- xpcall = xpcall;
-
KEYS = KEYS;
ARGV = ARGV;
+
+!!SANDBOX_ENV REPLACEMENT TARGET!!
}
+
-- no reference to outermost set of globals (_G) should survive sandboxing
sandbox_env._G = sandbox_env
-- lock down a table, recursively doing the same to all table members
@@ -298,7 +457,7 @@ function recursively_readonly_table(table)
end
end
- setmetatable(table, readonly_metatable)
+ setMetatableRef(table, readonly_metatable)
end
-- do resets in the Lua side to minimize pinvokes
function reset_keys_and_argv(fromKey, fromArgv)
@@ -316,98 +475,6 @@ function reset_keys_and_argv(fromKey, fromArgv)
end
-- responsible for sandboxing user provided code
function load_sandboxed(source)
- -- move into a local to avoid global lookup
- local garnetCallRef = garnet_call
- local pCallRef = pcall
- local sha1hexRef = garnet_sha1hex
- local logRef = garnet_log
- local aclCheckCmdRef = garnet_acl_check_cmd
- local setRespRef = garnet_setresp
- local setMetatableRef = setmetatable
- local rawsetRaw = rawset
-
- sandbox_env.redis = {
- status_reply = function(text)
- return text
- end,
-
- error_reply = function(text)
- return { err = 'ERR ' .. text }
- end,
-
- call = garnetCallRef,
-
- pcall = function(...)
- local success, errOrRes = pCallRef(garnetCallRef, ...)
- if success then
- return errOrRes
- end
-
- return { err = errOrRes }
- end,
-
- sha1hex = sha1hexRef,
-
- LOG_DEBUG = 0,
- LOG_VERBOSE = 1,
- LOG_NOTICE = 2,
- LOG_WARNING = 3,
-
- log = logRef,
-
- REPL_ALL = 3,
- REPL_AOF = 1,
- REPL_REPLICA = 2,
- REPL_SLAVE = 2,
- REPL_NONE = 0,
-
- set_repl = function(...)
- -- this is a giant footgun, straight up not implementing it
- error('ERR redis.set_repl is not supported in Garnet', 0)
- end,
-
- replicate_commands = function(...)
- return true
- end,
-
- breakpoint = function(...)
- -- this is giant and weird, not implementing
- error('ERR redis.breakpoint is not supported in Garnet', 0)
- end,
-
- debug = function(...)
- -- this is giant and weird, not implementing
- error('ERR redis.debug is not supported in Garnet', 0)
- end,
-
- acl_check_cmd = aclCheckCmdRef,
- setresp = setRespRef,
-
- REDIS_VERSION = garnet_REDIS_VERSION,
- REDIS_VERSION_NUM = garnet_REDIS_VERSION_NUM
- }
-
- -- prevent modification to metatables for readonly tables
- -- Redis accomplishes this by patching Lua, we'd rather ship
- -- vanilla Lua and do it in code
- sandbox_env.setmetatable = function(table, metatable)
- if table and table.__readonly then
- error('Attempt to modify a readonly table', 0)
- end
-
- return setMetatableRef(table, metatable)
- end
-
- -- prevent bypassing metatables to update readonly tables
- -- as above, Redis prevents this with a patch to Lua
- sandbox_env.rawset = function(table, key, value)
- if table and table.__readonly then
- error('Attempt to modify a readonly table', 0)
- end
-
- return rawsetRef(table, key, value)
- end
-
recursively_readonly_table(sandbox_env)
local rawFunc, err = load(source, nil, nil, sandbox_env)
@@ -415,8 +482,53 @@ function load_sandboxed(source)
return err, rawFunc
end
";
-
- private static readonly ReadOnlyMemory LoaderBlockBytes = Encoding.UTF8.GetBytes(LoaderBlock);
+ private static readonly HashSet DefaultAllowedFunctions = [
+ // Built ins
+ "assert",
+ "collectgarbage",
+ "coroutine",
+ "error",
+ "gcinfo",
+ // Intentionally not supporting getfenv, as it's too weird to backport to Lua 5.4
+ "getmetatable",
+ "ipairs",
+ "load",
+ "loadstring",
+ "math",
+ "next",
+ "pairs",
+ "pcall",
+ "rawequal",
+ "rawget",
+ // Note rawset is proxied to implement readonly tables
+ "rawset",
+ "select",
+ // Intentionally not supporting setfenv, as it's too weird to backport to Lua 5.4
+ // Note setmetatable is proxied to implement readonly tables
+ "setmetatable",
+ "string",
+ "table",
+ "tonumber",
+ "tostring",
+ "type",
+ // Note unpack is actually table.unpack, and defined in the loader block
+ "unpack",
+ "xpcall",
+
+ // Runtime libs
+ "bit",
+ "cjson",
+ "cmsgpack",
+ // Note os only contains clock due to definition in the loader block
+ "os",
+ // Note struct is actually implemented by Lua 5.4's string.pack/unpack/size
+ "struct",
+
+ // Interface force communicating back with Garnet
+ "redis",
+ ];
+
+ private static LoaderBlockCache CachedLoaderBlock;
private static (int Start, ulong[] ByteMask) NoScriptDetails = InitializeNoScriptDetails();
@@ -432,6 +544,7 @@ function load_sandboxed(source)
readonly ConstantStringRegistryIndexes constStrs;
readonly LuaLoggingMode logMode;
+ readonly HashSet allowedFunctions;
readonly ReadOnlyMemory source;
readonly ScratchBufferNetworkSender scratchBufferNetworkSender;
readonly RespServerSession respServerSession;
@@ -467,6 +580,7 @@ public unsafe LuaRunner(
LuaMemoryManagementMode memMode,
int? memLimitBytes,
LuaLoggingMode logMode,
+ HashSet allowedFunctions,
ReadOnlyMemory source,
bool txnMode = false,
RespServerSession respServerSession = null,
@@ -480,6 +594,7 @@ public unsafe LuaRunner(
this.respServerSession = respServerSession;
this.scratchBufferNetworkSender = scratchBufferNetworkSender;
this.logMode = logMode;
+ this.allowedFunctions = allowedFunctions;
this.logger = logger;
scratchBufferManager = respServerSession?.scratchBufferManager ?? new();
@@ -518,10 +633,56 @@ public unsafe LuaRunner(
garnetCall = &LuaRunnerTrampolines.GarnetCallNoSession;
}
- var loadRes = state.LoadBuffer(LoaderBlockBytes.Span);
+ // Lua 5.4 does not provide these functions, but 5.1 does - so implement tehm
+ state.Register("garnet_atan2\0"u8, &LuaRunnerTrampolines.Atan2);
+ state.Register("garnet_cosh\0"u8, &LuaRunnerTrampolines.Cosh);
+ state.Register("garnet_frexp\0"u8, &LuaRunnerTrampolines.Frexp);
+ state.Register("garnet_ldexp\0"u8, &LuaRunnerTrampolines.Ldexp);
+ state.Register("garnet_log10\0"u8, &LuaRunnerTrampolines.Log10);
+ state.Register("garnet_pow\0"u8, &LuaRunnerTrampolines.Pow);
+ state.Register("garnet_sinh\0"u8, &LuaRunnerTrampolines.Sinh);
+ state.Register("garnet_tanh\0"u8, &LuaRunnerTrampolines.Tanh);
+ state.Register("garnet_maxn\0"u8, &LuaRunnerTrampolines.Maxn);
+ state.Register("garnet_loadstring\0"u8, &LuaRunnerTrampolines.LoadString);
+
+ // Things provided as Lua libraries, which we actually implement in .NET
+ state.Register("garnet_cjson_encode\0"u8, &LuaRunnerTrampolines.CJsonEncode);
+ state.Register("garnet_cjson_decode\0"u8, &LuaRunnerTrampolines.CJsonDecode);
+ state.Register("garnet_bit_tobit\0"u8, &LuaRunnerTrampolines.BitToBit);
+ state.Register("garnet_bit_tohex\0"u8, &LuaRunnerTrampolines.BitToHex);
+ // garnet_bitop implements bnot, bor, band, xor, etc. but isn't directly exposed
+ state.Register("garnet_bitop\0"u8, &LuaRunnerTrampolines.Bitop);
+ state.Register("garnet_bit_bswap\0"u8, &LuaRunnerTrampolines.BitBswap);
+ state.Register("garnet_cmsgpack_pack\0"u8, &LuaRunnerTrampolines.CMsgPackPack);
+ state.Register("garnet_cmsgpack_unpack\0"u8, &LuaRunnerTrampolines.CMsgPackUnpack);
+ state.Register("garnet_call\0"u8, garnetCall);
+ state.Register("garnet_sha1hex\0"u8, &LuaRunnerTrampolines.SHA1Hex);
+ state.Register("garnet_log\0"u8, &LuaRunnerTrampolines.Log);
+ state.Register("garnet_acl_check_cmd\0"u8, &LuaRunnerTrampolines.AclCheckCommand);
+ state.Register("garnet_setresp\0"u8, &LuaRunnerTrampolines.SetResp);
+
+ var redisVersionBytes = Encoding.UTF8.GetBytes(redisVersion);
+ state.PushBuffer(redisVersionBytes);
+ state.SetGlobal("garnet_REDIS_VERSION\0"u8);
+
+ var redisVersionParsed = Version.Parse(redisVersion);
+ var redisVersionNum =
+ ((byte)redisVersionParsed.Major << 16) |
+ ((byte)redisVersionParsed.Minor << 8) |
+ ((byte)redisVersionParsed.Build << 0);
+ state.PushInteger(redisVersionNum);
+ state.SetGlobal("garnet_REDIS_VERSION_NUM\0"u8);
+
+ var loadRes = state.LoadBuffer(PrepareLoaderBlockBytes(allowedFunctions).Span);
if (loadRes != LuaStatus.OK)
{
- throw new GarnetException("Couldn't load loader into Lua");
+ if (state.StackTop == 1 && state.CheckBuffer(1, out var buff))
+ {
+ var innerError = Encoding.UTF8.GetString(buff);
+ throw new GarnetException($"Could initialize Lua VM: {innerError}");
+ }
+
+ throw new GarnetException("Could initialize Lua VM");
}
var sandboxRes = state.PCall(0, -1);
@@ -549,25 +710,6 @@ public unsafe LuaRunner(
throw new GarnetException($"Could not initialize Lua sandbox state: {errMsg}");
}
- // Register functions provided by .NET in global namespace
- state.Register("garnet_call\0"u8, garnetCall);
- state.Register("garnet_sha1hex\0"u8, &LuaRunnerTrampolines.SHA1Hex);
- state.Register("garnet_log\0"u8, &LuaRunnerTrampolines.Log);
- state.Register("garnet_acl_check_cmd\0"u8, &LuaRunnerTrampolines.AclCheckCommand);
- state.Register("garnet_setresp\0"u8, &LuaRunnerTrampolines.SetResp);
-
- var redisVersionBytes = Encoding.UTF8.GetBytes(redisVersion);
- state.PushBuffer(redisVersionBytes);
- state.SetGlobal("garnet_REDIS_VERSION\0"u8);
-
- var redisVersionParsed = Version.Parse(redisVersion);
- var redisVersionNum =
- ((byte)redisVersionParsed.Major << 16) |
- ((byte)redisVersionParsed.Minor << 8) |
- ((byte)redisVersionParsed.Build << 0);
- state.PushInteger(redisVersionNum);
- state.SetGlobal("garnet_REDIS_VERSION_NUM\0"u8);
-
state.GetGlobal(LuaType.Table, "KEYS\0"u8);
keysTableRegistryIndex = state.Ref();
@@ -590,7 +732,7 @@ public unsafe LuaRunner(
/// Creates a new runner with the source of the script
///
public LuaRunner(LuaOptions options, string source, bool txnMode = false, RespServerSession respServerSession = null, ScratchBufferNetworkSender scratchBufferNetworkSender = null, string redisVersion = "0.0.0.0", ILogger logger = null)
- : this(options.MemoryManagementMode, options.GetMemoryLimitBytes(), options.LogMode, Encoding.UTF8.GetBytes(source), txnMode, respServerSession, scratchBufferNetworkSender, redisVersion, logger)
+ : this(options.MemoryManagementMode, options.GetMemoryLimitBytes(), options.LogMode, options.AllowedFunctions, Encoding.UTF8.GetBytes(source), txnMode, respServerSession, scratchBufferNetworkSender, redisVersion, logger)
{
}
@@ -662,7 +804,7 @@ public unsafe bool CompileForSession(RespServerSession session)
return false;
}
- return true;
+ return functionRegistryIndex != -1;
}
finally
{
@@ -742,122 +884,1960 @@ private unsafe int CompileCommon(ref TResponse resp)
public void Dispose()
=> state.Dispose();
- ///
- /// Entry point for redis.sha1hex method from a Lua script.
- ///
- public int SHA1Hex(nint luaStatePtr)
- {
- state.CallFromLuaEntered(luaStatePtr);
+ ///
+ /// Entry point for redis.sha1hex method from a Lua script.
+ ///
+ public int SHA1Hex(nint luaStatePtr)
+ {
+ state.CallFromLuaEntered(luaStatePtr);
+
+ var argCount = state.StackTop;
+ if (argCount != 1)
+ {
+ state.PushConstantString(constStrs.ErrWrongNumberOfArgs);
+ return state.RaiseErrorFromStack();
+ }
+
+ if (!state.CheckBuffer(1, out var bytes))
+ {
+ bytes = default;
+ }
+
+ Span hashBytes = stackalloc byte[SessionScriptCache.SHA1Len / 2];
+ Span hexRes = stackalloc byte[SessionScriptCache.SHA1Len];
+
+ SessionScriptCache.GetScriptDigest(bytes, hashBytes, hexRes);
+
+ state.PushBuffer(hexRes);
+ return 1;
+ }
+
+ ///
+ /// Entry point for redis.log(...) from a Lua script.
+ ///
+ public int Log(nint luaStatePtr)
+ {
+ state.CallFromLuaEntered(luaStatePtr);
+
+ var argCount = state.StackTop;
+ if (argCount < 2)
+ {
+ return LuaStaticError(constStrs.ErrRedisLogRequired);
+ }
+
+ if (state.Type(1) != LuaType.Number)
+ {
+ return LuaStaticError(constStrs.ErrFirstArgMustBeNumber);
+ }
+
+ var rawLevel = state.CheckNumber(1);
+ if (rawLevel is not (0 or 1 or 2 or 3))
+ {
+ return LuaStaticError(constStrs.ErrInvalidDebugLevel);
+ }
+
+ if (logMode == LuaLoggingMode.Disable)
+ {
+ return LuaStaticError(constStrs.ErrLoggingDisabled);
+ }
+
+ // When shipped as a service, allowing arbitrary writes to logs is dangerous
+ // so we support disabling it (while not breaking existing scripts)
+ if (logMode == LuaLoggingMode.Silent)
+ {
+ return 0;
+ }
+
+ // Even if enabled, if no logger was provided we can just bail
+ if (logger == null)
+ {
+ return 0;
+ }
+
+ // Construct and log the equivalent message
+ string logMessage;
+ if (argCount == 2)
+ {
+ if (state.CheckBuffer(2, out var buff))
+ {
+ logMessage = Encoding.UTF8.GetString(buff);
+ }
+ else
+ {
+ logMessage = "";
+ }
+ }
+ else
+ {
+ var sb = new StringBuilder();
+
+ for (var argIx = 2; argIx <= argCount; argIx++)
+ {
+ if (state.CheckBuffer(argIx, out var buff))
+ {
+ if (sb.Length != 0)
+ {
+ _ = sb.Append(' ');
+ }
+
+ _ = sb.Append(Encoding.UTF8.GetString(buff));
+ }
+ }
+
+ logMessage = sb.ToString();
+ }
+
+ var logLevel =
+ rawLevel switch
+ {
+ 0 => LogLevel.Debug,
+ 1 => LogLevel.Information,
+ 2 => LogLevel.Warning,
+ // We validated this above, so really it's just 3 but the switch needs to be exhaustive
+ _ => LogLevel.Error,
+ };
+
+ logger.Log(logLevel, "redis.log: {message}", logMessage.ToString());
+
+ return 0;
+ }
+
+ ///
+ /// Entry point for math.atan2 from a Lua script.
+ ///
+ public int Atan2(nint luaStatePtr)
+ {
+ state.CallFromLuaEntered(luaStatePtr);
+
+ var luaArgCount = state.StackTop;
+ if (luaArgCount != 2 || state.Type(1) != LuaType.Number || state.Type(2) != LuaType.Number)
+ {
+ return state.RaiseError("bad argument to atan2");
+ }
+
+ var x = state.CheckNumber(1);
+ var y = state.CheckNumber(2);
+
+ var res = Math.Atan2(x, y);
+ state.Pop(2);
+ state.PushNumber(res);
+ return 1;
+ }
+
+ ///
+ /// Entry point for math.cosh from a Lua script.
+ ///
+ public int Cosh(nint luaStatePtr)
+ {
+ state.CallFromLuaEntered(luaStatePtr);
+
+ var luaArgCount = state.StackTop;
+ if (luaArgCount != 1 || state.Type(1) != LuaType.Number)
+ {
+ return state.RaiseError("bad argument to cosh");
+ }
+
+ var value = state.CheckNumber(1);
+
+ var res = Math.Cosh(value);
+ state.Pop(1);
+ state.PushNumber(res);
+ return 1;
+ }
+
+ ///
+ /// Entry point for math.frexp from a Lua script.
+ ///
+ public int Frexp(nint luaStatePtr)
+ {
+ // Based on: https://github.com/MachineCognitis/C.math.NET/ (MIT License)
+
+ const long DBL_EXP_MASK = 0x7FF0000000000000L;
+ const int DBL_MANT_BITS = 52;
+ const long DBL_SGN_MASK = -1 - 0x7FFFFFFFFFFFFFFFL;
+ const long DBL_MANT_MASK = 0x000FFFFFFFFFFFFFL;
+ const long DBL_EXP_CLR_MASK = DBL_SGN_MASK | DBL_MANT_MASK;
+
+ state.CallFromLuaEntered(luaStatePtr);
+
+ var luaArgCount = state.StackTop;
+ if (luaArgCount != 1 || state.Type(1) != LuaType.Number)
+ {
+ return state.RaiseError("bad argument to frexp");
+ }
+
+ var number = state.CheckNumber(1);
+
+ var bits = BitConverter.DoubleToInt64Bits(number);
+ var exp = (int)((bits & DBL_EXP_MASK) >> DBL_MANT_BITS);
+ var exponent = 0;
+
+ if (exp == 0x7FF || number == 0D)
+ {
+ number += number;
+ }
+ else
+ {
+ // Not zero and finite.
+ exponent = exp - 1022;
+ if (exp == 0)
+ {
+ // Subnormal, scale number so that it is in [1, 2).
+ number *= BitConverter.Int64BitsToDouble(0x4350000000000000L); // 2^54
+ bits = BitConverter.DoubleToInt64Bits(number);
+ exp = (int)((bits & DBL_EXP_MASK) >> DBL_MANT_BITS);
+ exponent = exp - 1022 - 54;
+ }
+ // Set exponent to -1 so that number is in [0.5, 1).
+ number = BitConverter.Int64BitsToDouble((bits & DBL_EXP_CLR_MASK) | 0x3FE0000000000000L);
+ }
+
+ state.ForceMinimumStackCapacity(2);
+ state.Pop(1);
+
+ var numberAsFloat = (float)number;
+
+ if ((long)numberAsFloat == numberAsFloat)
+ {
+ state.PushInteger((long)numberAsFloat);
+ }
+ else
+ {
+ state.PushNumber(number);
+ }
+
+ state.PushInteger(exponent);
+
+ return 2;
+ }
+
+ ///
+ /// Entry point for math.ldexp from a Lua script.
+ ///
+ public int Ldexp(nint luaStatePtr)
+ {
+ state.CallFromLuaEntered(luaStatePtr);
+
+ var luaArgCount = state.StackTop;
+ if (luaArgCount != 2 || state.Type(1) != LuaType.Number || state.Type(2) != LuaType.Number)
+ {
+ return state.RaiseError("bad argument to ldexp");
+ }
+
+ var m = state.CheckNumber(1);
+ var e = (int)state.CheckNumber(2);
+
+ var res = m * Math.Pow(2, e);
+
+ state.Pop(2);
+
+ if ((long)res == res)
+ {
+ state.PushInteger((long)res);
+ }
+ else
+ {
+ state.PushNumber(res);
+ }
+
+ return 1;
+ }
+
+ ///
+ /// Entry point for math.log10 from a Lua script.
+ ///
+ public int Log10(nint luaStatePtr)
+ {
+ state.CallFromLuaEntered(luaStatePtr);
+
+ var luaArgCount = state.StackTop;
+ if (luaArgCount != 1 || state.Type(1) != LuaType.Number)
+ {
+ return state.RaiseError("bad argument to log10");
+ }
+
+ var val = state.CheckNumber(1);
+
+ var res = Math.Log10(val);
+
+ state.Pop(1);
+ state.PushNumber(res);
+ return 1;
+ }
+
+ ///
+ /// Entry point for math.pow from a Lua script.
+ ///
+ public int Pow(nint luaStatePtr)
+ {
+ state.CallFromLuaEntered(luaStatePtr);
+
+ var luaArgCount = state.StackTop;
+ if (luaArgCount != 2 || state.Type(1) != LuaType.Number || state.Type(2) != LuaType.Number)
+ {
+ return state.RaiseError("bad argument to pow");
+ }
+
+ var x = state.CheckNumber(1);
+ var y = state.CheckNumber(2);
+
+ var res = Math.Pow(x, y);
+
+ state.Pop(2);
+ state.PushNumber(res);
+ return 1;
+ }
+
+ ///
+ /// Entry point for math.sinh from a Lua script.
+ ///
+ public int Sinh(nint luaStatePtr)
+ {
+ state.CallFromLuaEntered(luaStatePtr);
+
+ var luaArgCount = state.StackTop;
+ if (luaArgCount != 1 || state.Type(1) != LuaType.Number)
+ {
+ return state.RaiseError("bad argument to sinh");
+ }
+
+ var val = state.CheckNumber(1);
+
+ var res = Math.Sinh(val);
+
+ state.Pop(1);
+ state.PushNumber(res);
+ return 1;
+ }
+
+ ///
+ /// Entry point for math.sinh from a Lua script.
+ ///
+ public int Tanh(nint luaStatePtr)
+ {
+ state.CallFromLuaEntered(luaStatePtr);
+
+ var luaArgCount = state.StackTop;
+ if (luaArgCount != 1 || state.Type(1) != LuaType.Number)
+ {
+ return state.RaiseError("bad argument to tanh");
+ }
+
+ var val = state.CheckNumber(1);
+
+ var res = Math.Tanh(val);
+
+ state.Pop(1);
+ state.PushNumber(res);
+ return 1;
+ }
+
+ ///
+ /// Entry point for table.maxn from a Lua script.
+ ///
+ public int Maxn(nint luaStatePtr)
+ {
+ state.CallFromLuaEntered(luaStatePtr);
+
+ var luaArgCount = state.StackTop;
+ if (luaArgCount != 1 || state.Type(1) != LuaType.Table)
+ {
+ return state.RaiseError("bad argument to maxn");
+ }
+
+ state.ForceMinimumStackCapacity(2);
+
+ double res = 0;
+
+ // Initial key value onto stack
+ state.PushNil();
+ while (state.Next(1) != 0)
+ {
+ // Remove value
+ state.Pop(1);
+
+ double keyVal;
+ if (state.Type(2) == LuaType.Number && (keyVal = state.CheckNumber(2)) > res)
+ {
+ res = keyVal;
+ }
+ }
+
+ // Remove table, and push largest number
+ state.Pop(1);
+ state.PushNumber(res);
+ return 1;
+ }
+
+ ///
+ /// Entry point for loadstring from a Lua script.
+ ///
+ public int LoadString(nint luaStatePtr)
+ {
+ state.CallFromLuaEntered(luaStatePtr);
+
+ var luaArgCount = state.StackTop;
+ if (
+ (luaArgCount == 1 && state.Type(1) != LuaType.String) ||
+ (luaArgCount == 2 && (state.Type(1) != LuaType.String || state.Type(2) != LuaType.String)) ||
+ (luaArgCount > 2)
+ )
+ {
+ return state.RaiseError("bad argument to loadstring");
+ }
+
+ // Ignore chunk name
+ if (luaArgCount == 2)
+ {
+ state.Pop(1);
+ }
+
+ _ = state.CheckBuffer(1, out var buff);
+ if (buff.Contains((byte)0))
+ {
+ return state.RaiseError("bad argument to loadstring, interior null byte");
+ }
+
+ state.ForceMinimumStackCapacity(1);
+
+ var res = state.LoadString(buff);
+ if (res != LuaStatus.OK)
+ {
+ state.ClearStack();
+ state.PushNil();
+ state.PushBuffer("load_string encountered error"u8);
+ return 2;
+ }
+
+ return 1;
+ }
+
+ ///
+ /// Converts a Lua number (ie. a double) into the expected 32-bit integer for
+ /// bit operations.
+ ///
+ private static int LuaNumberToBitValue(double value)
+ {
+ var scaled = value + 6_755_399_441_055_744.0;
+ var asULong = BitConverter.DoubleToUInt64Bits(scaled);
+ var asUInt = (uint)asULong;
+
+ return (int)asUInt;
+ }
+
+ ///
+ /// Entry point for bit.tobit from a Lua script.
+ ///
+ public int BitToBit(nint luaStatePtr)
+ {
+ state.CallFromLuaEntered(luaStatePtr);
+
+ var luaArgCount = state.StackTop;
+ if (luaArgCount < 1 || state.Type(1) != LuaType.Number)
+ {
+ return state.RaiseError("bad argument to tobit");
+ }
+
+ var rawValue = state.CheckNumber(1);
+
+ // Make space on the stack
+ state.Pop(1);
+
+ state.PushNumber(LuaNumberToBitValue(rawValue));
+
+ return 1;
+ }
+
+ ///
+ /// Entry point for bit.tohex from a Lua script.
+ ///
+ public int BitToHex(nint luaStatePtr)
+ {
+ state.CallFromLuaEntered(luaStatePtr);
+
+ var luaArgCount = state.StackTop;
+ if (luaArgCount == 0 || state.Type(1) != LuaType.Number)
+ {
+ return state.RaiseError("bad argument to tohex");
+ }
+
+ var numDigits = 8;
+
+ if (luaArgCount == 2)
+ {
+ if (state.Type(2) != LuaType.Number)
+ {
+ return state.RaiseError("bad argument to tohex");
+ }
+
+ numDigits = (int)state.CheckNumber(2);
+ }
+
+ var value = LuaNumberToBitValue(state.CheckNumber(1));
+
+ ReadOnlySpan hexBytes;
+ if (numDigits == int.MinValue)
+ {
+ numDigits = 8;
+ hexBytes = "0123456789ABCDEF"u8;
+ }
+ else if (numDigits < 0)
+ {
+ numDigits = -numDigits;
+ hexBytes = "0123456789ABCDEF"u8;
+ }
+ else
+ {
+ hexBytes = "0123456789abcdef"u8;
+ }
+
+ if (numDigits > 8)
+ {
+ numDigits = 8;
+ }
+
+ Span buff = stackalloc byte[numDigits];
+ for (var i = buff.Length - 1; i >= 0; i--)
+ {
+ buff[i] = hexBytes[value & 0xF];
+ value >>= 4;
+ }
+
+ // Free up space on stack
+ state.Pop(luaArgCount);
+
+ state.PushBuffer(buff);
+ return 1;
+ }
+
+ ///
+ /// Entry point for bit.bswap from a Lua script.
+ ///
+ public int BitBswap(nint luaStatePtr)
+ {
+ state.CallFromLuaEntered(luaStatePtr);
+
+ var luaArgCount = state.StackTop;
+ if (luaArgCount != 1 || state.Type(1) != LuaType.Number)
+ {
+ return state.RaiseError("bad argument to bswap");
+ }
+
+ var value = LuaNumberToBitValue(state.CheckNumber(1));
+
+ // Free up space on stack
+ state.Pop(1);
+
+ var swapped = BinaryPrimitives.ReverseEndianness(value);
+ state.PushNumber(swapped);
+ return 1;
+ }
+
+ ///
+ /// Entry point for garnet_bitop from a Lua script.
+ ///
+ /// Used to implement bit.bnot, bit.bor, bit.band, etc.
+ ///
+ public int Bitop(nint luaStatePtr)
+ {
+ const int BNot = 0;
+ const int BOr = 1;
+ const int BAnd = 2;
+ const int BXor = 3;
+ const int LShift = 4;
+ const int RShift = 5;
+ const int ARShift = 6;
+ const int Rol = 7;
+ const int Ror = 8;
+
+ state.CallFromLuaEntered(luaStatePtr);
+
+ var luaArgCount = state.StackTop;
+ if (luaArgCount == 0 || state.Type(1) != LuaType.Number)
+ {
+ return state.RaiseError("bitop was not indicated, should never happen");
+ }
+
+ var bitop = (int)state.CheckNumber(1);
+ if (bitop < BNot || bitop > Ror)
+ {
+ return state.RaiseError($"invalid bitop {bitop} was indicated, should never happen");
+ }
+
+ // Handle bnot specially
+ if (bitop == BNot)
+ {
+ if (luaArgCount < 2 || state.Type(2) != LuaType.Number)
+ {
+ return state.RaiseError("bad argument to bnot");
+ }
+
+ var val = LuaNumberToBitValue(state.CheckNumber(2));
+ var res = ~val;
+ state.Pop(2);
+
+ state.PushNumber(res);
+ return 1;
+ }
+
+ var binOpName =
+ bitop switch
+ {
+ BOr => "bor",
+ BAnd => "band",
+ BXor => "bxor",
+ LShift => "lshift",
+ RShift => "rshift",
+ ARShift => "arshift",
+ Rol => "rol",
+ _ => "ror",
+ };
+
+ if (luaArgCount < 2)
+ {
+ return state.RaiseError($"bad argument to {binOpName}");
+ }
+
+ if (bitop is BOr or BAnd or BXor)
+ {
+ var ret =
+ bitop switch
+ {
+ BOr => 0,
+ BXor => 0,
+ _ => -1,
+ };
+
+ for (var argIx = 2; argIx <= luaArgCount; argIx++)
+ {
+ if (state.Type(argIx) != LuaType.Number)
+ {
+ return state.RaiseError($"bad argument to {binOpName}");
+ }
+
+ var nextValue = LuaNumberToBitValue(state.CheckNumber(argIx));
+
+ ret =
+ bitop switch
+ {
+ BOr => ret | nextValue,
+ BXor => ret ^ nextValue,
+ _ => ret & nextValue,
+ };
+ }
+
+ state.Pop(luaArgCount);
+ state.PushNumber(ret);
+
+ return 1;
+ }
+
+ if (luaArgCount < 3 || state.Type(2) != LuaType.Number || state.Type(3) != LuaType.Number)
+ {
+ return state.RaiseError($"bad argument to {binOpName}");
+ }
+
+ var x = LuaNumberToBitValue(state.CheckNumber(2));
+ var n = ((int)state.CheckNumber(3)) & 0b1111;
+
+ var shiftRes =
+ bitop switch
+ {
+ LShift => x << n,
+ RShift => (int)((uint)x >> n),
+ ARShift => x >> n,
+ Rol => (int)BitOperations.RotateLeft((uint)x, n),
+ _ => (int)BitOperations.RotateRight((uint)x, n),
+ };
+
+ state.Pop(luaArgCount);
+ state.PushNumber(shiftRes);
+ return 1;
+ }
+
+ ///
+ /// Entry point for cjson.encode from a Lua script.
+ ///
+ public int CJsonEncode(nint luaStatePtr)
+ {
+ state.CallFromLuaEntered(luaStatePtr);
+
+ var luaArgCount = state.StackTop;
+ if (luaArgCount != 1)
+ {
+ return state.RaiseError("bad argument to encode");
+ }
+
+ Encode(this, 0);
+
+ // Encoding should leave nothing on the stack
+ state.ExpectLuaStackEmpty();
+
+ // Push the encoded string
+ var result = scratchBufferManager.ViewFullArgSlice().ReadOnlySpan;
+ state.PushBuffer(result);
+
+ return 1;
+
+ // Encode the unknown type on the top of the stack
+ static void Encode(LuaRunner self, int depth)
+ {
+ if (depth > 1000)
+ {
+ // Match Redis max decoding depth
+ _ = self.state.RaiseError("Cannot serialise, excessive nesting (1001)");
+ }
+
+ var argType = self.state.Type(self.state.StackTop);
+
+ switch (argType)
+ {
+ case LuaType.Boolean:
+ EncodeBool(self);
+ break;
+ case LuaType.Nil:
+ EncodeNull(self);
+ break;
+ case LuaType.Number:
+ EncodeNumber(self);
+ break;
+ case LuaType.String:
+ EncodeString(self);
+ break;
+ case LuaType.Table:
+ EncodeTable(self, depth);
+ break;
+ case LuaType.Function:
+ case LuaType.LightUserData:
+ case LuaType.None:
+ case LuaType.Thread:
+ case LuaType.UserData:
+ default:
+ _ = self.state.RaiseError($"Cannot serialise {argType} to JSON");
+ break;
+ }
+ }
+
+ // Encode the boolean on the top of the stack and remove it
+ static void EncodeBool(LuaRunner self)
+ {
+ Debug.Assert(self.state.Type(self.state.StackTop) == LuaType.Boolean, "Expected boolean on top of stack");
+
+ var data = self.state.ToBoolean(self.state.StackTop) ? "true"u8 : "false"u8;
+
+ var into = self.scratchBufferManager.ViewRemainingArgSlice(data.Length).Span;
+ data.CopyTo(into);
+ self.scratchBufferManager.MoveOffset(data.Length);
+
+ self.state.Pop(1);
+ }
+
+ // Encode the nil on the top of the stack and remove it
+ static void EncodeNull(LuaRunner self)
+ {
+ Debug.Assert(self.state.Type(self.state.StackTop) == LuaType.Nil, "Expected nil on top of stack");
+
+ var into = self.scratchBufferManager.ViewRemainingArgSlice(4).Span;
+ "null"u8.CopyTo(into);
+ self.scratchBufferManager.MoveOffset(4);
+
+ self.state.Pop(1);
+ }
+
+ // Encode the number on the top of the stack and remove it
+ static void EncodeNumber(LuaRunner self)
+ {
+ Debug.Assert(self.state.Type(self.state.StackTop) == LuaType.Number, "Expected number on top of stack");
+
+ var number = self.state.CheckNumber(self.state.StackTop);
+
+ Span space = stackalloc byte[64];
+
+ if (!number.TryFormat(space, out var written, "G", CultureInfo.InvariantCulture))
+ {
+ _ = self.state.RaiseError("Unable to format number");
+ }
+
+ var into = self.scratchBufferManager.ViewRemainingArgSlice(written).Span;
+ space[..written].CopyTo(into);
+ self.scratchBufferManager.MoveOffset(written);
+
+ self.state.Pop(1);
+ }
+
+ // Encode the string on the top of the stack and remove it
+ static void EncodeString(LuaRunner self)
+ {
+ Debug.Assert(self.state.Type(self.state.StackTop) == LuaType.String, "Expected string on top of stack");
+
+ _ = self.state.CheckBuffer(self.state.StackTop, out var buff);
+
+ self.scratchBufferManager.ViewRemainingArgSlice(1).Span[0] = (byte)'"';
+ self.scratchBufferManager.MoveOffset(1);
+
+ var escapeIx = buff.IndexOfAny((byte)'"', (byte)'\\');
+ while (escapeIx != -1)
+ {
+ var into = self.scratchBufferManager.ViewRemainingArgSlice(escapeIx + 2).Span;
+ buff[..escapeIx].CopyTo(into);
+
+ into[escapeIx] = (byte)'\\';
+
+ var toEscape = buff[escapeIx];
+ if (toEscape == (byte)'"')
+ {
+ into[escapeIx + 1] = (byte)'"';
+ }
+ else
+ {
+ into[escapeIx + 1] = (byte)'\\';
+ }
+
+ self.scratchBufferManager.MoveOffset(escapeIx + 2);
+
+ buff = buff[(escapeIx + 1)..];
+ escapeIx = buff.IndexOfAny((byte)'"', (byte)'\\');
+ }
+
+ var tailInto = self.scratchBufferManager.ViewRemainingArgSlice(buff.Length + 1).Span;
+ buff.CopyTo(tailInto);
+ tailInto[buff.Length] = (byte)'"';
+ self.scratchBufferManager.MoveOffset(buff.Length + 1);
+
+ self.state.Pop(1);
+ }
+
+ // Encode the table on the top of the stack and remove it
+ static void EncodeTable(LuaRunner self, int depth)
+ {
+ Debug.Assert(self.state.Type(self.state.StackTop) == LuaType.Table, "Expected table on top of stack");
+
+ // Space for key & value
+ self.state.ForceMinimumStackCapacity(2);
+
+ var tableIndex = self.state.StackTop;
+
+ var isArray = false;
+ var arrayLength = 0;
+
+ self.state.PushNil();
+ while (self.state.Next(tableIndex) != 0)
+ {
+ // Pop value
+ self.state.Pop(1);
+
+ double keyAsNumber;
+ if (self.state.Type(tableIndex + 1) == LuaType.Number && (keyAsNumber = self.state.CheckNumber(tableIndex + 1)) >= 1 && keyAsNumber == (int)keyAsNumber)
+ {
+ if (keyAsNumber > arrayLength)
+ {
+ // Need at least one integer key >= 1 to consider this an array
+ isArray = true;
+ arrayLength = (int)keyAsNumber;
+ }
+ }
+ else
+ {
+ // Non-integer key, or integer <= 0, so it's not an array
+ isArray = false;
+
+ // Remove key
+ self.state.Pop(1);
+
+ break;
+ }
+ }
+
+ if (isArray)
+ {
+ EncodeArray(self, arrayLength, depth);
+ }
+ else
+ {
+ EncodeObject(self, depth);
+ }
+ }
+
+ // Encode the table on the top of the stack as an array and remove it
+ static void EncodeArray(LuaRunner self, int length, int depth)
+ {
+ Debug.Assert(self.state.Type(self.state.StackTop) == LuaType.Table, "Expected table on top of stack");
+
+ // Space for value
+ self.state.ForceMinimumStackCapacity(1);
+
+ var tableIndex = self.state.StackTop;
+
+ self.scratchBufferManager.ViewRemainingArgSlice(1).Span[0] = (byte)'[';
+ self.scratchBufferManager.MoveOffset(1);
+
+ for (var ix = 1; ix <= length; ix++)
+ {
+ if (ix != 1)
+ {
+ self.scratchBufferManager.ViewRemainingArgSlice(1).Span[0] = (byte)',';
+ self.scratchBufferManager.MoveOffset(1);
+ }
+
+ _ = self.state.RawGetInteger(null, tableIndex, ix);
+ Encode(self, depth + 1);
+ }
+
+ self.scratchBufferManager.ViewRemainingArgSlice(1).Span[0] = (byte)']';
+ self.scratchBufferManager.MoveOffset(1);
+
+ // Remove table
+ self.state.Pop(1);
+ }
+
+ // Encode the table on the top of the stack as an object and remove it
+ static void EncodeObject(LuaRunner self, int depth)
+ {
+ Debug.Assert(self.state.Type(self.state.StackTop) == LuaType.Table, "Expected table on top of stack");
+
+ // Space for key and value and a copy of key
+ self.state.ForceMinimumStackCapacity(3);
+
+ var tableIndex = self.state.StackTop;
+
+ self.scratchBufferManager.ViewRemainingArgSlice(1).Span[0] = (byte)'{';
+ self.scratchBufferManager.MoveOffset(1);
+
+ var firstValue = true;
+
+ self.state.PushNil();
+ while (self.state.Next(tableIndex) != 0)
+ {
+ LuaType keyType;
+ if ((keyType = self.state.Type(tableIndex + 1)) is not (LuaType.String or LuaType.Number))
+ {
+ // Ignore non-string-ify-abile keys
+
+ // Remove value
+ self.state.Pop(1);
+
+ continue;
+ }
+
+ if (!firstValue)
+ {
+ self.scratchBufferManager.ViewRemainingArgSlice(1).Span[0] = (byte)',';
+ self.scratchBufferManager.MoveOffset(1);
+ }
+
+ // Copy key to top of stack
+ self.state.PushValue(tableIndex + 1);
+
+ // Force the _copy_ of the key to be a string
+ // if it is not already one.
+ //
+ // We don't modify the original key value, so we
+ // can continue using it with Next(...)
+ if (keyType == LuaType.Number)
+ {
+ _ = self.state.CheckBuffer(tableIndex + 3, out _);
+ }
+
+ // Encode key
+ Encode(self, depth + 1);
+
+ self.scratchBufferManager.ViewRemainingArgSlice(1).Span[0] = (byte)':';
+ self.scratchBufferManager.MoveOffset(1);
+
+ // Encode value
+ Encode(self, depth + 1);
+
+ firstValue = false;
+ }
+
+ self.scratchBufferManager.ViewRemainingArgSlice(1).Span[0] = (byte)'}';
+ self.scratchBufferManager.MoveOffset(1);
+
+ // Remove table
+ self.state.Pop(1);
+ }
+ }
+
+ ///
+ /// Entry point for cjson.decode from a Lua script.
+ ///
+ public int CJsonDecode(nint luaStatePtr)
+ {
+ state.CallFromLuaEntered(luaStatePtr);
+
+ var luaArgCount = state.StackTop;
+ if (luaArgCount != 1)
+ {
+ return state.RaiseError("bad argument to decode");
+ }
+
+ var argType = state.Type(1);
+ if (argType == LuaType.Number)
+ {
+ // We'd coerce this to a string, and then decode it, so just pass it back as is
+ //
+ // There are some cases where this wouldn't work, potentially, but they are super implementation
+ // specific so we can just pretend we made them work
+ return 1;
+ }
+
+ if (argType != LuaType.String)
+ {
+ return state.RaiseError("bad argument to decode");
+ }
+
+ _ = state.CheckBuffer(1, out var buff);
+
+ try
+ {
+ var parsed = JsonNode.Parse(buff, documentOptions: new JsonDocumentOptions { MaxDepth = 1000 });
+ Decode(this, parsed);
+
+ return 1;
+ }
+ catch (Exception e)
+ {
+ if (e.Message.Contains("maximum configured depth of 1000"))
+ {
+ // Maximum depth exceeded, munge to a compatible Redis error
+ return state.RaiseError("Found too many nested data structures (1001)");
+ }
+
+ // Invalid token is implied (and matches Redis error replies)
+ //
+ // Additinal error details can be gleaned from messages
+ return state.RaiseError($"Expected value but found invalid token. Inner Message = {e.Message}");
+ }
+
+ // Convert the JsonNode into a Lua value on the stack
+ static void Decode(LuaRunner self, JsonNode node)
+ {
+ if (node is JsonValue v)
+ {
+ DecodeValue(self, v);
+ }
+ else if (node is JsonArray a)
+ {
+ DecodeArray(self, a);
+ }
+ else if (node is JsonObject o)
+ {
+ DecodeObject(self, o);
+ }
+ else
+ {
+ _ = self.state.RaiseError($"Unexpected json node type: {node.GetType().Name}");
+ }
+ }
+
+ // Convert the JsonValue int to a Lua string, nil, or number on the stack
+ static void DecodeValue(LuaRunner self, JsonValue value)
+ {
+ // Reserve space for the value
+ self.state.ForceMinimumStackCapacity(1);
+
+ switch (value.GetValueKind())
+ {
+ case JsonValueKind.Null: self.state.PushNil(); break;
+ case JsonValueKind.True: self.state.PushBoolean(true); break;
+ case JsonValueKind.False: self.state.PushBoolean(false); break;
+ case JsonValueKind.Number: self.state.PushNumber(value.GetValue()); break;
+ case JsonValueKind.String:
+ var str = value.GetValue();
+
+ self.scratchBufferManager.Reset();
+ var buf = self.scratchBufferManager.UTF8EncodeString(str);
+
+ self.state.PushBuffer(buf);
+ break;
+ case JsonValueKind.Undefined:
+ case JsonValueKind.Object:
+ case JsonValueKind.Array:
+ default:
+ _ = self.state.RaiseError($"Unexpected json value kind: {value.GetValueKind()}");
+ break;
+ }
+ }
+
+ // Convert the JsonArray into a Lua table on the stack
+ static void DecodeArray(LuaRunner self, JsonArray arr)
+ {
+ // Reserve space for the table
+ self.state.ForceMinimumStackCapacity(1);
+
+ self.state.CreateTable(arr.Count, 0);
+
+ var tableIndex = self.state.StackTop;
+
+ var storeAtIx = 1;
+ foreach (var item in arr)
+ {
+ // Places item on the stack
+ Decode(self, item);
+
+ // Save into the table
+ self.state.RawSetInteger(tableIndex, storeAtIx);
+ storeAtIx++;
+ }
+ }
+
+ // Convert the JsonObject into a Lua table on the stack
+ static void DecodeObject(LuaRunner self, JsonObject obj)
+ {
+ // Reserve space for table and key
+ self.state.ForceMinimumStackCapacity(2);
+
+ self.state.CreateTable(0, obj.Count);
+
+ var tableIndex = self.state.StackTop;
+
+ foreach (var (key, value) in obj)
+ {
+ // Decode key to string
+ self.scratchBufferManager.Reset();
+ var buf = self.scratchBufferManager.UTF8EncodeString(key);
+ self.state.PushBuffer(buf);
+
+ // Decode value
+ Decode(self, value);
+
+ self.state.RawSet(tableIndex);
+ }
+ }
+ }
+
+ ///
+ /// Entry point for cmsgpack.pack from a Lua script.
+ ///
+ public int CMsgPackPack(nint luaStatePtr)
+ {
+ state.CallFromLuaEntered(luaStatePtr);
+
+ var numLuaArgs = state.StackTop;
+
+ if (numLuaArgs == 0)
+ {
+ return state.RaiseError("bad argument to pack");
+ }
+
+ // Redis concatenates all the message packs together if there are multiple
+ //
+ // Somewhat odd, but we match that behavior
+
+ scratchBufferManager.Reset();
+
+ for (var argIx = 1; argIx <= numLuaArgs; argIx++)
+ {
+ // Because each encode removes the encoded value
+ // we always encode the argument at position 1
+ Encode(this, 1, 0);
+ }
+
+ // After all encoding, stack should be empty
+ state.ExpectLuaStackEmpty();
+
+ var ret = scratchBufferManager.ViewFullArgSlice().ReadOnlySpan;
+ state.PushBuffer(ret);
+
+ return 1;
+
+ // Encode a single item at the top of the stack, and remove it
+ static void Encode(LuaRunner self, int stackIndex, int depth)
+ {
+ var type = self.state.Type(stackIndex);
+ switch (type)
+ {
+ case LuaType.Boolean: EncodeBool(self, stackIndex); break;
+ case LuaType.Number: EncodeNumber(self, stackIndex); break;
+ case LuaType.String: EncodeBytes(self, stackIndex); break;
+ case LuaType.Table:
+
+ if (depth == 16)
+ {
+ // Redis treats a too deeply nested table as a null
+ //
+ // This is weird, but we match it
+ self.scratchBufferManager.ViewRemainingArgSlice(1).Span[0] = 0xC0;
+ self.scratchBufferManager.MoveOffset(1);
+
+ self.state.Remove(stackIndex);
+
+ return;
+ }
+
+ EncodeTable(self, stackIndex, depth);
+ break;
+
+ // Everything else maps to null, NOT an error
+ case LuaType.Function:
+ case LuaType.LightUserData:
+ case LuaType.Nil:
+ case LuaType.None:
+ case LuaType.Thread:
+ case LuaType.UserData:
+ default: EncodeNull(self, stackIndex); break;
+ }
+ }
+
+ // Encode a null-ish value at stackIndex, and remove it
+ static void EncodeNull(LuaRunner self, int stackIndex)
+ {
+ Debug.Assert(self.state.Type(stackIndex) is not (LuaType.Boolean or LuaType.Number or LuaType.String or LuaType.Table), "Expected null-ish type");
+
+ self.scratchBufferManager.ViewRemainingArgSlice(1).Span[0] = 0xC0;
+ self.scratchBufferManager.MoveOffset(1);
+
+ self.state.Remove(stackIndex);
+ }
+
+ // Encode a boolean at stackIndex, and remove it
+ static void EncodeBool(LuaRunner self, int stackIndex)
+ {
+ Debug.Assert(self.state.Type(stackIndex) == LuaType.Boolean, "Expected boolean");
+
+ var value = (byte)(self.state.ToBoolean(stackIndex) ? 0xC3 : 0xC2);
+
+ self.scratchBufferManager.ViewRemainingArgSlice(1).Span[0] = value;
+ self.scratchBufferManager.MoveOffset(1);
+
+ self.state.Remove(stackIndex);
+ }
+
+ // Encode a number at stackIndex, and remove it
+ static void EncodeNumber(LuaRunner self, int stackIndex)
+ {
+ Debug.Assert(self.state.Type(stackIndex) == LuaType.Number, "Expected number");
+
+ var numRaw = self.state.CheckNumber(stackIndex);
+ var isInt = numRaw == (long)numRaw;
+
+ if (isInt)
+ {
+ EncodeInteger(self, (long)numRaw);
+ }
+ else
+ {
+ EncodeFloatingPoint(self, numRaw);
+ }
+
+ self.state.Remove(stackIndex);
+ }
+
+ // Encode an integer
+ static void EncodeInteger(LuaRunner self, long value)
+ {
+ // positive 7-bit fixint
+ if ((byte)(value & 0b0111_1111) == value)
+ {
+ self.scratchBufferManager.ViewRemainingArgSlice(1).Span[0] = (byte)value;
+ self.scratchBufferManager.MoveOffset(1);
+
+ return;
+ }
+
+ // negative 5-bit fixint
+ if ((sbyte)(value | 0b1110_0000) == value)
+ {
+ self.scratchBufferManager.ViewRemainingArgSlice(1).Span[0] = (byte)value;
+ self.scratchBufferManager.MoveOffset(1);
+ return;
+ }
+
+ // 8-bit int
+ if (value is >= sbyte.MinValue and <= sbyte.MaxValue)
+ {
+ var into = self.scratchBufferManager.ViewRemainingArgSlice(2).Span;
+
+ into[0] = 0xD0;
+ into[1] = (byte)value;
+ self.scratchBufferManager.MoveOffset(2);
+ return;
+ }
+
+ // 8-bit uint
+ if (value is >= byte.MinValue and <= byte.MaxValue)
+ {
+ var into = self.scratchBufferManager.ViewRemainingArgSlice(2).Span;
+
+ into[0] = 0xCC;
+ into[1] = (byte)value;
+ self.scratchBufferManager.MoveOffset(2);
+ return;
+ }
+
+ // 16-bit int
+ if (value is >= short.MinValue and <= short.MaxValue)
+ {
+ var into = self.scratchBufferManager.ViewRemainingArgSlice(3).Span;
+
+ into[0] = 0xD1;
+ BinaryPrimitives.WriteInt16BigEndian(into[1..], (short)value);
+ self.scratchBufferManager.MoveOffset(3);
+ return;
+ }
+
+ // 16-bit uint
+ if (value is >= ushort.MinValue and <= ushort.MaxValue)
+ {
+ var into = self.scratchBufferManager.ViewRemainingArgSlice(3).Span;
+
+ into[0] = 0xCD;
+ BinaryPrimitives.WriteUInt16BigEndian(into[1..], (ushort)value);
+ self.scratchBufferManager.MoveOffset(3);
+ return;
+ }
+
+ // 32-bit int
+ if (value is >= int.MinValue and <= int.MaxValue)
+ {
+ var into = self.scratchBufferManager.ViewRemainingArgSlice(5).Span;
+
+ into[0] = 0xD2;
+ BinaryPrimitives.WriteInt32BigEndian(into[1..], (int)value);
+ self.scratchBufferManager.MoveOffset(5);
+ return;
+ }
+
+ // 32-bit uint
+ if (value is >= uint.MinValue and <= uint.MaxValue)
+ {
+ var into = self.scratchBufferManager.ViewRemainingArgSlice(5).Span;
+
+ into[0] = 0xCE;
+ BinaryPrimitives.WriteUInt32BigEndian(into[1..], (uint)value);
+ self.scratchBufferManager.MoveOffset(5);
+ return;
+ }
+
+ // 64-bit uint
+ if (value > uint.MaxValue)
+ {
+ var into = self.scratchBufferManager.ViewRemainingArgSlice(9).Span;
+
+ into[0] = 0xCF;
+ BinaryPrimitives.WriteUInt64BigEndian(into[1..], (ulong)value);
+ self.scratchBufferManager.MoveOffset(9);
+ return;
+ }
+
+ // 64-bit int
+ {
+ var into = self.scratchBufferManager.ViewRemainingArgSlice(9).Span;
+
+ into[0] = 0xD3;
+ BinaryPrimitives.WriteInt64BigEndian(into[1..], value);
+ self.scratchBufferManager.MoveOffset(9);
+ }
+ }
+
+ // Encode a floating point value
+ static void EncodeFloatingPoint(LuaRunner self, double value)
+ {
+ // While Redis has code that attempts to pack doubles into floats
+ // it doesn't appear to do anything, so we just always write a double
+
+ var into = self.scratchBufferManager.ViewRemainingArgSlice(9).Span;
+
+ into[0] = 0xCB;
+ BinaryPrimitives.WriteDoubleBigEndian(into[1..], value);
+ self.scratchBufferManager.MoveOffset(9);
+ }
+
+ // Encodes a string as at stackIndex, and remove it
+ static void EncodeBytes(LuaRunner self, int stackIndex)
+ {
+ Debug.Assert(self.state.Type(stackIndex) == LuaType.String, "Expected string");
+
+ _ = self.state.CheckBuffer(stackIndex, out var data);
+
+ if (data.Length < 32)
+ {
+ var into = self.scratchBufferManager.ViewRemainingArgSlice(1 + data.Length).Span;
+
+ into[0] = (byte)(0xA0 | data.Length);
+ data.CopyTo(into[1..]);
+ self.scratchBufferManager.MoveOffset(1 + data.Length);
+ }
+ else if (data.Length <= byte.MaxValue)
+ {
+ var into = self.scratchBufferManager.ViewRemainingArgSlice(2 + data.Length).Span;
+
+ into[0] = 0xD9;
+ into[1] = (byte)data.Length;
+ data.CopyTo(into[2..]);
+ self.scratchBufferManager.MoveOffset(2 + data.Length);
+ }
+ else if (data.Length <= ushort.MaxValue)
+ {
+ var into = self.scratchBufferManager.ViewRemainingArgSlice(3 + data.Length).Span;
+
+ into[0] = 0xDA;
+ BinaryPrimitives.WriteUInt16BigEndian(into[1..], (ushort)data.Length);
+ data.CopyTo(into[3..]);
+ self.scratchBufferManager.MoveOffset(3 + data.Length);
+ }
+ else
+ {
+ var into = self.scratchBufferManager.ViewRemainingArgSlice(5 + data.Length).Span;
+
+ into[0] = 0xDB;
+ BinaryPrimitives.WriteUInt32BigEndian(into[1..], (uint)data.Length);
+ data.CopyTo(into[5..]);
+ self.scratchBufferManager.MoveOffset(5 + data.Length);
+ }
+
+ self.state.Remove(stackIndex);
+ }
+
+ // Encode a table at stackIndex, and remove it
+ static void EncodeTable(LuaRunner self, int stackIndex, int depth)
+ {
+ Debug.Assert(self.state.Type(stackIndex) == LuaType.Table, "Expected table");
+
+ // Space for key and value
+ self.state.ForceMinimumStackCapacity(2);
+
+ var tableIndex = stackIndex;
+
+ // A zero-length table is serialized as an array
+ var isArray = true;
+ var count = 0;
+ var max = 0;
+
+ var keyIndex = self.state.StackTop + 1;
+
+ // Measure the table and figure out if we're creating a map or an array
+ self.state.PushNil();
+ while (self.state.Next(tableIndex) != 0)
+ {
+ count++;
+
+ // Remove value
+ self.state.Pop(1);
+
+ double keyAsNum;
+ if (self.state.Type(keyIndex) != LuaType.Number || (keyAsNum = self.state.CheckNumber(keyIndex)) <= 0 || keyAsNum != (int)keyAsNum)
+ {
+ isArray = false;
+ }
+ else
+ {
+ if (keyAsNum > max)
+ {
+ max = (int)keyAsNum;
+ }
+ }
+ }
+
+ if (isArray && count == max)
+ {
+ EncodeArray(self, stackIndex, depth, count);
+ }
+ else
+ {
+ EncodeMap(self, stackIndex, depth, count);
+ }
+ }
+
+ // Encode a table at stackIndex into an array, and remove it
+ static void EncodeArray(LuaRunner self, int stackIndex, int depth, int count)
+ {
+ Debug.Assert(self.state.Type(stackIndex) == LuaType.Table, "Expected table");
+ Debug.Assert(count >= 0, "Array should have positive length");
+
+ // Reserve space for value
+ self.state.ForceMinimumStackCapacity(1);
+
+ var tableIndex = stackIndex;
+ var valueIndex = tableIndex + 1;
+
+ // Encode length
+ if (count <= 15)
+ {
+ var into = self.scratchBufferManager.ViewRemainingArgSlice(1).Span;
+ into[0] = (byte)(0b1001_0000 | count);
+ self.scratchBufferManager.MoveOffset(1);
+ }
+ else if (count <= ushort.MaxValue)
+ {
+ var into = self.scratchBufferManager.ViewRemainingArgSlice(3).Span;
+
+ into[0] = 0xDC;
+ BinaryPrimitives.WriteUInt16BigEndian(into[1..], (ushort)count);
+ self.scratchBufferManager.MoveOffset(3);
+ }
+ else
+ {
+ var into = self.scratchBufferManager.ViewRemainingArgSlice(5).Span;
+
+ into[0] = 0xDD;
+ BinaryPrimitives.WriteUInt32BigEndian(into[1..], (uint)count);
+ self.scratchBufferManager.MoveOffset(5);
+ }
+
+ // Write each element out
+ for (var ix = 1; ix <= count; ix++)
+ {
+ _ = self.state.RawGetInteger(null, tableIndex, ix);
+ Encode(self, valueIndex, depth + 1);
+ }
+
+ self.state.Remove(tableIndex);
+ }
+
+ // Encode a table at stackIndex into a map, and remove it
+ static void EncodeMap(LuaRunner self, int stackIndex, int depth, int count)
+ {
+ Debug.Assert(self.state.Type(stackIndex) == LuaType.Table, "Expected table");
+ Debug.Assert(count >= 0, "Map should have positive length");
+
+ // Reserve space for key, value, and copy of key
+ self.state.ForceMinimumStackCapacity(2);
+
+ var tableIndex = stackIndex;
+ var keyIndex = self.state.StackTop + 1;
+ var valueIndex = keyIndex + 1;
+ var keyCopyIndex = valueIndex + 1;
+
+ // Encode length
+ if (count <= 15)
+ {
+ var into = self.scratchBufferManager.ViewRemainingArgSlice(1).Span;
+
+ into[0] = (byte)(0b1000_0000 | count);
+ self.scratchBufferManager.MoveOffset(1);
+ }
+ else if (count <= ushort.MaxValue)
+ {
+ var into = self.scratchBufferManager.ViewRemainingArgSlice(3).Span;
+
+ into[0] = 0xDE;
+ BinaryPrimitives.WriteUInt16BigEndian(into[1..], (ushort)count);
+ self.scratchBufferManager.MoveOffset(3);
+ }
+ else
+ {
+ var into = self.scratchBufferManager.ViewRemainingArgSlice(5).Span;
+
+ into[0] = 0xDF;
+ BinaryPrimitives.WriteUInt32BigEndian(into[1..], (uint)count);
+ self.scratchBufferManager.MoveOffset(5);
+ }
+
+ self.state.PushNil();
+ while (self.state.Next(tableIndex) != 0)
+ {
+ // Make a copy of the key
+ self.state.PushValue(keyIndex);
+
+ // Write the key
+ Encode(self, keyCopyIndex, depth + 1);
+
+ // Write the value
+ Encode(self, valueIndex, depth + 1);
+ }
+
+ self.state.Remove(tableIndex);
+ }
+ }
+
+ ///
+ /// Entry point for cmsgpack.unpack from a Lua script.
+ ///
+ public int CMsgPackUnpack(nint luaStatePtr)
+ {
+ state.CallFromLuaEntered(luaStatePtr);
+
+ var numLuaArgs = state.StackTop;
+
+ if (numLuaArgs == 0)
+ {
+ return state.RaiseError("bad argument to unpack");
+ }
+
+ _ = state.CheckBuffer(1, out var data);
+
+ var decodedCount = 0;
+ while (!data.IsEmpty)
+ {
+ // Reserve space for the result
+ state.ForceMinimumStackCapacity(1);
+
+ try
+ {
+ Decode(ref data, ref state);
+ decodedCount++;
+ }
+ catch (Exception e)
+ {
+ // Best effort at matching Redis behavior
+ return state.RaiseError($"Missing bytes in input. {e.Message}");
+ }
+ }
+
+ return decodedCount;
+
+ // Decode a msg pack
+ static void Decode(ref ReadOnlySpan data, ref LuaStateWrapper state)
+ {
+ var sigil = data[0];
+ data = data[1..];
+
+ switch (sigil)
+ {
+ case 0xC0: DecodeNull(ref data, ref state); return;
+ case 0xC2: DecodeBoolean(false, ref data, ref state); return;
+ case 0xC3: DecodeBoolean(true, ref data, ref state); return;
+ // 7-bit positive integers handled below
+ // 5-bit negative integers handled below
+ case 0xCC: DecodeUInt8(ref data, ref state); return;
+ case 0xCD: DecodeUInt16(ref data, ref state); return;
+ case 0xCE: DecodeUInt32(ref data, ref state); return;
+ case 0xCF: DecodeUInt64(ref data, ref state); return;
+ case 0xD0: DecodeInt8(ref data, ref state); return;
+ case 0xD1: DecodeInt16(ref data, ref state); return;
+ case 0xD2: DecodeInt32(ref data, ref state); return;
+ case 0xD3: DecodeInt64(ref data, ref state); return;
+ case 0xCA: DecodeSingle(ref data, ref state); return;
+ case 0xCB: DecodeDouble(ref data, ref state); return;
+ // <= 31 byte strings handled below
+ case 0xD9: DecodeSmallString(ref data, ref state); return;
+ case 0xDA: DecodeMidString(ref data, ref state); return;
+ case 0xDB: DecodeLargeString(ref data, ref state); return;
+ // We treat bins as strings
+ case 0xC4: goto case 0xD9;
+ case 0xC5: goto case 0xDA;
+ case 0xC6: goto case 0xDB;
+ // <= 15 element arrays are handled below
+ case 0xDC: DecodeMidArray(ref data, ref state); return;
+ case 0xDD: DecodeLargeArray(ref data, ref state); return;
+ // <= 15 pair maps are handled below
+ case 0xDE: DecodeMidMap(ref data, ref state); return;
+ case 0xDF: DecodeLargeMap(ref data, ref state); return;
+
+ default:
+ if ((sigil & 0b1000_0000) == 0)
+ {
+ DecodeTinyUInt(sigil, ref state);
+ return;
+ }
+ else if ((sigil & 0b1110_0000) == 0b1110_0000)
+ {
+ DecodeTinyInt(sigil, ref state);
+ return;
+ }
+ else if ((sigil & 0b1110_0000) == 0b1010_0000)
+ {
+ DecodeTinyString(sigil, ref data, ref state);
+ return;
+ }
+ else if ((sigil & 0b1111_0000) == 0b1001_0000)
+ {
+ DecodeSmallArray(sigil, ref data, ref state);
+ return;
+ }
+ else if ((sigil & 0b1111_0000) == 0b1000_0000)
+ {
+ DecodeSmallMap(sigil, ref data, ref state);
+ return;
+ }
+
+ _ = state.RaiseError($"Unexpected MsgPack sigil {sigil}/x{sigil:X2}/b{sigil:B8}");
+ return;
+ }
+ }
+
+ // Decode a null push it to the stack
+ static void DecodeNull(ref ReadOnlySpan data, ref LuaStateWrapper state)
+ {
+ state.PushNil();
+ }
+
+ // Decode a boolean and push it to the stack
+ static void DecodeBoolean(bool b, ref ReadOnlySpan data, ref LuaStateWrapper state)
+ {
+ state.PushBoolean(b);
+ }
+
+ // Decode a byte, moving past it in data and pushing it to the stack
+ static void DecodeUInt8(ref ReadOnlySpan data, ref LuaStateWrapper state)
+ {
+ state.PushNumber(data[0]);
+ data = data[1..];
+ }
+
+ // Decode a positive 7-bit value, pushing it to the stack
+ static void DecodeTinyUInt(byte sigil, ref LuaStateWrapper state)
+ {
+ state.PushNumber(sigil);
+ }
+
+ // Decode a ushort, moving past it in data and pushing it to the stack
+ static void DecodeUInt16(ref ReadOnlySpan data, ref LuaStateWrapper state)
+ {
+ state.PushNumber(BinaryPrimitives.ReadUInt16BigEndian(data));
+ data = data[2..];
+ }
+
+ // Decode a uint, moving past it in data and pushing it to the stack
+ static void DecodeUInt32(ref ReadOnlySpan data, ref LuaStateWrapper state)
+ {
+ state.PushNumber(BinaryPrimitives.ReadUInt32BigEndian(data));
+ data = data[4..];
+ }
+
+ // Decode a ulong, moving past it in data and pushing it to the stack
+ static void DecodeUInt64(ref ReadOnlySpan data, ref LuaStateWrapper state)
+ {
+ state.PushNumber(BinaryPrimitives.ReadUInt64BigEndian(data));
+ data = data[8..];
+ }
+
+ // Decode a negative 5-bit value, pushing it to the stack
+ static void DecodeTinyInt(byte sigil, ref LuaStateWrapper state)
+ {
+ var signExtended = (int)(0xFFFF_FF00 | sigil);
+ state.PushNumber(signExtended);
+ }
- var argCount = state.StackTop;
- if (argCount != 1)
+ // Decode a sbyte, moving past it in data and pushing it to the stack
+ static void DecodeInt8(ref ReadOnlySpan data, ref LuaStateWrapper state)
{
- state.PushConstantString(constStrs.ErrWrongNumberOfArgs);
- return state.RaiseErrorFromStack();
+ state.PushNumber((sbyte)data[0]);
+ data = data[1..];
}
- if (!state.CheckBuffer(1, out var bytes))
+ // Decode a short, moving past it in data and pushing it to the stack
+ static void DecodeInt16(ref ReadOnlySpan data, ref LuaStateWrapper state)
{
- bytes = default;
+ state.PushNumber(BinaryPrimitives.ReadInt16BigEndian(data));
+ data = data[2..];
}
- Span hashBytes = stackalloc byte[SessionScriptCache.SHA1Len / 2];
- Span hexRes = stackalloc byte[SessionScriptCache.SHA1Len];
+ // Decode a int, moving past it in data and pushing it to the stack
+ static void DecodeInt32(ref ReadOnlySpan data, ref LuaStateWrapper state)
+ {
+ state.PushNumber(BinaryPrimitives.ReadInt32BigEndian(data));
+ data = data[4..];
+ }
- SessionScriptCache.GetScriptDigest(bytes, hashBytes, hexRes);
+ // Decode a long, moving past it in data and pushing it to the stack
+ static void DecodeInt64(ref ReadOnlySpan data, ref LuaStateWrapper state)
+ {
+ state.PushNumber(BinaryPrimitives.ReadInt64BigEndian(data));
+ data = data[8..];
+ }
- state.PushBuffer(hexRes);
- return 1;
- }
+ // Decode a float, moving past it in data and pushing it to the stack
+ static void DecodeSingle(ref ReadOnlySpan data, ref LuaStateWrapper state)
+ {
+ state.PushNumber(BinaryPrimitives.ReadSingleBigEndian(data));
+ data = data[4..];
+ }
- ///
- /// Entry point for redis.log(...) from a Lua script.
- ///
- public int Log(nint luaStatePtr)
- {
- state.CallFromLuaEntered(luaStatePtr);
+ // Decode a double, moving past it in data and pushing it to the stack
+ static void DecodeDouble(ref ReadOnlySpan data, ref LuaStateWrapper state)
+ {
+ state.PushNumber(BinaryPrimitives.ReadDoubleBigEndian(data));
+ data = data[8..];
+ }
- var argCount = state.StackTop;
- if (argCount < 2)
+ // Decode a string size <= 31, moving past it in data and pushing it to the stack
+ static void DecodeTinyString(byte sigil, ref ReadOnlySpan data, ref LuaStateWrapper state)
{
- return LuaStaticError(constStrs.ErrRedisLogRequired);
+ var len = sigil & 0b0001_1111;
+ var str = data[..len];
+
+ state.PushBuffer(str);
+ data = data[len..];
}
- if (state.Type(1) != LuaType.Number)
+ // Decode a string size <= 255, moving past it in data and pushing it to the stack
+ static void DecodeSmallString(ref ReadOnlySpan data, ref LuaStateWrapper state)
{
- return LuaStaticError(constStrs.ErrFirstArgMustBeNumber);
+ var len = data[0];
+ data = data[1..];
+
+ var str = data[..len];
+
+ state.PushBuffer(str);
+
+ data = data[str.Length..];
}
- var rawLevel = state.CheckNumber(1);
- if (rawLevel is not (0 or 1 or 2 or 3))
+ // Decode a string size <= 65,535, moving past it in data and pushing it to the stack
+ static void DecodeMidString(ref ReadOnlySpan data, ref LuaStateWrapper state)
{
- return LuaStaticError(constStrs.ErrInvalidDebugLevel);
+ var len = BinaryPrimitives.ReadUInt16BigEndian(data);
+ data = data[2..];
+
+ var str = data[..(int)len];
+
+ state.PushBuffer(str);
+
+ data = data[str.Length..];
}
- if (logMode == LuaLoggingMode.Disable)
+ // Decode a string size <= 4,294,967,295, moving past it in data and pushing it to the stack
+ static void DecodeLargeString(ref ReadOnlySpan data, ref LuaStateWrapper state)
{
- return LuaStaticError(constStrs.ErrLoggingDisabled);
+ var len = BinaryPrimitives.ReadUInt32BigEndian(data);
+ data = data[4..];
+
+ if ((int)len < 0)
+ {
+ _ = state.RaiseError($"String length is too long: {len}");
+ return;
+ }
+
+ var str = data[..(int)len];
+
+ state.PushBuffer(str);
+
+ data = data[str.Length..];
}
- // When shipped as a service, allowing arbitrary writes to logs is dangerous
- // so we support disabling it (while not breaking existing scripts)
- if (logMode == LuaLoggingMode.Silent)
+ // Decode an array with <= 15 items, moving past it in data and pushing it to the stack
+ static void DecodeSmallArray(byte sigil, ref ReadOnlySpan data, ref LuaStateWrapper state)
{
- return 0;
+ // Reserve extra space for the temporary item
+ state.ForceMinimumStackCapacity(1);
+
+ var len = sigil & 0b0000_1111;
+
+ state.CreateTable(len, 0);
+ var arrayIndex = state.StackTop;
+
+ for (var i = 1; i <= len; i++)
+ {
+ // Push the element onto the stack
+ Decode(ref data, ref state);
+ state.RawSetInteger(arrayIndex, i);
+ }
}
- // Even if enabled, if no logger was provided we can just bail
- if (logger == null)
+ // Decode an array with <= 65,535 items, moving past it in data and pushing it to the stack
+ static void DecodeMidArray(ref ReadOnlySpan data, ref LuaStateWrapper state)
{
- return 0;
+ // Reserve extra space for the temporary item
+ state.ForceMinimumStackCapacity(1);
+
+ var len = BinaryPrimitives.ReadUInt16BigEndian(data);
+ data = data[2..];
+
+ state.CreateTable(len, 0);
+ var arrayIndex = state.StackTop;
+
+ for (var i = 1; i <= len; i++)
+ {
+ // Push the element onto the stack
+ Decode(ref data, ref state);
+ state.RawSetInteger(arrayIndex, i);
+ }
}
- // Construct and log the equivalent message
- string logMessage;
- if (argCount == 2)
+ // Decode an array with <= 4,294,967,295 items, moving past it in data and pushing it to the stack
+ static void DecodeLargeArray(ref ReadOnlySpan data, ref LuaStateWrapper state)
{
- if (state.CheckBuffer(2, out var buff))
+ // Reserve extra space for the temporary item
+ state.ForceMinimumStackCapacity(1);
+
+ var len = BinaryPrimitives.ReadUInt32BigEndian(data);
+ data = data[4..];
+
+ if ((int)len < 0)
{
- logMessage = Encoding.UTF8.GetString(buff);
+ _ = state.RaiseError($"Array length is too long: {len}");
+ return;
}
- else
+
+ state.CreateTable((int)len, 0);
+ var arrayIndex = state.StackTop;
+
+ for (var i = 1; i <= len; i++)
{
- logMessage = "";
+ // Push the element onto the stack
+ Decode(ref data, ref state);
+ state.RawSetInteger(arrayIndex, i);
}
}
- else
+
+ // Decode an map with <= 15 key-value pairs, moving past it in data and pushing it to the stack
+ static void DecodeSmallMap(byte sigil, ref ReadOnlySpan data, ref LuaStateWrapper state)
{
- var sb = new StringBuilder();
+ // Reserve extra space for the temporary key & value
+ state.ForceMinimumStackCapacity(2);
- for (var argIx = 2; argIx <= argCount; argIx++)
+ var len = sigil & 0b0000_1111;
+
+ state.CreateTable(0, len);
+ var mapIndex = state.StackTop;
+
+ for (var i = 1; i <= len; i++)
{
- if (state.CheckBuffer(argIx, out var buff))
- {
- if (sb.Length != 0)
- {
- _ = sb.Append(' ');
- }
+ // Push the key onto the stack
+ Decode(ref data, ref state);
- _ = sb.Append(Encoding.UTF8.GetString(buff));
- }
+ // Push the value onto the stack
+ Decode(ref data, ref state);
+
+ state.RawSet(mapIndex);
}
+ }
- logMessage = sb.ToString();
+ // Decode a map with <= 65,535 key-value pairs, moving past it in data and pushing it to the stack
+ static void DecodeMidMap(ref ReadOnlySpan data, ref LuaStateWrapper state)
+ {
+ // Reserve extra space for the temporary key & value
+ state.ForceMinimumStackCapacity(2);
+
+ var len = BinaryPrimitives.ReadUInt16BigEndian(data);
+ data = data[2..];
+
+ state.CreateTable(0, len);
+ var mapIndex = state.StackTop;
+
+ for (var i = 1; i <= len; i++)
+ {
+ // Push the key onto the stack
+ Decode(ref data, ref state);
+
+ // Push the value onto the stack
+ Decode(ref data, ref state);
+
+ state.RawSet(mapIndex);
+ }
}
- var logLevel =
- rawLevel switch
+ // Decode a map with <= 4,294,967,295 key-value pairs, moving past it in data and pushing it to the stack
+ static void DecodeLargeMap(ref ReadOnlySpan data, ref LuaStateWrapper state)
+ {
+ // Reserve extra space for the temporary key & value
+ state.ForceMinimumStackCapacity(2);
+
+ var len = BinaryPrimitives.ReadUInt32BigEndian(data);
+ data = data[4..];
+
+ if ((int)len < 0)
{
- 0 => LogLevel.Debug,
- 1 => LogLevel.Information,
- 2 => LogLevel.Warning,
- // We validated this above, so really it's just 3 but the switch needs to be exhaustive
- _ => LogLevel.Error,
- };
+ _ = state.RaiseError($"Map length is too long: {len}");
+ return;
+ }
- logger.Log(logLevel, "redis.log: {message}", logMessage.ToString());
+ state.CreateTable(0, (int)len);
+ var mapIndex = state.StackTop;
- return 0;
+ for (var i = 1; i <= len; i++)
+ {
+ // Push the key onto the stack
+ Decode(ref data, ref state);
+
+ // Push the value onto the stack
+ Decode(ref data, ref state);
+
+ state.RawSet(mapIndex);
+ }
+ }
}
///
@@ -1328,7 +3308,7 @@ private unsafe int ProcessRespResponse(byte respProtocolVersion, byte* respPtr,
{
var respEnd = respPtr + respLen;
- var ret = ProcessSingleResp3Term(respProtocolVersion, ref respPtr, respEnd);
+ var ret = ProcessSingleRespTerm(respProtocolVersion, ref respPtr, respEnd);
if (respPtr != respEnd)
{
@@ -1338,7 +3318,7 @@ private unsafe int ProcessRespResponse(byte respProtocolVersion, byte* respPtr,
return ret;
}
- private unsafe int ProcessSingleResp3Term(byte respProtocolVersion, ref byte* respPtr, byte* respEnd)
+ private unsafe int ProcessSingleRespTerm(byte respProtocolVersion, ref byte* respPtr, byte* respEnd)
{
var indicator = (char)*respPtr;
@@ -1440,7 +3420,7 @@ private unsafe int ProcessSingleResp3Term(byte respProtocolVersion, ref byte* re
for (var itemIx = 0; itemIx < arrayItemCount; itemIx++)
{
// Pushes the item to the top of the stack
- _ = ProcessSingleResp3Term(respProtocolVersion, ref respPtr, respEnd);
+ _ = ProcessSingleRespTerm(respProtocolVersion, ref respPtr, respEnd);
// Store the item into the table
state.RawSetInteger(curTop + 1, itemIx + 1);
@@ -1470,10 +3450,10 @@ private unsafe int ProcessSingleResp3Term(byte respProtocolVersion, ref byte* re
for (var pair = 0; pair < mapPairCount; pair++)
{
// Read key
- _ = ProcessSingleResp3Term(respProtocolVersion, ref respPtr, respEnd);
+ _ = ProcessSingleRespTerm(respProtocolVersion, ref respPtr, respEnd);
// Read value
- _ = ProcessSingleResp3Term(respProtocolVersion, ref respPtr, respEnd);
+ _ = ProcessSingleRespTerm(respProtocolVersion, ref respPtr, respEnd);
// Set t[k] = v
state.RawSet(curTop + 3);
@@ -1526,7 +3506,7 @@ private unsafe int ProcessSingleResp3Term(byte respProtocolVersion, ref byte* re
for (var pair = 0; pair < setItemCount; pair++)
{
// Read value, which we use as a key
- _ = ProcessSingleResp3Term(respProtocolVersion, ref respPtr, respEnd);
+ _ = ProcessSingleRespTerm(respProtocolVersion, ref respPtr, respEnd);
// Unconditionally the value under the key is true
state.PushBoolean(true);
@@ -2471,7 +4451,34 @@ static unsafe void WriteString(LuaRunner runner, ref TResponse resp)
{
runner.state.KnownStringToBuffer(runner.state.StackTop, out var buf);
- while (!RespWriteUtils.TryWriteBulkString(buf, ref resp.BufferCur, resp.BufferEnd))
+ // Strings can be veeeerrrrry large, so we can't use the short helpers
+ // Thus we write the full string directly
+ while (!RespWriteUtils.TryWriteBulkStringLength(buf, ref resp.BufferCur, resp.BufferEnd))
+ resp.SendAndReset();
+
+ // Repeat while we have bytes left to write
+ while (!buf.IsEmpty)
+ {
+ // Copy bytes over
+ var destSpace = resp.BufferEnd - resp.BufferCur;
+ var copyLen = (int)(destSpace < buf.Length ? destSpace : buf.Length);
+ buf.Slice(0, copyLen).CopyTo(new Span(resp.BufferCur, copyLen));
+
+ // Advance
+ resp.BufferCur += copyLen;
+
+ // Flush if we filled the buffer
+ if (destSpace == copyLen)
+ {
+ resp.SendAndReset();
+ }
+
+ // Move past the data we wrote out
+ buf = buf.Slice(copyLen);
+ }
+
+ // End the string
+ while (!RespWriteUtils.TryWriteNewLine(ref resp.BufferCur, resp.BufferEnd))
resp.SendAndReset();
runner.state.Pop(1);
@@ -2811,6 +4818,77 @@ private static (int Start, ulong[] Bitmap) InitializeNoScriptDetails()
return (start, bitmap);
}
+
+ ///
+ /// Modifies to account for , and converts to bytes.
+ ///
+ /// Provided as an optimization, as often this can be memoized.
+ ///
+ private static ReadOnlyMemory PrepareLoaderBlockBytes(HashSet allowedFunctions)
+ {
+ // If nothing is explicitly allowed, fallback to our defaults
+ if (allowedFunctions.Count == 0)
+ {
+ allowedFunctions = DefaultAllowedFunctions;
+ }
+
+ // Most of the time this list never changes, so reuse the work
+ var cache = CachedLoaderBlock;
+ if (cache != null && ReferenceEquals(cache.AllowedFunctions, allowedFunctions))
+ {
+ return cache.LoaderBlockBytes;
+ }
+
+ // Build the subset of a Lua table where we export all these functions
+ var wholeIncludes = new HashSet(StringComparer.OrdinalIgnoreCase);
+ var replacement = new StringBuilder();
+ foreach (var wholeRef in allowedFunctions.Where(static x => !x.Contains('.')))
+ {
+ if (!DefaultAllowedFunctions.Contains(wholeRef))
+ {
+ // Skip functions not intentionally exported
+ continue;
+ }
+
+ _ = replacement.AppendLine($" {wholeRef}={wholeRef};");
+ _ = wholeIncludes.Add(wholeRef);
+ }
+
+ // Partial includes (ie. os.clock) need special handling
+ var partialIncludes = allowedFunctions.Where(static x => x.Contains('.')).Select(static x => (Leading: x[..x.IndexOf('.')], Trailing: x[(x.IndexOf('.') + 1)..]));
+ foreach (var grouped in partialIncludes.GroupBy(static t => t.Leading, StringComparer.OrdinalIgnoreCase))
+ {
+ if (wholeIncludes.Contains(grouped.Key))
+ {
+ // Including a subset of something included in whole doesn't affect things
+ continue;
+ }
+
+ if (!DefaultAllowedFunctions.Contains(grouped.Key))
+ {
+ // Skip functions not intentionally exported
+ continue;
+ }
+
+ _ = replacement.AppendLine($" {grouped.Key}={{");
+ foreach (var part in grouped.Select(static t => t.Trailing).Distinct().OrderBy(static t => t))
+ {
+ _ = replacement.AppendLine($" {part}={grouped.Key}.{part};");
+ }
+ _ = replacement.AppendLine(" };");
+ }
+
+ var decl = replacement.ToString();
+ var finalLoaderBlock = LoaderBlock.Replace("!!SANDBOX_ENV REPLACEMENT TARGET!!", decl);
+
+ // Save off for next caller
+ //
+ // Inherently race-y, but that's fine - worst case we do a little extra work
+ var newCache = new LoaderBlockCache(allowedFunctions, Encoding.UTF8.GetBytes(finalLoaderBlock));
+ CachedLoaderBlock = newCache;
+
+ return newCache.LoaderBlockBytes;
+ }
}
///
@@ -2950,5 +5028,132 @@ internal static int AclCheckCommand(nint luaState)
[UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])]
internal static int SetResp(nint luaState)
=> CallbackContext.SetResp(luaState);
+
+ ///
+ /// Entry point for calls to math.atan2.
+ ///
+ [UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])]
+ internal static int Atan2(nint luaState)
+ => CallbackContext.Atan2(luaState);
+
+ ///
+ /// Entry point for calls to math.cosh.
+ ///
+ [UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])]
+ internal static int Cosh(nint luaState)
+ => CallbackContext.Cosh(luaState);
+
+ ///
+ /// Entry point for calls to math.frexp.
+ ///
+ [UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])]
+ internal static int Frexp(nint luaState)
+ => CallbackContext.Frexp(luaState);
+
+ ///
+ /// Entry point for calls to math.ldexp.
+ ///
+ [UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])]
+ internal static int Ldexp(nint luaState)
+ => CallbackContext.Ldexp(luaState);
+
+ ///
+ /// Entry point for calls to math.log10.
+ ///
+ [UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])]
+ internal static int Log10(nint luaState)
+ => CallbackContext.Log10(luaState);
+
+ ///
+ /// Entry point for calls to math.pow.
+ ///
+ [UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])]
+ internal static int Pow(nint luaState)
+ => CallbackContext.Pow(luaState);
+
+ ///
+ /// Entry point for calls to math.sinh.
+ ///
+ [UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])]
+ internal static int Sinh(nint luaState)
+ => CallbackContext.Sinh(luaState);
+
+ ///
+ /// Entry point for calls to math.tanh.
+ ///
+ [UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])]
+ internal static int Tanh(nint luaState)
+ => CallbackContext.Tanh(luaState);
+
+ ///
+ /// Entry point for calls to table.maxn.
+ ///
+ [UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])]
+ internal static int Maxn(nint luaState)
+ => CallbackContext.Maxn(luaState);
+
+ ///
+ /// Entry point for calls to loadstring.
+ ///
+ [UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])]
+ internal static int LoadString(nint luaState)
+ => CallbackContext.LoadString(luaState);
+
+ ///
+ /// Entry point for calls to cjson.encode.
+ ///
+ [UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])]
+ internal static int CJsonEncode(nint luaState)
+ => CallbackContext.CJsonEncode(luaState);
+
+ ///
+ /// Entry point for calls to cjson.decode.
+ ///
+ [UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])]
+ internal static int CJsonDecode(nint luaState)
+ => CallbackContext.CJsonDecode(luaState);
+
+ ///
+ /// Entry point for calls to bit.tobit.
+ ///
+ [UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])]
+ internal static int BitToBit(nint luaState)
+ => CallbackContext.BitToBit(luaState);
+
+ ///
+ /// Entry point for calls to bit.tohex.
+ ///
+ [UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])]
+ internal static int BitToHex(nint luaState)
+ => CallbackContext.BitToHex(luaState);
+
+ ///
+ /// Entry point for calls to garnet_bitop, which backs
+ /// bit.bnot, bit.bor, etc.
+ ///
+ [UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])]
+ internal static int Bitop(nint luaState)
+ => CallbackContext.Bitop(luaState);
+
+ ///
+ /// Entry point for calls to bit.bswap.
+ ///
+ [UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])]
+ internal static int BitBswap(nint luaState)
+ => CallbackContext.BitBswap(luaState);
+
+ ///
+ /// Entry point for calls to cmsgpack.pack.
+ ///
+ [UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])]
+ internal static int CMsgPackPack(nint luaState)
+ => CallbackContext.CMsgPackPack(luaState);
+
+ ///
+ /// Entry point for calls to cmsgpack.unpack.
+ ///
+ [UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])]
+ internal static int CMsgPackUnpack(nint luaState)
+ => CallbackContext.CMsgPackUnpack(luaState);
}
}
\ No newline at end of file
diff --git a/libs/server/Lua/LuaStateWrapper.cs b/libs/server/Lua/LuaStateWrapper.cs
index 7404e1dce73..a3cf6e32f8d 100644
--- a/libs/server/Lua/LuaStateWrapper.cs
+++ b/libs/server/Lua/LuaStateWrapper.cs
@@ -3,6 +3,7 @@
using System;
using System.Diagnostics;
+using System.Diagnostics.CodeAnalysis;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Text;
@@ -443,18 +444,54 @@ internal void SetGlobal(ReadOnlySpan nullTerminatedGlobalName)
///
/// This should be used for all LoadBuffers into Lua.
///
- /// Note that this is different from pushing a buffer, as the loaded buffer is compiled.
+ /// Note that this is different from pushing a buffer, as the loaded buffer is compiled and executed.
///
/// Maintains and to minimize p/invoke calls.
///
[MethodImpl(MethodImplOptions.AggressiveInlining)]
internal LuaStatus LoadBuffer(ReadOnlySpan buffer)
{
- AssertLuaStackNotFull();
+ AssertLuaStackNotFull(2);
var ret = NativeMethods.LoadBuffer(state, buffer);
- UpdateStackTop(1);
+ if (ret != LuaStatus.OK)
+ {
+ StackTop = NativeMethods.GetTop(state);
+ }
+ else
+ {
+ UpdateStackTop(1);
+ }
+
+ AssertLuaStackExpected();
+
+ return ret;
+ }
+
+ ///
+ /// This should be used for all LoadStrings into Lua.
+ ///
+ /// Note that this is different from pushing or loading buffer, as the loaded buffer is compiled but NOT executed.
+ ///
+ /// Maintains and to minimize p/invoke calls.
+ ///
+ internal LuaStatus LoadString(ReadOnlySpan buffer)
+ {
+ AssertLuaStackNotFull(2);
+
+ var ret = NativeMethods.LoadString(state, buffer);
+
+ if (ret != LuaStatus.OK)
+ {
+ StackTop = NativeMethods.GetTop(state);
+ }
+ else
+ {
+ UpdateStackTop(1);
+ }
+
+ AssertLuaStackExpected();
return ret;
}
@@ -581,6 +618,20 @@ internal void PushValue(int stackIndex)
UpdateStackTop(1);
}
+ ///
+ /// This should be used for all Removes into Lua.
+ ///
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ internal void Remove(int stackIndex)
+ {
+ AssertLuaStackIndexInBounds(stackIndex);
+
+ NativeMethods.Rotate(state, stackIndex, -1);
+ NativeMethods.Pop(state, 1);
+
+ UpdateStackTop(-1);
+ }
+
// Rarely used
///
@@ -624,6 +675,7 @@ internal void ClearStack()
///
/// Clear the stack and raise an error with the given message.
///
+ [DoesNotReturn]
internal int RaiseError(string msg)
{
ClearStack();
@@ -636,6 +688,7 @@ internal int RaiseError(string msg)
///
/// Raise an error, where the top of the stack is the error message.
///
+ [DoesNotReturn]
internal readonly int RaiseErrorFromStack()
{
Debug.Assert(StackTop != 0, "Expected error message on the stack");
diff --git a/libs/server/Lua/NativeMethods.cs b/libs/server/Lua/NativeMethods.cs
index 5035ff78fa3..6a784d9c060 100644
--- a/libs/server/Lua/NativeMethods.cs
+++ b/libs/server/Lua/NativeMethods.cs
@@ -46,6 +46,13 @@ internal static partial class NativeMethods
[UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl)])]
private static partial LuaStatus luaL_loadbufferx(lua_State luaState, charptr_t buff, size_t sz, charptr_t name, charptr_t mode);
+ ///
+ /// see: https://www.lua.org/manual/5.4/manual.html#luaL_loadstring
+ ///
+ [LibraryImport(LuaLibraryName)]
+ [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl)])]
+ private static partial LuaStatus luaL_loadstring(lua_State lua_State, charptr_t buff);
+
///
/// see: https://www.lua.org/manual/5.4/manual.html#luaL_newstate
///
@@ -186,6 +193,13 @@ internal static partial class NativeMethods
[UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl)])]
private static partial int lua_next(lua_State luaState, int tableIndex);
+ ///
+ /// see: https://www.lua.org/manual/5.4/manual.html#lua_rotate
+ ///
+ [LibraryImport(LuaLibraryName)]
+ [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl)])]
+ private static partial void lua_rotate(lua_State luaState, int stackIndex, int n);
+
// GC Transition suppressed - only do this after auditing the Lua method and confirming constant-ish, fast, runtime w/o allocations
///
@@ -383,7 +397,7 @@ internal static unsafe ref byte PushBuffer(lua_State luaState, ReadOnlySpan
- /// Push given span to stack, and compiles it.
+ /// Push given span to stack, compiles it, and executes it.
///
/// Provided data is copied, and can be reused once this call returns.
///
@@ -395,6 +409,19 @@ internal static unsafe LuaStatus LoadBuffer(lua_State luaState, ReadOnlySpan
+ /// Push given span to stack, and compiles it.
+ ///
+ /// Provided data is copied, and can be reused once this call returns.
+ ///
+ internal static unsafe LuaStatus LoadString(lua_State luaState, ReadOnlySpan str)
+ {
+ fixed (byte* ptr = str)
+ {
+ return luaL_loadstring(luaState, (charptr_t)ptr);
+ }
+ }
+
///
/// Get the top index on the stack.
///
@@ -644,6 +671,15 @@ internal static void PushValue(lua_State luaState, int stackIndex)
internal static void SetTop(lua_State lua_State, int top)
=> lua_settop(lua_State, top);
+ ///
+ /// Rotates elements above (and including) in steps in
+ /// the direction of the top of the stack.
+ ///
+ /// can be negative.
+ ///
+ internal static void Rotate(lua_State luaState, int stackIndex, int n)
+ => lua_rotate(luaState, stackIndex, n);
+
///
/// Raise an error, using the top of the stack as an error item.
///
diff --git a/libs/server/Lua/SessionScriptCache.cs b/libs/server/Lua/SessionScriptCache.cs
index 226329e6c7d..cae4c8ef9ca 100644
--- a/libs/server/Lua/SessionScriptCache.cs
+++ b/libs/server/Lua/SessionScriptCache.cs
@@ -31,6 +31,7 @@ internal sealed class SessionScriptCache : IDisposable
readonly int? memoryLimitBytes;
readonly LuaTimeoutManager timeoutManager;
readonly LuaLoggingMode logMode;
+ readonly HashSet allowedFunctions;
LuaRunner timeoutRunningScript;
LuaTimeoutManager.Registration timeoutRegistration;
@@ -57,6 +58,7 @@ public SessionScriptCache(StoreWrapper storeWrapper, IGarnetAuthenticator authen
memoryManagementMode = storeWrapper.serverOptions.LuaOptions.MemoryManagementMode;
memoryLimitBytes = storeWrapper.serverOptions.LuaOptions.GetMemoryLimitBytes();
logMode = storeWrapper.serverOptions.LuaOptions.LogMode;
+ allowedFunctions = storeWrapper.serverOptions.LuaOptions.AllowedFunctions;
}
public void Dispose()
@@ -138,7 +140,7 @@ internal bool TryLoad(RespServerSession session, ReadOnlySpan source, Scri
{
var sourceOnHeap = source.ToArray();
- runner = new LuaRunner(memoryManagementMode, memoryLimitBytes, logMode, sourceOnHeap, storeWrapper.serverOptions.LuaTransactionMode, processor, scratchBufferNetworkSender, storeWrapper.redisProtocolVersion, logger);
+ runner = new LuaRunner(memoryManagementMode, memoryLimitBytes, logMode, allowedFunctions, sourceOnHeap, storeWrapper.serverOptions.LuaTransactionMode, processor, scratchBufferNetworkSender, storeWrapper.redisProtocolVersion, logger);
// If compilation fails, an error is written out
if (runner.CompileForSession(session))
diff --git a/test/Garnet.test/GarnetServerConfigTests.cs b/test/Garnet.test/GarnetServerConfigTests.cs
index 500bf3ddba7..6a49537fbef 100644
--- a/test/Garnet.test/GarnetServerConfigTests.cs
+++ b/test/Garnet.test/GarnetServerConfigTests.cs
@@ -520,7 +520,7 @@ public void LuaLoggingOptions()
}
}
- // Command line args
+ // JSON args
{
// No value is accepted
{
@@ -567,6 +567,92 @@ public void LuaLoggingOptions()
}
}
+ [Test]
+ public void LuaAllowedFunctions()
+ {
+ // Command line args
+ {
+ // No value is accepted
+ {
+ var args = new[] { "--lua" };
+ var parseSuccessful = ServerSettingsManager.TryParseCommandLineArguments(args, out var options, out var invalidOptions, out var exitGracefully);
+ ClassicAssert.IsTrue(parseSuccessful);
+ ClassicAssert.IsTrue(options.EnableLua);
+ ClassicAssert.AreEqual(0, options.LuaAllowedFunctions.Count());
+ }
+
+ // One option works
+ {
+ var args = new[] { "--lua", "--lua-allowed-functions", "os" };
+ var parseSuccessful = ServerSettingsManager.TryParseCommandLineArguments(args, out var options, out var invalidOptions, out var exitGracefully);
+ ClassicAssert.IsTrue(parseSuccessful);
+ ClassicAssert.IsTrue(options.EnableLua);
+ ClassicAssert.AreEqual(1, options.LuaAllowedFunctions.Count());
+ ClassicAssert.IsTrue(options.LuaAllowedFunctions.Contains("os"));
+ }
+
+ // Multiple option works
+ {
+ var args = new[] { "--lua", "--lua-allowed-functions", "os,assert,rawget" };
+ var parseSuccessful = ServerSettingsManager.TryParseCommandLineArguments(args, out var options, out var invalidOptions, out var exitGracefully);
+ ClassicAssert.IsTrue(parseSuccessful);
+ ClassicAssert.IsTrue(options.EnableLua);
+ ClassicAssert.AreEqual(3, options.LuaAllowedFunctions.Count());
+ ClassicAssert.IsTrue(options.LuaAllowedFunctions.Contains("os"));
+ ClassicAssert.IsTrue(options.LuaAllowedFunctions.Contains("assert"));
+ ClassicAssert.IsTrue(options.LuaAllowedFunctions.Contains("rawget"));
+ }
+
+ // Invalid rejected
+ {
+ var args = new[] { "--lua", "--lua-allowed-functions" };
+ var parseSuccessful = ServerSettingsManager.TryParseCommandLineArguments(args, out var options, out var invalidOptions, out var exitGracefully);
+ ClassicAssert.IsFalse(parseSuccessful);
+ }
+ }
+
+ // JSON args
+ {
+ // No value is accepted
+ {
+ const string JSON = @"{ ""EnableLua"": true }";
+ var parseSuccessful = TryParseGarnetConfOptions(JSON, out var options, out var invalidOptions, out var exitGracefully);
+ ClassicAssert.IsTrue(parseSuccessful);
+ ClassicAssert.IsTrue(options.EnableLua);
+ ClassicAssert.AreEqual(0, options.LuaAllowedFunctions.Count());
+ }
+
+ // One option works
+ {
+ const string JSON = @"{ ""EnableLua"": true, ""LuaAllowedFunctions"": [""os""] }";
+ var parseSuccessful = TryParseGarnetConfOptions(JSON, out var options, out var invalidOptions, out var exitGracefully);
+ ClassicAssert.IsTrue(parseSuccessful);
+ ClassicAssert.IsTrue(options.EnableLua);
+ ClassicAssert.AreEqual(1, options.LuaAllowedFunctions.Count());
+ ClassicAssert.IsTrue(options.LuaAllowedFunctions.Contains("os"));
+ }
+
+ // Multiple option works
+ {
+ const string JSON = @"{ ""EnableLua"": true, ""LuaAllowedFunctions"": [""os"", ""assert"", ""rawget""] }";
+ var parseSuccessful = TryParseGarnetConfOptions(JSON, out var options, out var invalidOptions, out var exitGracefully);
+ ClassicAssert.IsTrue(parseSuccessful);
+ ClassicAssert.IsTrue(options.EnableLua);
+ ClassicAssert.AreEqual(3, options.LuaAllowedFunctions.Count());
+ ClassicAssert.IsTrue(options.LuaAllowedFunctions.Contains("os"));
+ ClassicAssert.IsTrue(options.LuaAllowedFunctions.Contains("assert"));
+ ClassicAssert.IsTrue(options.LuaAllowedFunctions.Contains("rawget"));
+ }
+
+ // Invalid rejected
+ {
+ const string JSON = @"{ ""EnableLua"": true, ""LuaAllowedFunctions"": { } }";
+ var parseSuccessful = TryParseGarnetConfOptions(JSON, out var options, out var invalidOptions, out var exitGracefully);
+ ClassicAssert.IsFalse(parseSuccessful);
+ }
+ }
+ }
+
///
/// Import a garnet.conf file with the given contents
///
diff --git a/test/Garnet.test/LuaScriptRunnerTests.cs b/test/Garnet.test/LuaScriptRunnerTests.cs
index 8f5cb39354c..fdf2ab9a6a9 100644
--- a/test/Garnet.test/LuaScriptRunnerTests.cs
+++ b/test/Garnet.test/LuaScriptRunnerTests.cs
@@ -59,7 +59,7 @@ public void CannotRunUnsafeScript()
{
runner.CompileForRunner();
var ex = Assert.Throws(() => runner.RunForRunner());
- ClassicAssert.AreEqual("ERR Lua encountered an error: [string \"os.exit();\"]:1: attempt to index a nil value (global 'os')", ex.Message);
+ ClassicAssert.AreEqual("ERR Lua encountered an error: [string \"os.exit();\"]:1: attempt to call a nil value (field 'exit')", ex.Message);
}
// Try to include a new .net library
@@ -518,7 +518,7 @@ public void RedisLogDisabled()
#endif
// Just because it's hard to test in LuaScriptTests, doing this here
- using var runner = new LuaRunner(new(LuaMemoryManagementMode.Native, "", Timeout.InfiniteTimeSpan, LuaLoggingMode.Disable), "redis.log(redis.LOG_WARNING, 'foo')");
+ using var runner = new LuaRunner(new(LuaMemoryManagementMode.Native, "", Timeout.InfiniteTimeSpan, LuaLoggingMode.Disable, []), "redis.log(redis.LOG_WARNING, 'foo')");
runner.CompileForRunner();
@@ -531,7 +531,7 @@ public void RedisLogSilent()
{
// Just because it's hard to test in LuaScriptTests, doing this here
using var logger = new FakeLogger();
- using var runner = new LuaRunner(new(LuaMemoryManagementMode.Native, "", Timeout.InfiniteTimeSpan, LuaLoggingMode.Silent), "redis.log(redis.LOG_WARNING, 'foo')", logger: logger);
+ using var runner = new LuaRunner(new(LuaMemoryManagementMode.Native, "", Timeout.InfiniteTimeSpan, LuaLoggingMode.Silent, []), "redis.log(redis.LOG_WARNING, 'foo')", logger: logger);
runner.CompileForRunner();
_ = runner.RunForRunner();
@@ -539,6 +539,210 @@ public void RedisLogSilent()
ClassicAssert.IsFalse(logger.Logs.Any(static x => x.Contains("redis.log")));
}
+ [Test]
+ public void AllowedFunctions()
+ {
+ List globalFuncs = ["xpcall", "tostring", "setmetatable", "next", "assert", "tonumber", "rawequal", "collectgarbage", "getmetatable", "rawset", "pcall", "coroutine", "type", "_G", "select", "unpack", "gcinfo", "pairs", "rawget", "loadstring", "ipairs", "_VERSION", "load", "error"];
+ var exportedFuncs =
+ new Dictionary>
+ {
+ ["bit"] = ["tobit", "tohex", "bnot", "bor", "band", "bxor", "lshift", "rshift", "arshift", "rol", "ror", "bswap"],
+ ["cjson"] = ["encode", "decode"],
+ ["cmsgpack"] = ["pack", "unpack"],
+ ["math"] = ["abs", "acos", "asin", "atan", "atan2", "ceil", "cos", "cosh", "deg", "exp", "floor", "fmod", "frexp", "huge", "ldexp", "log", "log10", "max", "min", "modf", "pi", "pow", "rad", "random", "randomseed", "sin", "sinh", "sqrt", "tan", "tanh"],
+ ["os"] = ["clock"],
+ ["redis"] = ["call", "pcall", "error_reply", "status_reply", "sha1hex", "log", "LOG_DEBUG", "LOG_VERBOSE", "LOG_NOTICE", "LOG_WARNING", "setresp", "set_repl", "REPL_ALL", "REPL_AOF", "REPL_REPLICA", "REPL_SLAVE", "REPL_NONE", "replicate_commands", "breakpoint", "debug", "acl_check_cmd", "REDIS_VERSION", "REDIS_VERSION_NUM"],
+ ["string"] = ["byte", "char", "dump", "find", "format", "gmatch", "gsub", "len", "lower", "match", "rep", "reverse", "sub", "upper"],
+ ["struct"] = ["pack", "unpack", "size"],
+ ["table"] = ["concat", "insert", "maxn", "remove", "sort"],
+ };
+
+ // Check the supported globals
+ {
+ using var allRunner =
+ new LuaRunner(
+ new(LuaMemoryManagementMode.Native, "", Timeout.InfiniteTimeSpan, LuaLoggingMode.Silent, []),
+ @$"local ret = {{ }}
+ for k, v in pairs(_G) do
+ table.insert(ret, k)
+ end
+ return ret"
+ );
+
+ allRunner.CompileForRunner();
+ // __readonly is special and fine to leak, so ignore it
+ var allDefined = ((object[])allRunner.RunForRunner()).Select(static x => (string)x).Except(["__readonly"]).ToList();
+
+ var expected = globalFuncs.Concat(exportedFuncs.Keys);
+
+ var missing = expected.Except(allDefined).ToList();
+
+ // ARGV, KEYS, and redis are always available
+ var extra = allDefined.Except(expected).Except(["ARGV", "KEYS", "redis"]).ToList();
+
+ ClassicAssert.AreEqual(0, missing.Count, $"Missing globals: {string.Join(", ", missing)}");
+ ClassicAssert.AreEqual(0, extra.Count, $"Extra globals: {string.Join(", ", extra)}");
+
+ foreach (var globalFunc in globalFuncs)
+ {
+ // These are special, just ignore them
+ if (globalFunc is "type" or "_G" or "_VERSION")
+ {
+ continue;
+ }
+
+ var everythingExceptGlobalFunc = expected.Except([globalFunc]);
+
+ using var withoutRunner =
+ new LuaRunner(
+ new(LuaMemoryManagementMode.Native, "", Timeout.InfiniteTimeSpan, LuaLoggingMode.Silent, everythingExceptGlobalFunc),
+ $"return type({globalFunc})"
+ );
+
+ withoutRunner.CompileForRunner();
+ var withoutDefined = (string)withoutRunner.RunForRunner();
+
+ ClassicAssert.AreEqual("nil", withoutDefined, $"Global {globalFunc} available when it shouldn't have been");
+
+ using var withRunner =
+ new LuaRunner(
+ new(LuaMemoryManagementMode.Native, "", Timeout.InfiniteTimeSpan, LuaLoggingMode.Silent, ["type", globalFunc]),
+ $"return type({globalFunc})"
+ );
+
+ withRunner.CompileForRunner();
+ var withDefined = (string)withRunner.RunForRunner();
+
+ ClassicAssert.AreNotEqual("nil", withDefined, $"Global {globalFunc} not available when it should have been");
+ }
+ }
+
+ // Check for the supported Lua functions which are under names in globals
+
+ foreach (var (funcGroup, funcs) in exportedFuncs)
+ {
+ // Get all the keys under the funcGroup
+ {
+ using var runner =
+ new LuaRunner(
+ new(LuaMemoryManagementMode.Native, "", Timeout.InfiniteTimeSpan, LuaLoggingMode.Silent, [funcGroup, "pairs", "table.insert"]),
+ @$"local ret = {{ }}
+ for k, v in pairs({funcGroup}) do
+ table.insert(ret, k)
+ end
+ return ret"
+ );
+
+ runner.CompileForRunner();
+ // __readonly is special and fine to leak, so ignore it
+ var defined = ((object[])runner.RunForRunner()).Select(static x => (string)x).Except(["__readonly"]).ToList();
+
+ var missing = funcs.Except(defined).ToList();
+ var extra = defined.Except(funcs).ToList();
+
+ ClassicAssert.AreEqual(0, missing.Count, $"Missing funcs in {funcGroup}: {string.Join(", ", missing)}");
+ ClassicAssert.AreEqual(0, extra.Count, $"Extra funcs in {funcGroup}: {string.Join(", ", extra)}");
+ }
+
+ // Check all expected funcs are defined
+ {
+ var allTypes = string.Join(", ", funcs.Select(x => $"type({funcGroup}.{x})"));
+
+ using var runner =
+ new LuaRunner(
+ new(LuaMemoryManagementMode.Native, "", Timeout.InfiniteTimeSpan, LuaLoggingMode.Silent, [funcGroup, "type"]),
+ $"return {{ {allTypes} }}"
+ );
+
+ runner.CompileForRunner();
+ var defined = (object[])runner.RunForRunner();
+
+ for (var i = 0; i < funcs.Count; i++)
+ {
+ var forFunc = funcs[i];
+ var funcType = (string)defined[i];
+
+ ClassicAssert.AreNotEqual("nil", funcType, $"{funcGroup}.{forFunc} is not defined when it should be");
+ }
+ }
+
+ // Check NOT including top level group causes all functions to be unavailable
+ {
+ var otherGroups = exportedFuncs.Keys.Except(["_G", funcGroup]);
+
+ using var runner =
+ new LuaRunner(
+ new(LuaMemoryManagementMode.Native, "", Timeout.InfiniteTimeSpan, LuaLoggingMode.Silent, [.. otherGroups, "type"]),
+ $"return type({funcGroup})"
+ );
+
+ runner.CompileForRunner();
+ var defined = (string)runner.RunForRunner();
+
+ ClassicAssert.AreEqual("nil", defined, $"{funcGroup} is defined when it should not be");
+ }
+
+ // Check allowing just the one func
+ foreach (var func in funcs)
+ {
+ using var runner =
+ new LuaRunner(
+ new(LuaMemoryManagementMode.Native, "", Timeout.InfiniteTimeSpan, LuaLoggingMode.Silent, [$"{funcGroup}.{func}", "type"]),
+ $"return type({funcGroup}.{func})"
+ );
+
+ runner.CompileForRunner();
+ var defined = (string)runner.RunForRunner();
+
+ ClassicAssert.AreNotEqual("nil", defined, $"{funcGroup}.{func} is not defined when it should be");
+ }
+
+ // Check that disallowing just the one func in the group works
+ foreach (var func in funcs)
+ {
+ var others = funcs.Except([func]).Select(x => $"{funcGroup}.{x}");
+
+ string typeStatement;
+ if (others.Any())
+ {
+ typeStatement = $"type({funcGroup}.{func})";
+ }
+ else
+ {
+ // If a group of function is completely removed, the table is nulled out
+ typeStatement = $"type({funcGroup})";
+ }
+
+ using var runner =
+ new LuaRunner(
+ new(LuaMemoryManagementMode.Native, "", Timeout.InfiniteTimeSpan, LuaLoggingMode.Silent, [.. others, "type"]),
+ $"return {typeStatement}"
+ );
+
+ runner.CompileForRunner();
+ var defined = (string)runner.RunForRunner();
+
+ ClassicAssert.AreEqual("nil", defined, $"{funcGroup}.{func} is defined when it should not be");
+ }
+ }
+ }
+
+ [Test]
+ public void InternalFunctionsIgnoredInAllowedFunctions()
+ {
+ // Check if an internal implementation detail (garnet_call in this case) can be allowed
+ using var allRunner =
+ new LuaRunner(
+ new(LuaMemoryManagementMode.Native, "", Timeout.InfiniteTimeSpan, LuaLoggingMode.Silent, ["tostring", "garnet_call"]),
+ @$"return tostring(garnet_call)"
+ );
+
+ allRunner.CompileForRunner();
+
+ var res = (string)allRunner.RunForRunner();
+ ClassicAssert.AreEqual("nil", res);
+ }
+
private sealed class FakeLogger : ILogger, IDisposable
{
private readonly List logs = new();
diff --git a/test/Garnet.test/LuaScriptTests.cs b/test/Garnet.test/LuaScriptTests.cs
index d78bf02373d..02cf2719703 100644
--- a/test/Garnet.test/LuaScriptTests.cs
+++ b/test/Garnet.test/LuaScriptTests.cs
@@ -2,10 +2,13 @@
// Licensed under the MIT license.
using System;
+using System.Buffers.Binary;
+using System.Collections.Generic;
using System.Diagnostics;
using System.Globalization;
using System.IO;
using System.Linq;
+using System.Numerics;
using System.Text;
using System.Text.RegularExpressions;
using System.Threading;
@@ -24,7 +27,7 @@ namespace Garnet.test
[TestFixture(LuaMemoryManagementMode.Tracked, "", "")]
[TestFixture(LuaMemoryManagementMode.Tracked, "13m", "")]
[TestFixture(LuaMemoryManagementMode.Managed, "", "")]
- [TestFixture(LuaMemoryManagementMode.Managed, "16m", "")]
+ [TestFixture(LuaMemoryManagementMode.Managed, "17m", "")]
public class LuaScriptTests
{
///
@@ -95,6 +98,13 @@ public override string ToString()
private string aclFile;
private GarnetServer server;
+ ///
+ /// Temporarily disable errors raised from Lua until longjmp work is completed.
+ ///
+ /// TODO: Delete all of this
+ ///
+ private static bool CanTestLuaErrors { get; } = OperatingSystem.IsWindows() && Environment.Version.Major <= 8;
+
public LuaScriptTests(LuaMemoryManagementMode allocMode, string limitBytes, string limitTimeout)
{
this.allocMode = allocMode;
@@ -512,9 +522,10 @@ public void RedisPCall()
{
// This is a temporary fix to address a regression in .NET9, an open issue can be found here - https://github.com/dotnet/runtime/issues/111242
// Once the issue is resolved the #if can be removed permanently.
-#if NET9_0_OR_GREATER
- Assert.Ignore($"Ignoring test when running in .NET9.");
-#endif
+ if (!CanTestLuaErrors)
+ {
+ Assert.Ignore($"Ignoring test when running in .NET9.");
+ }
using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig());
var db = redis.GetDatabase(0);
@@ -533,9 +544,10 @@ public void RedisSha1Hex()
{
// This is a temporary fix to address a regression in .NET9, an open issue can be found here - https://github.com/dotnet/runtime/issues/111242
// Once the issue is resolved the #if can be removed permanently.
-#if NET9_0_OR_GREATER
- Assert.Ignore($"Ignoring test when running in .NET9.");
-#endif
+ if (!CanTestLuaErrors)
+ {
+ Assert.Ignore($"Ignoring test when running in .NET9.");
+ }
using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig());
var db = redis.GetDatabase(0);
@@ -566,9 +578,10 @@ public void RedisLog()
{
// This is a temporary fix to address a regression in .NET9, an open issue can be found here - https://github.com/dotnet/runtime/issues/111242
// Once the issue is resolved the #if can be removed permanently.
-#if NET9_0_OR_GREATER
- Assert.Ignore($"Ignoring test when running in .NET9.");
-#endif
+ if (!CanTestLuaErrors)
+ {
+ Assert.Ignore($"Ignoring test when running in .NET9.");
+ }
using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig());
var db = redis.GetDatabase(0);
@@ -649,12 +662,6 @@ public void RedisDebugAndBreakpoint()
[Test]
public void RedisAclCheckCmd()
{
- // This is a temporary fix to address a regression in .NET9, an open issue can be found here - https://github.com/dotnet/runtime/issues/111242
- // Once the issue is resolved the #if can be removed permanently.
-#if NET9_0_OR_GREATER
- Assert.Ignore($"Ignoring test when running in .NET9.");
-#endif
-
// Note this path is more heavily exercised in ACL tests
using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig());
@@ -663,17 +670,20 @@ public void RedisAclCheckCmd()
using var denyRedis = ConnectionMultiplexer.Connect(TestUtils.GetConfig(authUsername: "deny"));
var denyDB = denyRedis.GetDatabase(0);
- var noArgs = ClassicAssert.Throws(() => db.ScriptEvaluate("return redis.acl_check_cmd()"));
- ClassicAssert.IsTrue(noArgs.Message.StartsWith("ERR Please specify at least one argument for this redis lib call"));
+ if (CanTestLuaErrors)
+ {
+ var noArgs = ClassicAssert.Throws(() => db.ScriptEvaluate("return redis.acl_check_cmd()"));
+ ClassicAssert.IsTrue(noArgs.Message.StartsWith("ERR Please specify at least one argument for this redis lib call"));
- var invalidCmdArgType = ClassicAssert.Throws(() => db.ScriptEvaluate("return redis.acl_check_cmd({123})"));
- ClassicAssert.IsTrue(invalidCmdArgType.Message.StartsWith("ERR Lua redis lib command arguments must be strings or integers"));
+ var invalidCmdArgType = ClassicAssert.Throws(() => db.ScriptEvaluate("return redis.acl_check_cmd({123})"));
+ ClassicAssert.IsTrue(invalidCmdArgType.Message.StartsWith("ERR Lua redis lib command arguments must be strings or integers"));
- var invalidCmd = ClassicAssert.Throws(() => db.ScriptEvaluate("return redis.acl_check_cmd('nope')"));
- ClassicAssert.IsTrue(invalidCmd.Message.StartsWith("ERR Invalid command passed to redis.acl_check_cmd()"));
+ var invalidCmd = ClassicAssert.Throws(() => db.ScriptEvaluate("return redis.acl_check_cmd('nope')"));
+ ClassicAssert.IsTrue(invalidCmd.Message.StartsWith("ERR Invalid command passed to redis.acl_check_cmd()"));
- var invalidArgType = ClassicAssert.Throws(() => db.ScriptEvaluate("return redis.acl_check_cmd('GET', {123})"));
- ClassicAssert.IsTrue(invalidArgType.Message.StartsWith("ERR Lua redis lib command arguments must be strings or integers"));
+ var invalidArgType = ClassicAssert.Throws(() => db.ScriptEvaluate("return redis.acl_check_cmd('GET', {123})"));
+ ClassicAssert.IsTrue(invalidArgType.Message.StartsWith("ERR Lua redis lib command arguments must be strings or integers"));
+ }
var canRun = (bool)db.ScriptEvaluate("return redis.acl_check_cmd('GET')");
ClassicAssert.IsTrue(canRun);
@@ -709,23 +719,20 @@ public void RedisAclCheckCmd()
[Test]
public void RedisSetResp()
{
- // This is a temporary fix to address a regression in .NET9, an open issue can be found here - https://github.com/dotnet/runtime/issues/111242
- // Once the issue is resolved the #if can be removed permanently.
-#if NET9_0_OR_GREATER
- Assert.Ignore($"Ignoring test when running in .NET9.");
-#endif
-
using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig());
var db = redis.GetDatabase(0);
- var noArgs = ClassicAssert.Throws(() => db.ScriptEvaluate("redis.setresp()"));
- ClassicAssert.IsTrue(noArgs.Message.StartsWith("ERR redis.setresp() requires one argument."));
+ if (CanTestLuaErrors)
+ {
+ var noArgs = ClassicAssert.Throws(() => db.ScriptEvaluate("redis.setresp()"));
+ ClassicAssert.IsTrue(noArgs.Message.StartsWith("ERR redis.setresp() requires one argument."));
- var tooManyArgs = ClassicAssert.Throws(() => db.ScriptEvaluate("redis.setresp(1, 2)"));
- ClassicAssert.IsTrue(tooManyArgs.Message.StartsWith("ERR redis.setresp() requires one argument."));
+ var tooManyArgs = ClassicAssert.Throws(() => db.ScriptEvaluate("redis.setresp(1, 2)"));
+ ClassicAssert.IsTrue(tooManyArgs.Message.StartsWith("ERR redis.setresp() requires one argument."));
- var badArg = ClassicAssert.Throws(() => db.ScriptEvaluate("redis.setresp({123})"));
- ClassicAssert.IsTrue(badArg.Message.StartsWith("ERR RESP version must be 2 or 3."));
+ var badArg = ClassicAssert.Throws(() => db.ScriptEvaluate("redis.setresp({123})"));
+ ClassicAssert.IsTrue(badArg.Message.StartsWith("ERR RESP version must be 2 or 3."));
+ }
var resp2 = db.ScriptEvaluate("redis.setresp(2)");
ClassicAssert.IsTrue(resp2.IsNull);
@@ -733,8 +740,11 @@ public void RedisSetResp()
var resp3 = db.ScriptEvaluate("redis.setresp(3)");
ClassicAssert.IsTrue(resp3.IsNull);
- var badRespVersion = ClassicAssert.Throws(() => db.ScriptEvaluate("redis.setresp(1)"));
- ClassicAssert.IsTrue(badRespVersion.Message.StartsWith("ERR RESP version must be 2 or 3."));
+ if (CanTestLuaErrors)
+ {
+ var badRespVersion = ClassicAssert.Throws(() => db.ScriptEvaluate("redis.setresp(1)"));
+ ClassicAssert.IsTrue(badRespVersion.Message.StartsWith("ERR RESP version must be 2 or 3."));
+ }
}
[Test]
@@ -965,9 +975,10 @@ public void RedisCallErrors()
{
// This is a temporary fix to address a regression in .NET9, an open issue can be found here - https://github.com/dotnet/runtime/issues/111242
// Once the issue is resolved the #if can be removed permanently.
-#if NET9_0_OR_GREATER
- Assert.Ignore($"Ignoring test when running in .NET9.");
-#endif
+ if (!CanTestLuaErrors)
+ {
+ Assert.Ignore($"Ignoring test when running in .NET9.");
+ }
// Testing that our error replies for redis.call match Redis behavior
//
@@ -1215,6 +1226,13 @@ public void IntentionalOOM()
return;
}
+ // This is a temporary fix to address a regression in .NET9, an open issue can be found here - https://github.com/dotnet/runtime/issues/111242
+ // Once the issue is resolved the #if can be removed permanently.
+ if (!CanTestLuaErrors)
+ {
+ Assert.Ignore($"Ignoring test when running in .NET9.");
+ }
+
const string ScriptOOMText = @"
local foo = 'abcdefghijklmnopqrstuvwxyz'
if @Ctrl == 'OOM' then
@@ -1242,11 +1260,6 @@ public void IntentionalOOM()
[Test]
public void Issue939()
{
- // This is a temporary fix to address a regression in .NET9, an open issue can be found here - https://github.com/dotnet/runtime/issues/111242
- // Once the issue is resolved the #if can be removed permanently.
-#if NET9_0_OR_GREATER
- Assert.Ignore($"Ignoring test when running in .NET9.");
-#endif
// See: https://github.com/microsoft/garnet/issues/939
const string Script = @"
@@ -1326,9 +1339,35 @@ public void Issue939()
}
}
- // Finally, check that nil is an illegal argument
- var exc = ClassicAssert.Throws(() => db.ScriptEvaluate("return redis.call('GET', nil)"));
- ClassicAssert.True(exc.Message.StartsWith("ERR Lua redis lib command arguments must be strings or integers"));
+ if (CanTestLuaErrors)
+ {
+ // Finally, check that nil is an illegal argument
+ var exc = ClassicAssert.Throws(() => db.ScriptEvaluate("return redis.call('GET', nil)"));
+ ClassicAssert.True(exc.Message.StartsWith("ERR Lua redis lib command arguments must be strings or integers"));
+ }
+ }
+
+ [Test]
+ public void Issue1079()
+ {
+ // Repeated submission of invalid Lua scripts shouldn't be cached, and thus should produce the same compilation error each time
+ //
+ // They also shouldn't break the session for future executions
+
+ const string BrokenScript = "return \"hello lua";
+ const string FixedScript = "return \"hello lua\"";
+
+ using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig());
+ var db = redis.GetDatabase();
+
+ var brokenExc1 = ClassicAssert.Throws(() => db.Execute("EVAL", BrokenScript, 0));
+ ClassicAssert.True(brokenExc1.Message.StartsWith("Compilation error: "));
+
+ var brokenExc2 = ClassicAssert.Throws(() => db.Execute("EVAL", BrokenScript, 0));
+ ClassicAssert.AreEqual(brokenExc1.Message, brokenExc2.Message);
+
+ var success = (string)db.Execute("EVAL", FixedScript, 0);
+ ClassicAssert.AreEqual("hello lua", success);
}
[TestCase(2)]
@@ -1639,9 +1678,10 @@ public void NoScriptCommandsForbidden()
{
// This is a temporary fix to address a regression in .NET9, an open issue can be found here - https://github.com/dotnet/runtime/issues/111242
// Once the issue is resolved the #if can be removed permanently.
-#if NET9_0_OR_GREATER
- Assert.Ignore($"Ignoring test when running in .NET9.");
-#endif
+ if (!CanTestLuaErrors)
+ {
+ Assert.Ignore($"Ignoring test when running in .NET9.");
+ }
ClassicAssert.True(RespCommandsInfo.TryGetRespCommandsInfo(out var allCommands, externalOnly: true));
@@ -1660,9 +1700,10 @@ public void IntentionalTimeout()
{
// This is a temporary fix to address a regression in .NET9, an open issue can be found here - https://github.com/dotnet/runtime/issues/111242
// Once the issue is resolved the #if can be removed permanently.
-#if NET9_0_OR_GREATER
- Assert.Ignore($"Ignoring test when running in .NET9.");
-#endif
+ if (!CanTestLuaErrors)
+ {
+ Assert.Ignore($"Ignoring test when running in .NET9.");
+ }
const string TimeoutScript = @"
local count = 0
@@ -1791,11 +1832,933 @@ public void Resp3ToLuaConversions(RedisProtocol connectionProtocol)
}
}
+ [Test]
+ public void Bit()
+ {
+ using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig());
+ var db = redis.GetDatabase();
+
+ // tobit
+ {
+ if (CanTestLuaErrors)
+ {
+ var noArgExc = ClassicAssert.Throws(() => db.ScriptEvaluate("return bit.tobit()"));
+ ClassicAssert.True(noArgExc.Message.Contains("bad argument") && noArgExc.Message.Contains("tobit"));
+
+ // Extra arguments are legal, but ignored
+
+ var badTypeExc = ClassicAssert.Throws(() => db.ScriptEvaluate("return bit.tobit({})"));
+ ClassicAssert.True(badTypeExc.Message.Contains("bad argument") && badTypeExc.Message.Contains("tobit"));
+ }
+
+ // Rules are suprisingly subtle, so test a bunch of tricky values
+ (string Value, string Expected)[] expectedValues = [
+ ("0", "0"),
+ ("1", "1"),
+ ("1.1", "1"),
+ ("1.5", "2"),
+ ("1.9", "2"),
+ ("0.1", "0"),
+ ("0.5", "0"),
+ ("0.9", "1"),
+ ("-1.1", "-1"),
+ ("-1.5", "-2"),
+ ("-1.9", "-2"),
+ ("-0.1", "0"),
+ ("-0.5", "0"),
+ ("-0.9", "-1"),
+ (int.MinValue.ToString(), int.MinValue.ToString()),
+ (int.MaxValue.ToString(), int.MaxValue.ToString()),
+ ((1L + int.MaxValue).ToString(), int.MinValue.ToString()),
+ ((-1L + int.MinValue).ToString(), int.MaxValue.ToString()),
+ (double.MaxValue.ToString(), "-1"),
+ (double.MinValue.ToString(), "-1"),
+ (float.MaxValue.ToString(), "-447893512"),
+ (float.MinValue.ToString(), "-447893512"),
+ ];
+ foreach (var (value, expected) in expectedValues)
+ {
+ var actual = (string)db.ScriptEvaluate($"return bit.tobit({value})");
+ ClassicAssert.AreEqual(expected, actual, $"bit.tobit conversion for {value} was incorrect");
+ }
+ }
+
+ // tohex
+ {
+ if (CanTestLuaErrors)
+ {
+ var noArgExc = ClassicAssert.Throws(() => db.ScriptEvaluate("return bit.tohex()"));
+ ClassicAssert.True(noArgExc.Message.Contains("bad argument") && noArgExc.Message.Contains("tohex"));
+
+ // Extra arguments are legal, but ignored
+
+ var badType1Exc = ClassicAssert.Throws(() => db.ScriptEvaluate("return bit.tohex({})"));
+ ClassicAssert.True(badType1Exc.Message.Contains("bad argument") && badType1Exc.Message.Contains("tohex"));
+
+ var badType2Exc = ClassicAssert.Throws(() => db.ScriptEvaluate("return bit.tohex(1, {})"));
+ ClassicAssert.True(badType2Exc.Message.Contains("bad argument") && badType2Exc.Message.Contains("tohex"));
+ }
+
+ // Make sure casing is handled correctly
+ for (var h = 0; h < 16; h++)
+ {
+ var lower = (string)db.ScriptEvaluate($"return bit.tohex({h}, 1)");
+ var upper = (string)db.ScriptEvaluate($"return bit.tohex({h}, -1)");
+
+ ClassicAssert.AreEqual(h.ToString("x1"), lower);
+ ClassicAssert.AreEqual(h.ToString("X1"), upper);
+ }
+
+ // Run through some weird values
+ (string Value, int? N, string Expected)[] expectedValues = [
+ ("0", null, "00000000"),
+ ("0", 16, "00000000"),
+ ("0", -8, "00000000"),
+ ("123456", null, "0001e240"),
+ ("123456", 5, "1e240"),
+ ("123456", -5, "1E240"),
+ (int.MinValue.ToString(), null, "80000000"),
+ (int.MaxValue.ToString(), null, "7fffffff"),
+ ((1L + int.MaxValue).ToString(), null, "80000000"),
+ ((-1L + int.MinValue).ToString(), null, "7fffffff"),
+ (double.MaxValue.ToString(), 1, "f"),
+ (double.MinValue.ToString(), -1, "F"),
+ (float.MaxValue.ToString(), null, "e54daff8"),
+ (float.MinValue.ToString(), null, "e54daff8"),
+ ];
+ foreach (var (value, length, expected) in expectedValues)
+ {
+ var actual = length != null ?
+ (string)db.ScriptEvaluate($"return bit.tohex({value},{length})") :
+ (string)db.ScriptEvaluate($"return bit.tohex({value})");
+
+ ClassicAssert.AreEqual(expected, actual, $"bit.tohex result for ({value},{length}) was incorrect");
+ }
+ }
+
+ // bswap
+ {
+ if (CanTestLuaErrors)
+ {
+ var noArgExc = ClassicAssert.Throws(() => db.ScriptEvaluate("return bit.bswap()"));
+ ClassicAssert.True(noArgExc.Message.Contains("bad argument") && noArgExc.Message.Contains("bswap"));
+
+ // Extra arguments are legal, but ignored
+
+ var badTypeExc = ClassicAssert.Throws(() => db.ScriptEvaluate("return bit.bswap({})"));
+ ClassicAssert.True(badTypeExc.Message.Contains("bad argument") && badTypeExc.Message.Contains("bswap"));
+ }
+
+ // Just brute force a bunch of trial values
+ foreach (var a in new[] { 0, 1, 2, 4 })
+ {
+ foreach (var b in new[] { 8, 16, 32, 128 })
+ {
+ foreach (var c in new[] { 0, 2, 8, 32, })
+ {
+ foreach (var d in new[] { 1, 4, 16, 64 })
+ {
+ var input = a | (b << 8) | (c << 16) | (d << 32);
+ var expected = BinaryPrimitives.ReverseEndianness(input);
+
+ var actual = (int)db.ScriptEvaluate($"return bit.bswap({input})");
+ ClassicAssert.AreEqual(expected, actual);
+ }
+ }
+ }
+ }
+ }
+
+ // bnot
+ {
+ if (CanTestLuaErrors)
+ {
+ var noArgExc = ClassicAssert.Throws(() => db.ScriptEvaluate("return bit.bnot()"));
+ ClassicAssert.True(noArgExc.Message.Contains("bad argument") && noArgExc.Message.Contains("bnot"));
+
+ // Extra arguments are legal, but ignored
+
+ var badTypeExc = ClassicAssert.Throws(() => db.ScriptEvaluate("return bit.bnot({})"));
+ ClassicAssert.True(badTypeExc.Message.Contains("bad argument") && badTypeExc.Message.Contains("bnot"));
+ }
+
+ foreach (var input in new int[] { 0, 1, 2, 4, 8, 32, 64, 128, 256, 0x70F0_F0F0, 0x6BCD_EF01, int.MinValue, int.MaxValue, -1 })
+ {
+ var expected = ~input;
+
+ var actual = (int)db.ScriptEvaluate($"return bit.bnot({input})");
+ ClassicAssert.AreEqual(expected, actual);
+ }
+ }
+
+ // band, bor, bxor
+ {
+ (int Base, string Name, Func Op)[] ops = [
+ (0, "bor", static (a, b) => a | b),
+ (-1, "band", static (a, b) => a & b),
+ (0, "bxor", static (a, b) => a ^ b),
+ ];
+
+ foreach (var op in ops)
+ {
+ if (CanTestLuaErrors)
+ {
+ var noArgExc = ClassicAssert.Throws(() => db.ScriptEvaluate($"return bit.{op.Name}()"));
+ ClassicAssert.True(noArgExc.Message.Contains("bad argument") && noArgExc.Message.Contains(op.Name));
+
+ var badType1Exc = ClassicAssert.Throws(() => db.ScriptEvaluate($"return bit.{op.Name}({{}})"));
+ ClassicAssert.True(badType1Exc.Message.Contains("bad argument") && badType1Exc.Message.Contains(op.Name));
+
+ var badType2Exc = ClassicAssert.Throws(() => db.ScriptEvaluate($"return bit.{op.Name}(1, {{}})"));
+ ClassicAssert.True(badType2Exc.Message.Contains("bad argument") && badType2Exc.Message.Contains(op.Name));
+
+ var badType3Exc = ClassicAssert.Throws(() => db.ScriptEvaluate($"return bit.{op.Name}(1, 2, {{}})"));
+ ClassicAssert.True(badType3Exc.Message.Contains("bad argument") && badType3Exc.Message.Contains(op.Name));
+ }
+
+ // Gin up some unusual values and test them in different combinations
+ var nextArg = 0x0102_0304;
+ for (var numArgs = 1; numArgs <= 4; numArgs++)
+ {
+ var args = new List();
+ while (args.Count < numArgs)
+ {
+ args.Add(nextArg);
+ nextArg *= 2;
+ nextArg += args.Count;
+ }
+
+ var expected = op.Base;
+ foreach (var arg in args)
+ {
+ expected = op.Op(expected, arg);
+ }
+
+ var actual = (int)db.ScriptEvaluate($"return bit.{op.Name}({string.Join(", ", args)})");
+ ClassicAssert.AreEqual(expected, actual);
+ }
+ }
+ }
+
+ // lshift, rshift, arshift, rol, ror
+ {
+ (string Name, Func Op)[] ops = [
+ ("lshift", static (x, n) => x << n),
+ ("rshift", static (x, n) => (int)((uint)x >> n)),
+ ("arshift", static (x, n) => x >> n),
+ ("rol", static (x, n) => (int)BitOperations.RotateLeft((uint)x, n)),
+ ("ror", static (x, n) => (int)BitOperations.RotateRight((uint)x, n)),
+ ];
+
+ foreach (var op in ops)
+ {
+ if (CanTestLuaErrors)
+ {
+ var noArgExc = ClassicAssert.Throws(() => db.ScriptEvaluate($"return bit.{op.Name}()"));
+ ClassicAssert.True(noArgExc.Message.Contains("bad argument") && noArgExc.Message.Contains(op.Name));
+
+ var badType1Exc = ClassicAssert.Throws(() => db.ScriptEvaluate($"return bit.{op.Name}({{}})"));
+ ClassicAssert.True(badType1Exc.Message.Contains("bad argument") && badType1Exc.Message.Contains(op.Name));
+
+ var badType2Exc = ClassicAssert.Throws(() => db.ScriptEvaluate($"return bit.{op.Name}(1, {{}})"));
+ ClassicAssert.True(badType2Exc.Message.Contains("bad argument") && badType2Exc.Message.Contains(op.Name));
+ }
+
+ // Extra args are allowed, but ignored
+
+ for (var shift = 0; shift < 16; shift++)
+ {
+ const int Value = 0x1234_5678;
+
+ var expected = op.Op(Value, shift);
+ var actual = (int)db.ScriptEvaluate($"return bit.{op.Name}({Value}, {shift})");
+
+ ClassicAssert.AreEqual(expected, actual, $"Incorrect value for bit.{op.Name}({Value}, {shift})");
+ }
+ }
+ }
+ }
+
+ [Test]
+ public void CJson()
+ {
+ using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig());
+ var db = redis.GetDatabase();
+
+ // Encoding
+ {
+ // TODO: Once refactored to avoid longjmp issues, restore on Linux
+ if (CanTestLuaErrors)
+ {
+ var noArgExc = ClassicAssert.Throws(() => db.ScriptEvaluate("return cjson.encode()"));
+ ClassicAssert.True(noArgExc.Message.Contains("bad argument") && noArgExc.Message.Contains("encode"));
+
+ var twoArgExc = ClassicAssert.Throws(() => db.ScriptEvaluate("return cjson.encode(1, 2)"));
+ ClassicAssert.True(twoArgExc.Message.Contains("bad argument") && twoArgExc.Message.Contains("encode"));
+
+ var badTypeExc = ClassicAssert.Throws(() => db.ScriptEvaluate("return cjson.encode((function() end))"));
+ ClassicAssert.True(badTypeExc.Message.Contains("Cannot serialise"));
+ }
+
+ var nilResp = (string)db.ScriptEvaluate("return cjson.encode(nil)");
+ ClassicAssert.AreEqual("null", nilResp);
+
+ var boolResp = (string)db.ScriptEvaluate("return cjson.encode(true)");
+ ClassicAssert.AreEqual("true", boolResp);
+
+ var doubleResp = (string)db.ScriptEvaluate("return cjson.encode(1.23)");
+ ClassicAssert.AreEqual("1.23", doubleResp);
+
+ var simpleStrResp = (string)db.ScriptEvaluate("return cjson.encode('hello')");
+ ClassicAssert.AreEqual("\"hello\"", simpleStrResp);
+
+ var encodedStrResp = (string)db.ScriptEvaluate("return cjson.encode('\"foo\" \\\\bar\\\\')");
+ ClassicAssert.AreEqual("\"\\\"foo\\\" \\\\bar\\\\\"", encodedStrResp);
+
+ var emptyTableResp = (string)db.ScriptEvaluate("return cjson.encode({})");
+ ClassicAssert.AreEqual("{}", emptyTableResp);
+
+ var keyedTableResp = (string)db.ScriptEvaluate("return cjson.encode({key=123})");
+ ClassicAssert.AreEqual("{\"key\":123}", keyedTableResp);
+
+ var indexedTableResp = (string)db.ScriptEvaluate("return cjson.encode({123, 'foo'})");
+ ClassicAssert.AreEqual("[123,\"foo\"]", indexedTableResp);
+
+ var mixedTableResp = (string)db.ScriptEvaluate("local ret = {123}; ret.bar = 'foo'; return cjson.encode(ret)");
+ ClassicAssert.AreEqual("{\"1\":123,\"bar\":\"foo\"}", mixedTableResp);
+
+ // Ordering here is undefined, just doing the brute force approach for ease of implementation
+ var nestedTableResp = (string)db.ScriptEvaluate("return cjson.encode({num=1,str='hello',arr={1,2,3,4},obj={foo='bar'}})");
+ string[] nestedTableRespParts = [
+ "\"arr\":[1,2,3,4]",
+ "\"num\":1",
+ "\"str\":\"hello\"",
+ "\"obj\":{\"foo\":\"bar\"}",
+ ];
+ var possibleNestedTableResps = new List();
+ for (var a = 0; a < nestedTableRespParts.Length; a++)
+ {
+ for (var b = 0; b < nestedTableRespParts.Length; b++)
+ {
+ for (var c = 0; c < nestedTableRespParts.Length; c++)
+ {
+ for (var d = 0; d < nestedTableRespParts.Length; d++)
+ {
+ if (a == b || a == c || a == d || b == c || b == d || c == d)
+ {
+ continue;
+ }
+
+ possibleNestedTableResps.Add($"{{{nestedTableRespParts[a]},{nestedTableRespParts[b]},{nestedTableRespParts[c]},{nestedTableRespParts[d]}}}");
+ }
+ }
+ }
+ }
+ ClassicAssert.True(possibleNestedTableResps.Contains(nestedTableResp));
+
+ var nestArrayResp = (string)db.ScriptEvaluate("return cjson.encode({1,'hello',{1,2,3,4},{foo='bar'}})");
+ ClassicAssert.AreEqual("[1,\"hello\",[1,2,3,4],{\"foo\":\"bar\"}]", nestArrayResp);
+
+ var deeplyNestedButLegal =
+ (string)db.ScriptEvaluate(
+@"local nested = 1
+for x = 1, 1000 do
+ local newNested = {}
+ newNested[1] = nested;
+ nested = newNested
+end
+
+return cjson.encode(nested)");
+ ClassicAssert.AreEqual(new string('[', 1000) + 1 + new string(']', 1000), deeplyNestedButLegal);
+
+ // TODO: Once refactored to avoid longjmp issues, restore on Linux
+ if (CanTestLuaErrors)
+ {
+
+ var deeplyNestedExc =
+ ClassicAssert.Throws(
+ () => db.ScriptEvaluate(
+@"local nested = 1
+for x = 1, 1001 do
+ local newNested = {}
+ newNested[1] = nested;
+ nested = newNested
+end
+
+return cjson.encode(nested)"));
+ ClassicAssert.True(deeplyNestedExc.Message.Contains("Cannot serialise, excessive nesting (1001)"));
+ }
+ }
+
+ // Decoding
+ {
+ // TODO: Once refactored to avoid longjmp issues, restore on Linux
+ if (CanTestLuaErrors)
+ {
+ var noArgExc = ClassicAssert.Throws(() => db.ScriptEvaluate("return cjson.decode()"));
+ ClassicAssert.True(noArgExc.Message.Contains("bad argument") && noArgExc.Message.Contains("decode"));
+
+ var twoArgExc = ClassicAssert.Throws(() => db.ScriptEvaluate("return cjson.decode(1, 2)"));
+ ClassicAssert.True(twoArgExc.Message.Contains("bad argument") && twoArgExc.Message.Contains("decode"));
+
+ var badTypeExc = ClassicAssert.Throws(() => db.ScriptEvaluate("return cjson.decode({})"));
+ ClassicAssert.True(badTypeExc.Message.Contains("bad argument") && badTypeExc.Message.Contains("decode"));
+
+ var badFormatExc = ClassicAssert.Throws(() => db.ScriptEvaluate("return cjson.decode('hello world')"));
+ ClassicAssert.True(badFormatExc.Message.Contains("Expected value but found invalid token"));
+ }
+
+ var numberDecode = (string)db.ScriptEvaluate("return cjson.decode(123)");
+ ClassicAssert.AreEqual("123", numberDecode);
+
+ var boolDecode = (string)db.ScriptEvaluate("return type(cjson.decode('true'))");
+ ClassicAssert.AreEqual("boolean", boolDecode);
+
+ var stringDecode = (string)db.ScriptEvaluate("return cjson.decode('\"hello world\"')");
+ ClassicAssert.AreEqual("hello world", stringDecode);
+
+ var mapDecode = (string)db.ScriptEvaluate("return cjson.decode('{\"hello\":\"world\"}').hello");
+ ClassicAssert.AreEqual("world", mapDecode);
+
+ var arrayDecode = (string)db.ScriptEvaluate("return cjson.decode('[123]')[1]");
+ ClassicAssert.AreEqual("123", arrayDecode);
+
+ var complexMapDecodeArr = (string[])db.ScriptEvaluate("return cjson.decode('{\"arr\":[1,2,3,4],\"num\":1,\"str\":\"hello\",\"obj\":{\"foo\":\"bar\"}}').arr");
+ ClassicAssert.True(complexMapDecodeArr.SequenceEqual(["1", "2", "3", "4"]));
+ var complexMapDecodeNum = (string)db.ScriptEvaluate("return cjson.decode('{\"arr\":[1,2,3,4],\"num\":1,\"str\":\"hello\",\"obj\":{\"foo\":\"bar\"}}').num");
+ ClassicAssert.AreEqual("1", complexMapDecodeNum);
+ var complexMapDecodeStr = (string)db.ScriptEvaluate("return cjson.decode('{\"arr\":[1,2,3,4],\"num\":1,\"str\":\"hello\",\"obj\":{\"foo\":\"bar\"}}').str");
+ ClassicAssert.AreEqual("hello", complexMapDecodeStr);
+ var complexMapDecodeObj = (string)db.ScriptEvaluate("return cjson.decode('{\"arr\":[1,2,3,4],\"num\":1,\"str\":\"hello\",\"obj\":{\"foo\":\"bar\"}}').obj.foo");
+ ClassicAssert.AreEqual("bar", complexMapDecodeObj);
+
+ var complexArrDecodeNum = (string)db.ScriptEvaluate("return cjson.decode('[1,\"hello\",[1,2,3,4],{\"foo\":\"bar\"}]')[1]");
+ ClassicAssert.AreEqual("1", complexArrDecodeNum);
+ var complexArrDecodeStr = (string)db.ScriptEvaluate("return cjson.decode('[1,\"hello\",[1,2,3,4],{\"foo\":\"bar\"}]')[2]");
+ ClassicAssert.AreEqual("hello", complexArrDecodeStr);
+ var complexArrDecodeArr = (string[])db.ScriptEvaluate("return cjson.decode('[1,\"hello\",[1,2,3,4],{\"foo\":\"bar\"}]')[3]");
+ ClassicAssert.True(complexArrDecodeArr.SequenceEqual(["1", "2", "3", "4"]));
+ var complexArrDecodeObj = (string)db.ScriptEvaluate("return cjson.decode('[1,\"hello\",[1,2,3,4],{\"foo\":\"bar\"}]')[4].foo");
+ ClassicAssert.AreEqual("bar", complexArrDecodeObj);
+
+ // Redis cuts us off at 1000 levels of recursion, so check that we're matching that
+ var deeplyNestedButLegal = (RedisResult[])db.ScriptEvaluate($"return cjson.decode('{new string('[', 1000)}{new string(']', 1000)}')");
+ var deeplyNestedButLegalCur = deeplyNestedButLegal;
+ for (var i = 1; i < 1000; i++)
+ {
+ ClassicAssert.AreEqual(1, deeplyNestedButLegalCur.Length);
+ deeplyNestedButLegalCur = (RedisResult[])deeplyNestedButLegalCur[0];
+ }
+ ClassicAssert.AreEqual(0, deeplyNestedButLegalCur.Length);
+
+ if (CanTestLuaErrors)
+ {
+ var deeplyNestedExc = ClassicAssert.Throws(() => db.ScriptEvaluate($"return cjson.decode('{new string('[', 1001)}{new string(']', 1001)}')"));
+ ClassicAssert.True(deeplyNestedExc.Message.Contains("Found too many nested data structures"));
+ }
+ }
+ }
+
+ [Test]
+ public void CMsgPackPack()
+ {
+ using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig());
+ var db = redis.GetDatabase();
+
+ // TODO: Once refactored to avoid longjmp issues, restore on Linux
+ if (CanTestLuaErrors)
+ {
+ var noArgExc = ClassicAssert.Throws(() => db.ScriptEvaluate("return cmsgpack.pack()"));
+ ClassicAssert.True(noArgExc.Message.Contains("bad argument") && noArgExc.Message.Contains("pack"));
+ }
+
+ // Multiple args are legal, and concat
+
+ var nullResp = (byte[])db.ScriptEvaluate("return cmsgpack.pack(nil)");
+ ClassicAssert.True(nullResp.SequenceEqual(new byte[] { 0xC0 }));
+
+ var trueResp = (byte[])db.ScriptEvaluate("return cmsgpack.pack(true)");
+ ClassicAssert.True(trueResp.SequenceEqual(new byte[] { 0xC3 }));
+
+ var falseResp = (byte[])db.ScriptEvaluate("return cmsgpack.pack(false)");
+ ClassicAssert.True(falseResp.SequenceEqual(new byte[] { 0xC2 }));
+
+ var tinyUInt1Resp = (byte[])db.ScriptEvaluate("return cmsgpack.pack(0)");
+ ClassicAssert.True(tinyUInt1Resp.SequenceEqual(new byte[] { 0x00 }));
+
+ var tinyUInt2Resp = (byte[])db.ScriptEvaluate("return cmsgpack.pack(127)");
+ ClassicAssert.True(tinyUInt2Resp.SequenceEqual(new byte[] { 0x7F }));
+
+ var tinyInt1Resp = (byte[])db.ScriptEvaluate("return cmsgpack.pack(-1)");
+ ClassicAssert.True(tinyInt1Resp.SequenceEqual(new byte[] { 0xFF }));
+
+ var tinyInt2Resp = (byte[])db.ScriptEvaluate("return cmsgpack.pack(-32)");
+ ClassicAssert.True(tinyInt2Resp.SequenceEqual(new byte[] { 0xE0 }));
+
+ var smallUInt1Resp = (byte[])db.ScriptEvaluate("return cmsgpack.pack(128)");
+ ClassicAssert.True(smallUInt1Resp.SequenceEqual(new byte[] { 0xCC, 0x80 }));
+
+ var smallUInt2Resp = (byte[])db.ScriptEvaluate("return cmsgpack.pack(255)");
+ ClassicAssert.True(smallUInt2Resp.SequenceEqual(new byte[] { 0xCC, 0xFF }));
+
+ var smallInt1Resp = (byte[])db.ScriptEvaluate("return cmsgpack.pack(-33)");
+ ClassicAssert.True(smallInt1Resp.SequenceEqual(new byte[] { 0xD0, 0xDF }));
+
+ var smallInt2Resp = (byte[])db.ScriptEvaluate("return cmsgpack.pack(-128)");
+ ClassicAssert.True(smallInt2Resp.SequenceEqual(new byte[] { 0xD0, 0x80 }));
+
+ var midUInt1Resp = (byte[])db.ScriptEvaluate("return cmsgpack.pack(32768)");
+ ClassicAssert.True(midUInt1Resp.SequenceEqual(new byte[] { 0xCD, 0x80, 0x00 }));
+
+ var midUInt2Resp = (byte[])db.ScriptEvaluate("return cmsgpack.pack(65535)");
+ ClassicAssert.True(midUInt2Resp.SequenceEqual(new byte[] { 0xCD, 0xFF, 0xFF }));
+
+ var midInt1Resp = (byte[])db.ScriptEvaluate("return cmsgpack.pack(-129)");
+ ClassicAssert.True(midInt1Resp.SequenceEqual(new byte[] { 0xD1, 0xFF, 0x7F }));
+
+ var midInt2Resp = (byte[])db.ScriptEvaluate("return cmsgpack.pack(-32768)");
+ ClassicAssert.True(midInt2Resp.SequenceEqual(new byte[] { 0xD1, 0x80, 0x00 }));
+
+ var uint1Resp = (byte[])db.ScriptEvaluate("return cmsgpack.pack(2147483648)");
+ ClassicAssert.True(uint1Resp.SequenceEqual(new byte[] { 0xCE, 0x80, 0x00, 0x00, 0x00 }));
+
+ var uint2Resp = (byte[])db.ScriptEvaluate("return cmsgpack.pack(4294967295)");
+ ClassicAssert.True(uint2Resp.SequenceEqual(new byte[] { 0xCE, 0xFF, 0xFF, 0xFF, 0xFF }));
+
+ var int1Resp = (byte[])db.ScriptEvaluate("return cmsgpack.pack(-32769)");
+ ClassicAssert.True(int1Resp.SequenceEqual(new byte[] { 0xD2, 0xFF, 0xFF, 0x7F, 0xFF }));
+
+ var int2Resp = (byte[])db.ScriptEvaluate("return cmsgpack.pack(-2147483648)");
+ ClassicAssert.True(int2Resp.SequenceEqual(new byte[] { 0xD2, 0x80, 0x00, 0x00, 0x00 }));
+
+ var bigUIntResp = (byte[])db.ScriptEvaluate("return cmsgpack.pack(4294967296)");
+ ClassicAssert.True(bigUIntResp.SequenceEqual(new byte[] { 0xCF, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00 }));
+
+ var bigIntResp = (byte[])db.ScriptEvaluate("return cmsgpack.pack(-2147483649)");
+ ClassicAssert.True(bigIntResp.SequenceEqual(new byte[] { 0xD3, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F, 0xFF, 0xFF, 0xFF }));
+
+ var floatResp = (byte[])db.ScriptEvaluate("return cmsgpack.pack(0.1)");
+ ClassicAssert.True(floatResp.SequenceEqual(new byte[] { 0xCB, 0x3F, 0xB9, 0x99, 0x99, 0x99, 0x99, 0x99, 0x9A }));
+
+ var tinyString1Resp = (byte[])db.ScriptEvaluate($"return cmsgpack.pack('')");
+ ClassicAssert.True(tinyString1Resp.SequenceEqual(new byte[] { 0xA0, }));
+
+ var tinyString2Resp = (byte[])db.ScriptEvaluate($"return cmsgpack.pack('{new string('0', 31)}')");
+ ClassicAssert.True(tinyString2Resp.SequenceEqual(new byte[] { 0xBF }.Concat(Enumerable.Repeat((byte)'0', 31))));
+
+ var shortString1Resp = (byte[])db.ScriptEvaluate($"return cmsgpack.pack('{new string('a', 32)}')");
+ ClassicAssert.True(shortString1Resp.SequenceEqual(new byte[] { 0xD9, 0x20 }.Concat(Enumerable.Repeat((byte)'a', 32))));
+
+ var shortString2Resp = (byte[])db.ScriptEvaluate($"return cmsgpack.pack('{new string('a', 255)}')");
+ ClassicAssert.True(shortString2Resp.SequenceEqual(new byte[] { 0xD9, 0xFF }.Concat(Enumerable.Repeat((byte)'a', 255))));
+
+ var midString1Resp = (byte[])db.ScriptEvaluate($"return cmsgpack.pack('{new string('b', 256)}')");
+ ClassicAssert.True(midString1Resp.SequenceEqual(new byte[] { 0xDA, 0x01, 0x00 }.Concat(Enumerable.Repeat((byte)'b', 256))));
+
+ var midString2Resp = (byte[])db.ScriptEvaluate($"return cmsgpack.pack('{new string('b', 65535)}')");
+ ClassicAssert.True(midString2Resp.SequenceEqual(new byte[] { 0xDA, 0xFF, 0xFF }.Concat(Enumerable.Repeat((byte)'b', 65535))));
+
+ var longStringResp = (byte[])db.ScriptEvaluate($"return cmsgpack.pack('{new string('c', 65536)}')");
+ ClassicAssert.True(longStringResp.SequenceEqual(new byte[] { 0xDB, 0x00, 0x01, 0x00, 0x00 }.Concat(Enumerable.Repeat((byte)'c', 65536))));
+
+ var emptyTableResp = (byte[])db.ScriptEvaluate("return cmsgpack.pack({})");
+ ClassicAssert.True(emptyTableResp.SequenceEqual(new byte[] { 0x90 }));
+
+ var smallArray1Resp = (byte[])db.ScriptEvaluate("return cmsgpack.pack({1})");
+ ClassicAssert.True(smallArray1Resp.SequenceEqual(new byte[] { 0x91, 0x01 }));
+
+ var smallArray2Resp = (byte[])db.ScriptEvaluate("return cmsgpack.pack({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15})");
+ ClassicAssert.True(smallArray2Resp.SequenceEqual(new byte[] { 0x9F, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F }));
+
+ var midArray1Resp = (byte[])db.ScriptEvaluate($"return cmsgpack.pack({{{string.Join(", ", Enumerable.Repeat(1, 16))}}})");
+ ClassicAssert.True(midArray1Resp.SequenceEqual(new byte[] { 0xDC, 0x00, 0x10 }.Concat(Enumerable.Repeat((byte)0x01, 16))));
+
+ var midArray2Resp = (byte[])db.ScriptEvaluate($"return cmsgpack.pack({{{string.Join(", ", Enumerable.Repeat(2, ushort.MaxValue))}}})");
+ ClassicAssert.True(midArray2Resp.SequenceEqual(new byte[] { 0xDC, 0xFF, 0xFF }.Concat(Enumerable.Repeat((byte)0x02, ushort.MaxValue))));
+
+ var bigArrayResp = (byte[])db.ScriptEvaluate($"return cmsgpack.pack({{{string.Join(", ", Enumerable.Repeat(3, ushort.MaxValue + 1))}}})");
+ ClassicAssert.True(bigArrayResp.SequenceEqual(new byte[] { 0xDD, 0x00, 0x01, 0x00, 0x00 }.Concat(Enumerable.Repeat((byte)0x03, ushort.MaxValue + 1))));
+
+ var smallMap1Resp = (byte[])db.ScriptEvaluate("return cmsgpack.pack({a=1})");
+ ClassicAssert.True(smallMap1Resp.SequenceEqual(new byte[] { 0x81, 0xA1, 0x61, 0x01 }));
+
+ var smallMap2Resp = (byte[])db.ScriptEvaluate("return cmsgpack.pack({a=1,b=2,c=3,d=4,e=5,f=6,g=7,h=8,i=9,j=10,k=11,l=12,m=13,n=14,o=15})");
+ ClassicAssert.AreEqual(46, smallMap2Resp.Length);
+ ClassicAssert.AreEqual(0x8F, smallMap2Resp[0]);
+ ClassicAssert.AreNotEqual(-1, smallMap2Resp.AsSpan().IndexOf(new byte[] { 0xA1, 0x61, 0x01 }));
+ ClassicAssert.AreNotEqual(-1, smallMap2Resp.AsSpan().IndexOf(new byte[] { 0xA1, 0x62, 0x02 }));
+ ClassicAssert.AreNotEqual(-1, smallMap2Resp.AsSpan().IndexOf(new byte[] { 0xA1, 0x63, 0x03 }));
+ ClassicAssert.AreNotEqual(-1, smallMap2Resp.AsSpan().IndexOf(new byte[] { 0xA1, 0x64, 0x04 }));
+ ClassicAssert.AreNotEqual(-1, smallMap2Resp.AsSpan().IndexOf(new byte[] { 0xA1, 0x65, 0x05 }));
+ ClassicAssert.AreNotEqual(-1, smallMap2Resp.AsSpan().IndexOf(new byte[] { 0xA1, 0x66, 0x06 }));
+ ClassicAssert.AreNotEqual(-1, smallMap2Resp.AsSpan().IndexOf(new byte[] { 0xA1, 0x67, 0x07 }));
+ ClassicAssert.AreNotEqual(-1, smallMap2Resp.AsSpan().IndexOf(new byte[] { 0xA1, 0x68, 0x08 }));
+ ClassicAssert.AreNotEqual(-1, smallMap2Resp.AsSpan().IndexOf(new byte[] { 0xA1, 0x69, 0x09 }));
+ ClassicAssert.AreNotEqual(-1, smallMap2Resp.AsSpan().IndexOf(new byte[] { 0xA1, 0x6A, 0x0A }));
+ ClassicAssert.AreNotEqual(-1, smallMap2Resp.AsSpan().IndexOf(new byte[] { 0xA1, 0x6B, 0x0B }));
+ ClassicAssert.AreNotEqual(-1, smallMap2Resp.AsSpan().IndexOf(new byte[] { 0xA1, 0x6C, 0x0C }));
+ ClassicAssert.AreNotEqual(-1, smallMap2Resp.AsSpan().IndexOf(new byte[] { 0xA1, 0x6D, 0x0D }));
+ ClassicAssert.AreNotEqual(-1, smallMap2Resp.AsSpan().IndexOf(new byte[] { 0xA1, 0x6E, 0x0E }));
+ ClassicAssert.AreNotEqual(-1, smallMap2Resp.AsSpan().IndexOf(new byte[] { 0xA1, 0x6F, 0x0F }));
+
+ var midKeys = string.Join(", ", Enumerable.Range(0, 16).Select(static x => $"m_{(char)('A' + x)}={x}"));
+ var midMapResp = (byte[])db.ScriptEvaluate($"return cmsgpack.pack({{ {midKeys} }})");
+ ClassicAssert.AreEqual(83, midMapResp.Length);
+ ClassicAssert.AreEqual(0xDE, midMapResp[0]);
+ ClassicAssert.AreEqual(0, midMapResp[1]);
+ ClassicAssert.AreEqual(16, midMapResp[2]);
+ for (var val = 0; val < 16; val++)
+ {
+ var keyPart = (byte)('A' + val);
+ var expected = new byte[] { 0xA3, (byte)'m', (byte)'_', keyPart, (byte)val };
+ ClassicAssert.AreNotEqual(-1, midMapResp.AsSpan().IndexOf(expected));
+ }
+
+ var bigKeys = string.Join(", ", Enumerable.Range(0, ushort.MaxValue + 1).Select(static x => $"f_{x:X4}=4"));
+ var bigMapResp = (byte[])db.ScriptEvaluate($"return cmsgpack.pack({{ {bigKeys} }})");
+ ClassicAssert.AreEqual(524_293, bigMapResp.Length);
+ ClassicAssert.AreEqual(0xDF, bigMapResp[0]);
+ ClassicAssert.AreEqual(0, bigMapResp[1]);
+ ClassicAssert.AreEqual(1, bigMapResp[2]);
+ ClassicAssert.AreEqual(0, bigMapResp[3]);
+ ClassicAssert.AreEqual(0, bigMapResp[4]);
+ for (var val = 0; val <= ushort.MaxValue; val++)
+ {
+ var keyStr = val.ToString("X4");
+
+ var expected = new byte[] { 0xA6, (byte)'f', (byte)'_', (byte)keyStr[0], (byte)keyStr[1], (byte)keyStr[2], (byte)keyStr[3], 4 };
+ ClassicAssert.AreNotEqual(-1, bigMapResp.AsSpan().IndexOf(expected));
+ }
+
+ var complexResp = (byte[])db.ScriptEvaluate("return cmsgpack.pack({4, { key='value', arr={5, 6} }, true, 1.23})");
+ ClassicAssert.AreEqual(30, complexResp.Length);
+ ClassicAssert.AreEqual(0x94, complexResp[0]);
+ ClassicAssert.AreEqual(0x04, complexResp[1]);
+ var complexNestedMap = complexResp.AsSpan().Slice(2, 18);
+ ClassicAssert.AreEqual(0x82, complexNestedMap[0]);
+ complexNestedMap = complexNestedMap[1..];
+ ClassicAssert.AreNotEqual(-1, complexNestedMap.IndexOf(new byte[] { 0b1010_0011, (byte)'k', (byte)'e', (byte)'y', 0b1010_0101, (byte)'v', (byte)'a', (byte)'l', (byte)'u', (byte)'e' }));
+ ClassicAssert.AreNotEqual(-1, complexNestedMap.IndexOf(new byte[] { 0b1010_0011, (byte)'a', (byte)'r', (byte)'r', 0b1001_0010, 0x05, 0x06 }));
+ ClassicAssert.AreEqual(0xC3, complexResp[20]);
+ ClassicAssert.AreEqual(0xCB, complexResp[21]);
+ ClassicAssert.AreEqual(1.23, BinaryPrimitives.ReadDoubleBigEndian(complexResp.AsSpan()[22..]));
+
+ var concatedResp = (byte[])db.ScriptEvaluate("return cmsgpack.pack(1, 2, 3, 4, {5})");
+ ClassicAssert.True(concatedResp.SequenceEqual(new byte[] { 0x01, 0x02, 0x03, 0x04, 0b1001_0001, 0x05 }));
+
+ // Rather than an error, Redis converts a too deeply nested object into a null (very strange)
+ //
+ // We match that behavior
+
+ var infiniteNestMap = (byte[])db.ScriptEvaluate("local a = {}; a.ref = a; return cmsgpack.pack(a)");
+ ClassicAssert.True(infiniteNestMap.SequenceEqual(new byte[] { 0x81, 0xA3, 0x72, 0x65, 0x66, 0x81, 0xA3, 0x72, 0x65, 0x66, 0x81, 0xA3, 0x72, 0x65, 0x66, 0x81, 0xA3, 0x72, 0x65, 0x66, 0x81, 0xA3, 0x72, 0x65, 0x66, 0x81, 0xA3, 0x72, 0x65, 0x66, 0x81, 0xA3, 0x72, 0x65, 0x66, 0x81, 0xA3, 0x72, 0x65, 0x66, 0x81, 0xA3, 0x72, 0x65, 0x66, 0x81, 0xA3, 0x72, 0x65, 0x66, 0x81, 0xA3, 0x72, 0x65, 0x66, 0x81, 0xA3, 0x72, 0x65, 0x66, 0x81, 0xA3, 0x72, 0x65, 0x66, 0x81, 0xA3, 0x72, 0x65, 0x66, 0x81, 0xA3, 0x72, 0x65, 0x66, 0x81, 0xA3, 0x72, 0x65, 0x66, 0xC0 }));
+
+ var infiniteNestArr = (byte[])db.ScriptEvaluate("local a = {}; a[1] = a; return cmsgpack.pack(a)");
+ ClassicAssert.True(infiniteNestArr.SequenceEqual(new byte[] { 0x91, 0x91, 0x91, 0x91, 0x91, 0x91, 0x91, 0x91, 0x91, 0x91, 0x91, 0x91, 0x91, 0x91, 0x91, 0x91, 0xC0 }));
+ }
+
+ [Test]
+ public void CMsgPackUnpack()
+ {
+ using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig());
+ var db = redis.GetDatabase();
+
+ // TODO: Once refactored to avoid longjmp issues, restore on Linux
+ if (CanTestLuaErrors)
+ {
+ var noArgExc = ClassicAssert.Throws(() => db.ScriptEvaluate("return cmsgpack.unpack()"));
+ ClassicAssert.True(noArgExc.Message.Contains("bad argument") && noArgExc.Message.Contains("unpack"));
+
+ // Multiple arguments are allowed, but ignored
+
+ // Table ends before it should
+ var badDataExc = ClassicAssert.Throws(() => db.ScriptEvaluate("return cmsgpack.unpack('\\220\\0\\96')"));
+ ClassicAssert.True(badDataExc.Message.Contains("Missing bytes in input"));
+ }
+
+ var nullResp = (string)db.ScriptEvaluate($"return type(cmsgpack.unpack({ToLuaString(0xC0)}))");
+ ClassicAssert.AreEqual("nil", nullResp);
+
+ var trueResp = (string)db.ScriptEvaluate($"return type(cmsgpack.unpack({ToLuaString(0xC3)}))");
+ ClassicAssert.AreEqual("boolean", trueResp);
+
+ var falseResp = (string)db.ScriptEvaluate($"return type(cmsgpack.unpack({ToLuaString(0xC2)}))");
+ ClassicAssert.AreEqual("boolean", falseResp);
+
+ var tinyUInt1Resp = (int)db.ScriptEvaluate($"return cmsgpack.unpack({ToLuaString(0x00)})");
+ ClassicAssert.AreEqual(0, tinyUInt1Resp);
+
+ var tinyUInt2Resp = (int)db.ScriptEvaluate($"return cmsgpack.unpack({ToLuaString(0x7F)})");
+ ClassicAssert.AreEqual(127, tinyUInt2Resp);
+
+ var tinyInt1Resp = (int)db.ScriptEvaluate($"return cmsgpack.unpack({ToLuaString(0xFF)})");
+ ClassicAssert.AreEqual(-1, tinyInt1Resp);
+
+ var tinyInt2Resp = (int)db.ScriptEvaluate($"return cmsgpack.unpack({ToLuaString(0xE0)})");
+ ClassicAssert.AreEqual(-32, tinyInt2Resp);
+
+ var smallUInt1Resp = (int)db.ScriptEvaluate($"return cmsgpack.unpack({ToLuaString(0xCC, 0x80)})");
+ ClassicAssert.AreEqual(128, smallUInt1Resp);
+
+ var smallUInt2Resp = (int)db.ScriptEvaluate($"return cmsgpack.unpack({ToLuaString(0xCC, 0xFF)})");
+ ClassicAssert.AreEqual(255, smallUInt2Resp);
+
+ var smallInt1Resp = (int)db.ScriptEvaluate($"return cmsgpack.unpack({ToLuaString(0xD0, 0xDF)})");
+ ClassicAssert.AreEqual(-33, smallInt1Resp);
+
+ var smallInt2Resp = (int)db.ScriptEvaluate($"return cmsgpack.unpack({ToLuaString(0xD0, 0x80)})");
+ ClassicAssert.AreEqual(-128, smallInt2Resp);
+
+ var midUInt1Resp = (int)db.ScriptEvaluate($"return cmsgpack.unpack({ToLuaString(0xCD, 0x80, 0x00)})");
+ ClassicAssert.AreEqual(32768, midUInt1Resp);
+
+ var midUInt2Resp = (int)db.ScriptEvaluate($"return cmsgpack.unpack({ToLuaString(0xCD, 0xFF, 0xFF)})");
+ ClassicAssert.AreEqual(65535, midUInt2Resp);
+
+ var midInt1Resp = (int)db.ScriptEvaluate($"return cmsgpack.unpack({ToLuaString(0xD1, 0xFF, 0x7F)})");
+ ClassicAssert.AreEqual(-129, midInt1Resp);
+
+ var midInt2Resp = (int)db.ScriptEvaluate($"return cmsgpack.unpack({ToLuaString(0xD1, 0x80, 0x00)})");
+ ClassicAssert.AreEqual(-32768, midInt2Resp);
+
+ var uint1Resp = (long)db.ScriptEvaluate($"return cmsgpack.unpack({ToLuaString(0xCE, 0x80, 0x00, 0x00, 0x00)})");
+ ClassicAssert.AreEqual(2147483648, uint1Resp);
+
+ var uint2Resp = (long)db.ScriptEvaluate($"return cmsgpack.unpack({ToLuaString(0xCE, 0xFF, 0xFF, 0xFF, 0xFF)})");
+ ClassicAssert.AreEqual(4294967295L, uint2Resp);
+
+ var int1Resp = (long)db.ScriptEvaluate($"return cmsgpack.unpack({ToLuaString(0xD2, 0xFF, 0xFF, 0x7F, 0xFF)})");
+ ClassicAssert.AreEqual(-32769, int1Resp);
+
+ var int2Resp = (long)db.ScriptEvaluate($"return cmsgpack.unpack({ToLuaString(0xD2, 0x80, 0x00, 0x00, 0x00)})");
+ ClassicAssert.AreEqual(-2147483648, int2Resp);
+
+ var bigUIntResp = (long)db.ScriptEvaluate($"return cmsgpack.unpack({ToLuaString(0xCF, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00)})");
+ ClassicAssert.AreEqual(4294967296L, bigUIntResp);
+
+ var bigIntResp = (long)db.ScriptEvaluate($"return cmsgpack.unpack({ToLuaString(0xD3, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F, 0xFF, 0xFF, 0xFF)})");
+ ClassicAssert.AreEqual(-2147483649L, bigIntResp);
+
+ var floatResp = (string)db.ScriptEvaluate($"return tostring(cmsgpack.unpack({ToLuaString(0xCB, 0x3F, 0xB9, 0x99, 0x99, 0x99, 0x99, 0x99, 0x9A)}))");
+ ClassicAssert.AreEqual("0.1", floatResp);
+
+ var tinyString1Resp = (string)db.ScriptEvaluate($"return cmsgpack.unpack({ToLuaString(0xA0)})");
+ ClassicAssert.AreEqual("", tinyString1Resp);
+
+ var tinyString2Resp = (string)db.ScriptEvaluate($"return cmsgpack.unpack({ToLuaString(new byte[] { 0xBF }.Concat(Enumerable.Repeat((byte)'0', 31)).ToArray())})");
+ ClassicAssert.AreEqual(new string('0', 31), tinyString2Resp);
+
+ var shortString1Resp = (string)db.ScriptEvaluate($"return cmsgpack.unpack({ToLuaString(new byte[] { 0xD9, 0x20 }.Concat(Enumerable.Repeat((byte)'a', 32)).ToArray())})");
+ ClassicAssert.AreEqual(new string('a', 32), shortString1Resp);
+
+ var shortString2Resp = (string)db.ScriptEvaluate($"return cmsgpack.unpack({ToLuaString(new byte[] { 0xD9, 0xFF }.Concat(Enumerable.Repeat((byte)'a', 255)).ToArray())})");
+ ClassicAssert.AreEqual(new string('a', 255), shortString2Resp);
+
+ var midString1Resp = (string)db.ScriptEvaluate($"return cmsgpack.unpack({ToLuaString(new byte[] { 0xDA, 0x01, 0x00 }.Concat(Enumerable.Repeat((byte)'b', 256)).ToArray())})");
+ ClassicAssert.AreEqual(new string('b', 256), midString1Resp);
+
+ var midString2Resp = (string)db.ScriptEvaluate($"return cmsgpack.unpack({ToLuaString(new byte[] { 0xDA, 0xFF, 0xFF }.Concat(Enumerable.Repeat((byte)'b', 65535)).ToArray())})");
+ ClassicAssert.AreEqual(new string('b', 65535), midString2Resp);
+
+ var longStringResp = (string)db.ScriptEvaluate($"return cmsgpack.unpack({ToLuaString(new byte[] { 0xDB, 0x00, 0x01, 0x00, 0x00 }.Concat(Enumerable.Repeat((byte)'c', 65536)).ToArray())})");
+ ClassicAssert.AreEqual(new string('c', 65536), longStringResp);
+
+ var emptyTableResp = (int)db.ScriptEvaluate($"return #cmsgpack.unpack({ToLuaString(0x90)})");
+ ClassicAssert.AreEqual(0, emptyTableResp);
+
+ var smallArray1Resp = (int)db.ScriptEvaluate($"return cmsgpack.unpack({ToLuaString(0x91, 0x01)})[1]");
+ ClassicAssert.AreEqual(1, smallArray1Resp);
+
+ var smallArray2Resp = (int)db.ScriptEvaluate($"return cmsgpack.unpack({ToLuaString(0x9F, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F)})[14]");
+ ClassicAssert.AreEqual(14, smallArray2Resp);
+
+ var midArray1Resp = (int)db.ScriptEvaluate($"return cmsgpack.unpack({ToLuaString(new byte[] { 0xDC, 0x00, 0x10 }.Concat(Enumerable.Repeat((byte)0x01, 16)).ToArray())})[16]");
+ ClassicAssert.AreEqual(1, midArray1Resp);
+
+ var midArray2Resp = (int)db.ScriptEvaluate($"return cmsgpack.unpack({ToLuaString(new byte[] { 0xDC, 0xFF, 0xFF }.Concat(Enumerable.Repeat((byte)0x02, ushort.MaxValue)).ToArray())})[1000]");
+ ClassicAssert.AreEqual(2, midArray2Resp);
+
+ var bigArrayResp = (int)db.ScriptEvaluate($"return cmsgpack.unpack({ToLuaString(new byte[] { 0xDD, 0x00, 0x01, 0x00, 0x00 }.Concat(Enumerable.Repeat((byte)0x03, ushort.MaxValue + 1)).ToArray())})[20000]");
+ ClassicAssert.AreEqual(3, bigArrayResp);
+
+ var smallMap1Resp = (int)db.ScriptEvaluate($"return cmsgpack.unpack({ToLuaString(0x81, 0xA1, 0x61, 0x01)}).a");
+ ClassicAssert.AreEqual(1, smallMap1Resp);
+
+ var smallMap2Resp = (int)db.ScriptEvaluate($"return cmsgpack.unpack({ToLuaString(0x8F, 0xA1, 0x61, 0x01, 0xA1, 0x62, 0x02, 0xA1, 0x63, 0x03, 0xA1, 0x64, 0x04, 0xA1, 0x65, 0x05, 0xA1, 0x66, 0x06, 0xA1, 0x67, 0x07, 0xA1, 0x68, 0x08, 0xA1, 0x69, 0x09, 0xA1, 0x6A, 0x0A, 0xA1, 0x6B, 0x0B, 0xA1, 0x6C, 0x0C, 0xA1, 0x6D, 0x0D, 0xA1, 0x6E, 0x0E, 0xA1, 0x6F, 0x0F)}).o");
+ ClassicAssert.AreEqual(15, smallMap2Resp);
+
+ var midMapBytes = new List { 0xDE, 0x00, 0x10 };
+ for (var val = 0; val < 16; val++)
+ {
+ var keyPart = (byte)('A' + val);
+ midMapBytes.AddRange([0xA3, (byte)'m', (byte)'_', keyPart, (byte)val]);
+ }
+ var midMapResp = (int)db.ScriptEvaluate($"return cmsgpack.unpack({ToLuaString(midMapBytes.ToArray())}).m_F");
+ ClassicAssert.AreEqual(5, midMapResp);
+
+ var bigMapBytes = new List { 0xDF, 0x00, 0x01, 0x00, 0x00 };
+ for (var val = 0; val <= ushort.MaxValue; val++)
+ {
+ var keyStr = val.ToString("X4");
+
+ bigMapBytes.AddRange([0xA6, (byte)'f', (byte)'_', (byte)keyStr[0], (byte)keyStr[1], (byte)keyStr[2], (byte)keyStr[3], 4]);
+ }
+ var bigMapRes = (int)db.ScriptEvaluate($"return cmsgpack.unpack({ToLuaString(bigMapBytes.ToArray())}).f_0123");
+ ClassicAssert.AreEqual(4, bigMapRes);
+
+ var nestedResp = (int)db.ScriptEvaluate($"return cmsgpack.unpack({ToLuaString(0x94, 0x04, 0x82, 0b1010_0011, (byte)'k', (byte)'e', (byte)'y', 0b1010_0101, (byte)'v', (byte)'a', (byte)'l', (byte)'u', (byte)'e', 0b1010_0011, (byte)'a', (byte)'r', (byte)'r', 0b1001_0010, 0x05, 0x06, 0xC3, 0xCB, 0x3F, 0xF3, 0xAE, 0x14, 0x7A, 0xE1, 0x47, 0xAE)})[2].arr[2]");
+ ClassicAssert.AreEqual(6, nestedResp);
+
+ var multiResp = (int[])db.ScriptEvaluate($"local a, b, c, d, e = cmsgpack.unpack({ToLuaString(0x01, 0x02, 0x03, 0x04, 0b1001_0001, 0x05)}); return {{e[1], d, c, b, a}}");
+ ClassicAssert.True(multiResp.SequenceEqual([5, 4, 3, 2, 1]));
+
+ // Helper for encoding a byte array into something that can be passed to Lua
+ static string ToLuaString(params byte[] data)
+ {
+ var ret = new StringBuilder();
+ _ = ret.Append('\'');
+
+ foreach (var b in data)
+ {
+ _ = ret.Append($"\\{b}");
+ }
+
+ _ = ret.Append('\'');
+
+ return ret.ToString();
+ }
+ }
+
+ [Test]
+ public void Struct()
+ {
+ // Redis struct.pack/unpack/size is a subset of Lua's string.pack/unpack/packsize; so just testing for basic functionality
+ using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig());
+ var db = redis.GetDatabase();
+
+ var packRes = (byte[])db.ScriptEvaluate("return struct.pack('HH', 1, 2)");
+ ClassicAssert.True(packRes.SequenceEqual(new byte[] { 0x01, 0x00, 0x02, 0x00 }));
+
+ var unpackRes = (int[])db.ScriptEvaluate("return { struct.unpack('HH', '\\01\\00\\02\\00') }");
+ ClassicAssert.True(unpackRes.SequenceEqual([1, 2, 5]));
+
+ var sizeRes = (int)db.ScriptEvaluate("return struct.size('HH')");
+ ClassicAssert.AreEqual(4, sizeRes);
+ }
+
+ [Test]
+ public void MathFunctions()
+ {
+ // There are a number of "weird" math functions Redis supports that don't have direct .NET equivalents
+ //
+ // Doing some basic testing on these implementations
+ //
+ // We don't actually guarantee bit-for-bit or char-for-char equivalence, but "close" is worth attempting
+ using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig());
+ var db = redis.GetDatabase();
+
+ // Frexp
+ {
+ var a = (string)db.ScriptEvaluate("local a,b = math.frexp(0); return tostring(a)..'|'..tostring(b)");
+ ClassicAssert.AreEqual("0|0", a);
+
+ var b = (string)db.ScriptEvaluate("local a,b = math.frexp(1); return tostring(a)..'|'..tostring(b)");
+ ClassicAssert.AreEqual("0.5|1", b);
+
+ var c = (string)db.ScriptEvaluate($"local a,b = math.frexp({double.MaxValue}); return tostring(a)..'|'..tostring(b)");
+ ClassicAssert.AreEqual("1|1024", c);
+
+ var d = (string)db.ScriptEvaluate($"local a,b = math.frexp({double.MinValue}); return tostring(a)..'|'..tostring(b)");
+ ClassicAssert.AreEqual("-1|1024", d);
+
+ var e = (string)db.ScriptEvaluate("local a,b = math.frexp(1234.56); return tostring(a)..'|'..tostring(b)");
+ ClassicAssert.AreEqual("0.6028125|11", e);
+
+ var f = (string)db.ScriptEvaluate("local a,b = math.frexp(-7890.12); return tostring(a)..'|'..tostring(b)");
+ ClassicAssert.AreEqual("-0.9631494140625|13", f);
+ }
+
+ // Ldexp
+ {
+ var a = (string)db.ScriptEvaluate("return tostring(math.ldexp(0, 0))");
+ ClassicAssert.AreEqual("0", a);
+
+ var b = (string)db.ScriptEvaluate("return tostring(math.ldexp(1, 1))");
+ ClassicAssert.AreEqual("2", b);
+
+ var c = (string)db.ScriptEvaluate("return tostring(math.ldexp(0, 1))");
+ ClassicAssert.AreEqual("0", c);
+
+ var d = (string)db.ScriptEvaluate("return tostring(math.ldexp(1, 0))");
+ ClassicAssert.AreEqual("1", d);
+
+ var e = (string)db.ScriptEvaluate($"return tostring(math.ldexp({double.MaxValue}, 0))");
+ ClassicAssert.AreEqual("1.7976931348623e+308", e);
+
+ var f = (string)db.ScriptEvaluate($"return tostring(math.ldexp({double.MaxValue}, 1))");
+ ClassicAssert.AreEqual("inf", f);
+
+ var g = (string)db.ScriptEvaluate($"return tostring(math.ldexp({double.MinValue}, 0))");
+ ClassicAssert.AreEqual("-1.7976931348623e+308", g);
+
+ var h = (string)db.ScriptEvaluate($"return tostring(math.ldexp({double.MinValue}, 1))");
+ ClassicAssert.AreEqual("-inf", h);
+
+ var i = (string)db.ScriptEvaluate($"return tostring(math.ldexp(1.234, 1.234))");
+ ClassicAssert.AreEqual("2.468", i);
+
+ var j = (string)db.ScriptEvaluate($"return tostring(math.ldexp(-5.6798, 9.0123))");
+ ClassicAssert.AreEqual("-2908.0576", j);
+ }
+ }
+
+ [Test]
+ public void Maxn()
+ {
+ using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig());
+ var db = redis.GetDatabase();
+
+ var empty = (int)db.ScriptEvaluate("return table.maxn({})");
+ ClassicAssert.AreEqual(0, empty);
+
+ var single = (int)db.ScriptEvaluate("return table.maxn({4})");
+ ClassicAssert.AreEqual(1, single);
+
+ var multiple = (int)db.ScriptEvaluate("return table.maxn({-1, 1, 2, 5})");
+ ClassicAssert.AreEqual(4, multiple);
+
+ var keyed = (int)db.ScriptEvaluate("return table.maxn({foo='bar',fizz='buzz',hello='world'})");
+ ClassicAssert.AreEqual(0, keyed);
+
+ var mixed = (int)db.ScriptEvaluate("return table.maxn({-1, 1, foo='bar', 3})");
+ ClassicAssert.AreEqual(3, mixed);
+
+ var allNegative = (int)db.ScriptEvaluate("local x = {}; x[-1] = 1; x[-2]=2; return table.maxn(x)");
+ ClassicAssert.AreEqual(0, allNegative);
+ }
+
+ [Test]
+ public void LoadString()
+ {
+ using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig());
+ var db = redis.GetDatabase();
+
+ var basic = (int)db.ScriptEvaluate("local x = loadstring('return 123'); return x()");
+ ClassicAssert.AreEqual(123, basic);
+
+ // TODO: Once refactored to avoid longjmp issues, restore on Linux
+ if (CanTestLuaErrors)
+ {
+ var rejectNullExc = ClassicAssert.Throws(() => db.ScriptEvaluate("local x = loadstring('return \"\\0\"'); return x()"));
+ ClassicAssert.True(rejectNullExc.Message.Contains("bad argument to loadstring, interior null byte"));
+ }
+ }
+
[Test]
[Ignore("Long running, disabled by default")]
public void StressTimeouts()
{
- // Psuedo-repeatably random
+ // Pseudo-repeatably random
const int SEED = 2025_01_30_00;
const int DurationMS = 60 * 60 * 1_000;
diff --git a/test/Garnet.test/TestUtils.cs b/test/Garnet.test/TestUtils.cs
index d2a392cae98..09c004723a2 100644
--- a/test/Garnet.test/TestUtils.cs
+++ b/test/Garnet.test/TestUtils.cs
@@ -228,6 +228,7 @@ public static GarnetServer CreateGarnetServer(
string luaMemoryLimit = "",
TimeSpan? luaTimeout = null,
LuaLoggingMode luaLoggingMode = LuaLoggingMode.Enable,
+ IEnumerable luaAllowedFunctions = null,
string unixSocketPath = null,
UnixFileMode unixSocketPermission = default,
int slowLogThreshold = 0,
@@ -312,7 +313,7 @@ public static GarnetServer CreateGarnetServer(
EnableReadCache = enableReadCache,
EnableObjectStoreReadCache = enableObjectStoreReadCache,
ReplicationOffsetMaxLag = asyncReplay ? -1 : 0,
- LuaOptions = enableLua ? new LuaOptions(luaMemoryMode, luaMemoryLimit, luaTimeout ?? Timeout.InfiniteTimeSpan, luaLoggingMode, logger) : null,
+ LuaOptions = enableLua ? new LuaOptions(luaMemoryMode, luaMemoryLimit, luaTimeout ?? Timeout.InfiniteTimeSpan, luaLoggingMode, luaAllowedFunctions ?? [], logger) : null,
UnixSocketPath = unixSocketPath,
UnixSocketPermission = unixSocketPermission,
SlowLogThreshold = slowLogThreshold
@@ -528,6 +529,7 @@ public static GarnetServerOptions GetGarnetServerOptions(
string luaMemoryLimit = "",
TimeSpan? luaTimeout = null,
LuaLoggingMode luaLoggingMode = LuaLoggingMode.Enable,
+ IEnumerable luaAllowedFunctions = null,
string unixSocketPath = null)
{
if (useAzureStorage)
@@ -631,7 +633,7 @@ public static GarnetServerOptions GetGarnetServerOptions(
ClusterPassword = authPassword,
EnableLua = enableLua,
ReplicationOffsetMaxLag = asyncReplay ? -1 : 0,
- LuaOptions = enableLua ? new LuaOptions(luaMemoryMode, luaMemoryLimit, luaTimeout ?? Timeout.InfiniteTimeSpan, luaLoggingMode, logger) : null,
+ LuaOptions = enableLua ? new LuaOptions(luaMemoryMode, luaMemoryLimit, luaTimeout ?? Timeout.InfiniteTimeSpan, luaLoggingMode, luaAllowedFunctions ?? [], logger) : null,
UnixSocketPath = unixSocketPath,
ReplicaDisklessSync = enableDisklessSync,
ReplicaDisklessSyncDelay = 1