|
16 | 16 | #include "CPerfStatManager.h" |
17 | 17 | #include "lua/CLuaCallback.h" |
18 | 18 | #include "Utils.h" |
| 19 | +#include <lua/CLuaShared.h> |
19 | 20 |
|
20 | 21 | void CLuaDatabaseDefs::LoadFunctions() |
21 | 22 | { |
22 | 23 | constexpr static const std::pair<const char*, lua_CFunction> functions[]{ |
23 | | - {"dbConnect", DbConnect}, |
| 24 | + {"dbConnect", ArgumentParser<DbConnect>}, |
24 | 25 | {"dbExec", DbExec}, |
25 | 26 | {"dbQuery", DbQuery}, |
26 | 27 | {"dbFree", DbFree}, |
@@ -225,106 +226,121 @@ void CLuaDatabaseDefs::DbFreeCallback(CDbJobData* pJobData, void* pContext) |
225 | 226 | } |
226 | 227 | } |
227 | 228 |
|
228 | | -int CLuaDatabaseDefs::DbConnect(lua_State* luaVM) |
| 229 | +std::variant<CDatabaseConnectionElement*, bool> CLuaDatabaseDefs::DbConnect(lua_State* luaVM, std::string type, std::string host, |
| 230 | + std::optional<std::string> username, std::optional<std::string> password, |
| 231 | + std::optional<std::string> options, std::optional<CLuaFunctionRef> callback) |
229 | 232 | { |
230 | | - // element dbConnect ( string type, string host, string username, string password, string options ) |
231 | | - SString strType; |
232 | | - SString strHost; |
233 | | - SString strUsername; |
234 | | - SString strPassword; |
235 | | - SString strOptions; |
| 233 | + if (!username.has_value()) |
| 234 | + username = ""; |
| 235 | + if (!password.has_value()) |
| 236 | + password = ""; |
| 237 | + if (!options.has_value()) |
| 238 | + options = ""; |
236 | 239 |
|
237 | | - CScriptArgReader argStream(luaVM); |
238 | | - argStream.ReadString(strType); |
239 | | - argStream.ReadString(strHost); |
240 | | - argStream.ReadString(strUsername, ""); |
241 | | - argStream.ReadString(strPassword, ""); |
242 | | - argStream.ReadString(strOptions, ""); |
| 240 | + CResource* resource = &lua_getownerresource(luaVM); |
243 | 241 |
|
244 | | - if (!argStream.HasErrors()) |
| 242 | + if (type == "sqlite" && !host.empty()) |
245 | 243 | { |
246 | | - CResource* pThisResource = m_pLuaManager->GetVirtualMachineResource(luaVM); |
247 | | - if (pThisResource) |
| 244 | + // If path starts with :/ then use global database directory |
| 245 | + if (!host.rfind(":/", 0)) |
248 | 246 | { |
249 | | - // If type is sqlite, and has a host, try to resolve path |
250 | | - if (strType == "sqlite" && !strHost.empty()) |
| 247 | + host = host.substr(1); |
| 248 | + if (!IsValidFilePath(host.c_str())) |
251 | 249 | { |
252 | | - // If path starts with :/ then use global database directory |
253 | | - if (strHost.BeginsWith(":/")) |
254 | | - { |
255 | | - strHost = strHost.SubStr(1); |
256 | | - if (!IsValidFilePath(strHost)) |
257 | | - { |
258 | | - argStream.SetCustomError(SString("host path %s not valid", *strHost)); |
259 | | - } |
260 | | - else |
261 | | - { |
262 | | - strHost = PathJoin(g_pGame->GetConfig()->GetGlobalDatabasesPath(), strHost); |
263 | | - } |
264 | | - } |
265 | | - else |
266 | | - { |
267 | | - std::string strAbsPath; |
268 | | - |
269 | | - // Parse path |
270 | | - CResource* pPathResource = pThisResource; |
271 | | - if (CResourceManager::ParseResourcePathInput(strHost, pPathResource, &strAbsPath)) |
272 | | - { |
273 | | - strHost = strAbsPath; |
274 | | - CheckCanModifyOtherResource(argStream, pThisResource, pPathResource); |
275 | | - } |
276 | | - else |
277 | | - { |
278 | | - argStream.SetCustomError(SString("host path %s not found", *strHost)); |
279 | | - } |
280 | | - } |
| 250 | + SString err("host path %s not valid", host.c_str()); |
| 251 | + throw LuaFunctionError(err.c_str()); |
281 | 252 | } |
282 | 253 |
|
283 | | - if (!argStream.HasErrors()) |
284 | | - { |
285 | | - if (strType == "mysql") |
286 | | - pThisResource->SetUsingDbConnectMysql(true); |
287 | | - |
288 | | - // Add logging options |
289 | | - bool bLoggingEnabled; |
290 | | - SString strLogTag; |
291 | | - SString strQueueName; |
292 | | - // Set default values if required |
293 | | - GetOption<CDbOptionsMap>(strOptions, "log", bLoggingEnabled, 1); |
294 | | - GetOption<CDbOptionsMap>(strOptions, "tag", strLogTag, "script"); |
295 | | - GetOption<CDbOptionsMap>(strOptions, "queue", strQueueName, (strType == "mysql") ? strHost : DB_SQLITE_QUEUE_NAME_DEFAULT); |
296 | | - SetOption<CDbOptionsMap>(strOptions, "log", bLoggingEnabled); |
297 | | - SetOption<CDbOptionsMap>(strOptions, "tag", strLogTag); |
298 | | - SetOption<CDbOptionsMap>(strOptions, "queue", strQueueName); |
299 | | - // Do connect |
300 | | - SConnectionHandle connection = g_pGame->GetDatabaseManager()->Connect(strType, strHost, strUsername, strPassword, strOptions); |
301 | | - if (connection == INVALID_DB_HANDLE) |
302 | | - { |
303 | | - argStream.SetCustomError(g_pGame->GetDatabaseManager()->GetLastErrorMessage()); |
304 | | - } |
305 | | - else |
306 | | - { |
307 | | - // Use an element to wrap the connection for auto disconnected when the resource stops |
308 | | - // Don't set a parent because the element should not be accessible from other resources |
309 | | - CDatabaseConnectionElement* pElement = new CDatabaseConnectionElement(NULL, connection); |
310 | | - CElementGroup* pGroup = pThisResource->GetElementGroup(); |
311 | | - if (pGroup) |
312 | | - { |
313 | | - pGroup->Add(pElement); |
314 | | - } |
| 254 | + host = PathJoin(g_pGame->GetConfig()->GetGlobalDatabasesPath(), host); |
| 255 | + } |
| 256 | + else |
| 257 | + { |
| 258 | + std::string absPath; |
315 | 259 |
|
316 | | - lua_pushelement(luaVM, pElement); |
317 | | - return 1; |
318 | | - } |
| 260 | + // Parse path |
| 261 | + CResource* pathResource = resource; |
| 262 | + if (CResourceManager::ParseResourcePathInput(host, pathResource, &absPath)) |
| 263 | + { |
| 264 | + host = absPath; |
| 265 | + auto [status, err] = CheckCanModifyOtherResource(resource, pathResource); |
| 266 | + if (!status) |
| 267 | + throw LuaFunctionError(err.c_str()); |
319 | 268 | } |
| 269 | + SString err("host path %s not found", host.c_str()); |
| 270 | + throw LuaFunctionError(err.c_str()); |
320 | 271 | } |
321 | 272 | } |
| 273 | + |
| 274 | + if (type == "mysql") |
| 275 | + resource->SetUsingDbConnectMysql(true); |
| 276 | + |
| 277 | + // Add logging options |
| 278 | + bool loggingEnabled; |
| 279 | + std::string logTag; |
| 280 | + std::string queueName; |
| 281 | + // Set default values if required |
| 282 | + GetOption<CDbOptionsMap>(*options, "log", loggingEnabled, 1); |
| 283 | + GetOption<CDbOptionsMap>(*options, "tag", logTag, "script"); |
| 284 | + GetOption<CDbOptionsMap>(*options, "queue", queueName, (type == "mysql") ? host.c_str() : DB_SQLITE_QUEUE_NAME_DEFAULT); |
| 285 | + SetOption<CDbOptionsMap>(*options, "log", loggingEnabled); |
| 286 | + SetOption<CDbOptionsMap>(*options, "tag", logTag); |
| 287 | + SetOption<CDbOptionsMap>(*options, "queue", queueName); |
| 288 | + |
| 289 | + const auto CreateConnection = [](CResource* resource, const SConnectionHandle& handle) |
| 290 | + -> CDatabaseConnectionElement* |
| 291 | + { |
| 292 | + // Use an element to wrap the connection for auto disconnected when the resource stops |
| 293 | + // Don't set a parent because the element should not be accessible from other resources |
| 294 | + auto* element = new CDatabaseConnectionElement(nullptr, handle); |
| 295 | + CElementGroup* group = resource->GetElementGroup(); |
| 296 | + if (group) |
| 297 | + group->Add(element); |
| 298 | + return element; |
| 299 | + }; |
322 | 300 |
|
323 | | - if (argStream.HasErrors()) |
324 | | - m_pScriptDebugging->LogCustom(luaVM, argStream.GetFullErrorMessage()); |
| 301 | + if (callback.has_value()) |
| 302 | + { |
| 303 | + const auto taskFunc = [type = type, host, username = username.value(), password = password.value(), options = options.value()] |
| 304 | + { |
| 305 | + return g_pGame->GetDatabaseManager()->Connect( |
| 306 | + type, |
| 307 | + host, |
| 308 | + username, |
| 309 | + password, |
| 310 | + options |
| 311 | + ); |
| 312 | + }; |
| 313 | + const auto readyFunc = [CreateConnection = CreateConnection, resource = resource, luaFunctionRef = callback.value()](const SConnectionHandle& handle) |
| 314 | + { |
| 315 | + CLuaMain* luaMain = m_pLuaManager->GetVirtualMachine(luaFunctionRef.GetLuaVM()); |
| 316 | + if (!luaMain) |
| 317 | + return; |
325 | 318 |
|
326 | | - lua_pushboolean(luaVM, false); |
327 | | - return 1; |
| 319 | + CLuaArguments arguments; |
| 320 | + |
| 321 | + if (handle == INVALID_DB_HANDLE) |
| 322 | + { |
| 323 | + auto lastError = g_pGame->GetDatabaseManager()->GetLastErrorMessage(); |
| 324 | + m_pScriptDebugging->LogCustom(luaMain->GetVM(), lastError.c_str()); |
| 325 | + arguments.PushBoolean(false); |
| 326 | + } |
| 327 | + |
| 328 | + arguments.PushElement(CreateConnection(resource, handle)); |
| 329 | + arguments.Call(luaMain, luaFunctionRef); |
| 330 | + }; |
| 331 | + |
| 332 | + CLuaShared::GetAsyncTaskScheduler()->PushTask(taskFunc, readyFunc); |
| 333 | + return true; |
| 334 | + } |
| 335 | + |
| 336 | + // Do connect |
| 337 | + SConnectionHandle connection = g_pGame->GetDatabaseManager()->Connect( |
| 338 | + type, host, *username, *password, *options |
| 339 | + ); |
| 340 | + if (connection == INVALID_DB_HANDLE) |
| 341 | + throw LuaFunctionError(g_pGame->GetDatabaseManager()->GetLastErrorMessage().c_str()); |
| 342 | + |
| 343 | + return CreateConnection(resource, connection); |
328 | 344 | } |
329 | 345 |
|
330 | 346 | // This method has an OOP counterpart - don't forget to update the OOP code too! |
|
0 commit comments