Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions runtime-light/stdlib/rpc/rpc-api.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ inline bool f$store_string(const string& v) noexcept {
}

inline bool f$store_string2(const string& v) noexcept {
tl::string{.value = {v.c_str(), v.size()}}.store2(RpcServerInstanceState::get().tl_storer);
tl2::string{.value = {v.c_str(), v.size()}}.store(RpcServerInstanceState::get().tl_storer);
return true;
}

Expand Down Expand Up @@ -175,7 +175,7 @@ inline string f$fetch_string() noexcept {
}

inline string f$fetch_string2() noexcept {
if (tl::string val{}; val.fetch2(RpcServerInstanceState::get().tl_fetcher)) [[likely]] {
if (tl2::string val{}; val.fetch(RpcServerInstanceState::get().tl_fetcher)) [[likely]] {
return {val.value.data(), static_cast<string::size_type>(val.value.size())};
}
THROW_EXCEPTION(kphp::rpc::exception::cant_fetch_string::make());
Expand Down
24 changes: 12 additions & 12 deletions runtime-light/tl/tl-functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ bool K2InvokeJobWorker::fetch(tl::fetcher& tlf) noexcept {
tl::magic magic{};
tl::mask flags{};
bool ok{magic.fetch(tlf) && magic.expect(K2_INVOKE_JOB_WORKER_MAGIC)};
ok &= flags.fetch(tlf);
ok &= image_id.fetch(tlf);
ok &= job_id.fetch(tlf);
ok &= timeout_ns.fetch(tlf);
ok &= body.fetch(tlf);
ok = ok && flags.fetch(tlf);
ok = ok && image_id.fetch(tlf);
ok = ok && job_id.fetch(tlf);
ok = ok && timeout_ns.fetch(tlf);
ok = ok && body.fetch(tlf);
ignore_answer = static_cast<bool>(flags.value & IGNORE_ANSWER_FLAG);
return ok;
}
Expand All @@ -42,14 +42,14 @@ bool K2InvokeHttp::fetch(tl::fetcher& tlf) noexcept {
tl::magic magic{};
tl::mask flags{};
bool ok{magic.fetch(tlf) && magic.expect(K2_INVOKE_HTTP_MAGIC)};
ok &= flags.fetch(tlf);
ok &= connection.fetch(tlf);
ok &= version.fetch(tlf);
ok &= method.fetch(tlf);
ok &= uri.fetch(tlf);
ok &= headers.fetch(tlf);
ok = ok && flags.fetch(tlf);
ok = ok && connection.fetch(tlf);
ok = ok && version.fetch(tlf);
ok = ok && method.fetch(tlf);
ok = ok && uri.fetch(tlf);
ok = ok && headers.fetch(tlf);
const auto opt_body{tlf.fetch_bytes(tlf.remaining())};
ok &= opt_body.has_value();
ok = ok && opt_body.has_value();

body = opt_body.value_or(std::span<const std::byte>{});

Expand Down
10 changes: 5 additions & 5 deletions runtime-light/tl/tl-functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -314,14 +314,14 @@ class K2InvokeRpc final {
bool fetch(tl::fetcher& tlf) noexcept {
tl::magic magic{};
bool ok{magic.fetch(tlf) && magic.expect(K2_INVOKE_RPC_MAGIC)};
ok &= flags.fetch(tlf);
ok &= query_id.fetch(tlf);
ok &= net_pid.fetch(tlf);
ok = ok && flags.fetch(tlf);
ok = ok && query_id.fetch(tlf);
ok = ok && net_pid.fetch(tlf);
if (static_cast<bool>(flags.value & ACTOR_ID_FLAG)) {
ok &= opt_actor_id.emplace().fetch(tlf);
ok = ok && opt_actor_id.emplace().fetch(tlf);
}
if (static_cast<bool>(flags.value & EXTRA_FLAG)) {
ok &= opt_extra.emplace().fetch(tlf);
ok = ok && opt_extra.emplace().fetch(tlf);
}
const auto opt_query{tlf.fetch_bytes(tlf.remaining())};
query = opt_query.value_or(std::span<const std::byte>{});
Expand Down
214 changes: 97 additions & 117 deletions runtime-light/tl/tl-types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <array>
#include <cstddef>
#include <cstdint>
#include <span>
#include <utility>

#include "runtime-light/stdlib/diagnostics/logs.h"
Expand All @@ -25,22 +26,19 @@ bool string::fetch(tl::fetcher& tlf) noexcept {
uint8_t size_len{};
uint64_t string_len{};
switch (first_byte) {
case LARGE_STRING_MAGIC: {
if (tlf.remaining() < LARGE_STRING_SIZE_LEN) [[unlikely]] {
case HUGE_STRING_MAGIC: {
if (tlf.remaining() < HUGE_STRING_SIZE_LEN) [[unlikely]] {
return false;
}
size_len = LARGE_STRING_SIZE_LEN + 1;
const auto first{static_cast<uint64_t>(*tlf.fetch_trivial<uint8_t>())};
const auto second{static_cast<uint64_t>(*tlf.fetch_trivial<uint8_t>()) << 8};
const auto third{static_cast<uint64_t>(*tlf.fetch_trivial<uint8_t>()) << 16};
const auto fourth{static_cast<uint64_t>(*tlf.fetch_trivial<uint8_t>()) << 24};
const auto fifth{static_cast<uint64_t>(*tlf.fetch_trivial<uint8_t>()) << 32};
const auto sixth{static_cast<uint64_t>(*tlf.fetch_trivial<uint8_t>()) << 40};
const auto seventh{static_cast<uint64_t>(*tlf.fetch_trivial<uint8_t>()) << 48};
string_len = first | second | third | fourth | fifth | sixth | seventh;
size_len = HUGE_STRING_SIZE_LEN + 1;
auto len_bytes{*tlf.fetch_bytes(HUGE_STRING_SIZE_LEN)};
string_len = static_cast<uint64_t>(len_bytes[0]) | (static_cast<uint64_t>(len_bytes[1]) << 8) | (static_cast<uint64_t>(len_bytes[2]) << 16) |
(static_cast<uint64_t>(len_bytes[3]) << 24) | (static_cast<uint64_t>(len_bytes[4]) << 32) | (static_cast<uint64_t>(len_bytes[5]) << 40) |
(static_cast<uint64_t>(len_bytes[6]) << 48);

if (string_len <= MEDIUM_STRING_MAX_LEN) [[unlikely]] {
kphp::log::warning("large string's length is less than (1 << 24) - 1 (length = {})", string_len);
return false;
}
break;
}
Expand All @@ -49,18 +47,17 @@ bool string::fetch(tl::fetcher& tlf) noexcept {
return false;
}
size_len = MEDIUM_STRING_SIZE_LEN + 1;
const auto first{static_cast<uint64_t>(*tlf.fetch_trivial<uint8_t>())};
const auto second{static_cast<uint64_t>(*tlf.fetch_trivial<uint8_t>()) << 8};
const auto third{static_cast<uint64_t>(*tlf.fetch_trivial<uint8_t>()) << 16};
string_len = first | second | third;
auto len_bytes{*tlf.fetch_bytes(MEDIUM_STRING_SIZE_LEN)};
string_len = static_cast<uint64_t>(len_bytes[0]) | (static_cast<uint64_t>(len_bytes[1]) << 8) | (static_cast<uint64_t>(len_bytes[2]) << 16);

if (string_len <= SMALL_STRING_MAX_LEN) [[unlikely]] {
if (string_len <= TINY_STRING_MAX_LEN) [[unlikely]] {
kphp::log::warning("long string's length is less than 254 (length = {})", string_len);
return false;
}
break;
}
default: {
size_len = SMALL_STRING_SIZE_LEN;
size_len = TINY_STRING_SIZE_LEN;
string_len = static_cast<uint64_t>(first_byte);
break;
}
Expand All @@ -83,20 +80,27 @@ void string::store(tl::storer& tls) const noexcept {
const char* str_buf{value.data()};
size_t str_len{value.size()};
uint8_t size_len{};
if (str_len <= SMALL_STRING_MAX_LEN) {
size_len = SMALL_STRING_SIZE_LEN;
if (str_len <= TINY_STRING_MAX_LEN) {
size_len = TINY_STRING_SIZE_LEN;
tls.store_trivial<uint8_t>(str_len);
} else if (str_len <= MEDIUM_STRING_MAX_LEN) {
size_len = MEDIUM_STRING_SIZE_LEN + 1;
tls.store_trivial<uint8_t>(MEDIUM_STRING_MAGIC);
tls.store_trivial<uint8_t>(str_len & 0xff);
tls.store_trivial<uint8_t>((str_len >> 8) & 0xff);
tls.store_trivial<uint8_t>((str_len >> 16) & 0xff);
std::array<std::byte, MEDIUM_STRING_SIZE_LEN> len_bytes{static_cast<std::byte>(str_len & 0xff), static_cast<std::byte>((str_len >> 8) & 0xff),
static_cast<std::byte>((str_len >> 16) & 0xff)};
tls.store_bytes(len_bytes);
} else if (str_len <= HUGE_STRING_MAX_LEN) {
size_len = HUGE_STRING_SIZE_LEN + 1;
tls.store_trivial<uint8_t>(HUGE_STRING_MAGIC);
std::array<std::byte, HUGE_STRING_SIZE_LEN> len_bytes{static_cast<std::byte>(str_len & 0xff), static_cast<std::byte>((str_len >> 8) & 0xff),
static_cast<std::byte>((str_len >> 16) & 0xff), static_cast<std::byte>((str_len >> 24) & 0xff),
static_cast<std::byte>((str_len >> 32) & 0xff), static_cast<std::byte>((str_len >> 40) & 0xff),
static_cast<std::byte>((str_len >> 48) & 0xff)};
tls.store_bytes(len_bytes);
} else {
kphp::log::warning("large strings aren't supported");
size_len = SMALL_STRING_SIZE_LEN;
kphp::log::warning("string length exceeds maximum allowed length: max allowed -> {}, actual -> {}", HUGE_STRING_MAX_LEN, str_len);
size_len = 0;
str_len = 0;
tls.store_trivial<uint8_t>(str_len);
}
tls.store_bytes({reinterpret_cast<const std::byte*>(str_buf), str_len});

Expand All @@ -108,90 +112,6 @@ void string::store(tl::storer& tls) const noexcept {
tls.store_bytes({reinterpret_cast<const std::byte*>(padding_array.data()), padding});
}

bool string::fetch2_len(tl::fetcher& tlf, uint64_t& string_len) noexcept {
uint8_t first_byte{};
if (const auto opt_first_byte{tlf.fetch_trivial<uint8_t>()}; opt_first_byte) [[likely]] {
first_byte = *opt_first_byte;
} else {
return false;
}

switch (first_byte) {
case LARGE_STRING_MAGIC: {
if (tlf.remaining() < 8) [[unlikely]] {
return false;
}
const auto first{static_cast<uint64_t>(*tlf.fetch_trivial<uint8_t>())};
const auto second{static_cast<uint64_t>(*tlf.fetch_trivial<uint8_t>()) << 8};
const auto third{static_cast<uint64_t>(*tlf.fetch_trivial<uint8_t>()) << 16};
const auto fourth{static_cast<uint64_t>(*tlf.fetch_trivial<uint8_t>()) << 24};
const auto fifth{static_cast<uint64_t>(*tlf.fetch_trivial<uint8_t>()) << 32};
const auto sixth{static_cast<uint64_t>(*tlf.fetch_trivial<uint8_t>()) << 40};
const auto seventh{static_cast<uint64_t>(*tlf.fetch_trivial<uint8_t>()) << 48};
const auto eighth{static_cast<uint64_t>(*tlf.fetch_trivial<uint8_t>()) << 56};
string_len = first | second | third | fourth | fifth | sixth | seventh | eighth;
// we allow non-canonical length to speed up some rare implementations
return true;
}
case MEDIUM_STRING_MAGIC: {
if (tlf.remaining() < 2) [[unlikely]] {
return false;
}
const auto first{static_cast<uint64_t>(*tlf.fetch_trivial<uint8_t>())};
const auto second{static_cast<uint64_t>(*tlf.fetch_trivial<uint8_t>()) << 8};
string_len = MEDIUM_STRING_MAGIC + (first | second);
return true;
}
default: {
string_len = static_cast<uint64_t>(first_byte);
return true;
}
}
}

bool string::fetch2(tl::fetcher& tlf) noexcept {
uint64_t string_len{};
if (!string::fetch2_len(tlf, string_len)) {
return false;
}
if (tlf.remaining() < string_len) [[unlikely]] {
return false;
}

value = {reinterpret_cast<const char*>(std::next(tlf.view().data(), tlf.pos())), static_cast<size_t>(string_len)};
tlf.adjust(string_len);
return true;
}

void string::store2_len(tl::storer& tls, uint64_t str_len) noexcept {
if (str_len < MEDIUM_STRING_MAGIC) {
tls.store_trivial<uint8_t>(str_len);
return;
}
if (str_len < MEDIUM_STRING_MAGIC + static_cast<uint64_t>(1 << 16)) {
str_len -= MEDIUM_STRING_MAGIC;
tls.store_trivial<uint8_t>(MEDIUM_STRING_MAGIC);
tls.store_trivial<uint8_t>(str_len & 0xff);
tls.store_trivial<uint8_t>((str_len >> 8) & 0xff);
return;
}
tls.store_trivial<uint8_t>(LARGE_STRING_MAGIC);
tls.store_trivial<uint8_t>(str_len & 0xff);
tls.store_trivial<uint8_t>((str_len >> 8) & 0xff);
tls.store_trivial<uint8_t>((str_len >> 16) & 0xff);
tls.store_trivial<uint8_t>((str_len >> 24) & 0xff);
tls.store_trivial<uint8_t>((str_len >> 32) & 0xff);
tls.store_trivial<uint8_t>((str_len >> 40) & 0xff);
tls.store_trivial<uint8_t>((str_len >> 48) & 0xff);
tls.store_trivial<uint8_t>((str_len >> 56) & 0xff);
}

void string::store2(tl::storer& tls) const noexcept {
uint64_t str_len = value.size();
string::store2_len(tls, str_len);
tls.store_bytes({reinterpret_cast<const std::byte*>(value.data()), str_len});
}

bool CertInfoItem::fetch(tl::fetcher& tlf) noexcept {
tl::magic magic{};
if (!magic.fetch(tlf)) [[unlikely]] {
Expand Down Expand Up @@ -230,28 +150,28 @@ bool CertInfoItem::fetch(tl::fetcher& tlf) noexcept {
bool rpcInvokeReqExtra::fetch(tl::fetcher& tlf) noexcept {
bool ok{flags.fetch(tlf)};
if (ok && static_cast<bool>(flags.value & WAIT_BINLOG_POS_FLAG)) {
ok &= opt_wait_binlog_pos.emplace().fetch(tlf);
ok = ok && opt_wait_binlog_pos.emplace().fetch(tlf);
}
if (ok && static_cast<bool>(flags.value & STRING_FORWARD_KEYS_FLAG)) {
ok &= opt_string_forward_keys.emplace().fetch(tlf);
ok = ok && opt_string_forward_keys.emplace().fetch(tlf);
}
if (ok && static_cast<bool>(flags.value & INT_FORWARD_KEYS_FLAG)) {
ok &= opt_int_forward_keys.emplace().fetch(tlf);
ok = ok && opt_int_forward_keys.emplace().fetch(tlf);
}
if (ok && static_cast<bool>(flags.value & STRING_FORWARD_FLAG)) {
ok &= opt_string_forward.emplace().fetch(tlf);
ok = ok && opt_string_forward.emplace().fetch(tlf);
}
if (ok && static_cast<bool>(flags.value & INT_FORWARD_FLAG)) {
ok &= opt_int_forward.emplace().fetch(tlf);
ok = ok && opt_int_forward.emplace().fetch(tlf);
}
if (ok && static_cast<bool>(flags.value & CUSTOM_TIMEOUT_MS_FLAG)) {
ok &= opt_custom_timeout_ms.emplace().fetch(tlf);
ok = ok && opt_custom_timeout_ms.emplace().fetch(tlf);
}
if (ok && static_cast<bool>(flags.value & SUPPORTED_COMPRESSION_VERSION_FLAG)) {
ok &= opt_supported_compression_version.emplace().fetch(tlf);
ok = ok && opt_supported_compression_version.emplace().fetch(tlf);
}
if (ok && static_cast<bool>(flags.value & RANDOM_DELAY_FLAG)) {
ok &= opt_random_delay.emplace().fetch(tlf);
ok = ok && opt_random_delay.emplace().fetch(tlf);
}

return_binlog_pos = static_cast<bool>(flags.value & RETURN_BINLOG_POS_FLAG);
Expand Down Expand Up @@ -327,3 +247,63 @@ size_t rpcReqResultExtra::footprint() const noexcept {
}

} // namespace tl

namespace tl2 {

bool string::fetch(tl::fetcher& tlf) noexcept {
uint8_t first_byte{};
if (const auto opt_first_byte{tlf.fetch_trivial<uint8_t>()}; opt_first_byte) [[likely]] {
first_byte = *opt_first_byte;
} else {
return false;
}

uint64_t string_len{};
switch (first_byte) {
case HUGE_STRING_MAGIC: {
if (tlf.remaining() < HUGE_STRING_SIZE_LEN) [[unlikely]] {
return false;
}
// we allow non-canonical length to speed up some rare implementations
string_len = *tlf.fetch_trivial<uint64_t>();
break;
}
case MEDIUM_STRING_MAGIC: {
if (tlf.remaining() < MEDIUM_STRING_SIZE_LEN) [[unlikely]] {
return false;
}
string_len = MEDIUM_STRING_MAGIC + *tlf.fetch_trivial<uint16_t>();
break;
}
default: {
string_len = static_cast<uint64_t>(first_byte);
break;
}
}

if (auto remaining{tlf.remaining()}; remaining < string_len) [[unlikely]] {
kphp::log::warning("not enough space in buffer to fetch string: required {} bytes, remain {} bytes", string_len, remaining);
return false;
}

value = {reinterpret_cast<const char*>(std::next(tlf.view().data(), tlf.pos())), static_cast<size_t>(string_len)};
tlf.adjust(string_len);
return true;
}

void string::store(tl::storer& tls) const noexcept {
const size_t str_len{value.size()};

if (str_len <= TINY_STRING_MAX_LEN) {
tls.store_trivial<uint8_t>(str_len);
} else if (str_len <= MEDIUM_STRING_MAX_LEN) {
tls.store_trivial<uint8_t>(MEDIUM_STRING_MAGIC);
tls.store_trivial<uint16_t>(str_len - MEDIUM_STRING_MAGIC);
} else {
tls.store_trivial<uint8_t>(HUGE_STRING_MAGIC);
tls.store_trivial<uint64_t>(str_len);
}
tls.store_bytes({reinterpret_cast<const std::byte*>(value.data()), str_len});
}

} // namespace tl2
Loading
Loading