diff --git a/CMakeLists.txt b/CMakeLists.txt index f0e22c88c7..084d03720f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -51,6 +51,7 @@ endif () find_package(Threads REQUIRED) find_package(PugiXML CONFIG REQUIRED) +find_package(OpenSSL REQUIRED) # Selects LuaJIT if user defines or auto-detected if (DEFINED USE_LUAJIT AND NOT USE_LUAJIT) @@ -69,7 +70,7 @@ find_package(Boost 1.66.0 REQUIRED COMPONENTS system iostreams) option(ENABLE_TESTING "Build unit tests" OFF) -include_directories(${Boost_INCLUDE_DIRS} ${Crypto++_INCLUDE_DIR} ${LUA_INCLUDE_DIR} ${MYSQL_INCLUDE_DIR} ${PUGIXML_INCLUDE_DIR}) +include_directories(${Boost_INCLUDE_DIRS} ${Crypto++_INCLUDE_DIR} ${LUA_INCLUDE_DIR} ${MYSQL_INCLUDE_DIR} ${PUGIXML_INCLUDE_DIR} ${OPENSSL_INCLUDE_DIR}) set(CMAKE_CXX_STANDARD 20) set(CMAKE_CXX_STANDARD_REQUIRED ON) diff --git a/data/lib/core/core.lua b/data/lib/core/core.lua index 912d80597a..533de9b5eb 100644 --- a/data/lib/core/core.lua +++ b/data/lib/core/core.lua @@ -20,3 +20,4 @@ dofile('data/lib/core/teleport.lua') dofile('data/lib/core/tile.lua') dofile('data/lib/core/vocation.lua') dofile('data/lib/core/quests.lua') +dofile('data/lib/core/json.lua') diff --git a/data/lib/core/json.lua b/data/lib/core/json.lua new file mode 100644 index 0000000000..c86acb7119 --- /dev/null +++ b/data/lib/core/json.lua @@ -0,0 +1,419 @@ +-- +-- json.lua +-- +-- Copyright (c) 2019 rxi +-- +-- Permission is hereby granted, free of charge, to any person obtaining a copy of +-- this software and associated documentation files (the "Software"), to deal in +-- the Software without restriction, including without limitation the rights to +-- use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +-- of the Software, and to permit persons to whom the Software is furnished to do +-- so, subject to the following conditions: +-- +-- The above copyright notice and this permission notice shall be included in all +-- copies or substantial portions of the Software. +-- +-- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +-- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +-- FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +-- AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +-- LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +-- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +-- SOFTWARE. +-- + +json = { _version = "0.1.1" } + +------------------------------------------------------------------------------- +-- Encode +------------------------------------------------------------------------------- + +local encode + +local escape_char_map = { + [ "\\" ] = "\\\\", + [ "\"" ] = "\\\"", + [ "\b" ] = "\\b", + [ "\f" ] = "\\f", + [ "\n" ] = "\\n", + [ "\r" ] = "\\r", + [ "\t" ] = "\\t", +} + +local escape_char_map_inv = { [ "\\/" ] = "/" } +for k, v in pairs(escape_char_map) do + escape_char_map_inv[v] = k +end + + +local function make_indent(state) + return string.rep(" ", state.currentIndentLevel * state.indent) +end + + +local function escape_char(c) + return escape_char_map[c] or string.format("\\u%04x", c:byte()) +end + + +local function encode_nil() + return "null" +end + + +local function encode_table(val, state) + local res = {} + local stack = state.stack + local pretty = state.indent > 0 + + local close_indent = make_indent(state) + local comma = pretty and ",\n" or "," + local colon = pretty and ": " or ":" + local open_brace = pretty and "{\n" or "{" + local close_brace = pretty and ("\n" .. close_indent .. "}") or "}" + local open_bracket = pretty and "[\n" or "[" + local close_bracket = pretty and ("\n" .. close_indent .. "]") or "]" + + -- Circular reference? + if stack[val] then error("circular reference") end + + stack[val] = true + + if rawget(val, 1) ~= nil or next(val) == nil then + -- Treat as array -- check keys are valid and it is not sparse + local n = 0 + for k in pairs(val) do + if type(k) ~= "number" then + error("invalid table: mixed or invalid key types") + end + n = n + 1 + end + if n ~= #val then + error("invalid table: sparse array") + end + -- Encode + for _, v in ipairs(val) do + state.currentIndentLevel = state.currentIndentLevel + 1 + table.insert(res, make_indent(state) .. encode(v, state)) + state.currentIndentLevel = state.currentIndentLevel - 1 + end + stack[val] = nil + return open_bracket .. table.concat(res, comma) .. close_bracket + + else + -- Treat as an object + for k, v in pairs(val) do + if type(k) ~= "string" then + error("invalid table: mixed or invalid key types") + end + state.currentIndentLevel = state.currentIndentLevel + 1 + table.insert(res, make_indent(state) .. encode(k, state) .. colon .. encode(v, state)) + state.currentIndentLevel = state.currentIndentLevel - 1 + end + stack[val] = nil + return open_brace .. table.concat(res, comma) .. close_brace + end +end + + +local function encode_string(val) + return '"' .. val:gsub('[%z\1-\31\\"]', escape_char) .. '"' +end + + +local function encode_number(val) + -- Check for NaN, -inf and inf + if val ~= val or val <= -math.huge or val >= math.huge then + error("unexpected number value '" .. tostring(val) .. "'") + end + return string.format("%.14g", val) +end + + +local type_func_map = { + [ "nil" ] = encode_nil, + [ "table" ] = encode_table, + [ "string" ] = encode_string, + [ "number" ] = encode_number, + [ "boolean" ] = tostring, +} + + +encode = function(val, state) + local t = type(val) + local f = type_func_map[t] + if f then + return f(val, state) + end + error("unexpected type '" .. t .. "'") +end + +function json.encode(val, indent) + local state = { + indent = indent or 0, + currentIndentLevel = 0, + stack = {} + } + return encode(val, state) +end + + +------------------------------------------------------------------------------- +-- Decode +------------------------------------------------------------------------------- + +local parse + +local function create_set(...) + local res = {} + for i = 1, select("#", ...) do + res[ select(i, ...) ] = true + end + return res +end + +local space_chars = create_set(" ", "\t", "\r", "\n") +local delim_chars = create_set(" ", "\t", "\r", "\n", "]", "}", ",") +local escape_chars = create_set("\\", "/", '"', "b", "f", "n", "r", "t", "u") +local literals = create_set("true", "false", "null") + +local literal_map = { + [ "true" ] = true, + [ "false" ] = false, + [ "null" ] = nil, +} + + +local function next_char(str, idx, set, negate) + for i = idx, #str do + if set[str:sub(i, i)] ~= negate then + return i + end + end + return #str + 1 +end + + +local function decode_error(str, idx, msg) + local line_count = 1 + local col_count = 1 + for i = 1, idx - 1 do + col_count = col_count + 1 + if str:sub(i, i) == "\n" then + line_count = line_count + 1 + col_count = 1 + end + end + error( string.format("%s at line %d col %d", msg, line_count, col_count) ) +end + + +local function codepoint_to_utf8(n) + -- http://scripts.sil.org/cms/scripts/page.php?site_id=nrsi&id=iws-appendixa + local f = math.floor + if n <= 0x7f then + return string.char(n) + elseif n <= 0x7ff then + return string.char(f(n / 64) + 192, n % 64 + 128) + elseif n <= 0xffff then + return string.char(f(n / 4096) + 224, f(n % 4096 / 64) + 128, n % 64 + 128) + elseif n <= 0x10ffff then + return string.char(f(n / 262144) + 240, f(n % 262144 / 4096) + 128, + f(n % 4096 / 64) + 128, n % 64 + 128) + end + error( string.format("invalid unicode codepoint '%x'", n) ) +end + + +local function parse_unicode_escape(s) + local n1 = tonumber( s:sub(3, 6), 16 ) + local n2 = tonumber( s:sub(9, 12), 16 ) + -- Surrogate pair? + if n2 then + return codepoint_to_utf8((n1 - 0xd800) * 0x400 + (n2 - 0xdc00) + 0x10000) + else + return codepoint_to_utf8(n1) + end +end + + +local function parse_string(str, i) + local has_unicode_escape = false + local has_surrogate_escape = false + local has_escape = false + local last + for j = i + 1, #str do + local x = str:byte(j) + + if x < 32 then + decode_error(str, j, "control character in string") + end + + if last == 92 then -- "\\" (escape char) + if x == 117 then -- "u" (unicode escape sequence) + local hex = str:sub(j + 1, j + 5) + if not hex:find("%x%x%x%x") then + decode_error(str, j, "invalid unicode escape in string") + end + if hex:find("^[dD][89aAbB]") then + has_surrogate_escape = true + else + has_unicode_escape = true + end + else + local c = string.char(x) + if not escape_chars[c] then + decode_error(str, j, "invalid escape char '" .. c .. "' in string") + end + has_escape = true + end + last = nil + + elseif x == 34 then -- '"' (end of string) + local s = str:sub(i + 1, j - 1) + if has_surrogate_escape then + s = s:gsub("\\u[dD][89aAbB]..\\u....", parse_unicode_escape) + end + if has_unicode_escape then + s = s:gsub("\\u....", parse_unicode_escape) + end + if has_escape then + s = s:gsub("\\.", escape_char_map_inv) + end + return s, j + 1 + + else + last = x + end + end + decode_error(str, i, "expected closing quote for string") +end + + +local function parse_number(str, i) + local x = next_char(str, i, delim_chars) + local s = str:sub(i, x - 1) + local n = tonumber(s) + if not n then + decode_error(str, i, "invalid number '" .. s .. "'") + end + return n, x +end + + +local function parse_literal(str, i) + local x = next_char(str, i, delim_chars) + local word = str:sub(i, x - 1) + if not literals[word] then + decode_error(str, i, "invalid literal '" .. word .. "'") + end + return literal_map[word], x +end + + +local function parse_array(str, i) + local res = {} + local n = 1 + i = i + 1 + while 1 do + local x + i = next_char(str, i, space_chars, true) + -- Empty / end of array? + if str:sub(i, i) == "]" then + i = i + 1 + break + end + -- Read token + x, i = parse(str, i) + res[n] = x + n = n + 1 + -- Next token + i = next_char(str, i, space_chars, true) + local chr = str:sub(i, i) + i = i + 1 + if chr == "]" then break end + if chr ~= "," then decode_error(str, i, "expected ']' or ','") end + end + return res, i +end + + +local function parse_object(str, i) + local res = {} + i = i + 1 + while 1 do + local key, val + i = next_char(str, i, space_chars, true) + -- Empty / end of object? + if str:sub(i, i) == "}" then + i = i + 1 + break + end + -- Read key + if str:sub(i, i) ~= '"' then + decode_error(str, i, "expected string for key") + end + key, i = parse(str, i) + -- Read ':' delimiter + i = next_char(str, i, space_chars, true) + if str:sub(i, i) ~= ":" then + decode_error(str, i, "expected ':' after key") + end + i = next_char(str, i + 1, space_chars, true) + -- Read value + val, i = parse(str, i) + -- Set + res[key] = val + -- Next token + i = next_char(str, i, space_chars, true) + local chr = str:sub(i, i) + i = i + 1 + if chr == "}" then break end + if chr ~= "," then decode_error(str, i, "expected '}' or ','") end + end + return res, i +end + + +local char_func_map = { + [ '"' ] = parse_string, + [ "0" ] = parse_number, + [ "1" ] = parse_number, + [ "2" ] = parse_number, + [ "3" ] = parse_number, + [ "4" ] = parse_number, + [ "5" ] = parse_number, + [ "6" ] = parse_number, + [ "7" ] = parse_number, + [ "8" ] = parse_number, + [ "9" ] = parse_number, + [ "-" ] = parse_number, + [ "t" ] = parse_literal, + [ "f" ] = parse_literal, + [ "n" ] = parse_literal, + [ "[" ] = parse_array, + [ "{" ] = parse_object, +} + + +parse = function(str, idx) + local chr = str:sub(idx, idx) + local f = char_func_map[chr] + if f then + return f(str, idx) + end + decode_error(str, idx, "unexpected character '" .. chr .. "'") +end + + +function json.decode(str) + if type(str) ~= "string" then + error("expected argument of type string, got " .. type(str)) + end + local res, idx = parse(str, next_char(str, 1, space_chars, true)) + idx = next_char(str, idx, space_chars, true) + if idx <= #str then + decode_error(str, idx, "trailing garbage") + end + return res +end \ No newline at end of file diff --git a/data/talkactions/scripts/httpclient.lua b/data/talkactions/scripts/httpclient.lua new file mode 100644 index 0000000000..81c60affb6 --- /dev/null +++ b/data/talkactions/scripts/httpclient.lua @@ -0,0 +1,159 @@ +local function isJson(contentType) + return contentType and string.find(contentType:lower(), "json") +end + +local function parseJson(bodyData) + if not bodyData then + print("parseJson: bodyData is invalid") + return nil + end + + local status, data = + pcall( + function() + return json.decode(bodyData) + end + ) + + if not status then + print("parseJson: failed to parse json data") + return nil + end + + if not data then + print("parseJson: data is invalid") + return nil + end + + return data +end + +local function httpCallback(param) + + if isJson(param["contentType"]) then + jsonData = parseJson(param["bodyData"]) + + if type(jsonData) == "table" then + print(dump(jsonData)) + else + print(jsonData) + end + else + if type(param) == "table" then + print(dump(param)) + else + print(param) + end + end + + print(" I am http callback being called") +end + +local function connect(httpClientRequest) + local url = "https://www.example.com" + httpClientRequest:connect(url, httpCallback) +end + +local function trace(httpClientRequest) + local url = "https://www.example.com" + httpClientRequest:trace(url, httpCallback) +end + +local function options(httpClientRequest) + local url = "http://httpbin.org" + httpClientRequest:options(url, httpCallback) +end + +local function head(httpClientRequest) + local url = "http://httpbin.org/get" + httpClientRequest:head(url, httpCallback) +end + +local function delete(httpClientRequest) + local url = "https://httpbin.org/delete" + local token = "abcdef123456789" + local data = "I want to delete something" + local fields = { + accept = "application/json", + authorization = "Bearer " .. token + } + + httpClientRequest:delete(url, httpCallback, fields, data) +end + +local function get(httpClientRequest) + local url = "https://httpbin.org/get" + local token = "abcdef123456789" + local fields = { + accept = "application/json", + authorization = "Bearer " .. token + } + + httpClientRequest:get(url, httpCallback, fields) +end + +local function post(httpClientRequest) + local url = "https://httpbin.org/post" + local token = "abcdef123456789" + local data = "I want to post something" + local fields = { + accept = "application/json", + authorization = "Bearer " .. token + } + + httpClientRequest:post(url, httpCallback, fields, data) +end + +local function patch(httpClientRequest) + local url = "https://httpbin.org/patch" + local token = "abcdef123456789" + local data = "I want to patch something" + local fields = { + accept = "application/json", + authorization = "Bearer " .. token + } + + httpClientRequest:patch(url, httpCallback, fields, data) +end + +local function put(httpClientRequest) + local url = "https://httpbin.org/patch" + local token = "abcdef123456789" + local data = "I want to put something" + local fields = { + accept = "application/json", + authorization = "Bearer " .. token + } + + httpClientRequest:put(url, httpCallback, fields, data) +end + +function onSay(player, words, param) + local httpClientRequest = HttpClientRequest() + local httpmethod = param + + if httpmethod == 'connect' then + connect(httpClientRequest) + elseif httpmethod == 'trace' then + trace(httpClientRequest) + elseif httpmethod == 'options' then + options(httpClientRequest) + elseif httpmethod == 'head' then + head(httpClientRequest) + elseif httpmethod == 'delete' then + delete(httpClientRequest) + elseif httpmethod == 'get' then + get(httpClientRequest) + elseif httpmethod == 'post' then + post(httpClientRequest) + elseif httpmethod == 'patch' then + patch(httpClientRequest) + elseif httpmethod == 'put' then + put(httpClientRequest) + else + player:sendTextMessage(MESSAGE_STATUS_CONSOLE_BLUE, string.format("Invalid http method. Available: connect, trace, options, head, delete, get, post, patch and put", words)) + return false + end + + return true +end diff --git a/data/talkactions/talkactions.xml b/data/talkactions/talkactions.xml index 11469b48d4..eca80cd8e1 100644 --- a/data/talkactions/talkactions.xml +++ b/data/talkactions/talkactions.xml @@ -34,6 +34,7 @@ + diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index e5bedaa438..2e8436c494 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -26,6 +26,7 @@ set(tfs_SRC ${CMAKE_CURRENT_LIST_DIR}/guild.cpp ${CMAKE_CURRENT_LIST_DIR}/house.cpp ${CMAKE_CURRENT_LIST_DIR}/housetile.cpp + ${CMAKE_CURRENT_LIST_DIR}/httpclient.cpp ${CMAKE_CURRENT_LIST_DIR}/inbox.cpp ${CMAKE_CURRENT_LIST_DIR}/iologindata.cpp ${CMAKE_CURRENT_LIST_DIR}/iomap.cpp @@ -109,6 +110,8 @@ set(tfs_HDR ${CMAKE_CURRENT_LIST_DIR}/guild.h ${CMAKE_CURRENT_LIST_DIR}/house.h ${CMAKE_CURRENT_LIST_DIR}/housetile.h + ${CMAKE_CURRENT_LIST_DIR}/httpclient.h + ${CMAKE_CURRENT_LIST_DIR}/httpclientlib.h ${CMAKE_CURRENT_LIST_DIR}/inbox.h ${CMAKE_CURRENT_LIST_DIR}/iologindata.h ${CMAKE_CURRENT_LIST_DIR}/iomap.h @@ -179,6 +182,7 @@ target_link_libraries(tfslib PRIVATE ${Crypto++_LIBRARIES} ${LUA_LIBRARIES} ${MYSQL_CLIENT_LIBS} + ${OPENSSL_LIBRARIES} ) add_custom_target(format COMMAND /usr/bin/clang-format -style=file -i ${tfs_HDR} ${tfs_SRC} ${tfs_MAIN}) diff --git a/src/game.cpp b/src/game.cpp index 6de690c024..31d2a9f4b5 100644 --- a/src/game.cpp +++ b/src/game.cpp @@ -15,6 +15,7 @@ #include "events.h" #include "globalevent.h" #include "housetile.h" +#include "httpclient.h" #include "inbox.h" #include "iologindata.h" #include "iomarket.h" @@ -132,6 +133,7 @@ void Game::setGameState(GameState_t newState) g_scheduler.stop(); g_databaseTasks.stop(); g_dispatcher.stop(); + g_http.stop(); break; } @@ -4926,6 +4928,7 @@ void Game::shutdown() g_scheduler.shutdown(); g_databaseTasks.shutdown(); g_dispatcher.shutdown(); + g_http.shutdown(); map.spawns.clear(); raids.clear(); diff --git a/src/httpclient.cpp b/src/httpclient.cpp new file mode 100644 index 0000000000..ecd84cd97b --- /dev/null +++ b/src/httpclient.cpp @@ -0,0 +1,252 @@ +// Copyright 2023 The Forgotten Server Authors.All rights reserved. +// Use of this source code is governed by the GPL-2.0 License that can be found +// in the LICENSE file. + +#include "otpch.h" + +#include "httpclient.h" + +#include "httpclientlib.h" + +#include "tasks.h" +extern Dispatcher g_dispatcher; + +void HttpClient::threadMain() +{ + HttpClientLib::Request requestsHandler( + [this](const HttpClientLib::HttpResponse_ptr& response) { clientRequestSuccessCallback(response); }, + [this](const HttpClientLib::HttpResponse_ptr& response) { clientRequestFailureCallback(response); }); + + std::unique_lock requestLockUnique(requestLock, std::defer_lock); + while (getState() != THREAD_STATE_TERMINATED) { + requestLockUnique.lock(); + + if (pendingRequests.empty() && pendingResponses.empty()) { + requestSignal.wait(requestLockUnique); + } + + if (!pendingRequests.empty() || !pendingResponses.empty()) { + bool shouldUnlock = false; + + if (!pendingRequests.empty()) { + HttpClientLib::HttpRequest_ptr pendingRequest = std::move(pendingRequests.front()); + pendingRequests.pop_front(); + + shouldUnlock = true; + dispatchRequest(requestsHandler, pendingRequest); + } + + if (!pendingResponses.empty()) { + HttpClientLib::HttpResponse_ptr pendingResponse = std::move(pendingResponses.front()); + pendingResponses.pop_front(); + + requestLockUnique.unlock(); + shouldUnlock = false; + processResponse(pendingResponse); + } + + if (shouldUnlock) { + requestLockUnique.unlock(); + } + } else { + requestLockUnique.unlock(); + } + } +} + +void HttpClient::dispatchRequest(HttpClientLib::Request& requestsHandler, HttpClientLib::HttpRequest_ptr& request) +{ + bool succesfullyDispatched = false; + switch (request->method) { + case HttpClientLib::HttpMethod::HTTP_CONNECT: + requestsHandler.setTimeout(request->timeout); + succesfullyDispatched = requestsHandler.connect(request->url, request->fields); + break; + + case HttpClientLib::HttpMethod::HTTP_TRACE: + requestsHandler.setTimeout(request->timeout); + succesfullyDispatched = requestsHandler.trace(request->url, request->fields); + break; + + case HttpClientLib::HttpMethod::HTTP_OPTIONS: + requestsHandler.setTimeout(request->timeout); + succesfullyDispatched = requestsHandler.options(request->url, request->fields); + break; + + case HttpClientLib::HttpMethod::HTTP_HEAD: + requestsHandler.setTimeout(request->timeout); + succesfullyDispatched = requestsHandler.head(request->url, request->fields); + break; + + case HttpClientLib::HttpMethod::HTTP_DELETE: + requestsHandler.setTimeout(request->timeout); + succesfullyDispatched = requestsHandler.delete_(request->url, request->fields); + break; + + case HttpClientLib::HttpMethod::HTTP_GET: + requestsHandler.setTimeout(request->timeout); + succesfullyDispatched = requestsHandler.get(request->url, request->fields); + break; + + case HttpClientLib::HttpMethod::HTTP_POST: + requestsHandler.setTimeout(request->timeout); + succesfullyDispatched = requestsHandler.post(request->url, request->data, request->fields); + break; + + case HttpClientLib::HttpMethod::HTTP_PATCH: + requestsHandler.setTimeout(request->timeout); + succesfullyDispatched = requestsHandler.patch(request->url, request->data, request->fields); + break; + + case HttpClientLib::HttpMethod::HTTP_PUT: + requestsHandler.setTimeout(request->timeout); + succesfullyDispatched = requestsHandler.put(request->url, request->data, request->fields); + break; + + case HttpClientLib::HttpMethod::HTTP_NONE: + default: + break; + } + + if (request->method != HttpClientLib::HTTP_NONE && succesfullyDispatched) { + requests.emplace(std::make_pair(requestsHandler.getRequestId(), std::move(request))); + } +} + +void HttpClient::clientRequestSuccessCallback(const HttpClientLib::HttpResponse_ptr& response) +{ + // std::cout << std::string("HTTP Response received: " + std::to_string(response->statusCode) + " (" + + // std::to_string(response->responseTimeMs) + "ms) id " + std::to_string(response->requestId)) << std::endl; + addResponse(response); + + std::string headerStr(reinterpret_cast(response->headerData.data()), response->headerData.size()); + std::string bodyStr(reinterpret_cast(response->bodyData.data()), response->bodyData.size()); + + // Print the string to the console + // std::cout << headerStr << std::endl; + // std::cout << bodyStr << std::endl; +} + +void HttpClient::clientRequestFailureCallback(const HttpClientLib::HttpResponse_ptr& response) +{ + std::cout << std::string("HTTP Response failed (" + response->errorMessage + ")") << std::endl; + addResponse(response); +} + +void HttpClient::processResponse(const HttpClientLib::HttpResponse_ptr& response) +{ + auto httpRequestIt = requests.find(response->requestId); + if (httpRequestIt == requests.end()) { + return; + } + + HttpClientLib::HttpRequest_ptr& httpRequest = httpRequestIt->second; + + if (httpRequest->callbackData.isLuaCallback()) { + luaClientRequestCallback(httpRequest->callbackData); + } + + if (httpRequest->callbackData.callbackFunction) { + g_dispatcher.addTask(createTask(std::bind(httpRequest->callbackData.callbackFunction, response))); + } + + requests.erase(response->requestId); +} + +void HttpClient::luaClientRequestCallback(HttpClientLib::HttpRequestCallbackData &callbackData) +{ + LuaScriptInterface *scriptInterface = callbackData.scriptInterface; + int32_t callbackId = callbackData.callbackId; + + lua_State *luaState = scriptInterface->getLuaState(); + if (!luaState) { + return; + } + + if (callbackId > 0) { + callbackData.callbackFunction = [callbackId, scriptInterface](const HttpClientLib::HttpResponse_ptr &response) { + lua_State *luaState = scriptInterface->getLuaState(); + if (!luaState) { + return; + } + + if (!LuaScriptInterface::reserveScriptEnv()) { + luaL_unref(luaState, LUA_REGISTRYINDEX, callbackId); + std::cout << "[Error - HttpClient::luaClientRequestCallback] Call stack overflow" << std::endl; + return; + } + + // push function + lua_rawgeti(luaState, LUA_REGISTRYINDEX, callbackId); + + // push parameters + lua_createtable(luaState, 0, 11); + + LuaScriptInterface::setField(luaState, "requestId", response->requestId); + LuaScriptInterface::setField(luaState, "version", response->version); + LuaScriptInterface::setField(luaState, "statusCode", response->statusCode); + LuaScriptInterface::setField(luaState, "location", response->location); + LuaScriptInterface::setField(luaState, "contentType", response->contentType); + LuaScriptInterface::setField(luaState, "responseTimeMs", response->responseTimeMs); + LuaScriptInterface::setField(luaState, "headerData", response->headerData); + LuaScriptInterface::setField(luaState, "bodySize", response->bodySize); + LuaScriptInterface::setField(luaState, "bodyData", response->bodyData); + LuaScriptInterface::setField(luaState, "success", response->success); + LuaScriptInterface::setField(luaState, "errorMessage", response->errorMessage); + + LuaScriptInterface::setMetatable(luaState, -1, "HttpResponse"); + + int parameter = luaL_ref(luaState, LUA_REGISTRYINDEX); + lua_rawgeti(luaState, LUA_REGISTRYINDEX, parameter); + + ScriptEnvironment *env = scriptInterface->getScriptEnv(); + auto scriptId = env->getScriptId(); + env->setScriptId(scriptId, scriptInterface); + + scriptInterface->callFunction(1); // callFunction already reset the reserved + // script env (resetScriptEnv) + + // free resources + luaL_unref(luaState, LUA_REGISTRYINDEX, callbackId); + luaL_unref(luaState, LUA_REGISTRYINDEX, parameter); + }; + } +} + +void HttpClient::addResponse(const HttpClientLib::HttpResponse_ptr& response) +{ + bool signal = false; + requestLock.lock(); + if (getState() == THREAD_STATE_RUNNING) { + signal = pendingResponses.empty(); + pendingResponses.emplace_back(response); + } + requestLock.unlock(); + + if (signal) { + requestSignal.notify_one(); + } +} + +void HttpClient::addRequest(const HttpClientLib::HttpRequest_ptr& request) +{ + bool signal = false; + requestLock.lock(); + if (getState() == THREAD_STATE_RUNNING) { + signal = pendingRequests.empty(); + pendingRequests.emplace_back(request); + } + requestLock.unlock(); + + if (signal) { + requestSignal.notify_one(); + } +} + +void HttpClient::shutdown() +{ + requestLock.lock(); + setState(THREAD_STATE_TERMINATED); + requestLock.unlock(); + requestSignal.notify_one(); +} diff --git a/src/httpclient.h b/src/httpclient.h new file mode 100644 index 0000000000..17b6173025 --- /dev/null +++ b/src/httpclient.h @@ -0,0 +1,89 @@ +// Copyright 2023 The Forgotten Server Authors. All rights reserved. +// Use of this source code is governed by the GPL-2.0 License that can be found +// in the LICENSE file. + +#ifndef HTTP_H +#define HTTP_H + +#include "luascript.h" +#include "thread_holder_base.h" + +namespace HttpClientLib { +class Request; +class HttpResponse; +using HttpResponse_ptr = std::shared_ptr; + +enum HttpMethod +{ + HTTP_NONE, + HTTP_CONNECT, + HTTP_TRACE, + HTTP_OPTIONS, + HTTP_HEAD, + HTTP_DELETE, + HTTP_GET, + HTTP_POST, + HTTP_PATCH, + HTTP_PUT +}; + +class HttpRequestCallbackData +{ +public: + std::function callbackFunction; + int32_t scriptId = -1; + int32_t callbackId = -1; + LuaScriptInterface* scriptInterface; + + bool isLuaCallback() { return scriptId != -1 && callbackId != -1; } +}; + +class HttpRequest +{ +public: + HttpMethod method; + std::string url; + std::string data; + std::unordered_map fields; + + uint32_t timeout = 0; + HttpRequestCallbackData callbackData; + + HttpRequest() {} +}; + +using HttpRequest_ptr = std::shared_ptr; +} // namespace HttpClientLib + +class HttpClient : public ThreadHolder +{ +public: + HttpClient() {} + void threadMain(); + void shutdown(); + + void addRequest(const HttpClientLib::HttpRequest_ptr& request); + +private: + void clientRequestSuccessCallback(const HttpClientLib::HttpResponse_ptr& response); + void luaClientRequestCallback(HttpClientLib::HttpRequestCallbackData &callbackData); + + void clientRequestFailureCallback(const HttpClientLib::HttpResponse_ptr& response); + + void dispatchRequest(HttpClientLib::Request& requestsHandler, HttpClientLib::HttpRequest_ptr& request); + void processResponse(const HttpClientLib::HttpResponse_ptr& response); + + void addResponse(const HttpClientLib::HttpResponse_ptr& response); + + std::list pendingRequests; + std::list pendingResponses; + + std::mutex requestLock; + std::condition_variable requestSignal; + + std::map requests; +}; + +extern HttpClient g_http; + +#endif diff --git a/src/httpclientlib.h b/src/httpclientlib.h new file mode 100644 index 0000000000..75777a13b7 --- /dev/null +++ b/src/httpclientlib.h @@ -0,0 +1,924 @@ +/**************************************************************************** + * + * Copyright (c) 2023, Danilo Pucci + * + * This file is part of the HTTP Client header-only library. + * + * Source Code: + * https://github.com/danilopucci/httpclient/ + * + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, including + * without limitation the rights to use, copy, modify, merge, publish, + * distribute, sublicense, and/or sell copies of the Software, and to + * permit persons to whom the Software is furnished to do so,subject to the + * following conditions: + * + * + The above copyright notice and this permission notice shall be included + * in all copies or substantial portions of the Software. + * + Credit is appreciated, but not required, if you find this project + * useful enough to include in your application, product, device, etc. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + * + ***************************************************************************/ + +#ifndef HTTPCLIENTLIB_H +#define HTTPCLIENTLIB_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// Disable old-style-cast warnings for this file +#if defined(__clang__) || defined(__GNUC__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wold-style-cast" +#endif + +namespace HttpClientLib { +class HttpConnection; +class HttpResponse; + +using HttpResponse_ptr = std::shared_ptr; +using HttpResponse_cb = std::function; +using HttpFailure_cb = std::function; + +class HttpUrl +{ +public: + HttpUrl(const std::string& url_) : url(url_) + { + parseUrl(url); + }; + + bool isValid() { return valid; } + + bool isProtocolSecure() const { return secure; } + + std::string url; + std::string target; + + std::string protocol; + std::string host; + int port; + std::string path; + std::string query; + std::string fragment; + +private: + void parseUrl(const std::string& url) + { + static const std::regex urlRegex(R"(^(https?:\/\/)?([^\/:]+)(:\d+)?(\/.*)?$)"); + valid = false; + + std::smatch matches; + if (std::regex_match(url, matches, urlRegex)) { + auto& scheme = matches[1]; + auto& host = matches[2]; + auto& port = matches[3]; + auto& arguments = matches[4]; + + setProtocol(scheme); + setHost(host); + setPort(port); + + if (arguments.matched) { + target = arguments.str(); + + parsePath(arguments.str()); + parseQuery(arguments.str()); + parseFragment(arguments.str()); + } else { + target = "/"; + } + + valid = true; + } + } + + void parsePath(const std::string& arguments) + { + static const std::regex pathRegex(R"(/([^?#]*))"); + std::smatch match; + if (std::regex_search(arguments, match, pathRegex)) { + setPath(match[1]); + } + } + + void parseQuery(const std::string& arguments) + { + static const std::regex queryRegex(R"(\?([^#]*))"); + std::smatch match; + if (std::regex_search(arguments, match, queryRegex)) { + setQuery(match[1]); + } + } + + void parseFragment(const std::string& arguments) + { + static const std::regex fragmentRegex(R"(#(.*))"); + std::smatch match; + if (std::regex_search(arguments, match, fragmentRegex)) { + setFragment(match[1]); + } + } + + void setProtocol(const std::ssub_match& match) + { + if (match.matched) { + protocol = match.str(); + boost::algorithm::to_lower(protocol); + } else { + protocol = "http://"; + } + + secure = protocol == "https://"; + } + + void setHost(const std::ssub_match& match) + { + if (match.matched) { + host = match.str(); + } + } + + void setPort(const std::ssub_match& match) + { + if (match.matched) { + port = match.str().empty() ? 0 : std::stoi(match.str().substr(1)); + } else { + if (protocol.find("https://") != std::string::npos) { + port = 443; + } else if (protocol.find("http://") != std::string::npos) { + port = 80; + } + } + } + + void setPath(const std::ssub_match& match) + { + if (match.matched) { + path = match.str(); + } + } + + void setQuery(const std::ssub_match& match) + { + if (match.matched) { + query = match.str(); + } + } + + void setFragment(const std::ssub_match& match) + { + if (match.matched) { + fragment = match.str(); + } + } + +private: + bool valid; + bool secure; +}; + +class HttpResponse +{ +public: + uint32_t requestId; + int version; + int statusCode; + std::string location; + std::string contentType; + uint32_t responseTimeMs; + + std::string headerData; + + size_t bodySize; + std::string bodyData; + + bool success; + std::string errorMessage; + +private: + void buildHeaderData(const boost::beast::http::response_parser& response) + { + auto& responseHeader = response.get(); + statusCode = responseHeader.result_int(); + version = responseHeader.version(); + location = std::string(responseHeader[boost::beast::http::field::location]); + contentType = std::string(responseHeader[boost::beast::http::field::content_type]); + + auto headers = responseHeader.base(); + for (const auto& header : headers) { + boost::beast::string_view headerName = header.name_string(); + boost::beast::string_view headerValue = header.value(); + + std::string headerString = std::string(headerName) + ": " + std::string(headerValue) + "\n"; + headerData.insert(headerData.end(), headerString.begin(), headerString.end()); + } + + bodySize = 0; + if (responseHeader.has_content_length()) { + bodySize = std::stoul(std::string(responseHeader[boost::beast::http::field::content_length])); + } + } + + inline void buildBodyData(const boost::beast::http::response_parser& response) + { + bodyData = boost::beast::buffers_to_string(response.get().body().data()); + } + + void setResponseTime(uint32_t responseTimeMs_) { responseTimeMs = responseTimeMs_; } + + void setRequestId(uint32_t requestId_) { requestId = requestId_; } + + friend class HttpConnectionBase; + friend class HttpConnection; + friend class HttpsConnection; +}; + +class HttpConnectionBase : public std::enable_shared_from_this +{ +public: + HttpConnectionBase(boost::asio::any_io_executor executor, uint32_t id, HttpResponse_cb responseCallback, + HttpFailure_cb failureCallback) : + resolver(executor), + stream(executor), + timer(executor), + id(id), + responseData(std::make_shared()), + responseCallback(responseCallback), + failureCallback(failureCallback) + { + setTimeout(30000); + responseData->setRequestId(id); + } + + virtual ~HttpConnectionBase() {} + + inline void setTimeout(int timeout_) { timeout = timeout_; } + + virtual void create(const boost::beast::http::request& request_, + const std::string& url, uint32_t port, bool skipBody = false) + { + timer.expires_after(std::chrono::milliseconds(timeout)); + timer.async_wait(std::bind(&HttpConnectionBase::onTimeout, shared_from_this(), std::placeholders::_1)); + + request = request_; + connectionStart = std::chrono::steady_clock::now(); + response.skip(skipBody); + resolve(url, port); + } + + virtual void resolve(const std::string& url, uint32_t port) + { + resolver.async_resolve(url, std::to_string(port), + std::bind(&HttpConnectionBase::onResolve, this->shared_from_this(), + std::placeholders::_1, std::placeholders::_2)); + } + + virtual void onResolve(const boost::system::error_code& error, boost::asio::ip::tcp::resolver::results_type results) + { + if (!error) { + connect(results); + } else { + onError("Failed to resolve to HTTP address: " + error.message()); + } + } + + virtual void connect(const boost::asio::ip::tcp::resolver::results_type& results) + { + boost::asio::async_connect( + stream, results.begin(), results.end(), + std::bind(&HttpConnectionBase::onConnect, this->shared_from_this(), std::placeholders::_1)); + } + + virtual void onConnect(const boost::system::error_code& error) + { + if (!error) { + writeRequest(); + } else { + onError("Failed to connect to HTTP socket: " + error.message()); + } + } + + virtual void writeRequest() + { + boost::beast::http::async_write( + stream, request, + std::bind(&HttpConnectionBase::onRequestWrite, this->shared_from_this(), std::placeholders::_1)); + } + + virtual void onRequestWrite(const boost::beast::error_code& error) + { + if (!error) { + readHeader(); + } else { + close(); + onError("Failed to write HTTP request: " + error.message()); + } + } + + virtual void readHeader() + { + buffer.max_size(MAX_HEADER_CHUNCK_SIZE); + boost::beast::http::async_read_header( + stream, buffer, response, + std::bind(&HttpConnectionBase::onReadHeader, this->shared_from_this(), std::placeholders::_1)); + } + + virtual void onReadHeader(const boost::beast::error_code& error) + { + if (!error || response.is_header_done()) { + responseData->setRequestId(id); + responseData->buildHeaderData(response); + readBody(); + } else { + close(); + onError("Failed to read HTTP header: " + error.message()); + } + } + + virtual void readBody() + { + timer.expires_after(std::chrono::milliseconds(timeout)); + buffer.max_size(MAX_BODY_CHUNCK_SIZE); + boost::beast::http::async_read_some( + stream, buffer, response, + std::bind(&HttpConnectionBase::onReadBody, this->shared_from_this(), std::placeholders::_1)); + } + + virtual void onReadBody(const boost::beast::error_code& error) + { + if (error && error != boost::beast::http::error::end_of_stream) { + close(); + onError("Failed to read HTTP body: " + error.message()); + return; + } + + if (error == boost::beast::http::error::end_of_stream || response.is_done()) { + responseData->setResponseTime(calculateResponseTime()); + responseData->buildBodyData(response); + responseData->success = true; + onSuccess(responseData); + + close(); + return; + } + + readBody(); + } + + virtual void close() + { + timer.cancel(); + boost::system::error_code ec; + stream.close(ec); + } + + virtual void onShutdown() {} + + void onTimeout(const boost::system::error_code& error) + { + if (!error) { + close(); + onError("Failed on HTTP: timeout"); + } + } + + void inline onError(const std::string& reason) + { + if (failureCallback) { + if (!responseData) { + responseData = std::make_shared(); + } + + responseData->success = false; + responseData->errorMessage = reason; + + failureCallback(responseData); + } + } + + void inline onSuccess(const HttpResponse_ptr& responseData) + { + if (responseCallback) { + responseCallback(responseData); + } + } + +protected: + boost::asio::ip::tcp::resolver resolver; + boost::asio::ip::tcp::socket stream; + + boost::asio::steady_timer timer; + + int timeout; + uint32_t id; + + boost::beast::flat_buffer buffer; + boost::beast::http::request request; + boost::beast::http::response_parser response; + + std::chrono::steady_clock::time_point connectionStart; + + const int MAX_HEADER_CHUNCK_SIZE = 8 * 1024; + const int MAX_BODY_CHUNCK_SIZE = 64 * 1024; + + HttpResponse_ptr responseData; + HttpResponse_cb responseCallback; + HttpFailure_cb failureCallback; + + uint32_t calculateResponseTime() + { + std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now(); + std::chrono::duration duration = end - connectionStart; + return static_cast(duration.count()); + } +}; + +class HttpConnection : public HttpConnectionBase +{ +public: + HttpConnection(boost::asio::io_context& ioContext, uint32_t id, HttpResponse_cb responseCallback, + HttpFailure_cb failureCallback) : + HttpConnectionBase(boost::asio::make_strand(ioContext), id, responseCallback, failureCallback) + {} +}; + +class HttpsConnection : public HttpConnectionBase +{ +public: + HttpsConnection(boost::asio::io_context& ioContext, uint32_t id, boost::asio::ssl::context& sslContext, + HttpResponse_cb responseCallback, HttpFailure_cb failureCallback) : + HttpConnectionBase(boost::asio::make_strand(ioContext), id, responseCallback, failureCallback), + sslStream(stream, sslContext) + {} + + void create(const boost::beast::http::request& request_, const std::string& url, + uint32_t port, bool skipBody = false) override + { + sslStream.set_verify_mode(boost::asio::ssl::verify_peer); + sslStream.set_verify_callback([](bool, boost::asio::ssl::verify_context&) { return true; }); + + if (!SSL_set_tlsext_host_name(sslStream.native_handle(), url.c_str())) { + boost::beast::error_code ec2(static_cast(::ERR_get_error()), boost::asio::error::get_ssl_category()); + onError("HTTPS error" + ec2.message()); + return; + } + + timer.expires_after(std::chrono::milliseconds(timeout)); + timer.async_wait(std::bind(&HttpConnectionBase::onTimeout, shared_from_this(), std::placeholders::_1)); + + request = request_; + connectionStart = std::chrono::steady_clock::now(); + response.skip(skipBody); + resolve(url, port); + } + + void onConnect(const boost::system::error_code& error) override + { + if (!error) { + handshake(); + } else { + onError("Failed to connect to HTTP socket: " + error.message()); + } + } + + void handshake() + { + auto self(shared_from_this()); + sslStream.async_handshake(boost::asio::ssl::stream_base::client, + [&, self](const boost::system::error_code& error) { + if (!error) { + writeRequest(); + } else { + onError("Failed SSL handshake: " + error.message()); + } + }); + } + + void writeRequest() override + { + boost::beast::http::async_write( + sslStream, request, + std::bind(&HttpConnectionBase::onRequestWrite, this->shared_from_this(), std::placeholders::_1)); + } + + void readHeader() override + { + buffer.max_size(MAX_HEADER_CHUNCK_SIZE); + boost::beast::http::async_read_header( + sslStream, buffer, response, + std::bind(&HttpConnectionBase::onReadHeader, this->shared_from_this(), std::placeholders::_1)); + } + + void readBody() override + { + timer.expires_after(std::chrono::milliseconds(timeout)); + buffer.max_size(MAX_BODY_CHUNCK_SIZE); + boost::beast::http::async_read_some( + sslStream, buffer, response, + std::bind(&HttpConnectionBase::onReadBody, this->shared_from_this(), std::placeholders::_1)); + } + + void close() override + { + timer.cancel(); + sslStream.async_shutdown(std::bind(&HttpConnectionBase::onShutdown, this->shared_from_this())); + } + + void onShutdown() override + { + boost::system::error_code ec; + stream.close(ec); + } + +private: + boost::asio::ssl::stream sslStream; +}; + +class Request +{ +public: + Request() : context(), guard(boost::asio::make_work_guard(context)), requestId(1) + { + thread = std::thread([this]() { context.run(); }); + } + + Request(const HttpResponse_cb& responseCallback) : + context(), guard(boost::asio::make_work_guard(context)), requestId(1), responseCallback(responseCallback) + { + thread = std::thread([this]() { context.run(); }); + } + + Request(const HttpResponse_cb& responseCallback, const HttpFailure_cb& failureCallback) : + context(), + guard(boost::asio::make_work_guard(context)), + requestId(1), + responseCallback(responseCallback), + failureCallback(failureCallback) + { + thread = std::thread([this]() { context.run(); }); + } + + ~Request() + { + context.stop(); + + guard.reset(); + if (thread.joinable()) { + thread.join(); + } + } + + inline uint32_t getRequestId() { return requestId; } + + inline void setTimeout(uint32_t timeout) { requestTimeoutMs = timeout; } + + bool connect(const std::string& url) { return connect(url, emptyFields); } + + bool connect(const std::string& url, std::unordered_map& fields) + { + HttpUrl httpUrl(url); + + if (!httpUrl.isValid()) { + onError("error during HTTP request CONNECT: invalid URL: " + url); + return false; + } + + try { + boost::beast::http::request request = buildBasicRequest(httpUrl, fields); + setUniqueRequestId(); + request.method(boost::beast::http::verb::connect); + + doRequest(httpUrl, request); + } catch (std::exception& e) { + onError("error during HTTP request CONNECT (" + url + "): " + e.what()); + return false; + } + + return true; + } + + bool trace(const std::string& url) { return trace(url, emptyFields); } + + bool trace(const std::string& url, std::unordered_map& fields) + { + HttpUrl httpUrl(url); + + if (!httpUrl.isValid()) { + onError("error during HTTP request TRACE: invalid URL: " + url); + return false; + } + + try { + boost::beast::http::request request = buildBasicRequest(httpUrl, fields); + setUniqueRequestId(); + request.method(boost::beast::http::verb::trace); + + doRequest(httpUrl, request); + } catch (std::exception& e) { + onError("error during HTTP request TRACE (" + url + "): " + e.what()); + return false; + } + + return true; + } + + bool options(const std::string& url) { return options(url, emptyFields); } + + bool options(const std::string& url, std::unordered_map& fields) + { + HttpUrl httpUrl(url); + + if (!httpUrl.isValid()) { + onError("error during HTTP request OPTIONS: invalid URL: " + url); + return false; + } + + try { + boost::beast::http::request request = buildBasicRequest(httpUrl, fields); + setUniqueRequestId(); + request.method(boost::beast::http::verb::options); + + doRequest(httpUrl, request); + } catch (std::exception& e) { + onError("error during HTTP request OPTIONS (" + url + "): " + e.what()); + return false; + } + + return true; + } + + bool head(const std::string& url) { return head(url, emptyFields); } + + bool head(const std::string& url, std::unordered_map& fields) + { + HttpUrl httpUrl(url); + + if (!httpUrl.isValid()) { + onError("error during HTTP request HEAD: invalid URL: " + url); + return false; + } + + try { + boost::beast::http::request request = buildBasicRequest(httpUrl, fields); + setUniqueRequestId(); + request.method(boost::beast::http::verb::head); + const bool skipBody = true; + + doRequest(httpUrl, request, skipBody); + } catch (std::exception& e) { + onError("error during HTTP request HEAD (" + url + "): " + e.what()); + return false; + } + + return true; + } + + bool delete_(const std::string& url) { return delete_(url, emptyFields); } + + bool delete_(const std::string& url, std::unordered_map& fields) + { + HttpUrl httpUrl(url); + + if (!httpUrl.isValid()) { + onError("error during HTTP request DELETE: invalid URL: " + url); + return false; + } + + try { + boost::beast::http::request request = buildBasicRequest(httpUrl, fields); + setUniqueRequestId(); + request.method(boost::beast::http::verb::delete_); + + doRequest(httpUrl, request); + } catch (std::exception& e) { + onError("error during HTTP request DELETE (" + url + "): " + e.what()); + return false; + } + + return true; + } + + bool get(const std::string& url) { return get(url, emptyFields); } + + bool get(const std::string& url, std::unordered_map& fields) + { + HttpUrl httpUrl(url); + + if (!httpUrl.isValid()) { + onError("error during HTTP request GET: invalid URL: " + url); + return false; + } + + try { + boost::beast::http::request request = buildBasicRequest(httpUrl, fields); + setUniqueRequestId(); + request.method(boost::beast::http::verb::get); + + doRequest(httpUrl, request); + } catch (std::exception& e) { + onError("error during HTTP request GET (" + url + "): " + e.what()); + return false; + } + + return true; + } + + bool post(const std::string& url, const std::string& postData) { return post(url, postData, emptyFields); } + + bool post(const std::string& url, const std::string& postData, std::unordered_map& fields) + { + HttpUrl httpUrl(url); + + if (!httpUrl.isValid()) { + onError("error during HTTP request POST: invalid URL: " + url); + return false; + } + + try { + boost::beast::http::request request = buildBasicRequest(httpUrl, fields); + setUniqueRequestId(); + request.method(boost::beast::http::verb::post); + request.body() = postData; + request.prepare_payload(); + + doRequest(httpUrl, request); + } catch (std::exception& e) { + onError("error during HTTP request POST (" + url + "): " + e.what()); + return false; + } + + return true; + } + + bool patch(const std::string& url, const std::string& patchData) { return patch(url, patchData, emptyFields); } + + bool patch(const std::string& url, const std::string& patchData, + std::unordered_map& fields) + { + HttpUrl httpUrl(url); + + if (!httpUrl.isValid()) { + onError("error during HTTP request PATCH: invalid URL: " + url); + return false; + } + + try { + boost::beast::http::request request = buildBasicRequest(httpUrl, fields); + setUniqueRequestId(); + request.method(boost::beast::http::verb::patch); + request.body() = patchData; + request.prepare_payload(); + + doRequest(httpUrl, request); + } catch (std::exception& e) { + onError("error during HTTP request PATCH (" + url + "): " + e.what()); + return false; + } + + return true; + } + + bool put(const std::string& url, const std::string& putData) { return put(url, putData, emptyFields); } + + bool put(const std::string& url, const std::string& putData, std::unordered_map& fields) + { + HttpUrl httpUrl(url); + + if (!httpUrl.isValid()) { + onError("error during HTTP request PUT: invalid URL: " + url); + return false; + } + + try { + boost::beast::http::request request = buildBasicRequest(httpUrl, fields); + setUniqueRequestId(); + request.method(boost::beast::http::verb::put); + request.body() = putData; + request.prepare_payload(); + + doRequest(httpUrl, request); + } catch (std::exception& e) { + onError("error during HTTP request PUT (" + url + "): " + e.what()); + return false; + } + + return true; + } + +private: + std::thread thread; + boost::asio::io_context context; + boost::asio::executor_work_guard guard; + + std::unordered_map emptyFields; + + uint32_t requestId; + uint32_t requestTimeoutMs = 0; + + HttpResponse_cb responseCallback; + HttpFailure_cb failureCallback; + + boost::beast::http::request buildBasicRequest( + const HttpUrl& httpUrl, std::unordered_map& fields) + { + boost::beast::http::request request; + + request.version(11); + request.prepare_payload(); + request.keep_alive(false); + + request.set(boost::beast::http::field::host, httpUrl.host); + + for (auto& field : fields) { + request.insert(field.first, field.second); + } + + request.target(httpUrl.target); + + return request; + } + + void setUniqueRequestId() { requestId++; } + + void doRequest(const HttpUrl& httpUrl, boost::beast::http::request& request, + bool skipBody = false) + { + std::shared_ptr httpConnection; + + if (httpUrl.isProtocolSecure()) { + boost::asio::ssl::context sslContext{boost::asio::ssl::context::tlsv12_client}; + sslContext.set_default_verify_paths(); + + httpConnection = std::make_shared( + context, requestId, sslContext, + std::bind(&Request::requestSuccessCallback, this, std::placeholders::_1), + std::bind(&Request::requestFailureCallback, this, std::placeholders::_1)); + } else { + httpConnection = std::make_shared( + context, requestId, std::bind(&Request::requestSuccessCallback, this, std::placeholders::_1), + std::bind(&Request::requestFailureCallback, this, std::placeholders::_1)); + } + + if (requestTimeoutMs > 0) { + httpConnection->setTimeout(requestTimeoutMs); + } + + httpConnection->create(request, httpUrl.host, httpUrl.port, skipBody); + } + + void requestSuccessCallback(HttpResponse_ptr response) + { + if (responseCallback) { + responseCallback(response); + } else { + std::cout << "HTTP response received (" << response->responseTimeMs + << "ms) but Request has no responseCallback" << std::endl; + } + } + + void requestFailureCallback(HttpResponse_ptr response) + { + if (failureCallback) { + failureCallback(response); + } else { + std::cout << "HTTP failure but Request has no failureCallback. Failure reason: " << response->errorMessage + << std::endl; + } + } + + void onError(const std::string& reason) { std::cout << "Could not complete HTTP request: " << reason << std::endl; } +}; + +} // namespace HttpClientLib + +// Restore the warning settings to their previous state +#if defined(__clang__) || defined(__GNUC__) +#pragma GCC diagnostic pop +#endif + +#endif // HTTPCLIENT_H diff --git a/src/luascript.cpp b/src/luascript.cpp index b790107acc..3d611f2875 100644 --- a/src/luascript.cpp +++ b/src/luascript.cpp @@ -15,6 +15,7 @@ #include "game.h" #include "globalevent.h" #include "housetile.h" +#include "httpclient.h" #include "inbox.h" #include "iologindata.h" #include "iomapserialize.h" @@ -3368,6 +3369,21 @@ void LuaScriptInterface::registerFunctions() registerMethod("XMLNode", "name", LuaScriptInterface::luaXmlNodeName); registerMethod("XMLNode", "firstChild", LuaScriptInterface::luaXmlNodeFirstChild); registerMethod("XMLNode", "nextSibling", LuaScriptInterface::luaXmlNodeNextSibling); + + // HttpClientRequest + registerClass("HttpClientRequest", "", LuaScriptInterface::luaCreateHttpClientRequest); + registerMetaMethod("HttpClientRequest", "__eq", LuaScriptInterface::luaUserdataCompare); + registerMetaMethod("HttpClientRequest", "__gc", LuaScriptInterface::luaDeleteHttpClientRequest); + registerMethod("HttpClientRequest", "setTimeout", LuaScriptInterface::luaHttpClientRequestSetTimeout); + registerMethod("HttpClientRequest", "connect", LuaScriptInterface::luaHttpClientRequestConnect); + registerMethod("HttpClientRequest", "trace", LuaScriptInterface::luaHttpClientRequestTrace); + registerMethod("HttpClientRequest", "options", LuaScriptInterface::luaHttpClientRequestOptions); + registerMethod("HttpClientRequest", "head", LuaScriptInterface::luaHttpClientRequestHead); + registerMethod("HttpClientRequest", "delete", LuaScriptInterface::luaHttpClientRequestDelete); + registerMethod("HttpClientRequest", "get", LuaScriptInterface::luaHttpClientRequestGet); + registerMethod("HttpClientRequest", "post", LuaScriptInterface::luaHttpClientRequestPost); + registerMethod("HttpClientRequest", "patch", LuaScriptInterface::luaHttpClientRequestPatch); + registerMethod("HttpClientRequest", "put", LuaScriptInterface::luaHttpClientRequestPut); } #undef registerEnum @@ -18290,6 +18306,288 @@ int LuaScriptInterface::luaXmlNodeNextSibling(lua_State* L) return 1; } +// HttpClient +int LuaScriptInterface::luaCreateHttpClientRequest(lua_State* L) +{ + // HttpClientRequest() + HttpClientLib::HttpRequest_ptr httpRequest = std::make_shared(); + + httpRequest->callbackData.scriptInterface = getScriptEnv()->getScriptInterface(); + pushSharedPtr(L, httpRequest); + setMetatable(L, -1, "HttpClientRequest"); + return 1; +} + +int LuaScriptInterface::luaDeleteHttpClientRequest(lua_State* L) +{ + HttpClientLib::HttpRequest_ptr& httpRequest = getSharedPtr(L, 1); + + if (httpRequest) { + httpRequest.reset(); + } + return 0; +} + +int LuaScriptInterface::luaHttpClientRequestSetTimeout(lua_State* L) +{ + HttpClientLib::HttpRequest_ptr& httpRequest = getSharedPtr(L, 1); + + if (!httpRequest) { + lua_pushnil(L); + return 1; + } + + httpRequest->timeout = getNumber(L, -1); + + pushBoolean(L, true); + return 1; +} + +int LuaScriptInterface::luaHttpClientRequestConnect(lua_State* L) +{ + HttpClientLib::HttpRequest_ptr& httpRequest = getSharedPtr(L, 1); + + if (!httpRequest) { + lua_pushnil(L); + return 1; + } + + luaHttpClientBuildRequest(L, httpRequest); + httpRequest->method = HttpClientLib::HTTP_CONNECT; + + g_http.addRequest(httpRequest); + + pushBoolean(L, true); + return 1; +} + +int LuaScriptInterface::luaHttpClientRequestTrace(lua_State* L) +{ + HttpClientLib::HttpRequest_ptr& httpRequest = getSharedPtr(L, 1); + + if (!httpRequest) { + lua_pushnil(L); + return 1; + } + + luaHttpClientBuildRequest(L, httpRequest); + httpRequest->method = HttpClientLib::HTTP_TRACE; + + g_http.addRequest(httpRequest); + + pushBoolean(L, true); + return 1; +} + +int LuaScriptInterface::luaHttpClientRequestOptions(lua_State* L) +{ + HttpClientLib::HttpRequest_ptr& httpRequest = getSharedPtr(L, 1); + + if (!httpRequest) { + lua_pushnil(L); + return 1; + } + + luaHttpClientBuildRequest(L, httpRequest); + httpRequest->method = HttpClientLib::HTTP_OPTIONS; + + g_http.addRequest(httpRequest); + + pushBoolean(L, true); + return 1; +} + +int LuaScriptInterface::luaHttpClientRequestHead(lua_State* L) +{ + HttpClientLib::HttpRequest_ptr& httpRequest = getSharedPtr(L, 1); + + if (!httpRequest) { + lua_pushnil(L); + return 1; + } + + luaHttpClientBuildRequest(L, httpRequest); + httpRequest->method = HttpClientLib::HTTP_HEAD; + + g_http.addRequest(httpRequest); + + pushBoolean(L, true); + return 1; +} + +int LuaScriptInterface::luaHttpClientRequestDelete(lua_State* L) +{ + HttpClientLib::HttpRequest_ptr& httpRequest = getSharedPtr(L, 1); + + if (!httpRequest) { + lua_pushnil(L); + return 1; + } + + luaHttpClientBuildRequest(L, httpRequest); + httpRequest->method = HttpClientLib::HTTP_DELETE; + + g_http.addRequest(httpRequest); + + pushBoolean(L, true); + return 1; +} + +int LuaScriptInterface::luaHttpClientRequestGet(lua_State* L) +{ + HttpClientLib::HttpRequest_ptr& httpRequest = getSharedPtr(L, 1); + + if (!httpRequest) { + lua_pushnil(L); + return 1; + } + + luaHttpClientBuildRequest(L, httpRequest); + httpRequest->method = HttpClientLib::HTTP_GET; + + g_http.addRequest(httpRequest); + + pushBoolean(L, true); + return 1; +} + +int LuaScriptInterface::luaHttpClientRequestPost(lua_State* L) +{ + HttpClientLib::HttpRequest_ptr& httpRequest = getSharedPtr(L, 1); + + if (!httpRequest) { + lua_pushnil(L); + return 1; + } + + luaHttpClientBuildRequest(L, httpRequest); + httpRequest->method = HttpClientLib::HTTP_POST; + + g_http.addRequest(httpRequest); + + pushBoolean(L, true); + return 1; +} + +int LuaScriptInterface::luaHttpClientRequestPatch(lua_State* L) +{ + HttpClientLib::HttpRequest_ptr& httpRequest = getSharedPtr(L, 1); + + if (!httpRequest) { + lua_pushnil(L); + return 1; + } + + luaHttpClientBuildRequest(L, httpRequest); + httpRequest->method = HttpClientLib::HTTP_PATCH; + + g_http.addRequest(httpRequest); + + pushBoolean(L, true); + return 1; +} + +int LuaScriptInterface::luaHttpClientRequestPut(lua_State* L) +{ + HttpClientLib::HttpRequest_ptr& httpRequest = getSharedPtr(L, 1); + + if (!httpRequest) { + lua_pushnil(L); + return 1; + } + + luaHttpClientBuildRequest(L, httpRequest); + httpRequest->method = HttpClientLib::HTTP_PUT; + + g_http.addRequest(httpRequest); + + pushBoolean(L, true); + return 1; +} + +void LuaScriptInterface::luaHttpClientBuildRequest(lua_State* L, HttpClientLib::HttpRequest_ptr& httpRequest) +{ + std::string url; + int32_t callbackId = -1; + std::unordered_map headerFields; + std::string data; + + luaHttpClientRetrieveParameters(L, url, callbackId, headerFields, data); + + httpRequest->url = url; + httpRequest->fields = headerFields; + httpRequest->data = data; + + httpRequest->callbackData.scriptId = getScriptEnv()->getScriptId(); + httpRequest->callbackData.callbackId = callbackId; +} + +bool LuaScriptInterface::luaHttpClientRetrieveParameters(lua_State* L, std::string& url, int32_t& callbackId, + std::unordered_map& headerFields, + std::string& data) +{ + int parameters = lua_gettop(L); + if (parameters < 2) { + reportErrorFunc(L, "httpClient: expecting at least two arguments: url, callback"); + pushBoolean(L, false); + return false; + } + + if (!isString(L, 2)) { + reportErrorFunc(L, "httpClient: url parameter should be a string."); + pushBoolean(L, false); + return false; + } + + if (!isFunction(L, 3) && !isNil(L, 3)) { + reportErrorFunc(L, "httpClient: callback parameter should be a function or a nil value."); + pushBoolean(L, false); + return false; + } + + if (parameters >= 5) { + if (!isString(L, 5)) { + reportErrorFunc(L, "httpClient: data parameter should be a string."); + pushBoolean(L, false); + return false; + } + + data = getString(L, 5); + lua_pop(L, 1); + } + + if (parameters >= 4) { + if (!isTable(L, 4)) { + reportErrorFunc(L, "httpClient: Invalid fields table."); + pushBoolean(L, false); + return false; + } + + lua_pushnil(L); + while (lua_next(L, 4) != 0) { + if (lua_isstring(L, -2) && lua_isstring(L, -1)) { + std::string key = getString(L, -2); + std::string value = getString(L, -1); + headerFields[key] = value; + + // Removes the value, keeps the key for the next iteration + lua_pop(L, 1); + } + } + lua_pop(L, 1); + } + + callbackId = luaL_ref(L, LUA_REGISTRYINDEX); + + url = getString(L, 2); + lua_pop(L, 1); + + // pop HttpClientRequest class + lua_pop(L, 1); + + return true; +} + // LuaEnvironment::LuaEnvironment() : LuaScriptInterface("Main Interface") {} diff --git a/src/luascript.h b/src/luascript.h index 1c82bcf38f..c648074c7d 100644 --- a/src/luascript.h +++ b/src/luascript.h @@ -37,6 +37,11 @@ struct Outfit; using Combat_ptr = std::shared_ptr; +namespace HttpClientLib { +class HttpRequest; +using HttpRequest_ptr = std::shared_ptr; +} // namespace HttpClientLib + enum { EVENT_ID_LOADING = 1, @@ -349,6 +354,7 @@ class LuaScriptInterface static bool isTable(lua_State* L, int32_t arg) { return lua_istable(L, arg); } static bool isFunction(lua_State* L, int32_t arg) { return lua_isfunction(L, arg); } static bool isUserdata(lua_State* L, int32_t arg) { return lua_isuserdata(L, arg) != 0; } + static bool isNil(lua_State* L, int32_t arg) { return lua_isnil(L, arg) != 0; } // Push static void pushBoolean(lua_State* L, bool value); @@ -1592,6 +1598,24 @@ class LuaScriptInterface static int luaXmlNodeFirstChild(lua_State* L); static int luaXmlNodeNextSibling(lua_State* L); + // http client + static int luaCreateHttpClientRequest(lua_State* L); + static int luaDeleteHttpClientRequest(lua_State* L); + static void luaHttpClientBuildRequest(lua_State* L, HttpClientLib::HttpRequest_ptr& httpRequest); + static bool luaHttpClientRetrieveParameters(lua_State* L, std::string& url, int32_t& callbackId, + std::unordered_map& headerFields, + std::string& data); + static int luaHttpClientRequestSetTimeout(lua_State* L); + static int luaHttpClientRequestConnect(lua_State* L); + static int luaHttpClientRequestTrace(lua_State* L); + static int luaHttpClientRequestOptions(lua_State* L); + static int luaHttpClientRequestHead(lua_State* L); + static int luaHttpClientRequestDelete(lua_State* L); + static int luaHttpClientRequestGet(lua_State* L); + static int luaHttpClientRequestPost(lua_State* L); + static int luaHttpClientRequestPatch(lua_State* L); + static int luaHttpClientRequestPut(lua_State* L); + // std::string lastLuaError; diff --git a/src/otserv.cpp b/src/otserv.cpp index 99851898e2..207f3fb833 100644 --- a/src/otserv.cpp +++ b/src/otserv.cpp @@ -7,6 +7,7 @@ #include "databasemanager.h" #include "databasetasks.h" #include "game.h" +#include "httpclient.h" #include "iomarket.h" #include "monsters.h" #include "outfit.h" @@ -28,6 +29,7 @@ DatabaseTasks g_databaseTasks; Dispatcher g_dispatcher; Scheduler g_scheduler; +HttpClient g_http; Game g_game; ConfigManager g_config; @@ -71,6 +73,7 @@ int main(int argc, char* argv[]) g_dispatcher.start(); g_scheduler.start(); + g_http.start(); g_dispatcher.addTask([=, services = &serviceManager]() { mainLoader(argc, argv, services); }); @@ -85,11 +88,13 @@ int main(int argc, char* argv[]) g_scheduler.shutdown(); g_databaseTasks.shutdown(); g_dispatcher.shutdown(); + g_http.shutdown(); } g_scheduler.join(); g_databaseTasks.join(); g_dispatcher.join(); + g_http.join(); return 0; } diff --git a/src/signals.cpp b/src/signals.cpp index 754b73a108..b8a1d78a55 100644 --- a/src/signals.cpp +++ b/src/signals.cpp @@ -11,6 +11,7 @@ #include "events.h" #include "game.h" #include "globalevent.h" +#include "httpclient.h" #include "monsters.h" #include "mounts.h" #include "movement.h" @@ -41,6 +42,7 @@ extern GlobalEvents* g_globalEvents; extern Events* g_events; extern Chat* g_chat; extern LuaEnvironment g_luaEnvironment; +extern HttpClient g_http; namespace { @@ -158,6 +160,7 @@ void dispatchSignalHandler(int signal) g_scheduler.join(); g_databaseTasks.join(); g_dispatcher.join(); + g_http.join(); break; #endif default: diff --git a/vc17/theforgottenserver.vcxproj b/vc17/theforgottenserver.vcxproj index 0d040e5932..b0cad0f169 100644 --- a/vc17/theforgottenserver.vcxproj +++ b/vc17/theforgottenserver.vcxproj @@ -185,6 +185,7 @@ + @@ -274,6 +275,8 @@ + + diff --git a/vcpkg.json b/vcpkg.json index 7870b2da47..1fdea837c0 100644 --- a/vcpkg.json +++ b/vcpkg.json @@ -6,6 +6,7 @@ "boost-lockfree", "boost-system", "boost-variant", + "boost-beast", "cryptopp", "fmt", "libmariadb",