diff --git a/Client/core/CVersionUpdater.Util.hpp b/Client/core/CVersionUpdater.Util.hpp index d6f7677c9ab..1267386cd20 100644 --- a/Client/core/CVersionUpdater.Util.hpp +++ b/Client/core/CVersionUpdater.Util.hpp @@ -601,8 +601,8 @@ namespace m_ArgMap.SetFromString(strSettings); // If build is 30 days old, default no report logging m_ArgMap.Get("filter2", strFilter, GetBuildAge() < 30 ? "+all" : "-all"); - m_ArgMap.Get("min", iMinSize, DEFAULT_MIN_SIZE); - m_ArgMap.Get("max", iMaxSize, DEFAULT_MAX_SIZE); + m_ArgMap.Get(std::string("min"), iMinSize, DEFAULT_MIN_SIZE); + m_ArgMap.Get(std::string("max"), iMaxSize, DEFAULT_MAX_SIZE); SaveReportSettings(); } diff --git a/Client/loader/Utils.cpp b/Client/loader/Utils.cpp index 9d3ead9eba9..2ac1f853eda 100644 --- a/Client/loader/Utils.cpp +++ b/Client/loader/Utils.cpp @@ -1579,7 +1579,7 @@ void CheckAndShowModelProblems() CArgMap argMap; argMap.SetFromString(GetApplicationSetting("diagnostics", "gta-model-fail")); argMap.Get("reason", strReason); - argMap.Get("id", iModelId); + argMap.Get(std::string("id"), iModelId); SetApplicationSetting("diagnostics", "gta-model-fail", ""); if (iModelId) @@ -1604,9 +1604,9 @@ void CheckAndShowUpgradeProblems() int iModelId = 0, iUpgradeId, iFrame; CArgMap argMap; argMap.SetFromString(GetApplicationSetting("diagnostics", "gta-upgrade-fail")); - argMap.Get("vehid", iModelId); - argMap.Get("upgid", iUpgradeId); - argMap.Get("frame", iFrame); + argMap.Get(std::string("vehid"), iModelId); + argMap.Get(std::string("upgid"), iUpgradeId); + argMap.Get(std::string("frame"), iFrame); SetApplicationSetting("diagnostics", "gta-upgrade-fail", ""); if (iModelId) diff --git a/Server/dbconmy/CDatabaseConnectionMySql.cpp b/Server/dbconmy/CDatabaseConnectionMySql.cpp index a1543dc38d1..4fd8ec94f80 100644 --- a/Server/dbconmy/CDatabaseConnectionMySql.cpp +++ b/Server/dbconmy/CDatabaseConnectionMySql.cpp @@ -85,11 +85,11 @@ CDatabaseConnectionMySql::CDatabaseConnectionMySql(CDatabaseType* pManager, cons // Parse options string CArgMap optionsMap("=", ";"); optionsMap.SetFromString(strOptions); - optionsMap.Get("autoreconnect", m_bAutomaticReconnect, 1); - optionsMap.Get("batch", m_bAutomaticTransactionsEnabled, 1); - optionsMap.Get("multi_statements", m_bMultipleStatements, 0); - optionsMap.Get("use_ssl", m_bUseSSL, 0); - optionsMap.Get("get_server_public_key", getServerPublicKey, 1); + optionsMap.Get("autoreconnect"s, m_bAutomaticReconnect, 1); + optionsMap.Get("batch"s, m_bAutomaticTransactionsEnabled, 1); + optionsMap.Get("multi_statements"s, m_bMultipleStatements, 0); + optionsMap.Get("use_ssl"s, m_bUseSSL, 0); + optionsMap.Get("get_server_public_key"s, getServerPublicKey, 1); SString strHostname; SString strDatabaseName; @@ -103,7 +103,7 @@ CDatabaseConnectionMySql::CDatabaseConnectionMySql(CDatabaseType* pManager, cons argMap.SetFromString(strHost); argMap.Get("dbname", strDatabaseName, ""); argMap.Get("host", strHostname, "localhost"); - argMap.Get("port", iPort, 0); + argMap.Get("port"s, iPort, 0); argMap.Get("unix_socket", strUnixSocket, ""); argMap.Get("charset", strCharset, ""); diff --git a/Server/mods/deathmatch/logic/lua/CLuaFunctionParseHelpers.cpp b/Server/mods/deathmatch/logic/lua/CLuaFunctionParseHelpers.cpp index ef25750c053..a4a57f98f13 100644 --- a/Server/mods/deathmatch/logic/lua/CLuaFunctionParseHelpers.cpp +++ b/Server/mods/deathmatch/logic/lua/CLuaFunctionParseHelpers.cpp @@ -746,6 +746,14 @@ void CheckCanModifyOtherResource(CScriptArgReader& argStream, CResource* pThisRe "Access denied"); } +std::pair CheckCanModifyOtherResource(CResource* thisResource, CResource* otherResource) +{ + if (GetResourceModifyScope(thisResource, otherResource) != eResourceModifyScope::NONE) + return {true, ""}; + + return {false, SString("ModifyOtherObjects in ACL denied resource %s to access %s", thisResource->GetName().c_str(), otherResource->GetName().c_str())}; +} + // // Set error if pThisResource does not have permission to modify every resource in resourceList // diff --git a/Server/mods/deathmatch/logic/lua/CLuaFunctionParseHelpers.h b/Server/mods/deathmatch/logic/lua/CLuaFunctionParseHelpers.h index 52a4536012e..b399f84fe36 100644 --- a/Server/mods/deathmatch/logic/lua/CLuaFunctionParseHelpers.h +++ b/Server/mods/deathmatch/logic/lua/CLuaFunctionParseHelpers.h @@ -402,8 +402,10 @@ enum class eResourceModifyScope }; eResourceModifyScope GetResourceModifyScope(CResource* pThisResource, CResource* pOtherResource); -void CheckCanModifyOtherResource(CScriptArgReader& argStream, CResource* pThisResource, CResource* pOtherResource); -void CheckCanModifyOtherResources(CScriptArgReader& argStream, CResource* pThisResource, std::initializer_list resourceList); +void CheckCanModifyOtherResource(CScriptArgReader& argStream, CResource* pThisResource, CResource* pOtherResource); +std::pair CheckCanModifyOtherResource(CResource* thisResource, CResource* otherResource); + +void CheckCanModifyOtherResources(CScriptArgReader& argStream, CResource* pThisResource, std::initializer_list resourceList); void CheckCanAccessOtherResourceFile(CScriptArgReader& argStream, CResource* pThisResource, CResource* pOtherResource, const SString& strAbsPath, bool* pbReadOnly = nullptr); diff --git a/Server/mods/deathmatch/logic/luadefs/CLuaDatabaseDefs.cpp b/Server/mods/deathmatch/logic/luadefs/CLuaDatabaseDefs.cpp index bb5a1d1d160..ec42ba3fd47 100644 --- a/Server/mods/deathmatch/logic/luadefs/CLuaDatabaseDefs.cpp +++ b/Server/mods/deathmatch/logic/luadefs/CLuaDatabaseDefs.cpp @@ -16,11 +16,12 @@ #include "CPerfStatManager.h" #include "lua/CLuaCallback.h" #include "Utils.h" +#include void CLuaDatabaseDefs::LoadFunctions() { constexpr static const std::pair functions[]{ - {"dbConnect", DbConnect}, + {"dbConnect", ArgumentParser}, {"dbExec", DbExec}, {"dbQuery", DbQuery}, {"dbFree", DbFree}, @@ -225,106 +226,122 @@ void CLuaDatabaseDefs::DbFreeCallback(CDbJobData* pJobData, void* pContext) } } -int CLuaDatabaseDefs::DbConnect(lua_State* luaVM) +std::variant CLuaDatabaseDefs::DbConnect(lua_State* luaVM, std::string type, std::string host, + std::optional username, std::optional password, + std::optional options, std::optional callback) { - // element dbConnect ( string type, string host, string username, string password, string options ) - SString strType; - SString strHost; - SString strUsername; - SString strPassword; - SString strOptions; + if (!username.has_value()) + username = ""; + if (!password.has_value()) + password = ""; + if (!options.has_value()) + options = ""; - CScriptArgReader argStream(luaVM); - argStream.ReadString(strType); - argStream.ReadString(strHost); - argStream.ReadString(strUsername, ""); - argStream.ReadString(strPassword, ""); - argStream.ReadString(strOptions, ""); + CResource* resource = &lua_getownerresource(luaVM); - if (!argStream.HasErrors()) + if (type == "sqlite" && !host.empty()) { - CResource* pThisResource = m_pLuaManager->GetVirtualMachineResource(luaVM); - if (pThisResource) + // If path starts with :/ then use global database directory + if (!host.rfind(":/", 0)) { - // If type is sqlite, and has a host, try to resolve path - if (strType == "sqlite" && !strHost.empty()) + host = host.substr(1); + if (!IsValidFilePath(host.c_str())) { - // If path starts with :/ then use global database directory - if (strHost.BeginsWith(":/")) - { - strHost = strHost.SubStr(1); - if (!IsValidFilePath(strHost)) - { - argStream.SetCustomError(SString("host path %s not valid", *strHost)); - } - else - { - strHost = PathJoin(g_pGame->GetConfig()->GetGlobalDatabasesPath(), strHost); - } - } - else - { - std::string strAbsPath; - - // Parse path - CResource* pPathResource = pThisResource; - if (CResourceManager::ParseResourcePathInput(strHost, pPathResource, &strAbsPath)) - { - strHost = strAbsPath; - CheckCanModifyOtherResource(argStream, pThisResource, pPathResource); - } - else - { - argStream.SetCustomError(SString("host path %s not found", *strHost)); - } - } + SString err("host path %s not valid", host.c_str()); + throw LuaFunctionError(err.c_str(), false); } - if (!argStream.HasErrors()) - { - if (strType == "mysql") - pThisResource->SetUsingDbConnectMysql(true); - - // Add logging options - bool bLoggingEnabled; - SString strLogTag; - SString strQueueName; - // Set default values if required - GetOption(strOptions, "log", bLoggingEnabled, 1); - GetOption(strOptions, "tag", strLogTag, "script"); - GetOption(strOptions, "queue", strQueueName, (strType == "mysql") ? strHost : DB_SQLITE_QUEUE_NAME_DEFAULT); - SetOption(strOptions, "log", bLoggingEnabled); - SetOption(strOptions, "tag", strLogTag); - SetOption(strOptions, "queue", strQueueName); - // Do connect - SConnectionHandle connection = g_pGame->GetDatabaseManager()->Connect(strType, strHost, strUsername, strPassword, strOptions); - if (connection == INVALID_DB_HANDLE) - { - argStream.SetCustomError(g_pGame->GetDatabaseManager()->GetLastErrorMessage()); - } - else - { - // Use an element to wrap the connection for auto disconnected when the resource stops - // Don't set a parent because the element should not be accessible from other resources - CDatabaseConnectionElement* pElement = new CDatabaseConnectionElement(NULL, connection); - CElementGroup* pGroup = pThisResource->GetElementGroup(); - if (pGroup) - { - pGroup->Add(pElement); - } + host = PathJoin(g_pGame->GetConfig()->GetGlobalDatabasesPath(), host); + } + else + { + std::string absPath; - lua_pushelement(luaVM, pElement); - return 1; - } + // Parse path + CResource* pathResource = resource; + if (!CResourceManager::ParseResourcePathInput(host, pathResource, &absPath)) + { + SString err("host path %s not found", host.c_str()); + throw LuaFunctionError(err.c_str(), false); } + + host = absPath; + auto [status, err] = CheckCanModifyOtherResource(resource, pathResource); + if (!status) + throw LuaFunctionError(err.c_str(), false); } } + + if (type == "mysql") + resource->SetUsingDbConnectMysql(true); + + // Add logging options + bool loggingEnabled; + std::string logTag; + std::string queueName; + // Set default values if required + GetOption(*options, "log", loggingEnabled, 1); + GetOption(*options, "tag", logTag, "script"); + GetOption(*options, "queue", queueName, (type == "mysql") ? host.c_str() : DB_SQLITE_QUEUE_NAME_DEFAULT); + SetOption(*options, "log", loggingEnabled); + SetOption(*options, "tag", logTag); + SetOption(*options, "queue", queueName); + + const auto CreateConnection = [](CResource* resource, const SConnectionHandle& handle) + -> CDatabaseConnectionElement* + { + // Use an element to wrap the connection for auto disconnected when the resource stops + // Don't set a parent because the element should not be accessible from other resources + auto* element = new CDatabaseConnectionElement(nullptr, handle); + CElementGroup* group = resource->GetElementGroup(); + if (group) + group->Add(element); + return element; + }; - if (argStream.HasErrors()) - m_pScriptDebugging->LogCustom(luaVM, argStream.GetFullErrorMessage()); + if (callback.has_value()) + { + const auto taskFunc = [type = type, host, username = username.value(), password = password.value(), options = options.value()] + { + return g_pGame->GetDatabaseManager()->Connect( + type, + host, + username, + password, + options + ); + }; + const auto readyFunc = [CreateConnection = CreateConnection, resource = resource, luaFunctionRef = callback.value()](const SConnectionHandle& handle) + { + CLuaMain* luaMain = m_pLuaManager->GetVirtualMachine(luaFunctionRef.GetLuaVM()); + if (!luaMain) + return; - lua_pushboolean(luaVM, false); - return 1; + CLuaArguments arguments; + + if (handle == INVALID_DB_HANDLE) + { + auto lastError = g_pGame->GetDatabaseManager()->GetLastErrorMessage(); + m_pScriptDebugging->LogCustom(luaMain->GetVM(), lastError.c_str()); + arguments.PushBoolean(false); + } + + arguments.PushElement(CreateConnection(resource, handle)); + arguments.Call(luaMain, luaFunctionRef); + }; + + CLuaShared::GetAsyncTaskScheduler()->PushTask(taskFunc, readyFunc); + return true; + } + + // Do connect + SConnectionHandle connection = g_pGame->GetDatabaseManager()->Connect( + type, host, *username, *password, *options + ); + if (connection == INVALID_DB_HANDLE) + throw LuaFunctionError(g_pGame->GetDatabaseManager()->GetLastErrorMessage().c_str(), false); + + return CreateConnection(resource, connection); } // This method has an OOP counterpart - don't forget to update the OOP code too! diff --git a/Server/mods/deathmatch/logic/luadefs/CLuaDatabaseDefs.h b/Server/mods/deathmatch/logic/luadefs/CLuaDatabaseDefs.h index 015ed4f76b7..e9f690fba81 100644 --- a/Server/mods/deathmatch/logic/luadefs/CLuaDatabaseDefs.h +++ b/Server/mods/deathmatch/logic/luadefs/CLuaDatabaseDefs.h @@ -22,7 +22,9 @@ class CLuaDatabaseDefs : public CLuaDefs static void DbFreeCallback(CDbJobData* pJobData, void* pContext); static void DbExecCallback(CDbJobData* pJobData, void* pContext); - LUA_DECLARE(DbConnect); + static std::variant DbConnect(lua_State* luaVM, std::string type, std::string host, std::optional username, + std::optional password, std::optional options, + std::optional callback); LUA_DECLARE(DbExec); LUA_DECLARE_OOP(DbQuery); LUA_DECLARE(DbFree); @@ -36,4 +38,4 @@ class CLuaDatabaseDefs : public CLuaDefs LUA_DECLARE(ExecuteSQLSelect); LUA_DECLARE(ExecuteSQLUpdate); LUA_DECLARE(ExecuteSQLQuery); -}; \ No newline at end of file +}; diff --git a/Server/mods/deathmatch/utils/CMasterServerAnnouncer.h b/Server/mods/deathmatch/utils/CMasterServerAnnouncer.h index d846d391097..eff94733d30 100644 --- a/Server/mods/deathmatch/utils/CMasterServerAnnouncer.h +++ b/Server/mods/deathmatch/utils/CMasterServerAnnouncer.h @@ -133,7 +133,7 @@ class CMasterServer : public CRefCountable { CArgMap argMap; argMap.SetFromString(result.pData); - SString strOkMessage = argMap.Get("ok_message"); + SString strOkMessage = argMap.Get("ok_message"s); // Log successful initial announcement if (result.iErrorCode == 200) diff --git a/Shared/sdk/SharedUtil.Misc.h b/Shared/sdk/SharedUtil.Misc.h index d64f155a7cd..363cf0cc196 100644 --- a/Shared/sdk/SharedUtil.Misc.h +++ b/Shared/sdk/SharedUtil.Misc.h @@ -728,6 +728,8 @@ namespace SharedUtil CArgMap(const SString& strArgSep = "=", const SString& strPartsSep = "&", const SString& strExtraDisallowedChars = ""); void SetEscapeCharacter(char cEscapeCharacter); void Merge(const CArgMap& other, bool bAllowMultiValues = false); + void SetFromString(const char* line, bool allowMultiValues = false); + void SetFromString(const std::string& line, bool allowMultiValues = false); void SetFromString(const SString& strLine, bool bAllowMultiValues = false); void MergeFromString(const SString& strLine, bool bAllowMultiValues = false); SString ToString() const; @@ -740,6 +742,17 @@ namespace SharedUtil void Insert(const SString& strInCmd, int iValue); // Insert a key int value void Insert(const SString& strInCmd, const SString& strInValue); // Insert a key string value bool Contains(const SString& strInCmd) const; // Test if key exists + + std::string Get(const char*& inCmd) const noexcept; + bool Get(const char*& inCmd, std::string& out, const char* defaultValue = "") const noexcept; + bool Get(const char*& inCmd, std::vector& outList) const; + bool Get(const char*& inCmd, int& value, int defaultValue = 0) const noexcept; + + std::string Get(const std::string& inCmd) const noexcept; + bool Get(const std::string& inCmd, std::string& out, const char* defaultValue = "") const noexcept; + bool Get(const std::string& inCmd, std::vector& outList) const; + bool Get(const std::string& inCmd, int& value, int defaultValue = 0) const noexcept; + bool Get(const SString& strInCmd, SString& strOut, const char* szDefault = "") const; // First result as string SString Get(const SString& strInCmd) const; // First result as string bool Get(const SString& strInCmd, std::vector& outList) const; // All results as strings @@ -757,6 +770,15 @@ namespace SharedUtil strText = temp.ToString(); } + template + void SetOption(std::string& text, const std::string& key, const U& value) + { + T temp; + temp.SetFromString(text); + temp.Set(key, value); + text = temp.ToString(); + } + template void GetOption(const SString& strText, const SString& strKey, SString& strOutValue, const char* szDefault = "") { @@ -765,6 +787,14 @@ namespace SharedUtil temp.Get(strKey, strOutValue, szDefault); } + template + void GetOption(const std::string& text, const std::string& key, std::string& outValue, const char* defaultStr = "") + { + T temp; + temp.SetFromString(text); + temp.Get(key, outValue, defaultStr); + } + template void GetOption(const SString& strText, const SString& strKey, int& iOutValue, int iDefault = 0) { @@ -773,6 +803,14 @@ namespace SharedUtil temp.Get(strKey, iOutValue, iDefault); } + template + void GetOption(const std::string& text, const std::string& key, int& outValue, int defaultStr = 0) + { + T temp; + temp.SetFromString(text); + temp.Get(key, outValue, defaultStr); + } + // Coerce to a bool from an int template void GetOption(const SString& strText, const SString& strKey, bool& bOutValue, int iDefault = 0) @@ -784,6 +822,16 @@ namespace SharedUtil bOutValue = (iOutValue != 0); } + template + void GetOption(const std::string& text, const std::string& key, bool& outValue, int defaultStr = 0) + { + T temp; + temp.SetFromString(text); + int outInt; + temp.Get(key, outInt, defaultStr); + outValue = outInt != 0; + } + // Coerce to other types from an int template void GetOption(const SString& strText, const SString& strKey, U& outValue, int iDefault = 0) @@ -795,6 +843,16 @@ namespace SharedUtil outValue = static_cast(iOutValue); } + template + void GetOption(const std::string& text, const std::string& key, U& outValue, int defaultStr = 0) + { + T temp; + temp.SetFromString(text); + int outInt; + temp.Get(key, outInt, defaultStr); + outValue = static_cast(outInt); + } + // Comma separated set of numbers template void GetOption(const SString& strText, const SString& strKey, const char* szSeperator, std::set& outValues) @@ -808,6 +866,25 @@ namespace SharedUtil MapInsert(outValues, static_cast(atoi(numberList[i]))); } + template + void GetOption(const std::string& text, const std::string& key, const char* separator, std::set& outValues) + { + SString numbers; + { + std::string temp; + GetOption(text, key, temp); + numbers = std::move(temp); + } + std::vector numberList; + numbers.Split(separator, numberList); + for (const auto& number : numberList) + { + if (number.empty()) + continue; + MapInsert(outValues, static_cast(atoi(number.c_str()))); + } + } + /////////////////////////////////////////////////////////////// // // CMappedContainer diff --git a/Shared/sdk/SharedUtil.Misc.hpp b/Shared/sdk/SharedUtil.Misc.hpp index 5697da49f49..8d64553dfc4 100644 --- a/Shared/sdk/SharedUtil.Misc.hpp +++ b/Shared/sdk/SharedUtil.Misc.hpp @@ -364,8 +364,9 @@ SString SharedUtil::GetPostUpdateConnect() CArgMap argMap; argMap.SetFromString(strPostUpdateConnect); - SString strHost = argMap.Get("host"); - time_t timeThen = (time_t)std::atoll(argMap.Get("time")); + SString strHost = argMap.Get(std::string("host")); + std::string timeKey = argMap.Get(std::string("time")); + time_t timeThen = (time_t)std::atoll(timeKey.c_str()); // Expire after 5 mins double seconds = difftime(time(NULL), timeThen); @@ -674,7 +675,7 @@ bool SharedUtil::ProcessPendingBrowseToSolution() if (!argMap.Get("type", strType)) return false; argMap.Get("message", strMessageBoxMessage); - argMap.Get("flags", iFlags); + argMap.Get(std::string("flags"), iFlags); argMap.Get("ecode", strErrorCode); ClearPendingBrowseToSolution(); @@ -1647,12 +1648,23 @@ namespace SharedUtil void CArgMap::Merge(const CArgMap& other, bool bAllowMultiValues) { MergeFromString(other.ToString(), bAllowMultiValues); } + void CArgMap::SetFromString(const char* line, bool allowMultiValues) + { + return SetFromString(std::string(line), allowMultiValues); + } + void CArgMap::SetFromString(const SString& strLine, bool bAllowMultiValues) { m_Map.clear(); MergeFromString(strLine, bAllowMultiValues); } + void CArgMap::SetFromString(const std::string& line, bool allowMultiValues) + { + m_Map.clear(); + MergeFromString(line, allowMultiValues); + } + void CArgMap::MergeFromString(const SString& strLine, bool bAllowMultiValues) { std::vector parts; @@ -1729,6 +1741,68 @@ namespace SharedUtil // Test if key exists bool CArgMap::Contains(const SString& strCmd) const { return MapFind(m_Map, Escape(strCmd)) != NULL; } + std::string CArgMap::Get(const char*& inCmd) const noexcept + { + return Get(std::string(inCmd)); + } + + bool CArgMap::Get(const char*& inCmd, std::string& out, const char* defaultValue) const noexcept + { + return Get(std::string(inCmd), out, defaultValue); + } + + bool CArgMap::Get(const char*& inCmd, std::vector& outList) const + { + return Get(std::string(inCmd), outList); + } + + bool CArgMap::Get(const char*& inCmd, int& value, int defaultValue) const noexcept + { + return Get(std::string(inCmd), value, defaultValue); + } + + std::string CArgMap::Get(const std::string& inCmd) const noexcept + { + std::string result; + Get(inCmd, result); + return result; + } + + bool CArgMap::Get(const std::string& inCmd, std::string& out, const char* defaultValue) const noexcept + { + assert(defaultValue); + const SString* result = MapFind(m_Map, Escape(inCmd)); + if (result) + { + out = Unescape(*result); + return true; + } + out = defaultValue; + return false; + } + + bool CArgMap::Get(const std::string& inCmd, std::vector& outList) const + { + std::vector newItems; + MultiFind(m_Map, Escape(inCmd), &newItems); + for (auto& item : newItems) + item = Unescape(item); + outList.insert(outList.end(), newItems.begin(), newItems.end()); + return !newItems.empty(); + } + + bool CArgMap::Get(const std::string& inCmd, int& value, int defaultValue) const noexcept + { + std::string result; + if (Get(inCmd, result)) + { + value = atoi(result.c_str()); + return true; + } + value = defaultValue; + return false; + } + // First result as string bool CArgMap::Get(const SString& strCmd, SString& strOut, const char* szDefault) const {