diff --git a/benchmark/BDN.benchmark/Lua/LuaParams.cs b/benchmark/BDN.benchmark/Lua/LuaParams.cs index bd9058f2680..40497074c19 100644 --- a/benchmark/BDN.benchmark/Lua/LuaParams.cs +++ b/benchmark/BDN.benchmark/Lua/LuaParams.cs @@ -29,7 +29,7 @@ public LuaParams(LuaMemoryManagementMode mode, bool memoryLimit, TimeSpan? timeo /// Get the equivalent . /// public LuaOptions CreateOptions() - => new(Mode, MemoryLimit ? "2m" : "", Timeout ?? System.Threading.Timeout.InfiniteTimeSpan, LuaLoggingMode.Enable); + => new(Mode, MemoryLimit ? "2m" : "", Timeout ?? System.Threading.Timeout.InfiniteTimeSpan, LuaLoggingMode.Enable, []); /// /// String representation diff --git a/benchmark/BDN.benchmark/Operations/OperationsBase.cs b/benchmark/BDN.benchmark/Operations/OperationsBase.cs index f9858260b9b..f57caaa032b 100644 --- a/benchmark/BDN.benchmark/Operations/OperationsBase.cs +++ b/benchmark/BDN.benchmark/Operations/OperationsBase.cs @@ -56,7 +56,7 @@ public virtual void GlobalSetup() QuietMode = true, EnableLua = true, DisablePubSub = true, - LuaOptions = new(LuaMemoryManagementMode.Native, "", Timeout.InfiniteTimeSpan, LuaLoggingMode.Enable), + LuaOptions = new(LuaMemoryManagementMode.Native, "", Timeout.InfiniteTimeSpan, LuaLoggingMode.Enable, []), }; if (Params.useAof) diff --git a/libs/host/Configuration/Options.cs b/libs/host/Configuration/Options.cs index ba85044b825..f002f31f30a 100644 --- a/libs/host/Configuration/Options.cs +++ b/libs/host/Configuration/Options.cs @@ -560,6 +560,27 @@ internal sealed class Options [Option("lua-logging-mode", Required = false, HelpText = "Behavior of redis.log(...) when called from Lua scripts. Defaults to Enable.")] public LuaLoggingMode LuaLoggingMode { get; set; } + // Parsing is a tad tricky here as JSON wants to set to empty at certain points + // + // A bespoke union-on-set gets the desired semantics. + private readonly HashSet luaAllowedFunctions = []; + + [OptionValidation] + [Option("lua-allowed-functions", Separator = ',', Required = false, HelpText = "If set, restricts the functions available in Lua scripts to given list.")] + public IEnumerable LuaAllowedFunctions + { + get => luaAllowedFunctions; + set + { + if (value == null) + { + return; + } + + luaAllowedFunctions.UnionWith(value); + } + } + [FilePathValidation(false, true, false)] [Option("unixsocket", Required = false, HelpText = "Unix socket address path to bind server to")] public string UnixSocketPath { get; set; } @@ -811,7 +832,7 @@ public GarnetServerOptions GetServerOptions(ILogger logger = null) LoadModuleCS = LoadModuleCS, FailOnRecoveryError = FailOnRecoveryError.GetValueOrDefault(), SkipRDBRestoreChecksumValidation = SkipRDBRestoreChecksumValidation.GetValueOrDefault(), - LuaOptions = EnableLua.GetValueOrDefault() ? new LuaOptions(LuaMemoryManagementMode, LuaScriptMemoryLimit, LuaScriptTimeoutMs == 0 ? Timeout.InfiniteTimeSpan : TimeSpan.FromMilliseconds(LuaScriptTimeoutMs), LuaLoggingMode, logger) : null, + LuaOptions = EnableLua.GetValueOrDefault() ? new LuaOptions(LuaMemoryManagementMode, LuaScriptMemoryLimit, LuaScriptTimeoutMs == 0 ? Timeout.InfiniteTimeSpan : TimeSpan.FromMilliseconds(LuaScriptTimeoutMs), LuaLoggingMode, LuaAllowedFunctions, logger) : null, UnixSocketPath = UnixSocketPath, UnixSocketPermission = unixSocketPermissions }; diff --git a/libs/host/defaults.conf b/libs/host/defaults.conf index 4741f65e63e..cd1e1f4c95a 100644 --- a/libs/host/defaults.conf +++ b/libs/host/defaults.conf @@ -380,6 +380,9 @@ /* Allow redis.log(...) to write to the Garnet logs */ "LuaLoggingMode": "Enable", + /* Allow all built in and redis.* functions by default */ + "LuaAllowedFunctions": null, + /* Unix socket address path to bind the server to */ "UnixSocketPath": null, diff --git a/libs/server/ArgSlice/ScratchBufferManager.cs b/libs/server/ArgSlice/ScratchBufferManager.cs index 1a892820b9b..3e64a42a6f9 100644 --- a/libs/server/ArgSlice/ScratchBufferManager.cs +++ b/libs/server/ArgSlice/ScratchBufferManager.cs @@ -114,6 +114,27 @@ public ArgSlice CreateArgSlice(string str) return retVal; } + public ReadOnlySpan UTF8EncodeString(string str) + { + // We'll always need AT LEAST this many bytes + ExpandScratchBufferIfNeeded(str.Length); + + var space = FullBuffer()[scratchBufferOffset..]; + + // Attempt to fit in the existing buffer first + if (!Encoding.UTF8.TryGetBytes(str, space, out var written)) + { + // If that fails, figure out exactly how much space we need + var neededBytes = Encoding.UTF8.GetByteCount(str); + ExpandScratchBufferIfNeeded(neededBytes); + + space = FullBuffer()[scratchBufferOffset..]; + written = Encoding.UTF8.GetBytes(str, space); + } + + return space[..written]; + } + /// /// Create an ArgSlice that includes a header of specified size, followed by RESP Bulk-String formatted versions of the specified ArgSlice values (arg1 and arg2) /// diff --git a/libs/server/Lua/LuaOptions.cs b/libs/server/Lua/LuaOptions.cs index 24e99e51159..f78e77b65c8 100644 --- a/libs/server/Lua/LuaOptions.cs +++ b/libs/server/Lua/LuaOptions.cs @@ -2,6 +2,7 @@ // Licensed under the MIT license. using System; +using System.Collections.Generic; using System.Runtime.InteropServices; using Microsoft.Extensions.Logging; @@ -18,6 +19,7 @@ public sealed class LuaOptions public string MemoryLimit = ""; public TimeSpan Timeout = System.Threading.Timeout.InfiniteTimeSpan; public LuaLoggingMode LogMode = LuaLoggingMode.Silent; + public HashSet AllowedFunctions = []; /// /// Construct options with default options. @@ -30,12 +32,13 @@ public LuaOptions(ILogger logger = null) /// /// Construct options with specific settings. /// - public LuaOptions(LuaMemoryManagementMode memoryMode, string memoryLimit, TimeSpan timeout, LuaLoggingMode logMode, ILogger logger = null) : this(logger) + public LuaOptions(LuaMemoryManagementMode memoryMode, string memoryLimit, TimeSpan timeout, LuaLoggingMode logMode, IEnumerable allowedFunctions, ILogger logger = null) : this(logger) { MemoryManagementMode = memoryMode; MemoryLimit = memoryLimit; Timeout = timeout; LogMode = logMode; + AllowedFunctions = new HashSet(allowedFunctions, StringComparer.OrdinalIgnoreCase); } /// diff --git a/libs/server/Lua/LuaRunner.cs b/libs/server/Lua/LuaRunner.cs index 95375225759..a5d59ca47f9 100644 --- a/libs/server/Lua/LuaRunner.cs +++ b/libs/server/Lua/LuaRunner.cs @@ -3,12 +3,18 @@ using System; using System.Buffers; +using System.Buffers.Binary; +using System.Collections.Generic; +using System.Data; using System.Diagnostics; using System.Globalization; using System.Linq; +using System.Numerics; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using System.Text; +using System.Text.Json; +using System.Text.Json.Nodes; using Garnet.common; using KeraLua; using Microsoft.Extensions.Logging; @@ -235,44 +241,197 @@ private int ConstantStringToRegistry(ref LuaStateWrapper state, ReadOnlySpan + /// Simple cache of allowed functions to loader block. + /// + private sealed record LoaderBlockCache(HashSet AllowedFunctions, ReadOnlyMemory LoaderBlockBytes); + + private const string LoaderBlock = @" +-- globals to fill in on each invocation KEYS = {} ARGV = {} + +-- disable for sandboxing purposes +import = function () end + +-- cutdown os for sandboxing purposes +local osClockRef = os.clock +os = { + clock = osClockRef +} + +-- define cjson for (optional) inclusion into sandbox_env +local cjson = { + encode = garnet_cjson_encode; + decode = garnet_cjson_decode; +} + +-- define bit for (optional) inclusion into sandbox_env +local bit = { + tobit = garnet_bit_tobit; + tohex = garnet_bit_tohex; + bnot = function(...) return garnet_bitop(0, ...); end; + bor = function(...) return garnet_bitop(1, ...); end; + band = function(...) return garnet_bitop(2, ...); end; + bxor = function(...) return garnet_bitop(3, ...); end; + lshift = function(...) return garnet_bitop(4, ...); end; + rshift = function(...) return garnet_bitop(5, ...); end; + arshift = function(...) return garnet_bitop(6, ...); end; + rol = function(...) return garnet_bitop(7, ...); end; + ror = function(...) return garnet_bitop(8, ...); end; + bswap = garnet_bit_bswap; +} + +-- define cmsgpack for (optional) inclusion into sandbox_env +local cmsgpack = { + pack = garnet_cmsgpack_pack; + unpack = garnet_cmsgpack_unpack; +} + +-- define struct for (optional) inclusion into sandbox_env +local struct = { + pack = string.pack; + unpack = string.unpack; + size = string.packsize; +} + +-- define redis for (optional, but almost always) inclusion into sandbox_env +local garnetCallRef = garnet_call +local pCallRef = pcall +local redis = { + status_reply = function(text) + return text + end, + + error_reply = function(text) + return { err = 'ERR ' .. text } + end, + + call = garnetCallRef, + + pcall = function(...) + local success, errOrRes = pCallRef(garnetCallRef, ...) + if success then + return errOrRes + end + + return { err = errOrRes } + end, + + sha1hex = garnet_sha1hex, + + LOG_DEBUG = 0, + LOG_VERBOSE = 1, + LOG_NOTICE = 2, + LOG_WARNING = 3, + + log = garnet_log, + + REPL_ALL = 3, + REPL_AOF = 1, + REPL_REPLICA = 2, + REPL_SLAVE = 2, + REPL_NONE = 0, + + set_repl = function(...) + -- this is a giant footgun, straight up not implementing it + error('ERR redis.set_repl is not supported in Garnet', 0) + end, + + replicate_commands = function(...) + return true + end, + + breakpoint = function(...) + -- this is giant and weird, not implementing + error('ERR redis.breakpoint is not supported in Garnet', 0) + end, + + debug = function(...) + -- this is giant and weird, not implementing + error('ERR redis.debug is not supported in Garnet', 0) + end, + + acl_check_cmd = garnet_acl_check_cmd, + setresp = garnet_setresp, + + REDIS_VERSION = garnet_REDIS_VERSION, + REDIS_VERSION_NUM = garnet_REDIS_VERSION_NUM +} + +-- unpack moved after Lua 5.1, this provides Redis compat +local unpack = table.unpack + +-- added after Lua 5.1, removing to maintain Redis compat +string.pack = nil +string.unpack = nil +string.packsize = nil +math.maxinteger = nil +math.type = nil +math.mininteger = nil +math.tointeger = nil +math.ult = nil +table.pack = nil +table.unpack = nil +table.move = nil + +-- in Lua 5.1 but not 5.4, so implemented on the .NET side +local loadstring = garnet_loadstring +math.atan2 = garnet_atan2 +math.cosh = garnet_cosh +math.frexp = garnet_frexp +math.ldexp = garnet_ldexp +math.log10 = garnet_log10 +math.pow = garnet_pow +math.sinh = garnet_sinh +math.tanh = garnet_tanh +table.maxn = garnet_maxn + +local collectgarbageRef = collectgarbage +local setMetatableRef = setmetatable +local rawsetRef = rawset + +-- prevent modification to metatables for readonly tables +-- Redis accomplishes this by patching Lua, we'd rather ship +-- vanilla Lua and do it in code +local setmetatable = function(table, metatable) + if table and table.__readonly then + error('Attempt to modify a readonly table', 0) + end + + return setMetatableRef(table, metatable) +end + +-- prevent bypassing metatables to update readonly tables +-- as above, Redis prevents this with a patch to Lua +local rawset = function(table, key, value) + if table and table.__readonly then + error('Attempt to modify a readonly table', 0) + end + + return rawsetRef(table, key, value) +end + +-- technically deprecated in 5.1, but available in Redis +-- this is only 'sort of' correct as 5.4 doesn't expose the same +-- gc primitives +local gcinfo = function() + return collectgarbageRef('count'), 0 +end + +-- global object used for the sandbox environment +-- +-- replacements are performed before VM initialization +-- to allow configuring available functions sandbox_env = { _VERSION = _VERSION; - assert = assert; - collectgarbage = collectgarbage; - coroutine = coroutine; - error = error; - gcinfo = gcinfo; - -- explicitly not allowing getfenv - getmetatable = getmetatable; - ipairs = ipairs; - load = load; - loadstring = loadstring; - math = math; - next = next; - pairs = pairs; - pcall = pcall; - rawequal = rawequal; - rawget = rawget; - -- rawset is proxied to implement readonly tables - select = select; - -- explicitly not allowing setfenv - -- setmetatable is proxied to implement readonly tables - string = string; - table = table; - tonumber = tonumber; - tostring = tostring; - type = type; - unpack = table.unpack; - xpcall = xpcall; - KEYS = KEYS; ARGV = ARGV; + +!!SANDBOX_ENV REPLACEMENT TARGET!! } + -- no reference to outermost set of globals (_G) should survive sandboxing sandbox_env._G = sandbox_env -- lock down a table, recursively doing the same to all table members @@ -298,7 +457,7 @@ function recursively_readonly_table(table) end end - setmetatable(table, readonly_metatable) + setMetatableRef(table, readonly_metatable) end -- do resets in the Lua side to minimize pinvokes function reset_keys_and_argv(fromKey, fromArgv) @@ -316,98 +475,6 @@ function reset_keys_and_argv(fromKey, fromArgv) end -- responsible for sandboxing user provided code function load_sandboxed(source) - -- move into a local to avoid global lookup - local garnetCallRef = garnet_call - local pCallRef = pcall - local sha1hexRef = garnet_sha1hex - local logRef = garnet_log - local aclCheckCmdRef = garnet_acl_check_cmd - local setRespRef = garnet_setresp - local setMetatableRef = setmetatable - local rawsetRaw = rawset - - sandbox_env.redis = { - status_reply = function(text) - return text - end, - - error_reply = function(text) - return { err = 'ERR ' .. text } - end, - - call = garnetCallRef, - - pcall = function(...) - local success, errOrRes = pCallRef(garnetCallRef, ...) - if success then - return errOrRes - end - - return { err = errOrRes } - end, - - sha1hex = sha1hexRef, - - LOG_DEBUG = 0, - LOG_VERBOSE = 1, - LOG_NOTICE = 2, - LOG_WARNING = 3, - - log = logRef, - - REPL_ALL = 3, - REPL_AOF = 1, - REPL_REPLICA = 2, - REPL_SLAVE = 2, - REPL_NONE = 0, - - set_repl = function(...) - -- this is a giant footgun, straight up not implementing it - error('ERR redis.set_repl is not supported in Garnet', 0) - end, - - replicate_commands = function(...) - return true - end, - - breakpoint = function(...) - -- this is giant and weird, not implementing - error('ERR redis.breakpoint is not supported in Garnet', 0) - end, - - debug = function(...) - -- this is giant and weird, not implementing - error('ERR redis.debug is not supported in Garnet', 0) - end, - - acl_check_cmd = aclCheckCmdRef, - setresp = setRespRef, - - REDIS_VERSION = garnet_REDIS_VERSION, - REDIS_VERSION_NUM = garnet_REDIS_VERSION_NUM - } - - -- prevent modification to metatables for readonly tables - -- Redis accomplishes this by patching Lua, we'd rather ship - -- vanilla Lua and do it in code - sandbox_env.setmetatable = function(table, metatable) - if table and table.__readonly then - error('Attempt to modify a readonly table', 0) - end - - return setMetatableRef(table, metatable) - end - - -- prevent bypassing metatables to update readonly tables - -- as above, Redis prevents this with a patch to Lua - sandbox_env.rawset = function(table, key, value) - if table and table.__readonly then - error('Attempt to modify a readonly table', 0) - end - - return rawsetRef(table, key, value) - end - recursively_readonly_table(sandbox_env) local rawFunc, err = load(source, nil, nil, sandbox_env) @@ -415,8 +482,53 @@ function load_sandboxed(source) return err, rawFunc end "; - - private static readonly ReadOnlyMemory LoaderBlockBytes = Encoding.UTF8.GetBytes(LoaderBlock); + private static readonly HashSet DefaultAllowedFunctions = [ + // Built ins + "assert", + "collectgarbage", + "coroutine", + "error", + "gcinfo", + // Intentionally not supporting getfenv, as it's too weird to backport to Lua 5.4 + "getmetatable", + "ipairs", + "load", + "loadstring", + "math", + "next", + "pairs", + "pcall", + "rawequal", + "rawget", + // Note rawset is proxied to implement readonly tables + "rawset", + "select", + // Intentionally not supporting setfenv, as it's too weird to backport to Lua 5.4 + // Note setmetatable is proxied to implement readonly tables + "setmetatable", + "string", + "table", + "tonumber", + "tostring", + "type", + // Note unpack is actually table.unpack, and defined in the loader block + "unpack", + "xpcall", + + // Runtime libs + "bit", + "cjson", + "cmsgpack", + // Note os only contains clock due to definition in the loader block + "os", + // Note struct is actually implemented by Lua 5.4's string.pack/unpack/size + "struct", + + // Interface force communicating back with Garnet + "redis", + ]; + + private static LoaderBlockCache CachedLoaderBlock; private static (int Start, ulong[] ByteMask) NoScriptDetails = InitializeNoScriptDetails(); @@ -432,6 +544,7 @@ function load_sandboxed(source) readonly ConstantStringRegistryIndexes constStrs; readonly LuaLoggingMode logMode; + readonly HashSet allowedFunctions; readonly ReadOnlyMemory source; readonly ScratchBufferNetworkSender scratchBufferNetworkSender; readonly RespServerSession respServerSession; @@ -467,6 +580,7 @@ public unsafe LuaRunner( LuaMemoryManagementMode memMode, int? memLimitBytes, LuaLoggingMode logMode, + HashSet allowedFunctions, ReadOnlyMemory source, bool txnMode = false, RespServerSession respServerSession = null, @@ -480,6 +594,7 @@ public unsafe LuaRunner( this.respServerSession = respServerSession; this.scratchBufferNetworkSender = scratchBufferNetworkSender; this.logMode = logMode; + this.allowedFunctions = allowedFunctions; this.logger = logger; scratchBufferManager = respServerSession?.scratchBufferManager ?? new(); @@ -518,10 +633,56 @@ public unsafe LuaRunner( garnetCall = &LuaRunnerTrampolines.GarnetCallNoSession; } - var loadRes = state.LoadBuffer(LoaderBlockBytes.Span); + // Lua 5.4 does not provide these functions, but 5.1 does - so implement tehm + state.Register("garnet_atan2\0"u8, &LuaRunnerTrampolines.Atan2); + state.Register("garnet_cosh\0"u8, &LuaRunnerTrampolines.Cosh); + state.Register("garnet_frexp\0"u8, &LuaRunnerTrampolines.Frexp); + state.Register("garnet_ldexp\0"u8, &LuaRunnerTrampolines.Ldexp); + state.Register("garnet_log10\0"u8, &LuaRunnerTrampolines.Log10); + state.Register("garnet_pow\0"u8, &LuaRunnerTrampolines.Pow); + state.Register("garnet_sinh\0"u8, &LuaRunnerTrampolines.Sinh); + state.Register("garnet_tanh\0"u8, &LuaRunnerTrampolines.Tanh); + state.Register("garnet_maxn\0"u8, &LuaRunnerTrampolines.Maxn); + state.Register("garnet_loadstring\0"u8, &LuaRunnerTrampolines.LoadString); + + // Things provided as Lua libraries, which we actually implement in .NET + state.Register("garnet_cjson_encode\0"u8, &LuaRunnerTrampolines.CJsonEncode); + state.Register("garnet_cjson_decode\0"u8, &LuaRunnerTrampolines.CJsonDecode); + state.Register("garnet_bit_tobit\0"u8, &LuaRunnerTrampolines.BitToBit); + state.Register("garnet_bit_tohex\0"u8, &LuaRunnerTrampolines.BitToHex); + // garnet_bitop implements bnot, bor, band, xor, etc. but isn't directly exposed + state.Register("garnet_bitop\0"u8, &LuaRunnerTrampolines.Bitop); + state.Register("garnet_bit_bswap\0"u8, &LuaRunnerTrampolines.BitBswap); + state.Register("garnet_cmsgpack_pack\0"u8, &LuaRunnerTrampolines.CMsgPackPack); + state.Register("garnet_cmsgpack_unpack\0"u8, &LuaRunnerTrampolines.CMsgPackUnpack); + state.Register("garnet_call\0"u8, garnetCall); + state.Register("garnet_sha1hex\0"u8, &LuaRunnerTrampolines.SHA1Hex); + state.Register("garnet_log\0"u8, &LuaRunnerTrampolines.Log); + state.Register("garnet_acl_check_cmd\0"u8, &LuaRunnerTrampolines.AclCheckCommand); + state.Register("garnet_setresp\0"u8, &LuaRunnerTrampolines.SetResp); + + var redisVersionBytes = Encoding.UTF8.GetBytes(redisVersion); + state.PushBuffer(redisVersionBytes); + state.SetGlobal("garnet_REDIS_VERSION\0"u8); + + var redisVersionParsed = Version.Parse(redisVersion); + var redisVersionNum = + ((byte)redisVersionParsed.Major << 16) | + ((byte)redisVersionParsed.Minor << 8) | + ((byte)redisVersionParsed.Build << 0); + state.PushInteger(redisVersionNum); + state.SetGlobal("garnet_REDIS_VERSION_NUM\0"u8); + + var loadRes = state.LoadBuffer(PrepareLoaderBlockBytes(allowedFunctions).Span); if (loadRes != LuaStatus.OK) { - throw new GarnetException("Couldn't load loader into Lua"); + if (state.StackTop == 1 && state.CheckBuffer(1, out var buff)) + { + var innerError = Encoding.UTF8.GetString(buff); + throw new GarnetException($"Could initialize Lua VM: {innerError}"); + } + + throw new GarnetException("Could initialize Lua VM"); } var sandboxRes = state.PCall(0, -1); @@ -549,25 +710,6 @@ public unsafe LuaRunner( throw new GarnetException($"Could not initialize Lua sandbox state: {errMsg}"); } - // Register functions provided by .NET in global namespace - state.Register("garnet_call\0"u8, garnetCall); - state.Register("garnet_sha1hex\0"u8, &LuaRunnerTrampolines.SHA1Hex); - state.Register("garnet_log\0"u8, &LuaRunnerTrampolines.Log); - state.Register("garnet_acl_check_cmd\0"u8, &LuaRunnerTrampolines.AclCheckCommand); - state.Register("garnet_setresp\0"u8, &LuaRunnerTrampolines.SetResp); - - var redisVersionBytes = Encoding.UTF8.GetBytes(redisVersion); - state.PushBuffer(redisVersionBytes); - state.SetGlobal("garnet_REDIS_VERSION\0"u8); - - var redisVersionParsed = Version.Parse(redisVersion); - var redisVersionNum = - ((byte)redisVersionParsed.Major << 16) | - ((byte)redisVersionParsed.Minor << 8) | - ((byte)redisVersionParsed.Build << 0); - state.PushInteger(redisVersionNum); - state.SetGlobal("garnet_REDIS_VERSION_NUM\0"u8); - state.GetGlobal(LuaType.Table, "KEYS\0"u8); keysTableRegistryIndex = state.Ref(); @@ -590,7 +732,7 @@ public unsafe LuaRunner( /// Creates a new runner with the source of the script /// public LuaRunner(LuaOptions options, string source, bool txnMode = false, RespServerSession respServerSession = null, ScratchBufferNetworkSender scratchBufferNetworkSender = null, string redisVersion = "0.0.0.0", ILogger logger = null) - : this(options.MemoryManagementMode, options.GetMemoryLimitBytes(), options.LogMode, Encoding.UTF8.GetBytes(source), txnMode, respServerSession, scratchBufferNetworkSender, redisVersion, logger) + : this(options.MemoryManagementMode, options.GetMemoryLimitBytes(), options.LogMode, options.AllowedFunctions, Encoding.UTF8.GetBytes(source), txnMode, respServerSession, scratchBufferNetworkSender, redisVersion, logger) { } @@ -662,7 +804,7 @@ public unsafe bool CompileForSession(RespServerSession session) return false; } - return true; + return functionRegistryIndex != -1; } finally { @@ -742,122 +884,1960 @@ private unsafe int CompileCommon(ref TResponse resp) public void Dispose() => state.Dispose(); - /// - /// Entry point for redis.sha1hex method from a Lua script. - /// - public int SHA1Hex(nint luaStatePtr) - { - state.CallFromLuaEntered(luaStatePtr); + /// + /// Entry point for redis.sha1hex method from a Lua script. + /// + public int SHA1Hex(nint luaStatePtr) + { + state.CallFromLuaEntered(luaStatePtr); + + var argCount = state.StackTop; + if (argCount != 1) + { + state.PushConstantString(constStrs.ErrWrongNumberOfArgs); + return state.RaiseErrorFromStack(); + } + + if (!state.CheckBuffer(1, out var bytes)) + { + bytes = default; + } + + Span hashBytes = stackalloc byte[SessionScriptCache.SHA1Len / 2]; + Span hexRes = stackalloc byte[SessionScriptCache.SHA1Len]; + + SessionScriptCache.GetScriptDigest(bytes, hashBytes, hexRes); + + state.PushBuffer(hexRes); + return 1; + } + + /// + /// Entry point for redis.log(...) from a Lua script. + /// + public int Log(nint luaStatePtr) + { + state.CallFromLuaEntered(luaStatePtr); + + var argCount = state.StackTop; + if (argCount < 2) + { + return LuaStaticError(constStrs.ErrRedisLogRequired); + } + + if (state.Type(1) != LuaType.Number) + { + return LuaStaticError(constStrs.ErrFirstArgMustBeNumber); + } + + var rawLevel = state.CheckNumber(1); + if (rawLevel is not (0 or 1 or 2 or 3)) + { + return LuaStaticError(constStrs.ErrInvalidDebugLevel); + } + + if (logMode == LuaLoggingMode.Disable) + { + return LuaStaticError(constStrs.ErrLoggingDisabled); + } + + // When shipped as a service, allowing arbitrary writes to logs is dangerous + // so we support disabling it (while not breaking existing scripts) + if (logMode == LuaLoggingMode.Silent) + { + return 0; + } + + // Even if enabled, if no logger was provided we can just bail + if (logger == null) + { + return 0; + } + + // Construct and log the equivalent message + string logMessage; + if (argCount == 2) + { + if (state.CheckBuffer(2, out var buff)) + { + logMessage = Encoding.UTF8.GetString(buff); + } + else + { + logMessage = ""; + } + } + else + { + var sb = new StringBuilder(); + + for (var argIx = 2; argIx <= argCount; argIx++) + { + if (state.CheckBuffer(argIx, out var buff)) + { + if (sb.Length != 0) + { + _ = sb.Append(' '); + } + + _ = sb.Append(Encoding.UTF8.GetString(buff)); + } + } + + logMessage = sb.ToString(); + } + + var logLevel = + rawLevel switch + { + 0 => LogLevel.Debug, + 1 => LogLevel.Information, + 2 => LogLevel.Warning, + // We validated this above, so really it's just 3 but the switch needs to be exhaustive + _ => LogLevel.Error, + }; + + logger.Log(logLevel, "redis.log: {message}", logMessage.ToString()); + + return 0; + } + + /// + /// Entry point for math.atan2 from a Lua script. + /// + public int Atan2(nint luaStatePtr) + { + state.CallFromLuaEntered(luaStatePtr); + + var luaArgCount = state.StackTop; + if (luaArgCount != 2 || state.Type(1) != LuaType.Number || state.Type(2) != LuaType.Number) + { + return state.RaiseError("bad argument to atan2"); + } + + var x = state.CheckNumber(1); + var y = state.CheckNumber(2); + + var res = Math.Atan2(x, y); + state.Pop(2); + state.PushNumber(res); + return 1; + } + + /// + /// Entry point for math.cosh from a Lua script. + /// + public int Cosh(nint luaStatePtr) + { + state.CallFromLuaEntered(luaStatePtr); + + var luaArgCount = state.StackTop; + if (luaArgCount != 1 || state.Type(1) != LuaType.Number) + { + return state.RaiseError("bad argument to cosh"); + } + + var value = state.CheckNumber(1); + + var res = Math.Cosh(value); + state.Pop(1); + state.PushNumber(res); + return 1; + } + + /// + /// Entry point for math.frexp from a Lua script. + /// + public int Frexp(nint luaStatePtr) + { + // Based on: https://github.com/MachineCognitis/C.math.NET/ (MIT License) + + const long DBL_EXP_MASK = 0x7FF0000000000000L; + const int DBL_MANT_BITS = 52; + const long DBL_SGN_MASK = -1 - 0x7FFFFFFFFFFFFFFFL; + const long DBL_MANT_MASK = 0x000FFFFFFFFFFFFFL; + const long DBL_EXP_CLR_MASK = DBL_SGN_MASK | DBL_MANT_MASK; + + state.CallFromLuaEntered(luaStatePtr); + + var luaArgCount = state.StackTop; + if (luaArgCount != 1 || state.Type(1) != LuaType.Number) + { + return state.RaiseError("bad argument to frexp"); + } + + var number = state.CheckNumber(1); + + var bits = BitConverter.DoubleToInt64Bits(number); + var exp = (int)((bits & DBL_EXP_MASK) >> DBL_MANT_BITS); + var exponent = 0; + + if (exp == 0x7FF || number == 0D) + { + number += number; + } + else + { + // Not zero and finite. + exponent = exp - 1022; + if (exp == 0) + { + // Subnormal, scale number so that it is in [1, 2). + number *= BitConverter.Int64BitsToDouble(0x4350000000000000L); // 2^54 + bits = BitConverter.DoubleToInt64Bits(number); + exp = (int)((bits & DBL_EXP_MASK) >> DBL_MANT_BITS); + exponent = exp - 1022 - 54; + } + // Set exponent to -1 so that number is in [0.5, 1). + number = BitConverter.Int64BitsToDouble((bits & DBL_EXP_CLR_MASK) | 0x3FE0000000000000L); + } + + state.ForceMinimumStackCapacity(2); + state.Pop(1); + + var numberAsFloat = (float)number; + + if ((long)numberAsFloat == numberAsFloat) + { + state.PushInteger((long)numberAsFloat); + } + else + { + state.PushNumber(number); + } + + state.PushInteger(exponent); + + return 2; + } + + /// + /// Entry point for math.ldexp from a Lua script. + /// + public int Ldexp(nint luaStatePtr) + { + state.CallFromLuaEntered(luaStatePtr); + + var luaArgCount = state.StackTop; + if (luaArgCount != 2 || state.Type(1) != LuaType.Number || state.Type(2) != LuaType.Number) + { + return state.RaiseError("bad argument to ldexp"); + } + + var m = state.CheckNumber(1); + var e = (int)state.CheckNumber(2); + + var res = m * Math.Pow(2, e); + + state.Pop(2); + + if ((long)res == res) + { + state.PushInteger((long)res); + } + else + { + state.PushNumber(res); + } + + return 1; + } + + /// + /// Entry point for math.log10 from a Lua script. + /// + public int Log10(nint luaStatePtr) + { + state.CallFromLuaEntered(luaStatePtr); + + var luaArgCount = state.StackTop; + if (luaArgCount != 1 || state.Type(1) != LuaType.Number) + { + return state.RaiseError("bad argument to log10"); + } + + var val = state.CheckNumber(1); + + var res = Math.Log10(val); + + state.Pop(1); + state.PushNumber(res); + return 1; + } + + /// + /// Entry point for math.pow from a Lua script. + /// + public int Pow(nint luaStatePtr) + { + state.CallFromLuaEntered(luaStatePtr); + + var luaArgCount = state.StackTop; + if (luaArgCount != 2 || state.Type(1) != LuaType.Number || state.Type(2) != LuaType.Number) + { + return state.RaiseError("bad argument to pow"); + } + + var x = state.CheckNumber(1); + var y = state.CheckNumber(2); + + var res = Math.Pow(x, y); + + state.Pop(2); + state.PushNumber(res); + return 1; + } + + /// + /// Entry point for math.sinh from a Lua script. + /// + public int Sinh(nint luaStatePtr) + { + state.CallFromLuaEntered(luaStatePtr); + + var luaArgCount = state.StackTop; + if (luaArgCount != 1 || state.Type(1) != LuaType.Number) + { + return state.RaiseError("bad argument to sinh"); + } + + var val = state.CheckNumber(1); + + var res = Math.Sinh(val); + + state.Pop(1); + state.PushNumber(res); + return 1; + } + + /// + /// Entry point for math.sinh from a Lua script. + /// + public int Tanh(nint luaStatePtr) + { + state.CallFromLuaEntered(luaStatePtr); + + var luaArgCount = state.StackTop; + if (luaArgCount != 1 || state.Type(1) != LuaType.Number) + { + return state.RaiseError("bad argument to tanh"); + } + + var val = state.CheckNumber(1); + + var res = Math.Tanh(val); + + state.Pop(1); + state.PushNumber(res); + return 1; + } + + /// + /// Entry point for table.maxn from a Lua script. + /// + public int Maxn(nint luaStatePtr) + { + state.CallFromLuaEntered(luaStatePtr); + + var luaArgCount = state.StackTop; + if (luaArgCount != 1 || state.Type(1) != LuaType.Table) + { + return state.RaiseError("bad argument to maxn"); + } + + state.ForceMinimumStackCapacity(2); + + double res = 0; + + // Initial key value onto stack + state.PushNil(); + while (state.Next(1) != 0) + { + // Remove value + state.Pop(1); + + double keyVal; + if (state.Type(2) == LuaType.Number && (keyVal = state.CheckNumber(2)) > res) + { + res = keyVal; + } + } + + // Remove table, and push largest number + state.Pop(1); + state.PushNumber(res); + return 1; + } + + /// + /// Entry point for loadstring from a Lua script. + /// + public int LoadString(nint luaStatePtr) + { + state.CallFromLuaEntered(luaStatePtr); + + var luaArgCount = state.StackTop; + if ( + (luaArgCount == 1 && state.Type(1) != LuaType.String) || + (luaArgCount == 2 && (state.Type(1) != LuaType.String || state.Type(2) != LuaType.String)) || + (luaArgCount > 2) + ) + { + return state.RaiseError("bad argument to loadstring"); + } + + // Ignore chunk name + if (luaArgCount == 2) + { + state.Pop(1); + } + + _ = state.CheckBuffer(1, out var buff); + if (buff.Contains((byte)0)) + { + return state.RaiseError("bad argument to loadstring, interior null byte"); + } + + state.ForceMinimumStackCapacity(1); + + var res = state.LoadString(buff); + if (res != LuaStatus.OK) + { + state.ClearStack(); + state.PushNil(); + state.PushBuffer("load_string encountered error"u8); + return 2; + } + + return 1; + } + + /// + /// Converts a Lua number (ie. a double) into the expected 32-bit integer for + /// bit operations. + /// + private static int LuaNumberToBitValue(double value) + { + var scaled = value + 6_755_399_441_055_744.0; + var asULong = BitConverter.DoubleToUInt64Bits(scaled); + var asUInt = (uint)asULong; + + return (int)asUInt; + } + + /// + /// Entry point for bit.tobit from a Lua script. + /// + public int BitToBit(nint luaStatePtr) + { + state.CallFromLuaEntered(luaStatePtr); + + var luaArgCount = state.StackTop; + if (luaArgCount < 1 || state.Type(1) != LuaType.Number) + { + return state.RaiseError("bad argument to tobit"); + } + + var rawValue = state.CheckNumber(1); + + // Make space on the stack + state.Pop(1); + + state.PushNumber(LuaNumberToBitValue(rawValue)); + + return 1; + } + + /// + /// Entry point for bit.tohex from a Lua script. + /// + public int BitToHex(nint luaStatePtr) + { + state.CallFromLuaEntered(luaStatePtr); + + var luaArgCount = state.StackTop; + if (luaArgCount == 0 || state.Type(1) != LuaType.Number) + { + return state.RaiseError("bad argument to tohex"); + } + + var numDigits = 8; + + if (luaArgCount == 2) + { + if (state.Type(2) != LuaType.Number) + { + return state.RaiseError("bad argument to tohex"); + } + + numDigits = (int)state.CheckNumber(2); + } + + var value = LuaNumberToBitValue(state.CheckNumber(1)); + + ReadOnlySpan hexBytes; + if (numDigits == int.MinValue) + { + numDigits = 8; + hexBytes = "0123456789ABCDEF"u8; + } + else if (numDigits < 0) + { + numDigits = -numDigits; + hexBytes = "0123456789ABCDEF"u8; + } + else + { + hexBytes = "0123456789abcdef"u8; + } + + if (numDigits > 8) + { + numDigits = 8; + } + + Span buff = stackalloc byte[numDigits]; + for (var i = buff.Length - 1; i >= 0; i--) + { + buff[i] = hexBytes[value & 0xF]; + value >>= 4; + } + + // Free up space on stack + state.Pop(luaArgCount); + + state.PushBuffer(buff); + return 1; + } + + /// + /// Entry point for bit.bswap from a Lua script. + /// + public int BitBswap(nint luaStatePtr) + { + state.CallFromLuaEntered(luaStatePtr); + + var luaArgCount = state.StackTop; + if (luaArgCount != 1 || state.Type(1) != LuaType.Number) + { + return state.RaiseError("bad argument to bswap"); + } + + var value = LuaNumberToBitValue(state.CheckNumber(1)); + + // Free up space on stack + state.Pop(1); + + var swapped = BinaryPrimitives.ReverseEndianness(value); + state.PushNumber(swapped); + return 1; + } + + /// + /// Entry point for garnet_bitop from a Lua script. + /// + /// Used to implement bit.bnot, bit.bor, bit.band, etc. + /// + public int Bitop(nint luaStatePtr) + { + const int BNot = 0; + const int BOr = 1; + const int BAnd = 2; + const int BXor = 3; + const int LShift = 4; + const int RShift = 5; + const int ARShift = 6; + const int Rol = 7; + const int Ror = 8; + + state.CallFromLuaEntered(luaStatePtr); + + var luaArgCount = state.StackTop; + if (luaArgCount == 0 || state.Type(1) != LuaType.Number) + { + return state.RaiseError("bitop was not indicated, should never happen"); + } + + var bitop = (int)state.CheckNumber(1); + if (bitop < BNot || bitop > Ror) + { + return state.RaiseError($"invalid bitop {bitop} was indicated, should never happen"); + } + + // Handle bnot specially + if (bitop == BNot) + { + if (luaArgCount < 2 || state.Type(2) != LuaType.Number) + { + return state.RaiseError("bad argument to bnot"); + } + + var val = LuaNumberToBitValue(state.CheckNumber(2)); + var res = ~val; + state.Pop(2); + + state.PushNumber(res); + return 1; + } + + var binOpName = + bitop switch + { + BOr => "bor", + BAnd => "band", + BXor => "bxor", + LShift => "lshift", + RShift => "rshift", + ARShift => "arshift", + Rol => "rol", + _ => "ror", + }; + + if (luaArgCount < 2) + { + return state.RaiseError($"bad argument to {binOpName}"); + } + + if (bitop is BOr or BAnd or BXor) + { + var ret = + bitop switch + { + BOr => 0, + BXor => 0, + _ => -1, + }; + + for (var argIx = 2; argIx <= luaArgCount; argIx++) + { + if (state.Type(argIx) != LuaType.Number) + { + return state.RaiseError($"bad argument to {binOpName}"); + } + + var nextValue = LuaNumberToBitValue(state.CheckNumber(argIx)); + + ret = + bitop switch + { + BOr => ret | nextValue, + BXor => ret ^ nextValue, + _ => ret & nextValue, + }; + } + + state.Pop(luaArgCount); + state.PushNumber(ret); + + return 1; + } + + if (luaArgCount < 3 || state.Type(2) != LuaType.Number || state.Type(3) != LuaType.Number) + { + return state.RaiseError($"bad argument to {binOpName}"); + } + + var x = LuaNumberToBitValue(state.CheckNumber(2)); + var n = ((int)state.CheckNumber(3)) & 0b1111; + + var shiftRes = + bitop switch + { + LShift => x << n, + RShift => (int)((uint)x >> n), + ARShift => x >> n, + Rol => (int)BitOperations.RotateLeft((uint)x, n), + _ => (int)BitOperations.RotateRight((uint)x, n), + }; + + state.Pop(luaArgCount); + state.PushNumber(shiftRes); + return 1; + } + + /// + /// Entry point for cjson.encode from a Lua script. + /// + public int CJsonEncode(nint luaStatePtr) + { + state.CallFromLuaEntered(luaStatePtr); + + var luaArgCount = state.StackTop; + if (luaArgCount != 1) + { + return state.RaiseError("bad argument to encode"); + } + + Encode(this, 0); + + // Encoding should leave nothing on the stack + state.ExpectLuaStackEmpty(); + + // Push the encoded string + var result = scratchBufferManager.ViewFullArgSlice().ReadOnlySpan; + state.PushBuffer(result); + + return 1; + + // Encode the unknown type on the top of the stack + static void Encode(LuaRunner self, int depth) + { + if (depth > 1000) + { + // Match Redis max decoding depth + _ = self.state.RaiseError("Cannot serialise, excessive nesting (1001)"); + } + + var argType = self.state.Type(self.state.StackTop); + + switch (argType) + { + case LuaType.Boolean: + EncodeBool(self); + break; + case LuaType.Nil: + EncodeNull(self); + break; + case LuaType.Number: + EncodeNumber(self); + break; + case LuaType.String: + EncodeString(self); + break; + case LuaType.Table: + EncodeTable(self, depth); + break; + case LuaType.Function: + case LuaType.LightUserData: + case LuaType.None: + case LuaType.Thread: + case LuaType.UserData: + default: + _ = self.state.RaiseError($"Cannot serialise {argType} to JSON"); + break; + } + } + + // Encode the boolean on the top of the stack and remove it + static void EncodeBool(LuaRunner self) + { + Debug.Assert(self.state.Type(self.state.StackTop) == LuaType.Boolean, "Expected boolean on top of stack"); + + var data = self.state.ToBoolean(self.state.StackTop) ? "true"u8 : "false"u8; + + var into = self.scratchBufferManager.ViewRemainingArgSlice(data.Length).Span; + data.CopyTo(into); + self.scratchBufferManager.MoveOffset(data.Length); + + self.state.Pop(1); + } + + // Encode the nil on the top of the stack and remove it + static void EncodeNull(LuaRunner self) + { + Debug.Assert(self.state.Type(self.state.StackTop) == LuaType.Nil, "Expected nil on top of stack"); + + var into = self.scratchBufferManager.ViewRemainingArgSlice(4).Span; + "null"u8.CopyTo(into); + self.scratchBufferManager.MoveOffset(4); + + self.state.Pop(1); + } + + // Encode the number on the top of the stack and remove it + static void EncodeNumber(LuaRunner self) + { + Debug.Assert(self.state.Type(self.state.StackTop) == LuaType.Number, "Expected number on top of stack"); + + var number = self.state.CheckNumber(self.state.StackTop); + + Span space = stackalloc byte[64]; + + if (!number.TryFormat(space, out var written, "G", CultureInfo.InvariantCulture)) + { + _ = self.state.RaiseError("Unable to format number"); + } + + var into = self.scratchBufferManager.ViewRemainingArgSlice(written).Span; + space[..written].CopyTo(into); + self.scratchBufferManager.MoveOffset(written); + + self.state.Pop(1); + } + + // Encode the string on the top of the stack and remove it + static void EncodeString(LuaRunner self) + { + Debug.Assert(self.state.Type(self.state.StackTop) == LuaType.String, "Expected string on top of stack"); + + _ = self.state.CheckBuffer(self.state.StackTop, out var buff); + + self.scratchBufferManager.ViewRemainingArgSlice(1).Span[0] = (byte)'"'; + self.scratchBufferManager.MoveOffset(1); + + var escapeIx = buff.IndexOfAny((byte)'"', (byte)'\\'); + while (escapeIx != -1) + { + var into = self.scratchBufferManager.ViewRemainingArgSlice(escapeIx + 2).Span; + buff[..escapeIx].CopyTo(into); + + into[escapeIx] = (byte)'\\'; + + var toEscape = buff[escapeIx]; + if (toEscape == (byte)'"') + { + into[escapeIx + 1] = (byte)'"'; + } + else + { + into[escapeIx + 1] = (byte)'\\'; + } + + self.scratchBufferManager.MoveOffset(escapeIx + 2); + + buff = buff[(escapeIx + 1)..]; + escapeIx = buff.IndexOfAny((byte)'"', (byte)'\\'); + } + + var tailInto = self.scratchBufferManager.ViewRemainingArgSlice(buff.Length + 1).Span; + buff.CopyTo(tailInto); + tailInto[buff.Length] = (byte)'"'; + self.scratchBufferManager.MoveOffset(buff.Length + 1); + + self.state.Pop(1); + } + + // Encode the table on the top of the stack and remove it + static void EncodeTable(LuaRunner self, int depth) + { + Debug.Assert(self.state.Type(self.state.StackTop) == LuaType.Table, "Expected table on top of stack"); + + // Space for key & value + self.state.ForceMinimumStackCapacity(2); + + var tableIndex = self.state.StackTop; + + var isArray = false; + var arrayLength = 0; + + self.state.PushNil(); + while (self.state.Next(tableIndex) != 0) + { + // Pop value + self.state.Pop(1); + + double keyAsNumber; + if (self.state.Type(tableIndex + 1) == LuaType.Number && (keyAsNumber = self.state.CheckNumber(tableIndex + 1)) >= 1 && keyAsNumber == (int)keyAsNumber) + { + if (keyAsNumber > arrayLength) + { + // Need at least one integer key >= 1 to consider this an array + isArray = true; + arrayLength = (int)keyAsNumber; + } + } + else + { + // Non-integer key, or integer <= 0, so it's not an array + isArray = false; + + // Remove key + self.state.Pop(1); + + break; + } + } + + if (isArray) + { + EncodeArray(self, arrayLength, depth); + } + else + { + EncodeObject(self, depth); + } + } + + // Encode the table on the top of the stack as an array and remove it + static void EncodeArray(LuaRunner self, int length, int depth) + { + Debug.Assert(self.state.Type(self.state.StackTop) == LuaType.Table, "Expected table on top of stack"); + + // Space for value + self.state.ForceMinimumStackCapacity(1); + + var tableIndex = self.state.StackTop; + + self.scratchBufferManager.ViewRemainingArgSlice(1).Span[0] = (byte)'['; + self.scratchBufferManager.MoveOffset(1); + + for (var ix = 1; ix <= length; ix++) + { + if (ix != 1) + { + self.scratchBufferManager.ViewRemainingArgSlice(1).Span[0] = (byte)','; + self.scratchBufferManager.MoveOffset(1); + } + + _ = self.state.RawGetInteger(null, tableIndex, ix); + Encode(self, depth + 1); + } + + self.scratchBufferManager.ViewRemainingArgSlice(1).Span[0] = (byte)']'; + self.scratchBufferManager.MoveOffset(1); + + // Remove table + self.state.Pop(1); + } + + // Encode the table on the top of the stack as an object and remove it + static void EncodeObject(LuaRunner self, int depth) + { + Debug.Assert(self.state.Type(self.state.StackTop) == LuaType.Table, "Expected table on top of stack"); + + // Space for key and value and a copy of key + self.state.ForceMinimumStackCapacity(3); + + var tableIndex = self.state.StackTop; + + self.scratchBufferManager.ViewRemainingArgSlice(1).Span[0] = (byte)'{'; + self.scratchBufferManager.MoveOffset(1); + + var firstValue = true; + + self.state.PushNil(); + while (self.state.Next(tableIndex) != 0) + { + LuaType keyType; + if ((keyType = self.state.Type(tableIndex + 1)) is not (LuaType.String or LuaType.Number)) + { + // Ignore non-string-ify-abile keys + + // Remove value + self.state.Pop(1); + + continue; + } + + if (!firstValue) + { + self.scratchBufferManager.ViewRemainingArgSlice(1).Span[0] = (byte)','; + self.scratchBufferManager.MoveOffset(1); + } + + // Copy key to top of stack + self.state.PushValue(tableIndex + 1); + + // Force the _copy_ of the key to be a string + // if it is not already one. + // + // We don't modify the original key value, so we + // can continue using it with Next(...) + if (keyType == LuaType.Number) + { + _ = self.state.CheckBuffer(tableIndex + 3, out _); + } + + // Encode key + Encode(self, depth + 1); + + self.scratchBufferManager.ViewRemainingArgSlice(1).Span[0] = (byte)':'; + self.scratchBufferManager.MoveOffset(1); + + // Encode value + Encode(self, depth + 1); + + firstValue = false; + } + + self.scratchBufferManager.ViewRemainingArgSlice(1).Span[0] = (byte)'}'; + self.scratchBufferManager.MoveOffset(1); + + // Remove table + self.state.Pop(1); + } + } + + /// + /// Entry point for cjson.decode from a Lua script. + /// + public int CJsonDecode(nint luaStatePtr) + { + state.CallFromLuaEntered(luaStatePtr); + + var luaArgCount = state.StackTop; + if (luaArgCount != 1) + { + return state.RaiseError("bad argument to decode"); + } + + var argType = state.Type(1); + if (argType == LuaType.Number) + { + // We'd coerce this to a string, and then decode it, so just pass it back as is + // + // There are some cases where this wouldn't work, potentially, but they are super implementation + // specific so we can just pretend we made them work + return 1; + } + + if (argType != LuaType.String) + { + return state.RaiseError("bad argument to decode"); + } + + _ = state.CheckBuffer(1, out var buff); + + try + { + var parsed = JsonNode.Parse(buff, documentOptions: new JsonDocumentOptions { MaxDepth = 1000 }); + Decode(this, parsed); + + return 1; + } + catch (Exception e) + { + if (e.Message.Contains("maximum configured depth of 1000")) + { + // Maximum depth exceeded, munge to a compatible Redis error + return state.RaiseError("Found too many nested data structures (1001)"); + } + + // Invalid token is implied (and matches Redis error replies) + // + // Additinal error details can be gleaned from messages + return state.RaiseError($"Expected value but found invalid token. Inner Message = {e.Message}"); + } + + // Convert the JsonNode into a Lua value on the stack + static void Decode(LuaRunner self, JsonNode node) + { + if (node is JsonValue v) + { + DecodeValue(self, v); + } + else if (node is JsonArray a) + { + DecodeArray(self, a); + } + else if (node is JsonObject o) + { + DecodeObject(self, o); + } + else + { + _ = self.state.RaiseError($"Unexpected json node type: {node.GetType().Name}"); + } + } + + // Convert the JsonValue int to a Lua string, nil, or number on the stack + static void DecodeValue(LuaRunner self, JsonValue value) + { + // Reserve space for the value + self.state.ForceMinimumStackCapacity(1); + + switch (value.GetValueKind()) + { + case JsonValueKind.Null: self.state.PushNil(); break; + case JsonValueKind.True: self.state.PushBoolean(true); break; + case JsonValueKind.False: self.state.PushBoolean(false); break; + case JsonValueKind.Number: self.state.PushNumber(value.GetValue()); break; + case JsonValueKind.String: + var str = value.GetValue(); + + self.scratchBufferManager.Reset(); + var buf = self.scratchBufferManager.UTF8EncodeString(str); + + self.state.PushBuffer(buf); + break; + case JsonValueKind.Undefined: + case JsonValueKind.Object: + case JsonValueKind.Array: + default: + _ = self.state.RaiseError($"Unexpected json value kind: {value.GetValueKind()}"); + break; + } + } + + // Convert the JsonArray into a Lua table on the stack + static void DecodeArray(LuaRunner self, JsonArray arr) + { + // Reserve space for the table + self.state.ForceMinimumStackCapacity(1); + + self.state.CreateTable(arr.Count, 0); + + var tableIndex = self.state.StackTop; + + var storeAtIx = 1; + foreach (var item in arr) + { + // Places item on the stack + Decode(self, item); + + // Save into the table + self.state.RawSetInteger(tableIndex, storeAtIx); + storeAtIx++; + } + } + + // Convert the JsonObject into a Lua table on the stack + static void DecodeObject(LuaRunner self, JsonObject obj) + { + // Reserve space for table and key + self.state.ForceMinimumStackCapacity(2); + + self.state.CreateTable(0, obj.Count); + + var tableIndex = self.state.StackTop; + + foreach (var (key, value) in obj) + { + // Decode key to string + self.scratchBufferManager.Reset(); + var buf = self.scratchBufferManager.UTF8EncodeString(key); + self.state.PushBuffer(buf); + + // Decode value + Decode(self, value); + + self.state.RawSet(tableIndex); + } + } + } + + /// + /// Entry point for cmsgpack.pack from a Lua script. + /// + public int CMsgPackPack(nint luaStatePtr) + { + state.CallFromLuaEntered(luaStatePtr); + + var numLuaArgs = state.StackTop; + + if (numLuaArgs == 0) + { + return state.RaiseError("bad argument to pack"); + } + + // Redis concatenates all the message packs together if there are multiple + // + // Somewhat odd, but we match that behavior + + scratchBufferManager.Reset(); + + for (var argIx = 1; argIx <= numLuaArgs; argIx++) + { + // Because each encode removes the encoded value + // we always encode the argument at position 1 + Encode(this, 1, 0); + } + + // After all encoding, stack should be empty + state.ExpectLuaStackEmpty(); + + var ret = scratchBufferManager.ViewFullArgSlice().ReadOnlySpan; + state.PushBuffer(ret); + + return 1; + + // Encode a single item at the top of the stack, and remove it + static void Encode(LuaRunner self, int stackIndex, int depth) + { + var type = self.state.Type(stackIndex); + switch (type) + { + case LuaType.Boolean: EncodeBool(self, stackIndex); break; + case LuaType.Number: EncodeNumber(self, stackIndex); break; + case LuaType.String: EncodeBytes(self, stackIndex); break; + case LuaType.Table: + + if (depth == 16) + { + // Redis treats a too deeply nested table as a null + // + // This is weird, but we match it + self.scratchBufferManager.ViewRemainingArgSlice(1).Span[0] = 0xC0; + self.scratchBufferManager.MoveOffset(1); + + self.state.Remove(stackIndex); + + return; + } + + EncodeTable(self, stackIndex, depth); + break; + + // Everything else maps to null, NOT an error + case LuaType.Function: + case LuaType.LightUserData: + case LuaType.Nil: + case LuaType.None: + case LuaType.Thread: + case LuaType.UserData: + default: EncodeNull(self, stackIndex); break; + } + } + + // Encode a null-ish value at stackIndex, and remove it + static void EncodeNull(LuaRunner self, int stackIndex) + { + Debug.Assert(self.state.Type(stackIndex) is not (LuaType.Boolean or LuaType.Number or LuaType.String or LuaType.Table), "Expected null-ish type"); + + self.scratchBufferManager.ViewRemainingArgSlice(1).Span[0] = 0xC0; + self.scratchBufferManager.MoveOffset(1); + + self.state.Remove(stackIndex); + } + + // Encode a boolean at stackIndex, and remove it + static void EncodeBool(LuaRunner self, int stackIndex) + { + Debug.Assert(self.state.Type(stackIndex) == LuaType.Boolean, "Expected boolean"); + + var value = (byte)(self.state.ToBoolean(stackIndex) ? 0xC3 : 0xC2); + + self.scratchBufferManager.ViewRemainingArgSlice(1).Span[0] = value; + self.scratchBufferManager.MoveOffset(1); + + self.state.Remove(stackIndex); + } + + // Encode a number at stackIndex, and remove it + static void EncodeNumber(LuaRunner self, int stackIndex) + { + Debug.Assert(self.state.Type(stackIndex) == LuaType.Number, "Expected number"); + + var numRaw = self.state.CheckNumber(stackIndex); + var isInt = numRaw == (long)numRaw; + + if (isInt) + { + EncodeInteger(self, (long)numRaw); + } + else + { + EncodeFloatingPoint(self, numRaw); + } + + self.state.Remove(stackIndex); + } + + // Encode an integer + static void EncodeInteger(LuaRunner self, long value) + { + // positive 7-bit fixint + if ((byte)(value & 0b0111_1111) == value) + { + self.scratchBufferManager.ViewRemainingArgSlice(1).Span[0] = (byte)value; + self.scratchBufferManager.MoveOffset(1); + + return; + } + + // negative 5-bit fixint + if ((sbyte)(value | 0b1110_0000) == value) + { + self.scratchBufferManager.ViewRemainingArgSlice(1).Span[0] = (byte)value; + self.scratchBufferManager.MoveOffset(1); + return; + } + + // 8-bit int + if (value is >= sbyte.MinValue and <= sbyte.MaxValue) + { + var into = self.scratchBufferManager.ViewRemainingArgSlice(2).Span; + + into[0] = 0xD0; + into[1] = (byte)value; + self.scratchBufferManager.MoveOffset(2); + return; + } + + // 8-bit uint + if (value is >= byte.MinValue and <= byte.MaxValue) + { + var into = self.scratchBufferManager.ViewRemainingArgSlice(2).Span; + + into[0] = 0xCC; + into[1] = (byte)value; + self.scratchBufferManager.MoveOffset(2); + return; + } + + // 16-bit int + if (value is >= short.MinValue and <= short.MaxValue) + { + var into = self.scratchBufferManager.ViewRemainingArgSlice(3).Span; + + into[0] = 0xD1; + BinaryPrimitives.WriteInt16BigEndian(into[1..], (short)value); + self.scratchBufferManager.MoveOffset(3); + return; + } + + // 16-bit uint + if (value is >= ushort.MinValue and <= ushort.MaxValue) + { + var into = self.scratchBufferManager.ViewRemainingArgSlice(3).Span; + + into[0] = 0xCD; + BinaryPrimitives.WriteUInt16BigEndian(into[1..], (ushort)value); + self.scratchBufferManager.MoveOffset(3); + return; + } + + // 32-bit int + if (value is >= int.MinValue and <= int.MaxValue) + { + var into = self.scratchBufferManager.ViewRemainingArgSlice(5).Span; + + into[0] = 0xD2; + BinaryPrimitives.WriteInt32BigEndian(into[1..], (int)value); + self.scratchBufferManager.MoveOffset(5); + return; + } + + // 32-bit uint + if (value is >= uint.MinValue and <= uint.MaxValue) + { + var into = self.scratchBufferManager.ViewRemainingArgSlice(5).Span; + + into[0] = 0xCE; + BinaryPrimitives.WriteUInt32BigEndian(into[1..], (uint)value); + self.scratchBufferManager.MoveOffset(5); + return; + } + + // 64-bit uint + if (value > uint.MaxValue) + { + var into = self.scratchBufferManager.ViewRemainingArgSlice(9).Span; + + into[0] = 0xCF; + BinaryPrimitives.WriteUInt64BigEndian(into[1..], (ulong)value); + self.scratchBufferManager.MoveOffset(9); + return; + } + + // 64-bit int + { + var into = self.scratchBufferManager.ViewRemainingArgSlice(9).Span; + + into[0] = 0xD3; + BinaryPrimitives.WriteInt64BigEndian(into[1..], value); + self.scratchBufferManager.MoveOffset(9); + } + } + + // Encode a floating point value + static void EncodeFloatingPoint(LuaRunner self, double value) + { + // While Redis has code that attempts to pack doubles into floats + // it doesn't appear to do anything, so we just always write a double + + var into = self.scratchBufferManager.ViewRemainingArgSlice(9).Span; + + into[0] = 0xCB; + BinaryPrimitives.WriteDoubleBigEndian(into[1..], value); + self.scratchBufferManager.MoveOffset(9); + } + + // Encodes a string as at stackIndex, and remove it + static void EncodeBytes(LuaRunner self, int stackIndex) + { + Debug.Assert(self.state.Type(stackIndex) == LuaType.String, "Expected string"); + + _ = self.state.CheckBuffer(stackIndex, out var data); + + if (data.Length < 32) + { + var into = self.scratchBufferManager.ViewRemainingArgSlice(1 + data.Length).Span; + + into[0] = (byte)(0xA0 | data.Length); + data.CopyTo(into[1..]); + self.scratchBufferManager.MoveOffset(1 + data.Length); + } + else if (data.Length <= byte.MaxValue) + { + var into = self.scratchBufferManager.ViewRemainingArgSlice(2 + data.Length).Span; + + into[0] = 0xD9; + into[1] = (byte)data.Length; + data.CopyTo(into[2..]); + self.scratchBufferManager.MoveOffset(2 + data.Length); + } + else if (data.Length <= ushort.MaxValue) + { + var into = self.scratchBufferManager.ViewRemainingArgSlice(3 + data.Length).Span; + + into[0] = 0xDA; + BinaryPrimitives.WriteUInt16BigEndian(into[1..], (ushort)data.Length); + data.CopyTo(into[3..]); + self.scratchBufferManager.MoveOffset(3 + data.Length); + } + else + { + var into = self.scratchBufferManager.ViewRemainingArgSlice(5 + data.Length).Span; + + into[0] = 0xDB; + BinaryPrimitives.WriteUInt32BigEndian(into[1..], (uint)data.Length); + data.CopyTo(into[5..]); + self.scratchBufferManager.MoveOffset(5 + data.Length); + } + + self.state.Remove(stackIndex); + } + + // Encode a table at stackIndex, and remove it + static void EncodeTable(LuaRunner self, int stackIndex, int depth) + { + Debug.Assert(self.state.Type(stackIndex) == LuaType.Table, "Expected table"); + + // Space for key and value + self.state.ForceMinimumStackCapacity(2); + + var tableIndex = stackIndex; + + // A zero-length table is serialized as an array + var isArray = true; + var count = 0; + var max = 0; + + var keyIndex = self.state.StackTop + 1; + + // Measure the table and figure out if we're creating a map or an array + self.state.PushNil(); + while (self.state.Next(tableIndex) != 0) + { + count++; + + // Remove value + self.state.Pop(1); + + double keyAsNum; + if (self.state.Type(keyIndex) != LuaType.Number || (keyAsNum = self.state.CheckNumber(keyIndex)) <= 0 || keyAsNum != (int)keyAsNum) + { + isArray = false; + } + else + { + if (keyAsNum > max) + { + max = (int)keyAsNum; + } + } + } + + if (isArray && count == max) + { + EncodeArray(self, stackIndex, depth, count); + } + else + { + EncodeMap(self, stackIndex, depth, count); + } + } + + // Encode a table at stackIndex into an array, and remove it + static void EncodeArray(LuaRunner self, int stackIndex, int depth, int count) + { + Debug.Assert(self.state.Type(stackIndex) == LuaType.Table, "Expected table"); + Debug.Assert(count >= 0, "Array should have positive length"); + + // Reserve space for value + self.state.ForceMinimumStackCapacity(1); + + var tableIndex = stackIndex; + var valueIndex = tableIndex + 1; + + // Encode length + if (count <= 15) + { + var into = self.scratchBufferManager.ViewRemainingArgSlice(1).Span; + into[0] = (byte)(0b1001_0000 | count); + self.scratchBufferManager.MoveOffset(1); + } + else if (count <= ushort.MaxValue) + { + var into = self.scratchBufferManager.ViewRemainingArgSlice(3).Span; + + into[0] = 0xDC; + BinaryPrimitives.WriteUInt16BigEndian(into[1..], (ushort)count); + self.scratchBufferManager.MoveOffset(3); + } + else + { + var into = self.scratchBufferManager.ViewRemainingArgSlice(5).Span; + + into[0] = 0xDD; + BinaryPrimitives.WriteUInt32BigEndian(into[1..], (uint)count); + self.scratchBufferManager.MoveOffset(5); + } + + // Write each element out + for (var ix = 1; ix <= count; ix++) + { + _ = self.state.RawGetInteger(null, tableIndex, ix); + Encode(self, valueIndex, depth + 1); + } + + self.state.Remove(tableIndex); + } + + // Encode a table at stackIndex into a map, and remove it + static void EncodeMap(LuaRunner self, int stackIndex, int depth, int count) + { + Debug.Assert(self.state.Type(stackIndex) == LuaType.Table, "Expected table"); + Debug.Assert(count >= 0, "Map should have positive length"); + + // Reserve space for key, value, and copy of key + self.state.ForceMinimumStackCapacity(2); + + var tableIndex = stackIndex; + var keyIndex = self.state.StackTop + 1; + var valueIndex = keyIndex + 1; + var keyCopyIndex = valueIndex + 1; + + // Encode length + if (count <= 15) + { + var into = self.scratchBufferManager.ViewRemainingArgSlice(1).Span; + + into[0] = (byte)(0b1000_0000 | count); + self.scratchBufferManager.MoveOffset(1); + } + else if (count <= ushort.MaxValue) + { + var into = self.scratchBufferManager.ViewRemainingArgSlice(3).Span; + + into[0] = 0xDE; + BinaryPrimitives.WriteUInt16BigEndian(into[1..], (ushort)count); + self.scratchBufferManager.MoveOffset(3); + } + else + { + var into = self.scratchBufferManager.ViewRemainingArgSlice(5).Span; + + into[0] = 0xDF; + BinaryPrimitives.WriteUInt32BigEndian(into[1..], (uint)count); + self.scratchBufferManager.MoveOffset(5); + } + + self.state.PushNil(); + while (self.state.Next(tableIndex) != 0) + { + // Make a copy of the key + self.state.PushValue(keyIndex); + + // Write the key + Encode(self, keyCopyIndex, depth + 1); + + // Write the value + Encode(self, valueIndex, depth + 1); + } + + self.state.Remove(tableIndex); + } + } + + /// + /// Entry point for cmsgpack.unpack from a Lua script. + /// + public int CMsgPackUnpack(nint luaStatePtr) + { + state.CallFromLuaEntered(luaStatePtr); + + var numLuaArgs = state.StackTop; + + if (numLuaArgs == 0) + { + return state.RaiseError("bad argument to unpack"); + } + + _ = state.CheckBuffer(1, out var data); + + var decodedCount = 0; + while (!data.IsEmpty) + { + // Reserve space for the result + state.ForceMinimumStackCapacity(1); + + try + { + Decode(ref data, ref state); + decodedCount++; + } + catch (Exception e) + { + // Best effort at matching Redis behavior + return state.RaiseError($"Missing bytes in input. {e.Message}"); + } + } + + return decodedCount; + + // Decode a msg pack + static void Decode(ref ReadOnlySpan data, ref LuaStateWrapper state) + { + var sigil = data[0]; + data = data[1..]; + + switch (sigil) + { + case 0xC0: DecodeNull(ref data, ref state); return; + case 0xC2: DecodeBoolean(false, ref data, ref state); return; + case 0xC3: DecodeBoolean(true, ref data, ref state); return; + // 7-bit positive integers handled below + // 5-bit negative integers handled below + case 0xCC: DecodeUInt8(ref data, ref state); return; + case 0xCD: DecodeUInt16(ref data, ref state); return; + case 0xCE: DecodeUInt32(ref data, ref state); return; + case 0xCF: DecodeUInt64(ref data, ref state); return; + case 0xD0: DecodeInt8(ref data, ref state); return; + case 0xD1: DecodeInt16(ref data, ref state); return; + case 0xD2: DecodeInt32(ref data, ref state); return; + case 0xD3: DecodeInt64(ref data, ref state); return; + case 0xCA: DecodeSingle(ref data, ref state); return; + case 0xCB: DecodeDouble(ref data, ref state); return; + // <= 31 byte strings handled below + case 0xD9: DecodeSmallString(ref data, ref state); return; + case 0xDA: DecodeMidString(ref data, ref state); return; + case 0xDB: DecodeLargeString(ref data, ref state); return; + // We treat bins as strings + case 0xC4: goto case 0xD9; + case 0xC5: goto case 0xDA; + case 0xC6: goto case 0xDB; + // <= 15 element arrays are handled below + case 0xDC: DecodeMidArray(ref data, ref state); return; + case 0xDD: DecodeLargeArray(ref data, ref state); return; + // <= 15 pair maps are handled below + case 0xDE: DecodeMidMap(ref data, ref state); return; + case 0xDF: DecodeLargeMap(ref data, ref state); return; + + default: + if ((sigil & 0b1000_0000) == 0) + { + DecodeTinyUInt(sigil, ref state); + return; + } + else if ((sigil & 0b1110_0000) == 0b1110_0000) + { + DecodeTinyInt(sigil, ref state); + return; + } + else if ((sigil & 0b1110_0000) == 0b1010_0000) + { + DecodeTinyString(sigil, ref data, ref state); + return; + } + else if ((sigil & 0b1111_0000) == 0b1001_0000) + { + DecodeSmallArray(sigil, ref data, ref state); + return; + } + else if ((sigil & 0b1111_0000) == 0b1000_0000) + { + DecodeSmallMap(sigil, ref data, ref state); + return; + } + + _ = state.RaiseError($"Unexpected MsgPack sigil {sigil}/x{sigil:X2}/b{sigil:B8}"); + return; + } + } + + // Decode a null push it to the stack + static void DecodeNull(ref ReadOnlySpan data, ref LuaStateWrapper state) + { + state.PushNil(); + } + + // Decode a boolean and push it to the stack + static void DecodeBoolean(bool b, ref ReadOnlySpan data, ref LuaStateWrapper state) + { + state.PushBoolean(b); + } + + // Decode a byte, moving past it in data and pushing it to the stack + static void DecodeUInt8(ref ReadOnlySpan data, ref LuaStateWrapper state) + { + state.PushNumber(data[0]); + data = data[1..]; + } + + // Decode a positive 7-bit value, pushing it to the stack + static void DecodeTinyUInt(byte sigil, ref LuaStateWrapper state) + { + state.PushNumber(sigil); + } + + // Decode a ushort, moving past it in data and pushing it to the stack + static void DecodeUInt16(ref ReadOnlySpan data, ref LuaStateWrapper state) + { + state.PushNumber(BinaryPrimitives.ReadUInt16BigEndian(data)); + data = data[2..]; + } + + // Decode a uint, moving past it in data and pushing it to the stack + static void DecodeUInt32(ref ReadOnlySpan data, ref LuaStateWrapper state) + { + state.PushNumber(BinaryPrimitives.ReadUInt32BigEndian(data)); + data = data[4..]; + } + + // Decode a ulong, moving past it in data and pushing it to the stack + static void DecodeUInt64(ref ReadOnlySpan data, ref LuaStateWrapper state) + { + state.PushNumber(BinaryPrimitives.ReadUInt64BigEndian(data)); + data = data[8..]; + } + + // Decode a negative 5-bit value, pushing it to the stack + static void DecodeTinyInt(byte sigil, ref LuaStateWrapper state) + { + var signExtended = (int)(0xFFFF_FF00 | sigil); + state.PushNumber(signExtended); + } - var argCount = state.StackTop; - if (argCount != 1) + // Decode a sbyte, moving past it in data and pushing it to the stack + static void DecodeInt8(ref ReadOnlySpan data, ref LuaStateWrapper state) { - state.PushConstantString(constStrs.ErrWrongNumberOfArgs); - return state.RaiseErrorFromStack(); + state.PushNumber((sbyte)data[0]); + data = data[1..]; } - if (!state.CheckBuffer(1, out var bytes)) + // Decode a short, moving past it in data and pushing it to the stack + static void DecodeInt16(ref ReadOnlySpan data, ref LuaStateWrapper state) { - bytes = default; + state.PushNumber(BinaryPrimitives.ReadInt16BigEndian(data)); + data = data[2..]; } - Span hashBytes = stackalloc byte[SessionScriptCache.SHA1Len / 2]; - Span hexRes = stackalloc byte[SessionScriptCache.SHA1Len]; + // Decode a int, moving past it in data and pushing it to the stack + static void DecodeInt32(ref ReadOnlySpan data, ref LuaStateWrapper state) + { + state.PushNumber(BinaryPrimitives.ReadInt32BigEndian(data)); + data = data[4..]; + } - SessionScriptCache.GetScriptDigest(bytes, hashBytes, hexRes); + // Decode a long, moving past it in data and pushing it to the stack + static void DecodeInt64(ref ReadOnlySpan data, ref LuaStateWrapper state) + { + state.PushNumber(BinaryPrimitives.ReadInt64BigEndian(data)); + data = data[8..]; + } - state.PushBuffer(hexRes); - return 1; - } + // Decode a float, moving past it in data and pushing it to the stack + static void DecodeSingle(ref ReadOnlySpan data, ref LuaStateWrapper state) + { + state.PushNumber(BinaryPrimitives.ReadSingleBigEndian(data)); + data = data[4..]; + } - /// - /// Entry point for redis.log(...) from a Lua script. - /// - public int Log(nint luaStatePtr) - { - state.CallFromLuaEntered(luaStatePtr); + // Decode a double, moving past it in data and pushing it to the stack + static void DecodeDouble(ref ReadOnlySpan data, ref LuaStateWrapper state) + { + state.PushNumber(BinaryPrimitives.ReadDoubleBigEndian(data)); + data = data[8..]; + } - var argCount = state.StackTop; - if (argCount < 2) + // Decode a string size <= 31, moving past it in data and pushing it to the stack + static void DecodeTinyString(byte sigil, ref ReadOnlySpan data, ref LuaStateWrapper state) { - return LuaStaticError(constStrs.ErrRedisLogRequired); + var len = sigil & 0b0001_1111; + var str = data[..len]; + + state.PushBuffer(str); + data = data[len..]; } - if (state.Type(1) != LuaType.Number) + // Decode a string size <= 255, moving past it in data and pushing it to the stack + static void DecodeSmallString(ref ReadOnlySpan data, ref LuaStateWrapper state) { - return LuaStaticError(constStrs.ErrFirstArgMustBeNumber); + var len = data[0]; + data = data[1..]; + + var str = data[..len]; + + state.PushBuffer(str); + + data = data[str.Length..]; } - var rawLevel = state.CheckNumber(1); - if (rawLevel is not (0 or 1 or 2 or 3)) + // Decode a string size <= 65,535, moving past it in data and pushing it to the stack + static void DecodeMidString(ref ReadOnlySpan data, ref LuaStateWrapper state) { - return LuaStaticError(constStrs.ErrInvalidDebugLevel); + var len = BinaryPrimitives.ReadUInt16BigEndian(data); + data = data[2..]; + + var str = data[..(int)len]; + + state.PushBuffer(str); + + data = data[str.Length..]; } - if (logMode == LuaLoggingMode.Disable) + // Decode a string size <= 4,294,967,295, moving past it in data and pushing it to the stack + static void DecodeLargeString(ref ReadOnlySpan data, ref LuaStateWrapper state) { - return LuaStaticError(constStrs.ErrLoggingDisabled); + var len = BinaryPrimitives.ReadUInt32BigEndian(data); + data = data[4..]; + + if ((int)len < 0) + { + _ = state.RaiseError($"String length is too long: {len}"); + return; + } + + var str = data[..(int)len]; + + state.PushBuffer(str); + + data = data[str.Length..]; } - // When shipped as a service, allowing arbitrary writes to logs is dangerous - // so we support disabling it (while not breaking existing scripts) - if (logMode == LuaLoggingMode.Silent) + // Decode an array with <= 15 items, moving past it in data and pushing it to the stack + static void DecodeSmallArray(byte sigil, ref ReadOnlySpan data, ref LuaStateWrapper state) { - return 0; + // Reserve extra space for the temporary item + state.ForceMinimumStackCapacity(1); + + var len = sigil & 0b0000_1111; + + state.CreateTable(len, 0); + var arrayIndex = state.StackTop; + + for (var i = 1; i <= len; i++) + { + // Push the element onto the stack + Decode(ref data, ref state); + state.RawSetInteger(arrayIndex, i); + } } - // Even if enabled, if no logger was provided we can just bail - if (logger == null) + // Decode an array with <= 65,535 items, moving past it in data and pushing it to the stack + static void DecodeMidArray(ref ReadOnlySpan data, ref LuaStateWrapper state) { - return 0; + // Reserve extra space for the temporary item + state.ForceMinimumStackCapacity(1); + + var len = BinaryPrimitives.ReadUInt16BigEndian(data); + data = data[2..]; + + state.CreateTable(len, 0); + var arrayIndex = state.StackTop; + + for (var i = 1; i <= len; i++) + { + // Push the element onto the stack + Decode(ref data, ref state); + state.RawSetInteger(arrayIndex, i); + } } - // Construct and log the equivalent message - string logMessage; - if (argCount == 2) + // Decode an array with <= 4,294,967,295 items, moving past it in data and pushing it to the stack + static void DecodeLargeArray(ref ReadOnlySpan data, ref LuaStateWrapper state) { - if (state.CheckBuffer(2, out var buff)) + // Reserve extra space for the temporary item + state.ForceMinimumStackCapacity(1); + + var len = BinaryPrimitives.ReadUInt32BigEndian(data); + data = data[4..]; + + if ((int)len < 0) { - logMessage = Encoding.UTF8.GetString(buff); + _ = state.RaiseError($"Array length is too long: {len}"); + return; } - else + + state.CreateTable((int)len, 0); + var arrayIndex = state.StackTop; + + for (var i = 1; i <= len; i++) { - logMessage = ""; + // Push the element onto the stack + Decode(ref data, ref state); + state.RawSetInteger(arrayIndex, i); } } - else + + // Decode an map with <= 15 key-value pairs, moving past it in data and pushing it to the stack + static void DecodeSmallMap(byte sigil, ref ReadOnlySpan data, ref LuaStateWrapper state) { - var sb = new StringBuilder(); + // Reserve extra space for the temporary key & value + state.ForceMinimumStackCapacity(2); - for (var argIx = 2; argIx <= argCount; argIx++) + var len = sigil & 0b0000_1111; + + state.CreateTable(0, len); + var mapIndex = state.StackTop; + + for (var i = 1; i <= len; i++) { - if (state.CheckBuffer(argIx, out var buff)) - { - if (sb.Length != 0) - { - _ = sb.Append(' '); - } + // Push the key onto the stack + Decode(ref data, ref state); - _ = sb.Append(Encoding.UTF8.GetString(buff)); - } + // Push the value onto the stack + Decode(ref data, ref state); + + state.RawSet(mapIndex); } + } - logMessage = sb.ToString(); + // Decode a map with <= 65,535 key-value pairs, moving past it in data and pushing it to the stack + static void DecodeMidMap(ref ReadOnlySpan data, ref LuaStateWrapper state) + { + // Reserve extra space for the temporary key & value + state.ForceMinimumStackCapacity(2); + + var len = BinaryPrimitives.ReadUInt16BigEndian(data); + data = data[2..]; + + state.CreateTable(0, len); + var mapIndex = state.StackTop; + + for (var i = 1; i <= len; i++) + { + // Push the key onto the stack + Decode(ref data, ref state); + + // Push the value onto the stack + Decode(ref data, ref state); + + state.RawSet(mapIndex); + } } - var logLevel = - rawLevel switch + // Decode a map with <= 4,294,967,295 key-value pairs, moving past it in data and pushing it to the stack + static void DecodeLargeMap(ref ReadOnlySpan data, ref LuaStateWrapper state) + { + // Reserve extra space for the temporary key & value + state.ForceMinimumStackCapacity(2); + + var len = BinaryPrimitives.ReadUInt32BigEndian(data); + data = data[4..]; + + if ((int)len < 0) { - 0 => LogLevel.Debug, - 1 => LogLevel.Information, - 2 => LogLevel.Warning, - // We validated this above, so really it's just 3 but the switch needs to be exhaustive - _ => LogLevel.Error, - }; + _ = state.RaiseError($"Map length is too long: {len}"); + return; + } - logger.Log(logLevel, "redis.log: {message}", logMessage.ToString()); + state.CreateTable(0, (int)len); + var mapIndex = state.StackTop; - return 0; + for (var i = 1; i <= len; i++) + { + // Push the key onto the stack + Decode(ref data, ref state); + + // Push the value onto the stack + Decode(ref data, ref state); + + state.RawSet(mapIndex); + } + } } /// @@ -1328,7 +3308,7 @@ private unsafe int ProcessRespResponse(byte respProtocolVersion, byte* respPtr, { var respEnd = respPtr + respLen; - var ret = ProcessSingleResp3Term(respProtocolVersion, ref respPtr, respEnd); + var ret = ProcessSingleRespTerm(respProtocolVersion, ref respPtr, respEnd); if (respPtr != respEnd) { @@ -1338,7 +3318,7 @@ private unsafe int ProcessRespResponse(byte respProtocolVersion, byte* respPtr, return ret; } - private unsafe int ProcessSingleResp3Term(byte respProtocolVersion, ref byte* respPtr, byte* respEnd) + private unsafe int ProcessSingleRespTerm(byte respProtocolVersion, ref byte* respPtr, byte* respEnd) { var indicator = (char)*respPtr; @@ -1440,7 +3420,7 @@ private unsafe int ProcessSingleResp3Term(byte respProtocolVersion, ref byte* re for (var itemIx = 0; itemIx < arrayItemCount; itemIx++) { // Pushes the item to the top of the stack - _ = ProcessSingleResp3Term(respProtocolVersion, ref respPtr, respEnd); + _ = ProcessSingleRespTerm(respProtocolVersion, ref respPtr, respEnd); // Store the item into the table state.RawSetInteger(curTop + 1, itemIx + 1); @@ -1470,10 +3450,10 @@ private unsafe int ProcessSingleResp3Term(byte respProtocolVersion, ref byte* re for (var pair = 0; pair < mapPairCount; pair++) { // Read key - _ = ProcessSingleResp3Term(respProtocolVersion, ref respPtr, respEnd); + _ = ProcessSingleRespTerm(respProtocolVersion, ref respPtr, respEnd); // Read value - _ = ProcessSingleResp3Term(respProtocolVersion, ref respPtr, respEnd); + _ = ProcessSingleRespTerm(respProtocolVersion, ref respPtr, respEnd); // Set t[k] = v state.RawSet(curTop + 3); @@ -1526,7 +3506,7 @@ private unsafe int ProcessSingleResp3Term(byte respProtocolVersion, ref byte* re for (var pair = 0; pair < setItemCount; pair++) { // Read value, which we use as a key - _ = ProcessSingleResp3Term(respProtocolVersion, ref respPtr, respEnd); + _ = ProcessSingleRespTerm(respProtocolVersion, ref respPtr, respEnd); // Unconditionally the value under the key is true state.PushBoolean(true); @@ -2471,7 +4451,34 @@ static unsafe void WriteString(LuaRunner runner, ref TResponse resp) { runner.state.KnownStringToBuffer(runner.state.StackTop, out var buf); - while (!RespWriteUtils.TryWriteBulkString(buf, ref resp.BufferCur, resp.BufferEnd)) + // Strings can be veeeerrrrry large, so we can't use the short helpers + // Thus we write the full string directly + while (!RespWriteUtils.TryWriteBulkStringLength(buf, ref resp.BufferCur, resp.BufferEnd)) + resp.SendAndReset(); + + // Repeat while we have bytes left to write + while (!buf.IsEmpty) + { + // Copy bytes over + var destSpace = resp.BufferEnd - resp.BufferCur; + var copyLen = (int)(destSpace < buf.Length ? destSpace : buf.Length); + buf.Slice(0, copyLen).CopyTo(new Span(resp.BufferCur, copyLen)); + + // Advance + resp.BufferCur += copyLen; + + // Flush if we filled the buffer + if (destSpace == copyLen) + { + resp.SendAndReset(); + } + + // Move past the data we wrote out + buf = buf.Slice(copyLen); + } + + // End the string + while (!RespWriteUtils.TryWriteNewLine(ref resp.BufferCur, resp.BufferEnd)) resp.SendAndReset(); runner.state.Pop(1); @@ -2811,6 +4818,77 @@ private static (int Start, ulong[] Bitmap) InitializeNoScriptDetails() return (start, bitmap); } + + /// + /// Modifies to account for , and converts to bytes. + /// + /// Provided as an optimization, as often this can be memoized. + /// + private static ReadOnlyMemory PrepareLoaderBlockBytes(HashSet allowedFunctions) + { + // If nothing is explicitly allowed, fallback to our defaults + if (allowedFunctions.Count == 0) + { + allowedFunctions = DefaultAllowedFunctions; + } + + // Most of the time this list never changes, so reuse the work + var cache = CachedLoaderBlock; + if (cache != null && ReferenceEquals(cache.AllowedFunctions, allowedFunctions)) + { + return cache.LoaderBlockBytes; + } + + // Build the subset of a Lua table where we export all these functions + var wholeIncludes = new HashSet(StringComparer.OrdinalIgnoreCase); + var replacement = new StringBuilder(); + foreach (var wholeRef in allowedFunctions.Where(static x => !x.Contains('.'))) + { + if (!DefaultAllowedFunctions.Contains(wholeRef)) + { + // Skip functions not intentionally exported + continue; + } + + _ = replacement.AppendLine($" {wholeRef}={wholeRef};"); + _ = wholeIncludes.Add(wholeRef); + } + + // Partial includes (ie. os.clock) need special handling + var partialIncludes = allowedFunctions.Where(static x => x.Contains('.')).Select(static x => (Leading: x[..x.IndexOf('.')], Trailing: x[(x.IndexOf('.') + 1)..])); + foreach (var grouped in partialIncludes.GroupBy(static t => t.Leading, StringComparer.OrdinalIgnoreCase)) + { + if (wholeIncludes.Contains(grouped.Key)) + { + // Including a subset of something included in whole doesn't affect things + continue; + } + + if (!DefaultAllowedFunctions.Contains(grouped.Key)) + { + // Skip functions not intentionally exported + continue; + } + + _ = replacement.AppendLine($" {grouped.Key}={{"); + foreach (var part in grouped.Select(static t => t.Trailing).Distinct().OrderBy(static t => t)) + { + _ = replacement.AppendLine($" {part}={grouped.Key}.{part};"); + } + _ = replacement.AppendLine(" };"); + } + + var decl = replacement.ToString(); + var finalLoaderBlock = LoaderBlock.Replace("!!SANDBOX_ENV REPLACEMENT TARGET!!", decl); + + // Save off for next caller + // + // Inherently race-y, but that's fine - worst case we do a little extra work + var newCache = new LoaderBlockCache(allowedFunctions, Encoding.UTF8.GetBytes(finalLoaderBlock)); + CachedLoaderBlock = newCache; + + return newCache.LoaderBlockBytes; + } } /// @@ -2950,5 +5028,132 @@ internal static int AclCheckCommand(nint luaState) [UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])] internal static int SetResp(nint luaState) => CallbackContext.SetResp(luaState); + + /// + /// Entry point for calls to math.atan2. + /// + [UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])] + internal static int Atan2(nint luaState) + => CallbackContext.Atan2(luaState); + + /// + /// Entry point for calls to math.cosh. + /// + [UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])] + internal static int Cosh(nint luaState) + => CallbackContext.Cosh(luaState); + + /// + /// Entry point for calls to math.frexp. + /// + [UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])] + internal static int Frexp(nint luaState) + => CallbackContext.Frexp(luaState); + + /// + /// Entry point for calls to math.ldexp. + /// + [UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])] + internal static int Ldexp(nint luaState) + => CallbackContext.Ldexp(luaState); + + /// + /// Entry point for calls to math.log10. + /// + [UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])] + internal static int Log10(nint luaState) + => CallbackContext.Log10(luaState); + + /// + /// Entry point for calls to math.pow. + /// + [UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])] + internal static int Pow(nint luaState) + => CallbackContext.Pow(luaState); + + /// + /// Entry point for calls to math.sinh. + /// + [UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])] + internal static int Sinh(nint luaState) + => CallbackContext.Sinh(luaState); + + /// + /// Entry point for calls to math.tanh. + /// + [UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])] + internal static int Tanh(nint luaState) + => CallbackContext.Tanh(luaState); + + /// + /// Entry point for calls to table.maxn. + /// + [UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])] + internal static int Maxn(nint luaState) + => CallbackContext.Maxn(luaState); + + /// + /// Entry point for calls to loadstring. + /// + [UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])] + internal static int LoadString(nint luaState) + => CallbackContext.LoadString(luaState); + + /// + /// Entry point for calls to cjson.encode. + /// + [UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])] + internal static int CJsonEncode(nint luaState) + => CallbackContext.CJsonEncode(luaState); + + /// + /// Entry point for calls to cjson.decode. + /// + [UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])] + internal static int CJsonDecode(nint luaState) + => CallbackContext.CJsonDecode(luaState); + + /// + /// Entry point for calls to bit.tobit. + /// + [UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])] + internal static int BitToBit(nint luaState) + => CallbackContext.BitToBit(luaState); + + /// + /// Entry point for calls to bit.tohex. + /// + [UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])] + internal static int BitToHex(nint luaState) + => CallbackContext.BitToHex(luaState); + + /// + /// Entry point for calls to garnet_bitop, which backs + /// bit.bnot, bit.bor, etc. + /// + [UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])] + internal static int Bitop(nint luaState) + => CallbackContext.Bitop(luaState); + + /// + /// Entry point for calls to bit.bswap. + /// + [UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])] + internal static int BitBswap(nint luaState) + => CallbackContext.BitBswap(luaState); + + /// + /// Entry point for calls to cmsgpack.pack. + /// + [UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])] + internal static int CMsgPackPack(nint luaState) + => CallbackContext.CMsgPackPack(luaState); + + /// + /// Entry point for calls to cmsgpack.unpack. + /// + [UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])] + internal static int CMsgPackUnpack(nint luaState) + => CallbackContext.CMsgPackUnpack(luaState); } } \ No newline at end of file diff --git a/libs/server/Lua/LuaStateWrapper.cs b/libs/server/Lua/LuaStateWrapper.cs index 7404e1dce73..a3cf6e32f8d 100644 --- a/libs/server/Lua/LuaStateWrapper.cs +++ b/libs/server/Lua/LuaStateWrapper.cs @@ -3,6 +3,7 @@ using System; using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using System.Text; @@ -443,18 +444,54 @@ internal void SetGlobal(ReadOnlySpan nullTerminatedGlobalName) /// /// This should be used for all LoadBuffers into Lua. /// - /// Note that this is different from pushing a buffer, as the loaded buffer is compiled. + /// Note that this is different from pushing a buffer, as the loaded buffer is compiled and executed. /// /// Maintains and to minimize p/invoke calls. /// [MethodImpl(MethodImplOptions.AggressiveInlining)] internal LuaStatus LoadBuffer(ReadOnlySpan buffer) { - AssertLuaStackNotFull(); + AssertLuaStackNotFull(2); var ret = NativeMethods.LoadBuffer(state, buffer); - UpdateStackTop(1); + if (ret != LuaStatus.OK) + { + StackTop = NativeMethods.GetTop(state); + } + else + { + UpdateStackTop(1); + } + + AssertLuaStackExpected(); + + return ret; + } + + /// + /// This should be used for all LoadStrings into Lua. + /// + /// Note that this is different from pushing or loading buffer, as the loaded buffer is compiled but NOT executed. + /// + /// Maintains and to minimize p/invoke calls. + /// + internal LuaStatus LoadString(ReadOnlySpan buffer) + { + AssertLuaStackNotFull(2); + + var ret = NativeMethods.LoadString(state, buffer); + + if (ret != LuaStatus.OK) + { + StackTop = NativeMethods.GetTop(state); + } + else + { + UpdateStackTop(1); + } + + AssertLuaStackExpected(); return ret; } @@ -581,6 +618,20 @@ internal void PushValue(int stackIndex) UpdateStackTop(1); } + /// + /// This should be used for all Removes into Lua. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal void Remove(int stackIndex) + { + AssertLuaStackIndexInBounds(stackIndex); + + NativeMethods.Rotate(state, stackIndex, -1); + NativeMethods.Pop(state, 1); + + UpdateStackTop(-1); + } + // Rarely used /// @@ -624,6 +675,7 @@ internal void ClearStack() /// /// Clear the stack and raise an error with the given message. /// + [DoesNotReturn] internal int RaiseError(string msg) { ClearStack(); @@ -636,6 +688,7 @@ internal int RaiseError(string msg) /// /// Raise an error, where the top of the stack is the error message. /// + [DoesNotReturn] internal readonly int RaiseErrorFromStack() { Debug.Assert(StackTop != 0, "Expected error message on the stack"); diff --git a/libs/server/Lua/NativeMethods.cs b/libs/server/Lua/NativeMethods.cs index 5035ff78fa3..6a784d9c060 100644 --- a/libs/server/Lua/NativeMethods.cs +++ b/libs/server/Lua/NativeMethods.cs @@ -46,6 +46,13 @@ internal static partial class NativeMethods [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl)])] private static partial LuaStatus luaL_loadbufferx(lua_State luaState, charptr_t buff, size_t sz, charptr_t name, charptr_t mode); + /// + /// see: https://www.lua.org/manual/5.4/manual.html#luaL_loadstring + /// + [LibraryImport(LuaLibraryName)] + [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl)])] + private static partial LuaStatus luaL_loadstring(lua_State lua_State, charptr_t buff); + /// /// see: https://www.lua.org/manual/5.4/manual.html#luaL_newstate /// @@ -186,6 +193,13 @@ internal static partial class NativeMethods [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl)])] private static partial int lua_next(lua_State luaState, int tableIndex); + /// + /// see: https://www.lua.org/manual/5.4/manual.html#lua_rotate + /// + [LibraryImport(LuaLibraryName)] + [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl)])] + private static partial void lua_rotate(lua_State luaState, int stackIndex, int n); + // GC Transition suppressed - only do this after auditing the Lua method and confirming constant-ish, fast, runtime w/o allocations /// @@ -383,7 +397,7 @@ internal static unsafe ref byte PushBuffer(lua_State luaState, ReadOnlySpan - /// Push given span to stack, and compiles it. + /// Push given span to stack, compiles it, and executes it. /// /// Provided data is copied, and can be reused once this call returns. /// @@ -395,6 +409,19 @@ internal static unsafe LuaStatus LoadBuffer(lua_State luaState, ReadOnlySpan + /// Push given span to stack, and compiles it. + /// + /// Provided data is copied, and can be reused once this call returns. + /// + internal static unsafe LuaStatus LoadString(lua_State luaState, ReadOnlySpan str) + { + fixed (byte* ptr = str) + { + return luaL_loadstring(luaState, (charptr_t)ptr); + } + } + /// /// Get the top index on the stack. /// @@ -644,6 +671,15 @@ internal static void PushValue(lua_State luaState, int stackIndex) internal static void SetTop(lua_State lua_State, int top) => lua_settop(lua_State, top); + /// + /// Rotates elements above (and including) in steps in + /// the direction of the top of the stack. + /// + /// can be negative. + /// + internal static void Rotate(lua_State luaState, int stackIndex, int n) + => lua_rotate(luaState, stackIndex, n); + /// /// Raise an error, using the top of the stack as an error item. /// diff --git a/libs/server/Lua/SessionScriptCache.cs b/libs/server/Lua/SessionScriptCache.cs index 226329e6c7d..cae4c8ef9ca 100644 --- a/libs/server/Lua/SessionScriptCache.cs +++ b/libs/server/Lua/SessionScriptCache.cs @@ -31,6 +31,7 @@ internal sealed class SessionScriptCache : IDisposable readonly int? memoryLimitBytes; readonly LuaTimeoutManager timeoutManager; readonly LuaLoggingMode logMode; + readonly HashSet allowedFunctions; LuaRunner timeoutRunningScript; LuaTimeoutManager.Registration timeoutRegistration; @@ -57,6 +58,7 @@ public SessionScriptCache(StoreWrapper storeWrapper, IGarnetAuthenticator authen memoryManagementMode = storeWrapper.serverOptions.LuaOptions.MemoryManagementMode; memoryLimitBytes = storeWrapper.serverOptions.LuaOptions.GetMemoryLimitBytes(); logMode = storeWrapper.serverOptions.LuaOptions.LogMode; + allowedFunctions = storeWrapper.serverOptions.LuaOptions.AllowedFunctions; } public void Dispose() @@ -138,7 +140,7 @@ internal bool TryLoad(RespServerSession session, ReadOnlySpan source, Scri { var sourceOnHeap = source.ToArray(); - runner = new LuaRunner(memoryManagementMode, memoryLimitBytes, logMode, sourceOnHeap, storeWrapper.serverOptions.LuaTransactionMode, processor, scratchBufferNetworkSender, storeWrapper.redisProtocolVersion, logger); + runner = new LuaRunner(memoryManagementMode, memoryLimitBytes, logMode, allowedFunctions, sourceOnHeap, storeWrapper.serverOptions.LuaTransactionMode, processor, scratchBufferNetworkSender, storeWrapper.redisProtocolVersion, logger); // If compilation fails, an error is written out if (runner.CompileForSession(session)) diff --git a/test/Garnet.test/GarnetServerConfigTests.cs b/test/Garnet.test/GarnetServerConfigTests.cs index 500bf3ddba7..6a49537fbef 100644 --- a/test/Garnet.test/GarnetServerConfigTests.cs +++ b/test/Garnet.test/GarnetServerConfigTests.cs @@ -520,7 +520,7 @@ public void LuaLoggingOptions() } } - // Command line args + // JSON args { // No value is accepted { @@ -567,6 +567,92 @@ public void LuaLoggingOptions() } } + [Test] + public void LuaAllowedFunctions() + { + // Command line args + { + // No value is accepted + { + var args = new[] { "--lua" }; + var parseSuccessful = ServerSettingsManager.TryParseCommandLineArguments(args, out var options, out var invalidOptions, out var exitGracefully); + ClassicAssert.IsTrue(parseSuccessful); + ClassicAssert.IsTrue(options.EnableLua); + ClassicAssert.AreEqual(0, options.LuaAllowedFunctions.Count()); + } + + // One option works + { + var args = new[] { "--lua", "--lua-allowed-functions", "os" }; + var parseSuccessful = ServerSettingsManager.TryParseCommandLineArguments(args, out var options, out var invalidOptions, out var exitGracefully); + ClassicAssert.IsTrue(parseSuccessful); + ClassicAssert.IsTrue(options.EnableLua); + ClassicAssert.AreEqual(1, options.LuaAllowedFunctions.Count()); + ClassicAssert.IsTrue(options.LuaAllowedFunctions.Contains("os")); + } + + // Multiple option works + { + var args = new[] { "--lua", "--lua-allowed-functions", "os,assert,rawget" }; + var parseSuccessful = ServerSettingsManager.TryParseCommandLineArguments(args, out var options, out var invalidOptions, out var exitGracefully); + ClassicAssert.IsTrue(parseSuccessful); + ClassicAssert.IsTrue(options.EnableLua); + ClassicAssert.AreEqual(3, options.LuaAllowedFunctions.Count()); + ClassicAssert.IsTrue(options.LuaAllowedFunctions.Contains("os")); + ClassicAssert.IsTrue(options.LuaAllowedFunctions.Contains("assert")); + ClassicAssert.IsTrue(options.LuaAllowedFunctions.Contains("rawget")); + } + + // Invalid rejected + { + var args = new[] { "--lua", "--lua-allowed-functions" }; + var parseSuccessful = ServerSettingsManager.TryParseCommandLineArguments(args, out var options, out var invalidOptions, out var exitGracefully); + ClassicAssert.IsFalse(parseSuccessful); + } + } + + // JSON args + { + // No value is accepted + { + const string JSON = @"{ ""EnableLua"": true }"; + var parseSuccessful = TryParseGarnetConfOptions(JSON, out var options, out var invalidOptions, out var exitGracefully); + ClassicAssert.IsTrue(parseSuccessful); + ClassicAssert.IsTrue(options.EnableLua); + ClassicAssert.AreEqual(0, options.LuaAllowedFunctions.Count()); + } + + // One option works + { + const string JSON = @"{ ""EnableLua"": true, ""LuaAllowedFunctions"": [""os""] }"; + var parseSuccessful = TryParseGarnetConfOptions(JSON, out var options, out var invalidOptions, out var exitGracefully); + ClassicAssert.IsTrue(parseSuccessful); + ClassicAssert.IsTrue(options.EnableLua); + ClassicAssert.AreEqual(1, options.LuaAllowedFunctions.Count()); + ClassicAssert.IsTrue(options.LuaAllowedFunctions.Contains("os")); + } + + // Multiple option works + { + const string JSON = @"{ ""EnableLua"": true, ""LuaAllowedFunctions"": [""os"", ""assert"", ""rawget""] }"; + var parseSuccessful = TryParseGarnetConfOptions(JSON, out var options, out var invalidOptions, out var exitGracefully); + ClassicAssert.IsTrue(parseSuccessful); + ClassicAssert.IsTrue(options.EnableLua); + ClassicAssert.AreEqual(3, options.LuaAllowedFunctions.Count()); + ClassicAssert.IsTrue(options.LuaAllowedFunctions.Contains("os")); + ClassicAssert.IsTrue(options.LuaAllowedFunctions.Contains("assert")); + ClassicAssert.IsTrue(options.LuaAllowedFunctions.Contains("rawget")); + } + + // Invalid rejected + { + const string JSON = @"{ ""EnableLua"": true, ""LuaAllowedFunctions"": { } }"; + var parseSuccessful = TryParseGarnetConfOptions(JSON, out var options, out var invalidOptions, out var exitGracefully); + ClassicAssert.IsFalse(parseSuccessful); + } + } + } + /// /// Import a garnet.conf file with the given contents /// diff --git a/test/Garnet.test/LuaScriptRunnerTests.cs b/test/Garnet.test/LuaScriptRunnerTests.cs index 8f5cb39354c..fdf2ab9a6a9 100644 --- a/test/Garnet.test/LuaScriptRunnerTests.cs +++ b/test/Garnet.test/LuaScriptRunnerTests.cs @@ -59,7 +59,7 @@ public void CannotRunUnsafeScript() { runner.CompileForRunner(); var ex = Assert.Throws(() => runner.RunForRunner()); - ClassicAssert.AreEqual("ERR Lua encountered an error: [string \"os.exit();\"]:1: attempt to index a nil value (global 'os')", ex.Message); + ClassicAssert.AreEqual("ERR Lua encountered an error: [string \"os.exit();\"]:1: attempt to call a nil value (field 'exit')", ex.Message); } // Try to include a new .net library @@ -518,7 +518,7 @@ public void RedisLogDisabled() #endif // Just because it's hard to test in LuaScriptTests, doing this here - using var runner = new LuaRunner(new(LuaMemoryManagementMode.Native, "", Timeout.InfiniteTimeSpan, LuaLoggingMode.Disable), "redis.log(redis.LOG_WARNING, 'foo')"); + using var runner = new LuaRunner(new(LuaMemoryManagementMode.Native, "", Timeout.InfiniteTimeSpan, LuaLoggingMode.Disable, []), "redis.log(redis.LOG_WARNING, 'foo')"); runner.CompileForRunner(); @@ -531,7 +531,7 @@ public void RedisLogSilent() { // Just because it's hard to test in LuaScriptTests, doing this here using var logger = new FakeLogger(); - using var runner = new LuaRunner(new(LuaMemoryManagementMode.Native, "", Timeout.InfiniteTimeSpan, LuaLoggingMode.Silent), "redis.log(redis.LOG_WARNING, 'foo')", logger: logger); + using var runner = new LuaRunner(new(LuaMemoryManagementMode.Native, "", Timeout.InfiniteTimeSpan, LuaLoggingMode.Silent, []), "redis.log(redis.LOG_WARNING, 'foo')", logger: logger); runner.CompileForRunner(); _ = runner.RunForRunner(); @@ -539,6 +539,210 @@ public void RedisLogSilent() ClassicAssert.IsFalse(logger.Logs.Any(static x => x.Contains("redis.log"))); } + [Test] + public void AllowedFunctions() + { + List globalFuncs = ["xpcall", "tostring", "setmetatable", "next", "assert", "tonumber", "rawequal", "collectgarbage", "getmetatable", "rawset", "pcall", "coroutine", "type", "_G", "select", "unpack", "gcinfo", "pairs", "rawget", "loadstring", "ipairs", "_VERSION", "load", "error"]; + var exportedFuncs = + new Dictionary> + { + ["bit"] = ["tobit", "tohex", "bnot", "bor", "band", "bxor", "lshift", "rshift", "arshift", "rol", "ror", "bswap"], + ["cjson"] = ["encode", "decode"], + ["cmsgpack"] = ["pack", "unpack"], + ["math"] = ["abs", "acos", "asin", "atan", "atan2", "ceil", "cos", "cosh", "deg", "exp", "floor", "fmod", "frexp", "huge", "ldexp", "log", "log10", "max", "min", "modf", "pi", "pow", "rad", "random", "randomseed", "sin", "sinh", "sqrt", "tan", "tanh"], + ["os"] = ["clock"], + ["redis"] = ["call", "pcall", "error_reply", "status_reply", "sha1hex", "log", "LOG_DEBUG", "LOG_VERBOSE", "LOG_NOTICE", "LOG_WARNING", "setresp", "set_repl", "REPL_ALL", "REPL_AOF", "REPL_REPLICA", "REPL_SLAVE", "REPL_NONE", "replicate_commands", "breakpoint", "debug", "acl_check_cmd", "REDIS_VERSION", "REDIS_VERSION_NUM"], + ["string"] = ["byte", "char", "dump", "find", "format", "gmatch", "gsub", "len", "lower", "match", "rep", "reverse", "sub", "upper"], + ["struct"] = ["pack", "unpack", "size"], + ["table"] = ["concat", "insert", "maxn", "remove", "sort"], + }; + + // Check the supported globals + { + using var allRunner = + new LuaRunner( + new(LuaMemoryManagementMode.Native, "", Timeout.InfiniteTimeSpan, LuaLoggingMode.Silent, []), + @$"local ret = {{ }} + for k, v in pairs(_G) do + table.insert(ret, k) + end + return ret" + ); + + allRunner.CompileForRunner(); + // __readonly is special and fine to leak, so ignore it + var allDefined = ((object[])allRunner.RunForRunner()).Select(static x => (string)x).Except(["__readonly"]).ToList(); + + var expected = globalFuncs.Concat(exportedFuncs.Keys); + + var missing = expected.Except(allDefined).ToList(); + + // ARGV, KEYS, and redis are always available + var extra = allDefined.Except(expected).Except(["ARGV", "KEYS", "redis"]).ToList(); + + ClassicAssert.AreEqual(0, missing.Count, $"Missing globals: {string.Join(", ", missing)}"); + ClassicAssert.AreEqual(0, extra.Count, $"Extra globals: {string.Join(", ", extra)}"); + + foreach (var globalFunc in globalFuncs) + { + // These are special, just ignore them + if (globalFunc is "type" or "_G" or "_VERSION") + { + continue; + } + + var everythingExceptGlobalFunc = expected.Except([globalFunc]); + + using var withoutRunner = + new LuaRunner( + new(LuaMemoryManagementMode.Native, "", Timeout.InfiniteTimeSpan, LuaLoggingMode.Silent, everythingExceptGlobalFunc), + $"return type({globalFunc})" + ); + + withoutRunner.CompileForRunner(); + var withoutDefined = (string)withoutRunner.RunForRunner(); + + ClassicAssert.AreEqual("nil", withoutDefined, $"Global {globalFunc} available when it shouldn't have been"); + + using var withRunner = + new LuaRunner( + new(LuaMemoryManagementMode.Native, "", Timeout.InfiniteTimeSpan, LuaLoggingMode.Silent, ["type", globalFunc]), + $"return type({globalFunc})" + ); + + withRunner.CompileForRunner(); + var withDefined = (string)withRunner.RunForRunner(); + + ClassicAssert.AreNotEqual("nil", withDefined, $"Global {globalFunc} not available when it should have been"); + } + } + + // Check for the supported Lua functions which are under names in globals + + foreach (var (funcGroup, funcs) in exportedFuncs) + { + // Get all the keys under the funcGroup + { + using var runner = + new LuaRunner( + new(LuaMemoryManagementMode.Native, "", Timeout.InfiniteTimeSpan, LuaLoggingMode.Silent, [funcGroup, "pairs", "table.insert"]), + @$"local ret = {{ }} + for k, v in pairs({funcGroup}) do + table.insert(ret, k) + end + return ret" + ); + + runner.CompileForRunner(); + // __readonly is special and fine to leak, so ignore it + var defined = ((object[])runner.RunForRunner()).Select(static x => (string)x).Except(["__readonly"]).ToList(); + + var missing = funcs.Except(defined).ToList(); + var extra = defined.Except(funcs).ToList(); + + ClassicAssert.AreEqual(0, missing.Count, $"Missing funcs in {funcGroup}: {string.Join(", ", missing)}"); + ClassicAssert.AreEqual(0, extra.Count, $"Extra funcs in {funcGroup}: {string.Join(", ", extra)}"); + } + + // Check all expected funcs are defined + { + var allTypes = string.Join(", ", funcs.Select(x => $"type({funcGroup}.{x})")); + + using var runner = + new LuaRunner( + new(LuaMemoryManagementMode.Native, "", Timeout.InfiniteTimeSpan, LuaLoggingMode.Silent, [funcGroup, "type"]), + $"return {{ {allTypes} }}" + ); + + runner.CompileForRunner(); + var defined = (object[])runner.RunForRunner(); + + for (var i = 0; i < funcs.Count; i++) + { + var forFunc = funcs[i]; + var funcType = (string)defined[i]; + + ClassicAssert.AreNotEqual("nil", funcType, $"{funcGroup}.{forFunc} is not defined when it should be"); + } + } + + // Check NOT including top level group causes all functions to be unavailable + { + var otherGroups = exportedFuncs.Keys.Except(["_G", funcGroup]); + + using var runner = + new LuaRunner( + new(LuaMemoryManagementMode.Native, "", Timeout.InfiniteTimeSpan, LuaLoggingMode.Silent, [.. otherGroups, "type"]), + $"return type({funcGroup})" + ); + + runner.CompileForRunner(); + var defined = (string)runner.RunForRunner(); + + ClassicAssert.AreEqual("nil", defined, $"{funcGroup} is defined when it should not be"); + } + + // Check allowing just the one func + foreach (var func in funcs) + { + using var runner = + new LuaRunner( + new(LuaMemoryManagementMode.Native, "", Timeout.InfiniteTimeSpan, LuaLoggingMode.Silent, [$"{funcGroup}.{func}", "type"]), + $"return type({funcGroup}.{func})" + ); + + runner.CompileForRunner(); + var defined = (string)runner.RunForRunner(); + + ClassicAssert.AreNotEqual("nil", defined, $"{funcGroup}.{func} is not defined when it should be"); + } + + // Check that disallowing just the one func in the group works + foreach (var func in funcs) + { + var others = funcs.Except([func]).Select(x => $"{funcGroup}.{x}"); + + string typeStatement; + if (others.Any()) + { + typeStatement = $"type({funcGroup}.{func})"; + } + else + { + // If a group of function is completely removed, the table is nulled out + typeStatement = $"type({funcGroup})"; + } + + using var runner = + new LuaRunner( + new(LuaMemoryManagementMode.Native, "", Timeout.InfiniteTimeSpan, LuaLoggingMode.Silent, [.. others, "type"]), + $"return {typeStatement}" + ); + + runner.CompileForRunner(); + var defined = (string)runner.RunForRunner(); + + ClassicAssert.AreEqual("nil", defined, $"{funcGroup}.{func} is defined when it should not be"); + } + } + } + + [Test] + public void InternalFunctionsIgnoredInAllowedFunctions() + { + // Check if an internal implementation detail (garnet_call in this case) can be allowed + using var allRunner = + new LuaRunner( + new(LuaMemoryManagementMode.Native, "", Timeout.InfiniteTimeSpan, LuaLoggingMode.Silent, ["tostring", "garnet_call"]), + @$"return tostring(garnet_call)" + ); + + allRunner.CompileForRunner(); + + var res = (string)allRunner.RunForRunner(); + ClassicAssert.AreEqual("nil", res); + } + private sealed class FakeLogger : ILogger, IDisposable { private readonly List logs = new(); diff --git a/test/Garnet.test/LuaScriptTests.cs b/test/Garnet.test/LuaScriptTests.cs index d78bf02373d..02cf2719703 100644 --- a/test/Garnet.test/LuaScriptTests.cs +++ b/test/Garnet.test/LuaScriptTests.cs @@ -2,10 +2,13 @@ // Licensed under the MIT license. using System; +using System.Buffers.Binary; +using System.Collections.Generic; using System.Diagnostics; using System.Globalization; using System.IO; using System.Linq; +using System.Numerics; using System.Text; using System.Text.RegularExpressions; using System.Threading; @@ -24,7 +27,7 @@ namespace Garnet.test [TestFixture(LuaMemoryManagementMode.Tracked, "", "")] [TestFixture(LuaMemoryManagementMode.Tracked, "13m", "")] [TestFixture(LuaMemoryManagementMode.Managed, "", "")] - [TestFixture(LuaMemoryManagementMode.Managed, "16m", "")] + [TestFixture(LuaMemoryManagementMode.Managed, "17m", "")] public class LuaScriptTests { /// @@ -95,6 +98,13 @@ public override string ToString() private string aclFile; private GarnetServer server; + /// + /// Temporarily disable errors raised from Lua until longjmp work is completed. + /// + /// TODO: Delete all of this + /// + private static bool CanTestLuaErrors { get; } = OperatingSystem.IsWindows() && Environment.Version.Major <= 8; + public LuaScriptTests(LuaMemoryManagementMode allocMode, string limitBytes, string limitTimeout) { this.allocMode = allocMode; @@ -512,9 +522,10 @@ public void RedisPCall() { // This is a temporary fix to address a regression in .NET9, an open issue can be found here - https://github.com/dotnet/runtime/issues/111242 // Once the issue is resolved the #if can be removed permanently. -#if NET9_0_OR_GREATER - Assert.Ignore($"Ignoring test when running in .NET9."); -#endif + if (!CanTestLuaErrors) + { + Assert.Ignore($"Ignoring test when running in .NET9."); + } using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); var db = redis.GetDatabase(0); @@ -533,9 +544,10 @@ public void RedisSha1Hex() { // This is a temporary fix to address a regression in .NET9, an open issue can be found here - https://github.com/dotnet/runtime/issues/111242 // Once the issue is resolved the #if can be removed permanently. -#if NET9_0_OR_GREATER - Assert.Ignore($"Ignoring test when running in .NET9."); -#endif + if (!CanTestLuaErrors) + { + Assert.Ignore($"Ignoring test when running in .NET9."); + } using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); var db = redis.GetDatabase(0); @@ -566,9 +578,10 @@ public void RedisLog() { // This is a temporary fix to address a regression in .NET9, an open issue can be found here - https://github.com/dotnet/runtime/issues/111242 // Once the issue is resolved the #if can be removed permanently. -#if NET9_0_OR_GREATER - Assert.Ignore($"Ignoring test when running in .NET9."); -#endif + if (!CanTestLuaErrors) + { + Assert.Ignore($"Ignoring test when running in .NET9."); + } using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); var db = redis.GetDatabase(0); @@ -649,12 +662,6 @@ public void RedisDebugAndBreakpoint() [Test] public void RedisAclCheckCmd() { - // This is a temporary fix to address a regression in .NET9, an open issue can be found here - https://github.com/dotnet/runtime/issues/111242 - // Once the issue is resolved the #if can be removed permanently. -#if NET9_0_OR_GREATER - Assert.Ignore($"Ignoring test when running in .NET9."); -#endif - // Note this path is more heavily exercised in ACL tests using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); @@ -663,17 +670,20 @@ public void RedisAclCheckCmd() using var denyRedis = ConnectionMultiplexer.Connect(TestUtils.GetConfig(authUsername: "deny")); var denyDB = denyRedis.GetDatabase(0); - var noArgs = ClassicAssert.Throws(() => db.ScriptEvaluate("return redis.acl_check_cmd()")); - ClassicAssert.IsTrue(noArgs.Message.StartsWith("ERR Please specify at least one argument for this redis lib call")); + if (CanTestLuaErrors) + { + var noArgs = ClassicAssert.Throws(() => db.ScriptEvaluate("return redis.acl_check_cmd()")); + ClassicAssert.IsTrue(noArgs.Message.StartsWith("ERR Please specify at least one argument for this redis lib call")); - var invalidCmdArgType = ClassicAssert.Throws(() => db.ScriptEvaluate("return redis.acl_check_cmd({123})")); - ClassicAssert.IsTrue(invalidCmdArgType.Message.StartsWith("ERR Lua redis lib command arguments must be strings or integers")); + var invalidCmdArgType = ClassicAssert.Throws(() => db.ScriptEvaluate("return redis.acl_check_cmd({123})")); + ClassicAssert.IsTrue(invalidCmdArgType.Message.StartsWith("ERR Lua redis lib command arguments must be strings or integers")); - var invalidCmd = ClassicAssert.Throws(() => db.ScriptEvaluate("return redis.acl_check_cmd('nope')")); - ClassicAssert.IsTrue(invalidCmd.Message.StartsWith("ERR Invalid command passed to redis.acl_check_cmd()")); + var invalidCmd = ClassicAssert.Throws(() => db.ScriptEvaluate("return redis.acl_check_cmd('nope')")); + ClassicAssert.IsTrue(invalidCmd.Message.StartsWith("ERR Invalid command passed to redis.acl_check_cmd()")); - var invalidArgType = ClassicAssert.Throws(() => db.ScriptEvaluate("return redis.acl_check_cmd('GET', {123})")); - ClassicAssert.IsTrue(invalidArgType.Message.StartsWith("ERR Lua redis lib command arguments must be strings or integers")); + var invalidArgType = ClassicAssert.Throws(() => db.ScriptEvaluate("return redis.acl_check_cmd('GET', {123})")); + ClassicAssert.IsTrue(invalidArgType.Message.StartsWith("ERR Lua redis lib command arguments must be strings or integers")); + } var canRun = (bool)db.ScriptEvaluate("return redis.acl_check_cmd('GET')"); ClassicAssert.IsTrue(canRun); @@ -709,23 +719,20 @@ public void RedisAclCheckCmd() [Test] public void RedisSetResp() { - // This is a temporary fix to address a regression in .NET9, an open issue can be found here - https://github.com/dotnet/runtime/issues/111242 - // Once the issue is resolved the #if can be removed permanently. -#if NET9_0_OR_GREATER - Assert.Ignore($"Ignoring test when running in .NET9."); -#endif - using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); var db = redis.GetDatabase(0); - var noArgs = ClassicAssert.Throws(() => db.ScriptEvaluate("redis.setresp()")); - ClassicAssert.IsTrue(noArgs.Message.StartsWith("ERR redis.setresp() requires one argument.")); + if (CanTestLuaErrors) + { + var noArgs = ClassicAssert.Throws(() => db.ScriptEvaluate("redis.setresp()")); + ClassicAssert.IsTrue(noArgs.Message.StartsWith("ERR redis.setresp() requires one argument.")); - var tooManyArgs = ClassicAssert.Throws(() => db.ScriptEvaluate("redis.setresp(1, 2)")); - ClassicAssert.IsTrue(tooManyArgs.Message.StartsWith("ERR redis.setresp() requires one argument.")); + var tooManyArgs = ClassicAssert.Throws(() => db.ScriptEvaluate("redis.setresp(1, 2)")); + ClassicAssert.IsTrue(tooManyArgs.Message.StartsWith("ERR redis.setresp() requires one argument.")); - var badArg = ClassicAssert.Throws(() => db.ScriptEvaluate("redis.setresp({123})")); - ClassicAssert.IsTrue(badArg.Message.StartsWith("ERR RESP version must be 2 or 3.")); + var badArg = ClassicAssert.Throws(() => db.ScriptEvaluate("redis.setresp({123})")); + ClassicAssert.IsTrue(badArg.Message.StartsWith("ERR RESP version must be 2 or 3.")); + } var resp2 = db.ScriptEvaluate("redis.setresp(2)"); ClassicAssert.IsTrue(resp2.IsNull); @@ -733,8 +740,11 @@ public void RedisSetResp() var resp3 = db.ScriptEvaluate("redis.setresp(3)"); ClassicAssert.IsTrue(resp3.IsNull); - var badRespVersion = ClassicAssert.Throws(() => db.ScriptEvaluate("redis.setresp(1)")); - ClassicAssert.IsTrue(badRespVersion.Message.StartsWith("ERR RESP version must be 2 or 3.")); + if (CanTestLuaErrors) + { + var badRespVersion = ClassicAssert.Throws(() => db.ScriptEvaluate("redis.setresp(1)")); + ClassicAssert.IsTrue(badRespVersion.Message.StartsWith("ERR RESP version must be 2 or 3.")); + } } [Test] @@ -965,9 +975,10 @@ public void RedisCallErrors() { // This is a temporary fix to address a regression in .NET9, an open issue can be found here - https://github.com/dotnet/runtime/issues/111242 // Once the issue is resolved the #if can be removed permanently. -#if NET9_0_OR_GREATER - Assert.Ignore($"Ignoring test when running in .NET9."); -#endif + if (!CanTestLuaErrors) + { + Assert.Ignore($"Ignoring test when running in .NET9."); + } // Testing that our error replies for redis.call match Redis behavior // @@ -1215,6 +1226,13 @@ public void IntentionalOOM() return; } + // This is a temporary fix to address a regression in .NET9, an open issue can be found here - https://github.com/dotnet/runtime/issues/111242 + // Once the issue is resolved the #if can be removed permanently. + if (!CanTestLuaErrors) + { + Assert.Ignore($"Ignoring test when running in .NET9."); + } + const string ScriptOOMText = @" local foo = 'abcdefghijklmnopqrstuvwxyz' if @Ctrl == 'OOM' then @@ -1242,11 +1260,6 @@ public void IntentionalOOM() [Test] public void Issue939() { - // This is a temporary fix to address a regression in .NET9, an open issue can be found here - https://github.com/dotnet/runtime/issues/111242 - // Once the issue is resolved the #if can be removed permanently. -#if NET9_0_OR_GREATER - Assert.Ignore($"Ignoring test when running in .NET9."); -#endif // See: https://github.com/microsoft/garnet/issues/939 const string Script = @" @@ -1326,9 +1339,35 @@ public void Issue939() } } - // Finally, check that nil is an illegal argument - var exc = ClassicAssert.Throws(() => db.ScriptEvaluate("return redis.call('GET', nil)")); - ClassicAssert.True(exc.Message.StartsWith("ERR Lua redis lib command arguments must be strings or integers")); + if (CanTestLuaErrors) + { + // Finally, check that nil is an illegal argument + var exc = ClassicAssert.Throws(() => db.ScriptEvaluate("return redis.call('GET', nil)")); + ClassicAssert.True(exc.Message.StartsWith("ERR Lua redis lib command arguments must be strings or integers")); + } + } + + [Test] + public void Issue1079() + { + // Repeated submission of invalid Lua scripts shouldn't be cached, and thus should produce the same compilation error each time + // + // They also shouldn't break the session for future executions + + const string BrokenScript = "return \"hello lua"; + const string FixedScript = "return \"hello lua\""; + + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(); + + var brokenExc1 = ClassicAssert.Throws(() => db.Execute("EVAL", BrokenScript, 0)); + ClassicAssert.True(brokenExc1.Message.StartsWith("Compilation error: ")); + + var brokenExc2 = ClassicAssert.Throws(() => db.Execute("EVAL", BrokenScript, 0)); + ClassicAssert.AreEqual(brokenExc1.Message, brokenExc2.Message); + + var success = (string)db.Execute("EVAL", FixedScript, 0); + ClassicAssert.AreEqual("hello lua", success); } [TestCase(2)] @@ -1639,9 +1678,10 @@ public void NoScriptCommandsForbidden() { // This is a temporary fix to address a regression in .NET9, an open issue can be found here - https://github.com/dotnet/runtime/issues/111242 // Once the issue is resolved the #if can be removed permanently. -#if NET9_0_OR_GREATER - Assert.Ignore($"Ignoring test when running in .NET9."); -#endif + if (!CanTestLuaErrors) + { + Assert.Ignore($"Ignoring test when running in .NET9."); + } ClassicAssert.True(RespCommandsInfo.TryGetRespCommandsInfo(out var allCommands, externalOnly: true)); @@ -1660,9 +1700,10 @@ public void IntentionalTimeout() { // This is a temporary fix to address a regression in .NET9, an open issue can be found here - https://github.com/dotnet/runtime/issues/111242 // Once the issue is resolved the #if can be removed permanently. -#if NET9_0_OR_GREATER - Assert.Ignore($"Ignoring test when running in .NET9."); -#endif + if (!CanTestLuaErrors) + { + Assert.Ignore($"Ignoring test when running in .NET9."); + } const string TimeoutScript = @" local count = 0 @@ -1791,11 +1832,933 @@ public void Resp3ToLuaConversions(RedisProtocol connectionProtocol) } } + [Test] + public void Bit() + { + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(); + + // tobit + { + if (CanTestLuaErrors) + { + var noArgExc = ClassicAssert.Throws(() => db.ScriptEvaluate("return bit.tobit()")); + ClassicAssert.True(noArgExc.Message.Contains("bad argument") && noArgExc.Message.Contains("tobit")); + + // Extra arguments are legal, but ignored + + var badTypeExc = ClassicAssert.Throws(() => db.ScriptEvaluate("return bit.tobit({})")); + ClassicAssert.True(badTypeExc.Message.Contains("bad argument") && badTypeExc.Message.Contains("tobit")); + } + + // Rules are suprisingly subtle, so test a bunch of tricky values + (string Value, string Expected)[] expectedValues = [ + ("0", "0"), + ("1", "1"), + ("1.1", "1"), + ("1.5", "2"), + ("1.9", "2"), + ("0.1", "0"), + ("0.5", "0"), + ("0.9", "1"), + ("-1.1", "-1"), + ("-1.5", "-2"), + ("-1.9", "-2"), + ("-0.1", "0"), + ("-0.5", "0"), + ("-0.9", "-1"), + (int.MinValue.ToString(), int.MinValue.ToString()), + (int.MaxValue.ToString(), int.MaxValue.ToString()), + ((1L + int.MaxValue).ToString(), int.MinValue.ToString()), + ((-1L + int.MinValue).ToString(), int.MaxValue.ToString()), + (double.MaxValue.ToString(), "-1"), + (double.MinValue.ToString(), "-1"), + (float.MaxValue.ToString(), "-447893512"), + (float.MinValue.ToString(), "-447893512"), + ]; + foreach (var (value, expected) in expectedValues) + { + var actual = (string)db.ScriptEvaluate($"return bit.tobit({value})"); + ClassicAssert.AreEqual(expected, actual, $"bit.tobit conversion for {value} was incorrect"); + } + } + + // tohex + { + if (CanTestLuaErrors) + { + var noArgExc = ClassicAssert.Throws(() => db.ScriptEvaluate("return bit.tohex()")); + ClassicAssert.True(noArgExc.Message.Contains("bad argument") && noArgExc.Message.Contains("tohex")); + + // Extra arguments are legal, but ignored + + var badType1Exc = ClassicAssert.Throws(() => db.ScriptEvaluate("return bit.tohex({})")); + ClassicAssert.True(badType1Exc.Message.Contains("bad argument") && badType1Exc.Message.Contains("tohex")); + + var badType2Exc = ClassicAssert.Throws(() => db.ScriptEvaluate("return bit.tohex(1, {})")); + ClassicAssert.True(badType2Exc.Message.Contains("bad argument") && badType2Exc.Message.Contains("tohex")); + } + + // Make sure casing is handled correctly + for (var h = 0; h < 16; h++) + { + var lower = (string)db.ScriptEvaluate($"return bit.tohex({h}, 1)"); + var upper = (string)db.ScriptEvaluate($"return bit.tohex({h}, -1)"); + + ClassicAssert.AreEqual(h.ToString("x1"), lower); + ClassicAssert.AreEqual(h.ToString("X1"), upper); + } + + // Run through some weird values + (string Value, int? N, string Expected)[] expectedValues = [ + ("0", null, "00000000"), + ("0", 16, "00000000"), + ("0", -8, "00000000"), + ("123456", null, "0001e240"), + ("123456", 5, "1e240"), + ("123456", -5, "1E240"), + (int.MinValue.ToString(), null, "80000000"), + (int.MaxValue.ToString(), null, "7fffffff"), + ((1L + int.MaxValue).ToString(), null, "80000000"), + ((-1L + int.MinValue).ToString(), null, "7fffffff"), + (double.MaxValue.ToString(), 1, "f"), + (double.MinValue.ToString(), -1, "F"), + (float.MaxValue.ToString(), null, "e54daff8"), + (float.MinValue.ToString(), null, "e54daff8"), + ]; + foreach (var (value, length, expected) in expectedValues) + { + var actual = length != null ? + (string)db.ScriptEvaluate($"return bit.tohex({value},{length})") : + (string)db.ScriptEvaluate($"return bit.tohex({value})"); + + ClassicAssert.AreEqual(expected, actual, $"bit.tohex result for ({value},{length}) was incorrect"); + } + } + + // bswap + { + if (CanTestLuaErrors) + { + var noArgExc = ClassicAssert.Throws(() => db.ScriptEvaluate("return bit.bswap()")); + ClassicAssert.True(noArgExc.Message.Contains("bad argument") && noArgExc.Message.Contains("bswap")); + + // Extra arguments are legal, but ignored + + var badTypeExc = ClassicAssert.Throws(() => db.ScriptEvaluate("return bit.bswap({})")); + ClassicAssert.True(badTypeExc.Message.Contains("bad argument") && badTypeExc.Message.Contains("bswap")); + } + + // Just brute force a bunch of trial values + foreach (var a in new[] { 0, 1, 2, 4 }) + { + foreach (var b in new[] { 8, 16, 32, 128 }) + { + foreach (var c in new[] { 0, 2, 8, 32, }) + { + foreach (var d in new[] { 1, 4, 16, 64 }) + { + var input = a | (b << 8) | (c << 16) | (d << 32); + var expected = BinaryPrimitives.ReverseEndianness(input); + + var actual = (int)db.ScriptEvaluate($"return bit.bswap({input})"); + ClassicAssert.AreEqual(expected, actual); + } + } + } + } + } + + // bnot + { + if (CanTestLuaErrors) + { + var noArgExc = ClassicAssert.Throws(() => db.ScriptEvaluate("return bit.bnot()")); + ClassicAssert.True(noArgExc.Message.Contains("bad argument") && noArgExc.Message.Contains("bnot")); + + // Extra arguments are legal, but ignored + + var badTypeExc = ClassicAssert.Throws(() => db.ScriptEvaluate("return bit.bnot({})")); + ClassicAssert.True(badTypeExc.Message.Contains("bad argument") && badTypeExc.Message.Contains("bnot")); + } + + foreach (var input in new int[] { 0, 1, 2, 4, 8, 32, 64, 128, 256, 0x70F0_F0F0, 0x6BCD_EF01, int.MinValue, int.MaxValue, -1 }) + { + var expected = ~input; + + var actual = (int)db.ScriptEvaluate($"return bit.bnot({input})"); + ClassicAssert.AreEqual(expected, actual); + } + } + + // band, bor, bxor + { + (int Base, string Name, Func Op)[] ops = [ + (0, "bor", static (a, b) => a | b), + (-1, "band", static (a, b) => a & b), + (0, "bxor", static (a, b) => a ^ b), + ]; + + foreach (var op in ops) + { + if (CanTestLuaErrors) + { + var noArgExc = ClassicAssert.Throws(() => db.ScriptEvaluate($"return bit.{op.Name}()")); + ClassicAssert.True(noArgExc.Message.Contains("bad argument") && noArgExc.Message.Contains(op.Name)); + + var badType1Exc = ClassicAssert.Throws(() => db.ScriptEvaluate($"return bit.{op.Name}({{}})")); + ClassicAssert.True(badType1Exc.Message.Contains("bad argument") && badType1Exc.Message.Contains(op.Name)); + + var badType2Exc = ClassicAssert.Throws(() => db.ScriptEvaluate($"return bit.{op.Name}(1, {{}})")); + ClassicAssert.True(badType2Exc.Message.Contains("bad argument") && badType2Exc.Message.Contains(op.Name)); + + var badType3Exc = ClassicAssert.Throws(() => db.ScriptEvaluate($"return bit.{op.Name}(1, 2, {{}})")); + ClassicAssert.True(badType3Exc.Message.Contains("bad argument") && badType3Exc.Message.Contains(op.Name)); + } + + // Gin up some unusual values and test them in different combinations + var nextArg = 0x0102_0304; + for (var numArgs = 1; numArgs <= 4; numArgs++) + { + var args = new List(); + while (args.Count < numArgs) + { + args.Add(nextArg); + nextArg *= 2; + nextArg += args.Count; + } + + var expected = op.Base; + foreach (var arg in args) + { + expected = op.Op(expected, arg); + } + + var actual = (int)db.ScriptEvaluate($"return bit.{op.Name}({string.Join(", ", args)})"); + ClassicAssert.AreEqual(expected, actual); + } + } + } + + // lshift, rshift, arshift, rol, ror + { + (string Name, Func Op)[] ops = [ + ("lshift", static (x, n) => x << n), + ("rshift", static (x, n) => (int)((uint)x >> n)), + ("arshift", static (x, n) => x >> n), + ("rol", static (x, n) => (int)BitOperations.RotateLeft((uint)x, n)), + ("ror", static (x, n) => (int)BitOperations.RotateRight((uint)x, n)), + ]; + + foreach (var op in ops) + { + if (CanTestLuaErrors) + { + var noArgExc = ClassicAssert.Throws(() => db.ScriptEvaluate($"return bit.{op.Name}()")); + ClassicAssert.True(noArgExc.Message.Contains("bad argument") && noArgExc.Message.Contains(op.Name)); + + var badType1Exc = ClassicAssert.Throws(() => db.ScriptEvaluate($"return bit.{op.Name}({{}})")); + ClassicAssert.True(badType1Exc.Message.Contains("bad argument") && badType1Exc.Message.Contains(op.Name)); + + var badType2Exc = ClassicAssert.Throws(() => db.ScriptEvaluate($"return bit.{op.Name}(1, {{}})")); + ClassicAssert.True(badType2Exc.Message.Contains("bad argument") && badType2Exc.Message.Contains(op.Name)); + } + + // Extra args are allowed, but ignored + + for (var shift = 0; shift < 16; shift++) + { + const int Value = 0x1234_5678; + + var expected = op.Op(Value, shift); + var actual = (int)db.ScriptEvaluate($"return bit.{op.Name}({Value}, {shift})"); + + ClassicAssert.AreEqual(expected, actual, $"Incorrect value for bit.{op.Name}({Value}, {shift})"); + } + } + } + } + + [Test] + public void CJson() + { + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(); + + // Encoding + { + // TODO: Once refactored to avoid longjmp issues, restore on Linux + if (CanTestLuaErrors) + { + var noArgExc = ClassicAssert.Throws(() => db.ScriptEvaluate("return cjson.encode()")); + ClassicAssert.True(noArgExc.Message.Contains("bad argument") && noArgExc.Message.Contains("encode")); + + var twoArgExc = ClassicAssert.Throws(() => db.ScriptEvaluate("return cjson.encode(1, 2)")); + ClassicAssert.True(twoArgExc.Message.Contains("bad argument") && twoArgExc.Message.Contains("encode")); + + var badTypeExc = ClassicAssert.Throws(() => db.ScriptEvaluate("return cjson.encode((function() end))")); + ClassicAssert.True(badTypeExc.Message.Contains("Cannot serialise")); + } + + var nilResp = (string)db.ScriptEvaluate("return cjson.encode(nil)"); + ClassicAssert.AreEqual("null", nilResp); + + var boolResp = (string)db.ScriptEvaluate("return cjson.encode(true)"); + ClassicAssert.AreEqual("true", boolResp); + + var doubleResp = (string)db.ScriptEvaluate("return cjson.encode(1.23)"); + ClassicAssert.AreEqual("1.23", doubleResp); + + var simpleStrResp = (string)db.ScriptEvaluate("return cjson.encode('hello')"); + ClassicAssert.AreEqual("\"hello\"", simpleStrResp); + + var encodedStrResp = (string)db.ScriptEvaluate("return cjson.encode('\"foo\" \\\\bar\\\\')"); + ClassicAssert.AreEqual("\"\\\"foo\\\" \\\\bar\\\\\"", encodedStrResp); + + var emptyTableResp = (string)db.ScriptEvaluate("return cjson.encode({})"); + ClassicAssert.AreEqual("{}", emptyTableResp); + + var keyedTableResp = (string)db.ScriptEvaluate("return cjson.encode({key=123})"); + ClassicAssert.AreEqual("{\"key\":123}", keyedTableResp); + + var indexedTableResp = (string)db.ScriptEvaluate("return cjson.encode({123, 'foo'})"); + ClassicAssert.AreEqual("[123,\"foo\"]", indexedTableResp); + + var mixedTableResp = (string)db.ScriptEvaluate("local ret = {123}; ret.bar = 'foo'; return cjson.encode(ret)"); + ClassicAssert.AreEqual("{\"1\":123,\"bar\":\"foo\"}", mixedTableResp); + + // Ordering here is undefined, just doing the brute force approach for ease of implementation + var nestedTableResp = (string)db.ScriptEvaluate("return cjson.encode({num=1,str='hello',arr={1,2,3,4},obj={foo='bar'}})"); + string[] nestedTableRespParts = [ + "\"arr\":[1,2,3,4]", + "\"num\":1", + "\"str\":\"hello\"", + "\"obj\":{\"foo\":\"bar\"}", + ]; + var possibleNestedTableResps = new List(); + for (var a = 0; a < nestedTableRespParts.Length; a++) + { + for (var b = 0; b < nestedTableRespParts.Length; b++) + { + for (var c = 0; c < nestedTableRespParts.Length; c++) + { + for (var d = 0; d < nestedTableRespParts.Length; d++) + { + if (a == b || a == c || a == d || b == c || b == d || c == d) + { + continue; + } + + possibleNestedTableResps.Add($"{{{nestedTableRespParts[a]},{nestedTableRespParts[b]},{nestedTableRespParts[c]},{nestedTableRespParts[d]}}}"); + } + } + } + } + ClassicAssert.True(possibleNestedTableResps.Contains(nestedTableResp)); + + var nestArrayResp = (string)db.ScriptEvaluate("return cjson.encode({1,'hello',{1,2,3,4},{foo='bar'}})"); + ClassicAssert.AreEqual("[1,\"hello\",[1,2,3,4],{\"foo\":\"bar\"}]", nestArrayResp); + + var deeplyNestedButLegal = + (string)db.ScriptEvaluate( +@"local nested = 1 +for x = 1, 1000 do + local newNested = {} + newNested[1] = nested; + nested = newNested +end + +return cjson.encode(nested)"); + ClassicAssert.AreEqual(new string('[', 1000) + 1 + new string(']', 1000), deeplyNestedButLegal); + + // TODO: Once refactored to avoid longjmp issues, restore on Linux + if (CanTestLuaErrors) + { + + var deeplyNestedExc = + ClassicAssert.Throws( + () => db.ScriptEvaluate( +@"local nested = 1 +for x = 1, 1001 do + local newNested = {} + newNested[1] = nested; + nested = newNested +end + +return cjson.encode(nested)")); + ClassicAssert.True(deeplyNestedExc.Message.Contains("Cannot serialise, excessive nesting (1001)")); + } + } + + // Decoding + { + // TODO: Once refactored to avoid longjmp issues, restore on Linux + if (CanTestLuaErrors) + { + var noArgExc = ClassicAssert.Throws(() => db.ScriptEvaluate("return cjson.decode()")); + ClassicAssert.True(noArgExc.Message.Contains("bad argument") && noArgExc.Message.Contains("decode")); + + var twoArgExc = ClassicAssert.Throws(() => db.ScriptEvaluate("return cjson.decode(1, 2)")); + ClassicAssert.True(twoArgExc.Message.Contains("bad argument") && twoArgExc.Message.Contains("decode")); + + var badTypeExc = ClassicAssert.Throws(() => db.ScriptEvaluate("return cjson.decode({})")); + ClassicAssert.True(badTypeExc.Message.Contains("bad argument") && badTypeExc.Message.Contains("decode")); + + var badFormatExc = ClassicAssert.Throws(() => db.ScriptEvaluate("return cjson.decode('hello world')")); + ClassicAssert.True(badFormatExc.Message.Contains("Expected value but found invalid token")); + } + + var numberDecode = (string)db.ScriptEvaluate("return cjson.decode(123)"); + ClassicAssert.AreEqual("123", numberDecode); + + var boolDecode = (string)db.ScriptEvaluate("return type(cjson.decode('true'))"); + ClassicAssert.AreEqual("boolean", boolDecode); + + var stringDecode = (string)db.ScriptEvaluate("return cjson.decode('\"hello world\"')"); + ClassicAssert.AreEqual("hello world", stringDecode); + + var mapDecode = (string)db.ScriptEvaluate("return cjson.decode('{\"hello\":\"world\"}').hello"); + ClassicAssert.AreEqual("world", mapDecode); + + var arrayDecode = (string)db.ScriptEvaluate("return cjson.decode('[123]')[1]"); + ClassicAssert.AreEqual("123", arrayDecode); + + var complexMapDecodeArr = (string[])db.ScriptEvaluate("return cjson.decode('{\"arr\":[1,2,3,4],\"num\":1,\"str\":\"hello\",\"obj\":{\"foo\":\"bar\"}}').arr"); + ClassicAssert.True(complexMapDecodeArr.SequenceEqual(["1", "2", "3", "4"])); + var complexMapDecodeNum = (string)db.ScriptEvaluate("return cjson.decode('{\"arr\":[1,2,3,4],\"num\":1,\"str\":\"hello\",\"obj\":{\"foo\":\"bar\"}}').num"); + ClassicAssert.AreEqual("1", complexMapDecodeNum); + var complexMapDecodeStr = (string)db.ScriptEvaluate("return cjson.decode('{\"arr\":[1,2,3,4],\"num\":1,\"str\":\"hello\",\"obj\":{\"foo\":\"bar\"}}').str"); + ClassicAssert.AreEqual("hello", complexMapDecodeStr); + var complexMapDecodeObj = (string)db.ScriptEvaluate("return cjson.decode('{\"arr\":[1,2,3,4],\"num\":1,\"str\":\"hello\",\"obj\":{\"foo\":\"bar\"}}').obj.foo"); + ClassicAssert.AreEqual("bar", complexMapDecodeObj); + + var complexArrDecodeNum = (string)db.ScriptEvaluate("return cjson.decode('[1,\"hello\",[1,2,3,4],{\"foo\":\"bar\"}]')[1]"); + ClassicAssert.AreEqual("1", complexArrDecodeNum); + var complexArrDecodeStr = (string)db.ScriptEvaluate("return cjson.decode('[1,\"hello\",[1,2,3,4],{\"foo\":\"bar\"}]')[2]"); + ClassicAssert.AreEqual("hello", complexArrDecodeStr); + var complexArrDecodeArr = (string[])db.ScriptEvaluate("return cjson.decode('[1,\"hello\",[1,2,3,4],{\"foo\":\"bar\"}]')[3]"); + ClassicAssert.True(complexArrDecodeArr.SequenceEqual(["1", "2", "3", "4"])); + var complexArrDecodeObj = (string)db.ScriptEvaluate("return cjson.decode('[1,\"hello\",[1,2,3,4],{\"foo\":\"bar\"}]')[4].foo"); + ClassicAssert.AreEqual("bar", complexArrDecodeObj); + + // Redis cuts us off at 1000 levels of recursion, so check that we're matching that + var deeplyNestedButLegal = (RedisResult[])db.ScriptEvaluate($"return cjson.decode('{new string('[', 1000)}{new string(']', 1000)}')"); + var deeplyNestedButLegalCur = deeplyNestedButLegal; + for (var i = 1; i < 1000; i++) + { + ClassicAssert.AreEqual(1, deeplyNestedButLegalCur.Length); + deeplyNestedButLegalCur = (RedisResult[])deeplyNestedButLegalCur[0]; + } + ClassicAssert.AreEqual(0, deeplyNestedButLegalCur.Length); + + if (CanTestLuaErrors) + { + var deeplyNestedExc = ClassicAssert.Throws(() => db.ScriptEvaluate($"return cjson.decode('{new string('[', 1001)}{new string(']', 1001)}')")); + ClassicAssert.True(deeplyNestedExc.Message.Contains("Found too many nested data structures")); + } + } + } + + [Test] + public void CMsgPackPack() + { + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(); + + // TODO: Once refactored to avoid longjmp issues, restore on Linux + if (CanTestLuaErrors) + { + var noArgExc = ClassicAssert.Throws(() => db.ScriptEvaluate("return cmsgpack.pack()")); + ClassicAssert.True(noArgExc.Message.Contains("bad argument") && noArgExc.Message.Contains("pack")); + } + + // Multiple args are legal, and concat + + var nullResp = (byte[])db.ScriptEvaluate("return cmsgpack.pack(nil)"); + ClassicAssert.True(nullResp.SequenceEqual(new byte[] { 0xC0 })); + + var trueResp = (byte[])db.ScriptEvaluate("return cmsgpack.pack(true)"); + ClassicAssert.True(trueResp.SequenceEqual(new byte[] { 0xC3 })); + + var falseResp = (byte[])db.ScriptEvaluate("return cmsgpack.pack(false)"); + ClassicAssert.True(falseResp.SequenceEqual(new byte[] { 0xC2 })); + + var tinyUInt1Resp = (byte[])db.ScriptEvaluate("return cmsgpack.pack(0)"); + ClassicAssert.True(tinyUInt1Resp.SequenceEqual(new byte[] { 0x00 })); + + var tinyUInt2Resp = (byte[])db.ScriptEvaluate("return cmsgpack.pack(127)"); + ClassicAssert.True(tinyUInt2Resp.SequenceEqual(new byte[] { 0x7F })); + + var tinyInt1Resp = (byte[])db.ScriptEvaluate("return cmsgpack.pack(-1)"); + ClassicAssert.True(tinyInt1Resp.SequenceEqual(new byte[] { 0xFF })); + + var tinyInt2Resp = (byte[])db.ScriptEvaluate("return cmsgpack.pack(-32)"); + ClassicAssert.True(tinyInt2Resp.SequenceEqual(new byte[] { 0xE0 })); + + var smallUInt1Resp = (byte[])db.ScriptEvaluate("return cmsgpack.pack(128)"); + ClassicAssert.True(smallUInt1Resp.SequenceEqual(new byte[] { 0xCC, 0x80 })); + + var smallUInt2Resp = (byte[])db.ScriptEvaluate("return cmsgpack.pack(255)"); + ClassicAssert.True(smallUInt2Resp.SequenceEqual(new byte[] { 0xCC, 0xFF })); + + var smallInt1Resp = (byte[])db.ScriptEvaluate("return cmsgpack.pack(-33)"); + ClassicAssert.True(smallInt1Resp.SequenceEqual(new byte[] { 0xD0, 0xDF })); + + var smallInt2Resp = (byte[])db.ScriptEvaluate("return cmsgpack.pack(-128)"); + ClassicAssert.True(smallInt2Resp.SequenceEqual(new byte[] { 0xD0, 0x80 })); + + var midUInt1Resp = (byte[])db.ScriptEvaluate("return cmsgpack.pack(32768)"); + ClassicAssert.True(midUInt1Resp.SequenceEqual(new byte[] { 0xCD, 0x80, 0x00 })); + + var midUInt2Resp = (byte[])db.ScriptEvaluate("return cmsgpack.pack(65535)"); + ClassicAssert.True(midUInt2Resp.SequenceEqual(new byte[] { 0xCD, 0xFF, 0xFF })); + + var midInt1Resp = (byte[])db.ScriptEvaluate("return cmsgpack.pack(-129)"); + ClassicAssert.True(midInt1Resp.SequenceEqual(new byte[] { 0xD1, 0xFF, 0x7F })); + + var midInt2Resp = (byte[])db.ScriptEvaluate("return cmsgpack.pack(-32768)"); + ClassicAssert.True(midInt2Resp.SequenceEqual(new byte[] { 0xD1, 0x80, 0x00 })); + + var uint1Resp = (byte[])db.ScriptEvaluate("return cmsgpack.pack(2147483648)"); + ClassicAssert.True(uint1Resp.SequenceEqual(new byte[] { 0xCE, 0x80, 0x00, 0x00, 0x00 })); + + var uint2Resp = (byte[])db.ScriptEvaluate("return cmsgpack.pack(4294967295)"); + ClassicAssert.True(uint2Resp.SequenceEqual(new byte[] { 0xCE, 0xFF, 0xFF, 0xFF, 0xFF })); + + var int1Resp = (byte[])db.ScriptEvaluate("return cmsgpack.pack(-32769)"); + ClassicAssert.True(int1Resp.SequenceEqual(new byte[] { 0xD2, 0xFF, 0xFF, 0x7F, 0xFF })); + + var int2Resp = (byte[])db.ScriptEvaluate("return cmsgpack.pack(-2147483648)"); + ClassicAssert.True(int2Resp.SequenceEqual(new byte[] { 0xD2, 0x80, 0x00, 0x00, 0x00 })); + + var bigUIntResp = (byte[])db.ScriptEvaluate("return cmsgpack.pack(4294967296)"); + ClassicAssert.True(bigUIntResp.SequenceEqual(new byte[] { 0xCF, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00 })); + + var bigIntResp = (byte[])db.ScriptEvaluate("return cmsgpack.pack(-2147483649)"); + ClassicAssert.True(bigIntResp.SequenceEqual(new byte[] { 0xD3, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F, 0xFF, 0xFF, 0xFF })); + + var floatResp = (byte[])db.ScriptEvaluate("return cmsgpack.pack(0.1)"); + ClassicAssert.True(floatResp.SequenceEqual(new byte[] { 0xCB, 0x3F, 0xB9, 0x99, 0x99, 0x99, 0x99, 0x99, 0x9A })); + + var tinyString1Resp = (byte[])db.ScriptEvaluate($"return cmsgpack.pack('')"); + ClassicAssert.True(tinyString1Resp.SequenceEqual(new byte[] { 0xA0, })); + + var tinyString2Resp = (byte[])db.ScriptEvaluate($"return cmsgpack.pack('{new string('0', 31)}')"); + ClassicAssert.True(tinyString2Resp.SequenceEqual(new byte[] { 0xBF }.Concat(Enumerable.Repeat((byte)'0', 31)))); + + var shortString1Resp = (byte[])db.ScriptEvaluate($"return cmsgpack.pack('{new string('a', 32)}')"); + ClassicAssert.True(shortString1Resp.SequenceEqual(new byte[] { 0xD9, 0x20 }.Concat(Enumerable.Repeat((byte)'a', 32)))); + + var shortString2Resp = (byte[])db.ScriptEvaluate($"return cmsgpack.pack('{new string('a', 255)}')"); + ClassicAssert.True(shortString2Resp.SequenceEqual(new byte[] { 0xD9, 0xFF }.Concat(Enumerable.Repeat((byte)'a', 255)))); + + var midString1Resp = (byte[])db.ScriptEvaluate($"return cmsgpack.pack('{new string('b', 256)}')"); + ClassicAssert.True(midString1Resp.SequenceEqual(new byte[] { 0xDA, 0x01, 0x00 }.Concat(Enumerable.Repeat((byte)'b', 256)))); + + var midString2Resp = (byte[])db.ScriptEvaluate($"return cmsgpack.pack('{new string('b', 65535)}')"); + ClassicAssert.True(midString2Resp.SequenceEqual(new byte[] { 0xDA, 0xFF, 0xFF }.Concat(Enumerable.Repeat((byte)'b', 65535)))); + + var longStringResp = (byte[])db.ScriptEvaluate($"return cmsgpack.pack('{new string('c', 65536)}')"); + ClassicAssert.True(longStringResp.SequenceEqual(new byte[] { 0xDB, 0x00, 0x01, 0x00, 0x00 }.Concat(Enumerable.Repeat((byte)'c', 65536)))); + + var emptyTableResp = (byte[])db.ScriptEvaluate("return cmsgpack.pack({})"); + ClassicAssert.True(emptyTableResp.SequenceEqual(new byte[] { 0x90 })); + + var smallArray1Resp = (byte[])db.ScriptEvaluate("return cmsgpack.pack({1})"); + ClassicAssert.True(smallArray1Resp.SequenceEqual(new byte[] { 0x91, 0x01 })); + + var smallArray2Resp = (byte[])db.ScriptEvaluate("return cmsgpack.pack({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15})"); + ClassicAssert.True(smallArray2Resp.SequenceEqual(new byte[] { 0x9F, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F })); + + var midArray1Resp = (byte[])db.ScriptEvaluate($"return cmsgpack.pack({{{string.Join(", ", Enumerable.Repeat(1, 16))}}})"); + ClassicAssert.True(midArray1Resp.SequenceEqual(new byte[] { 0xDC, 0x00, 0x10 }.Concat(Enumerable.Repeat((byte)0x01, 16)))); + + var midArray2Resp = (byte[])db.ScriptEvaluate($"return cmsgpack.pack({{{string.Join(", ", Enumerable.Repeat(2, ushort.MaxValue))}}})"); + ClassicAssert.True(midArray2Resp.SequenceEqual(new byte[] { 0xDC, 0xFF, 0xFF }.Concat(Enumerable.Repeat((byte)0x02, ushort.MaxValue)))); + + var bigArrayResp = (byte[])db.ScriptEvaluate($"return cmsgpack.pack({{{string.Join(", ", Enumerable.Repeat(3, ushort.MaxValue + 1))}}})"); + ClassicAssert.True(bigArrayResp.SequenceEqual(new byte[] { 0xDD, 0x00, 0x01, 0x00, 0x00 }.Concat(Enumerable.Repeat((byte)0x03, ushort.MaxValue + 1)))); + + var smallMap1Resp = (byte[])db.ScriptEvaluate("return cmsgpack.pack({a=1})"); + ClassicAssert.True(smallMap1Resp.SequenceEqual(new byte[] { 0x81, 0xA1, 0x61, 0x01 })); + + var smallMap2Resp = (byte[])db.ScriptEvaluate("return cmsgpack.pack({a=1,b=2,c=3,d=4,e=5,f=6,g=7,h=8,i=9,j=10,k=11,l=12,m=13,n=14,o=15})"); + ClassicAssert.AreEqual(46, smallMap2Resp.Length); + ClassicAssert.AreEqual(0x8F, smallMap2Resp[0]); + ClassicAssert.AreNotEqual(-1, smallMap2Resp.AsSpan().IndexOf(new byte[] { 0xA1, 0x61, 0x01 })); + ClassicAssert.AreNotEqual(-1, smallMap2Resp.AsSpan().IndexOf(new byte[] { 0xA1, 0x62, 0x02 })); + ClassicAssert.AreNotEqual(-1, smallMap2Resp.AsSpan().IndexOf(new byte[] { 0xA1, 0x63, 0x03 })); + ClassicAssert.AreNotEqual(-1, smallMap2Resp.AsSpan().IndexOf(new byte[] { 0xA1, 0x64, 0x04 })); + ClassicAssert.AreNotEqual(-1, smallMap2Resp.AsSpan().IndexOf(new byte[] { 0xA1, 0x65, 0x05 })); + ClassicAssert.AreNotEqual(-1, smallMap2Resp.AsSpan().IndexOf(new byte[] { 0xA1, 0x66, 0x06 })); + ClassicAssert.AreNotEqual(-1, smallMap2Resp.AsSpan().IndexOf(new byte[] { 0xA1, 0x67, 0x07 })); + ClassicAssert.AreNotEqual(-1, smallMap2Resp.AsSpan().IndexOf(new byte[] { 0xA1, 0x68, 0x08 })); + ClassicAssert.AreNotEqual(-1, smallMap2Resp.AsSpan().IndexOf(new byte[] { 0xA1, 0x69, 0x09 })); + ClassicAssert.AreNotEqual(-1, smallMap2Resp.AsSpan().IndexOf(new byte[] { 0xA1, 0x6A, 0x0A })); + ClassicAssert.AreNotEqual(-1, smallMap2Resp.AsSpan().IndexOf(new byte[] { 0xA1, 0x6B, 0x0B })); + ClassicAssert.AreNotEqual(-1, smallMap2Resp.AsSpan().IndexOf(new byte[] { 0xA1, 0x6C, 0x0C })); + ClassicAssert.AreNotEqual(-1, smallMap2Resp.AsSpan().IndexOf(new byte[] { 0xA1, 0x6D, 0x0D })); + ClassicAssert.AreNotEqual(-1, smallMap2Resp.AsSpan().IndexOf(new byte[] { 0xA1, 0x6E, 0x0E })); + ClassicAssert.AreNotEqual(-1, smallMap2Resp.AsSpan().IndexOf(new byte[] { 0xA1, 0x6F, 0x0F })); + + var midKeys = string.Join(", ", Enumerable.Range(0, 16).Select(static x => $"m_{(char)('A' + x)}={x}")); + var midMapResp = (byte[])db.ScriptEvaluate($"return cmsgpack.pack({{ {midKeys} }})"); + ClassicAssert.AreEqual(83, midMapResp.Length); + ClassicAssert.AreEqual(0xDE, midMapResp[0]); + ClassicAssert.AreEqual(0, midMapResp[1]); + ClassicAssert.AreEqual(16, midMapResp[2]); + for (var val = 0; val < 16; val++) + { + var keyPart = (byte)('A' + val); + var expected = new byte[] { 0xA3, (byte)'m', (byte)'_', keyPart, (byte)val }; + ClassicAssert.AreNotEqual(-1, midMapResp.AsSpan().IndexOf(expected)); + } + + var bigKeys = string.Join(", ", Enumerable.Range(0, ushort.MaxValue + 1).Select(static x => $"f_{x:X4}=4")); + var bigMapResp = (byte[])db.ScriptEvaluate($"return cmsgpack.pack({{ {bigKeys} }})"); + ClassicAssert.AreEqual(524_293, bigMapResp.Length); + ClassicAssert.AreEqual(0xDF, bigMapResp[0]); + ClassicAssert.AreEqual(0, bigMapResp[1]); + ClassicAssert.AreEqual(1, bigMapResp[2]); + ClassicAssert.AreEqual(0, bigMapResp[3]); + ClassicAssert.AreEqual(0, bigMapResp[4]); + for (var val = 0; val <= ushort.MaxValue; val++) + { + var keyStr = val.ToString("X4"); + + var expected = new byte[] { 0xA6, (byte)'f', (byte)'_', (byte)keyStr[0], (byte)keyStr[1], (byte)keyStr[2], (byte)keyStr[3], 4 }; + ClassicAssert.AreNotEqual(-1, bigMapResp.AsSpan().IndexOf(expected)); + } + + var complexResp = (byte[])db.ScriptEvaluate("return cmsgpack.pack({4, { key='value', arr={5, 6} }, true, 1.23})"); + ClassicAssert.AreEqual(30, complexResp.Length); + ClassicAssert.AreEqual(0x94, complexResp[0]); + ClassicAssert.AreEqual(0x04, complexResp[1]); + var complexNestedMap = complexResp.AsSpan().Slice(2, 18); + ClassicAssert.AreEqual(0x82, complexNestedMap[0]); + complexNestedMap = complexNestedMap[1..]; + ClassicAssert.AreNotEqual(-1, complexNestedMap.IndexOf(new byte[] { 0b1010_0011, (byte)'k', (byte)'e', (byte)'y', 0b1010_0101, (byte)'v', (byte)'a', (byte)'l', (byte)'u', (byte)'e' })); + ClassicAssert.AreNotEqual(-1, complexNestedMap.IndexOf(new byte[] { 0b1010_0011, (byte)'a', (byte)'r', (byte)'r', 0b1001_0010, 0x05, 0x06 })); + ClassicAssert.AreEqual(0xC3, complexResp[20]); + ClassicAssert.AreEqual(0xCB, complexResp[21]); + ClassicAssert.AreEqual(1.23, BinaryPrimitives.ReadDoubleBigEndian(complexResp.AsSpan()[22..])); + + var concatedResp = (byte[])db.ScriptEvaluate("return cmsgpack.pack(1, 2, 3, 4, {5})"); + ClassicAssert.True(concatedResp.SequenceEqual(new byte[] { 0x01, 0x02, 0x03, 0x04, 0b1001_0001, 0x05 })); + + // Rather than an error, Redis converts a too deeply nested object into a null (very strange) + // + // We match that behavior + + var infiniteNestMap = (byte[])db.ScriptEvaluate("local a = {}; a.ref = a; return cmsgpack.pack(a)"); + ClassicAssert.True(infiniteNestMap.SequenceEqual(new byte[] { 0x81, 0xA3, 0x72, 0x65, 0x66, 0x81, 0xA3, 0x72, 0x65, 0x66, 0x81, 0xA3, 0x72, 0x65, 0x66, 0x81, 0xA3, 0x72, 0x65, 0x66, 0x81, 0xA3, 0x72, 0x65, 0x66, 0x81, 0xA3, 0x72, 0x65, 0x66, 0x81, 0xA3, 0x72, 0x65, 0x66, 0x81, 0xA3, 0x72, 0x65, 0x66, 0x81, 0xA3, 0x72, 0x65, 0x66, 0x81, 0xA3, 0x72, 0x65, 0x66, 0x81, 0xA3, 0x72, 0x65, 0x66, 0x81, 0xA3, 0x72, 0x65, 0x66, 0x81, 0xA3, 0x72, 0x65, 0x66, 0x81, 0xA3, 0x72, 0x65, 0x66, 0x81, 0xA3, 0x72, 0x65, 0x66, 0x81, 0xA3, 0x72, 0x65, 0x66, 0xC0 })); + + var infiniteNestArr = (byte[])db.ScriptEvaluate("local a = {}; a[1] = a; return cmsgpack.pack(a)"); + ClassicAssert.True(infiniteNestArr.SequenceEqual(new byte[] { 0x91, 0x91, 0x91, 0x91, 0x91, 0x91, 0x91, 0x91, 0x91, 0x91, 0x91, 0x91, 0x91, 0x91, 0x91, 0x91, 0xC0 })); + } + + [Test] + public void CMsgPackUnpack() + { + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(); + + // TODO: Once refactored to avoid longjmp issues, restore on Linux + if (CanTestLuaErrors) + { + var noArgExc = ClassicAssert.Throws(() => db.ScriptEvaluate("return cmsgpack.unpack()")); + ClassicAssert.True(noArgExc.Message.Contains("bad argument") && noArgExc.Message.Contains("unpack")); + + // Multiple arguments are allowed, but ignored + + // Table ends before it should + var badDataExc = ClassicAssert.Throws(() => db.ScriptEvaluate("return cmsgpack.unpack('\\220\\0\\96')")); + ClassicAssert.True(badDataExc.Message.Contains("Missing bytes in input")); + } + + var nullResp = (string)db.ScriptEvaluate($"return type(cmsgpack.unpack({ToLuaString(0xC0)}))"); + ClassicAssert.AreEqual("nil", nullResp); + + var trueResp = (string)db.ScriptEvaluate($"return type(cmsgpack.unpack({ToLuaString(0xC3)}))"); + ClassicAssert.AreEqual("boolean", trueResp); + + var falseResp = (string)db.ScriptEvaluate($"return type(cmsgpack.unpack({ToLuaString(0xC2)}))"); + ClassicAssert.AreEqual("boolean", falseResp); + + var tinyUInt1Resp = (int)db.ScriptEvaluate($"return cmsgpack.unpack({ToLuaString(0x00)})"); + ClassicAssert.AreEqual(0, tinyUInt1Resp); + + var tinyUInt2Resp = (int)db.ScriptEvaluate($"return cmsgpack.unpack({ToLuaString(0x7F)})"); + ClassicAssert.AreEqual(127, tinyUInt2Resp); + + var tinyInt1Resp = (int)db.ScriptEvaluate($"return cmsgpack.unpack({ToLuaString(0xFF)})"); + ClassicAssert.AreEqual(-1, tinyInt1Resp); + + var tinyInt2Resp = (int)db.ScriptEvaluate($"return cmsgpack.unpack({ToLuaString(0xE0)})"); + ClassicAssert.AreEqual(-32, tinyInt2Resp); + + var smallUInt1Resp = (int)db.ScriptEvaluate($"return cmsgpack.unpack({ToLuaString(0xCC, 0x80)})"); + ClassicAssert.AreEqual(128, smallUInt1Resp); + + var smallUInt2Resp = (int)db.ScriptEvaluate($"return cmsgpack.unpack({ToLuaString(0xCC, 0xFF)})"); + ClassicAssert.AreEqual(255, smallUInt2Resp); + + var smallInt1Resp = (int)db.ScriptEvaluate($"return cmsgpack.unpack({ToLuaString(0xD0, 0xDF)})"); + ClassicAssert.AreEqual(-33, smallInt1Resp); + + var smallInt2Resp = (int)db.ScriptEvaluate($"return cmsgpack.unpack({ToLuaString(0xD0, 0x80)})"); + ClassicAssert.AreEqual(-128, smallInt2Resp); + + var midUInt1Resp = (int)db.ScriptEvaluate($"return cmsgpack.unpack({ToLuaString(0xCD, 0x80, 0x00)})"); + ClassicAssert.AreEqual(32768, midUInt1Resp); + + var midUInt2Resp = (int)db.ScriptEvaluate($"return cmsgpack.unpack({ToLuaString(0xCD, 0xFF, 0xFF)})"); + ClassicAssert.AreEqual(65535, midUInt2Resp); + + var midInt1Resp = (int)db.ScriptEvaluate($"return cmsgpack.unpack({ToLuaString(0xD1, 0xFF, 0x7F)})"); + ClassicAssert.AreEqual(-129, midInt1Resp); + + var midInt2Resp = (int)db.ScriptEvaluate($"return cmsgpack.unpack({ToLuaString(0xD1, 0x80, 0x00)})"); + ClassicAssert.AreEqual(-32768, midInt2Resp); + + var uint1Resp = (long)db.ScriptEvaluate($"return cmsgpack.unpack({ToLuaString(0xCE, 0x80, 0x00, 0x00, 0x00)})"); + ClassicAssert.AreEqual(2147483648, uint1Resp); + + var uint2Resp = (long)db.ScriptEvaluate($"return cmsgpack.unpack({ToLuaString(0xCE, 0xFF, 0xFF, 0xFF, 0xFF)})"); + ClassicAssert.AreEqual(4294967295L, uint2Resp); + + var int1Resp = (long)db.ScriptEvaluate($"return cmsgpack.unpack({ToLuaString(0xD2, 0xFF, 0xFF, 0x7F, 0xFF)})"); + ClassicAssert.AreEqual(-32769, int1Resp); + + var int2Resp = (long)db.ScriptEvaluate($"return cmsgpack.unpack({ToLuaString(0xD2, 0x80, 0x00, 0x00, 0x00)})"); + ClassicAssert.AreEqual(-2147483648, int2Resp); + + var bigUIntResp = (long)db.ScriptEvaluate($"return cmsgpack.unpack({ToLuaString(0xCF, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00)})"); + ClassicAssert.AreEqual(4294967296L, bigUIntResp); + + var bigIntResp = (long)db.ScriptEvaluate($"return cmsgpack.unpack({ToLuaString(0xD3, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F, 0xFF, 0xFF, 0xFF)})"); + ClassicAssert.AreEqual(-2147483649L, bigIntResp); + + var floatResp = (string)db.ScriptEvaluate($"return tostring(cmsgpack.unpack({ToLuaString(0xCB, 0x3F, 0xB9, 0x99, 0x99, 0x99, 0x99, 0x99, 0x9A)}))"); + ClassicAssert.AreEqual("0.1", floatResp); + + var tinyString1Resp = (string)db.ScriptEvaluate($"return cmsgpack.unpack({ToLuaString(0xA0)})"); + ClassicAssert.AreEqual("", tinyString1Resp); + + var tinyString2Resp = (string)db.ScriptEvaluate($"return cmsgpack.unpack({ToLuaString(new byte[] { 0xBF }.Concat(Enumerable.Repeat((byte)'0', 31)).ToArray())})"); + ClassicAssert.AreEqual(new string('0', 31), tinyString2Resp); + + var shortString1Resp = (string)db.ScriptEvaluate($"return cmsgpack.unpack({ToLuaString(new byte[] { 0xD9, 0x20 }.Concat(Enumerable.Repeat((byte)'a', 32)).ToArray())})"); + ClassicAssert.AreEqual(new string('a', 32), shortString1Resp); + + var shortString2Resp = (string)db.ScriptEvaluate($"return cmsgpack.unpack({ToLuaString(new byte[] { 0xD9, 0xFF }.Concat(Enumerable.Repeat((byte)'a', 255)).ToArray())})"); + ClassicAssert.AreEqual(new string('a', 255), shortString2Resp); + + var midString1Resp = (string)db.ScriptEvaluate($"return cmsgpack.unpack({ToLuaString(new byte[] { 0xDA, 0x01, 0x00 }.Concat(Enumerable.Repeat((byte)'b', 256)).ToArray())})"); + ClassicAssert.AreEqual(new string('b', 256), midString1Resp); + + var midString2Resp = (string)db.ScriptEvaluate($"return cmsgpack.unpack({ToLuaString(new byte[] { 0xDA, 0xFF, 0xFF }.Concat(Enumerable.Repeat((byte)'b', 65535)).ToArray())})"); + ClassicAssert.AreEqual(new string('b', 65535), midString2Resp); + + var longStringResp = (string)db.ScriptEvaluate($"return cmsgpack.unpack({ToLuaString(new byte[] { 0xDB, 0x00, 0x01, 0x00, 0x00 }.Concat(Enumerable.Repeat((byte)'c', 65536)).ToArray())})"); + ClassicAssert.AreEqual(new string('c', 65536), longStringResp); + + var emptyTableResp = (int)db.ScriptEvaluate($"return #cmsgpack.unpack({ToLuaString(0x90)})"); + ClassicAssert.AreEqual(0, emptyTableResp); + + var smallArray1Resp = (int)db.ScriptEvaluate($"return cmsgpack.unpack({ToLuaString(0x91, 0x01)})[1]"); + ClassicAssert.AreEqual(1, smallArray1Resp); + + var smallArray2Resp = (int)db.ScriptEvaluate($"return cmsgpack.unpack({ToLuaString(0x9F, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F)})[14]"); + ClassicAssert.AreEqual(14, smallArray2Resp); + + var midArray1Resp = (int)db.ScriptEvaluate($"return cmsgpack.unpack({ToLuaString(new byte[] { 0xDC, 0x00, 0x10 }.Concat(Enumerable.Repeat((byte)0x01, 16)).ToArray())})[16]"); + ClassicAssert.AreEqual(1, midArray1Resp); + + var midArray2Resp = (int)db.ScriptEvaluate($"return cmsgpack.unpack({ToLuaString(new byte[] { 0xDC, 0xFF, 0xFF }.Concat(Enumerable.Repeat((byte)0x02, ushort.MaxValue)).ToArray())})[1000]"); + ClassicAssert.AreEqual(2, midArray2Resp); + + var bigArrayResp = (int)db.ScriptEvaluate($"return cmsgpack.unpack({ToLuaString(new byte[] { 0xDD, 0x00, 0x01, 0x00, 0x00 }.Concat(Enumerable.Repeat((byte)0x03, ushort.MaxValue + 1)).ToArray())})[20000]"); + ClassicAssert.AreEqual(3, bigArrayResp); + + var smallMap1Resp = (int)db.ScriptEvaluate($"return cmsgpack.unpack({ToLuaString(0x81, 0xA1, 0x61, 0x01)}).a"); + ClassicAssert.AreEqual(1, smallMap1Resp); + + var smallMap2Resp = (int)db.ScriptEvaluate($"return cmsgpack.unpack({ToLuaString(0x8F, 0xA1, 0x61, 0x01, 0xA1, 0x62, 0x02, 0xA1, 0x63, 0x03, 0xA1, 0x64, 0x04, 0xA1, 0x65, 0x05, 0xA1, 0x66, 0x06, 0xA1, 0x67, 0x07, 0xA1, 0x68, 0x08, 0xA1, 0x69, 0x09, 0xA1, 0x6A, 0x0A, 0xA1, 0x6B, 0x0B, 0xA1, 0x6C, 0x0C, 0xA1, 0x6D, 0x0D, 0xA1, 0x6E, 0x0E, 0xA1, 0x6F, 0x0F)}).o"); + ClassicAssert.AreEqual(15, smallMap2Resp); + + var midMapBytes = new List { 0xDE, 0x00, 0x10 }; + for (var val = 0; val < 16; val++) + { + var keyPart = (byte)('A' + val); + midMapBytes.AddRange([0xA3, (byte)'m', (byte)'_', keyPart, (byte)val]); + } + var midMapResp = (int)db.ScriptEvaluate($"return cmsgpack.unpack({ToLuaString(midMapBytes.ToArray())}).m_F"); + ClassicAssert.AreEqual(5, midMapResp); + + var bigMapBytes = new List { 0xDF, 0x00, 0x01, 0x00, 0x00 }; + for (var val = 0; val <= ushort.MaxValue; val++) + { + var keyStr = val.ToString("X4"); + + bigMapBytes.AddRange([0xA6, (byte)'f', (byte)'_', (byte)keyStr[0], (byte)keyStr[1], (byte)keyStr[2], (byte)keyStr[3], 4]); + } + var bigMapRes = (int)db.ScriptEvaluate($"return cmsgpack.unpack({ToLuaString(bigMapBytes.ToArray())}).f_0123"); + ClassicAssert.AreEqual(4, bigMapRes); + + var nestedResp = (int)db.ScriptEvaluate($"return cmsgpack.unpack({ToLuaString(0x94, 0x04, 0x82, 0b1010_0011, (byte)'k', (byte)'e', (byte)'y', 0b1010_0101, (byte)'v', (byte)'a', (byte)'l', (byte)'u', (byte)'e', 0b1010_0011, (byte)'a', (byte)'r', (byte)'r', 0b1001_0010, 0x05, 0x06, 0xC3, 0xCB, 0x3F, 0xF3, 0xAE, 0x14, 0x7A, 0xE1, 0x47, 0xAE)})[2].arr[2]"); + ClassicAssert.AreEqual(6, nestedResp); + + var multiResp = (int[])db.ScriptEvaluate($"local a, b, c, d, e = cmsgpack.unpack({ToLuaString(0x01, 0x02, 0x03, 0x04, 0b1001_0001, 0x05)}); return {{e[1], d, c, b, a}}"); + ClassicAssert.True(multiResp.SequenceEqual([5, 4, 3, 2, 1])); + + // Helper for encoding a byte array into something that can be passed to Lua + static string ToLuaString(params byte[] data) + { + var ret = new StringBuilder(); + _ = ret.Append('\''); + + foreach (var b in data) + { + _ = ret.Append($"\\{b}"); + } + + _ = ret.Append('\''); + + return ret.ToString(); + } + } + + [Test] + public void Struct() + { + // Redis struct.pack/unpack/size is a subset of Lua's string.pack/unpack/packsize; so just testing for basic functionality + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(); + + var packRes = (byte[])db.ScriptEvaluate("return struct.pack('HH', 1, 2)"); + ClassicAssert.True(packRes.SequenceEqual(new byte[] { 0x01, 0x00, 0x02, 0x00 })); + + var unpackRes = (int[])db.ScriptEvaluate("return { struct.unpack('HH', '\\01\\00\\02\\00') }"); + ClassicAssert.True(unpackRes.SequenceEqual([1, 2, 5])); + + var sizeRes = (int)db.ScriptEvaluate("return struct.size('HH')"); + ClassicAssert.AreEqual(4, sizeRes); + } + + [Test] + public void MathFunctions() + { + // There are a number of "weird" math functions Redis supports that don't have direct .NET equivalents + // + // Doing some basic testing on these implementations + // + // We don't actually guarantee bit-for-bit or char-for-char equivalence, but "close" is worth attempting + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(); + + // Frexp + { + var a = (string)db.ScriptEvaluate("local a,b = math.frexp(0); return tostring(a)..'|'..tostring(b)"); + ClassicAssert.AreEqual("0|0", a); + + var b = (string)db.ScriptEvaluate("local a,b = math.frexp(1); return tostring(a)..'|'..tostring(b)"); + ClassicAssert.AreEqual("0.5|1", b); + + var c = (string)db.ScriptEvaluate($"local a,b = math.frexp({double.MaxValue}); return tostring(a)..'|'..tostring(b)"); + ClassicAssert.AreEqual("1|1024", c); + + var d = (string)db.ScriptEvaluate($"local a,b = math.frexp({double.MinValue}); return tostring(a)..'|'..tostring(b)"); + ClassicAssert.AreEqual("-1|1024", d); + + var e = (string)db.ScriptEvaluate("local a,b = math.frexp(1234.56); return tostring(a)..'|'..tostring(b)"); + ClassicAssert.AreEqual("0.6028125|11", e); + + var f = (string)db.ScriptEvaluate("local a,b = math.frexp(-7890.12); return tostring(a)..'|'..tostring(b)"); + ClassicAssert.AreEqual("-0.9631494140625|13", f); + } + + // Ldexp + { + var a = (string)db.ScriptEvaluate("return tostring(math.ldexp(0, 0))"); + ClassicAssert.AreEqual("0", a); + + var b = (string)db.ScriptEvaluate("return tostring(math.ldexp(1, 1))"); + ClassicAssert.AreEqual("2", b); + + var c = (string)db.ScriptEvaluate("return tostring(math.ldexp(0, 1))"); + ClassicAssert.AreEqual("0", c); + + var d = (string)db.ScriptEvaluate("return tostring(math.ldexp(1, 0))"); + ClassicAssert.AreEqual("1", d); + + var e = (string)db.ScriptEvaluate($"return tostring(math.ldexp({double.MaxValue}, 0))"); + ClassicAssert.AreEqual("1.7976931348623e+308", e); + + var f = (string)db.ScriptEvaluate($"return tostring(math.ldexp({double.MaxValue}, 1))"); + ClassicAssert.AreEqual("inf", f); + + var g = (string)db.ScriptEvaluate($"return tostring(math.ldexp({double.MinValue}, 0))"); + ClassicAssert.AreEqual("-1.7976931348623e+308", g); + + var h = (string)db.ScriptEvaluate($"return tostring(math.ldexp({double.MinValue}, 1))"); + ClassicAssert.AreEqual("-inf", h); + + var i = (string)db.ScriptEvaluate($"return tostring(math.ldexp(1.234, 1.234))"); + ClassicAssert.AreEqual("2.468", i); + + var j = (string)db.ScriptEvaluate($"return tostring(math.ldexp(-5.6798, 9.0123))"); + ClassicAssert.AreEqual("-2908.0576", j); + } + } + + [Test] + public void Maxn() + { + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(); + + var empty = (int)db.ScriptEvaluate("return table.maxn({})"); + ClassicAssert.AreEqual(0, empty); + + var single = (int)db.ScriptEvaluate("return table.maxn({4})"); + ClassicAssert.AreEqual(1, single); + + var multiple = (int)db.ScriptEvaluate("return table.maxn({-1, 1, 2, 5})"); + ClassicAssert.AreEqual(4, multiple); + + var keyed = (int)db.ScriptEvaluate("return table.maxn({foo='bar',fizz='buzz',hello='world'})"); + ClassicAssert.AreEqual(0, keyed); + + var mixed = (int)db.ScriptEvaluate("return table.maxn({-1, 1, foo='bar', 3})"); + ClassicAssert.AreEqual(3, mixed); + + var allNegative = (int)db.ScriptEvaluate("local x = {}; x[-1] = 1; x[-2]=2; return table.maxn(x)"); + ClassicAssert.AreEqual(0, allNegative); + } + + [Test] + public void LoadString() + { + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(); + + var basic = (int)db.ScriptEvaluate("local x = loadstring('return 123'); return x()"); + ClassicAssert.AreEqual(123, basic); + + // TODO: Once refactored to avoid longjmp issues, restore on Linux + if (CanTestLuaErrors) + { + var rejectNullExc = ClassicAssert.Throws(() => db.ScriptEvaluate("local x = loadstring('return \"\\0\"'); return x()")); + ClassicAssert.True(rejectNullExc.Message.Contains("bad argument to loadstring, interior null byte")); + } + } + [Test] [Ignore("Long running, disabled by default")] public void StressTimeouts() { - // Psuedo-repeatably random + // Pseudo-repeatably random const int SEED = 2025_01_30_00; const int DurationMS = 60 * 60 * 1_000; diff --git a/test/Garnet.test/TestUtils.cs b/test/Garnet.test/TestUtils.cs index d2a392cae98..09c004723a2 100644 --- a/test/Garnet.test/TestUtils.cs +++ b/test/Garnet.test/TestUtils.cs @@ -228,6 +228,7 @@ public static GarnetServer CreateGarnetServer( string luaMemoryLimit = "", TimeSpan? luaTimeout = null, LuaLoggingMode luaLoggingMode = LuaLoggingMode.Enable, + IEnumerable luaAllowedFunctions = null, string unixSocketPath = null, UnixFileMode unixSocketPermission = default, int slowLogThreshold = 0, @@ -312,7 +313,7 @@ public static GarnetServer CreateGarnetServer( EnableReadCache = enableReadCache, EnableObjectStoreReadCache = enableObjectStoreReadCache, ReplicationOffsetMaxLag = asyncReplay ? -1 : 0, - LuaOptions = enableLua ? new LuaOptions(luaMemoryMode, luaMemoryLimit, luaTimeout ?? Timeout.InfiniteTimeSpan, luaLoggingMode, logger) : null, + LuaOptions = enableLua ? new LuaOptions(luaMemoryMode, luaMemoryLimit, luaTimeout ?? Timeout.InfiniteTimeSpan, luaLoggingMode, luaAllowedFunctions ?? [], logger) : null, UnixSocketPath = unixSocketPath, UnixSocketPermission = unixSocketPermission, SlowLogThreshold = slowLogThreshold @@ -528,6 +529,7 @@ public static GarnetServerOptions GetGarnetServerOptions( string luaMemoryLimit = "", TimeSpan? luaTimeout = null, LuaLoggingMode luaLoggingMode = LuaLoggingMode.Enable, + IEnumerable luaAllowedFunctions = null, string unixSocketPath = null) { if (useAzureStorage) @@ -631,7 +633,7 @@ public static GarnetServerOptions GetGarnetServerOptions( ClusterPassword = authPassword, EnableLua = enableLua, ReplicationOffsetMaxLag = asyncReplay ? -1 : 0, - LuaOptions = enableLua ? new LuaOptions(luaMemoryMode, luaMemoryLimit, luaTimeout ?? Timeout.InfiniteTimeSpan, luaLoggingMode, logger) : null, + LuaOptions = enableLua ? new LuaOptions(luaMemoryMode, luaMemoryLimit, luaTimeout ?? Timeout.InfiniteTimeSpan, luaLoggingMode, luaAllowedFunctions ?? [], logger) : null, UnixSocketPath = unixSocketPath, ReplicaDisklessSync = enableDisklessSync, ReplicaDisklessSyncDelay = 1