Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/shared/protocols/protocols.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ enum class Protocol {
kKafka = 10,
kMux = 11,
kAMQP = 12,
kTLS = 13,
};

} // namespace protocols
Expand Down
2 changes: 1 addition & 1 deletion src/stirling/binaries/stirling_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ DEFINE_string(trace, "",
"Dynamic trace to deploy. Either (1) the path to a file containing PxL or IR trace "
"spec, or (2) <path to object file>:<symbol_name> for full-function tracing.");
DEFINE_string(print_record_batches,
"http_events,mysql_events,pgsql_events,redis_events,cql_events,dns_events",
"http_events,mysql_events,pgsql_events,redis_events,cql_events,dns_events,tls_events",
"Comma-separated list of tables to print.");
DEFINE_bool(init_only, false, "If true, only runs the init phase and exits. For testing.");
DEFINE_int32(timeout_secs, -1,
Expand Down
21 changes: 21 additions & 0 deletions src/stirling/source_connectors/socket_tracer/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,27 @@ pl_cc_bpf_test(
],
)

pl_cc_bpf_test(
name = "tls_trace_bpf_test",
timeout = "long",
srcs = ["tls_trace_bpf_test.cc"],
flaky = True,
shard_count = 2,
tags = [
"cpu:16",
"no_asan",
"requires_bpf",
],
deps = [
":cc_library",
"//src/common/testing/test_utils:cc_library",
"//src/stirling/source_connectors/socket_tracer/testing:cc_library",
"//src/stirling/source_connectors/socket_tracer/testing/container_images:curl_container",
"//src/stirling/source_connectors/socket_tracer/testing/container_images:nginx_openssl_3_0_8_container",
"//src/stirling/testing:cc_library",
],
)

pl_cc_bpf_test(
name = "dyn_lib_trace_bpf_test",
timeout = "moderate",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ pl_cc_test(
"ENABLE_NATS_TRACING=true",
"ENABLE_MONGO_TRACING=true",
"ENABLE_AMQP_TRACING=true",
"ENABLE_TLS_TRACING=true",
],
deps = [
"//src/stirling/bpf_tools/bcc_bpf:headers",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,46 @@ static __inline enum message_type_t infer_http_message(const char* buf, size_t c
return kUnknown;
}

static __inline enum message_type_t infer_tls_message(const char* buf, size_t count) {
if (count < 6) {
return kUnknown;
}

uint8_t content_type = buf[0];
// TLS content types correspond to the following:
// 0x14: ChangeCipherSpec
// 0x15: Alert
// 0x16: Handshake
// 0x17: ApplicationData
// 0x18: Heartbeat
if (content_type != 0x16) {
return kUnknown;
}

uint16_t legacy_version = buf[1] << 8 | buf[2];
// TLS versions correspond to the following:
// 0x0300: SSL 3.0
// 0x0301: TLS 1.0
// 0x0302: TLS 1.1
// 0x0303: TLS 1.2
// 0x0304: TLS 1.3
if (legacy_version < 0x0300 || legacy_version > 0x0304) {
return kUnknown;
}

uint8_t handshake_type = buf[5];
// Check for ServerHello
if (handshake_type == 2) {
return kResponse;
}
// Check for ClientHello
if (handshake_type == 1) {
return kRequest;
}

return kUnknown;
}

// Cassandra frame:
// 0 8 16 24 32 40
// +---------+---------+---------+---------+---------+
Expand Down Expand Up @@ -699,7 +739,16 @@ static __inline struct protocol_message_t infer_protocol(const char* buf, size_t
// role by considering which side called accept() vs connect(). Once the clean-up
// above is done, the code below can be turned into a chained ternary.
// PROTOCOL_LIST: Requires update on new protocols.
if (ENABLE_HTTP_TRACING && (inferred_message.type = infer_http_message(buf, count)) != kUnknown) {
//
// TODO(ddelnano): TLS tracing should be handled differently in the future as we want to be able
// to trace the handshake and the application data separately (gh#2095). The current connection
// tracker model only works with one or the other, meaning if TLS tracing is enabled, tracing the
// plaintext within an encrypted conn will not work. ENABLE_TLS_TRACING will default to false
// until this is revisted.
if (ENABLE_TLS_TRACING && (inferred_message.type = infer_tls_message(buf, count)) != kUnknown) {
inferred_message.protocol = kProtocolTLS;
} else if (ENABLE_HTTP_TRACING &&
(inferred_message.type = infer_http_message(buf, count)) != kUnknown) {
inferred_message.protocol = kProtocolHTTP;
} else if (ENABLE_CQL_TRACING &&
(inferred_message.type = infer_cql_message(buf, count)) != kUnknown) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -482,3 +482,27 @@ TEST(ProtocolInferenceTest, AMQPResponse) {
EXPECT_EQ(protocol_message.protocol, kProtocolAMQP);
EXPECT_EQ(protocol_message.type, kResponse);
}

TEST(ProtocolInferenceTest, TLSRequest) {
struct conn_info_t conn_info = {};
// TLS Client Hello
constexpr uint8_t kReqFrame[] = {
0x16, 0x03, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0xfc, 0x03, 0x03, 0x7b, 0x7b, 0x7b,
};
auto protocol_message =
infer_protocol(reinterpret_cast<const char*>(kReqFrame), sizeof(kReqFrame), &conn_info);
EXPECT_EQ(protocol_message.protocol, kProtocolTLS);
EXPECT_EQ(protocol_message.type, kRequest);
}

TEST(ProtocolInferenceTest, TLSResponse) {
struct conn_info_t conn_info = {};
// TLS Server Hello
constexpr uint8_t kRespFrame[] = {
0x16, 0x03, 0x01, 0x00, 0x01, 0x02, 0x00, 0x00, 0xfc, 0x03, 0x03, 0x7b, 0x7b, 0x7b,
};
auto protocol_message =
infer_protocol(reinterpret_cast<const char*>(kRespFrame), sizeof(kRespFrame), &conn_info);
EXPECT_EQ(protocol_message.protocol, kProtocolTLS);
EXPECT_EQ(protocol_message.type, kResponse);
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ enum traffic_protocol_t {
kProtocolKafka = 10,
kProtocolMux = 11,
kProtocolAMQP = 12,
kProtocolTLS = 13,
// We use magic enum to iterate through protocols in C++ land,
// and don't want the C-enum-size trick to show up there.
#ifndef __cplusplus
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,7 @@ auto CreateTraceRoles() {
res.Set(kProtocolKafka, {kRoleServer});
res.Set(kProtocolMux, {kRoleServer});
res.Set(kProtocolAMQP, {kRoleServer});
res.Set(kProtocolTLS, {kRoleServer});

DCHECK(res.AreAllKeysSet());
return res;
Expand Down
4 changes: 4 additions & 0 deletions src/stirling/source_connectors/socket_tracer/data_stream.cc
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,10 @@ template void DataStream::ProcessBytesToFrames<protocols::amqp::channel_id, prot
template void DataStream::ProcessBytesToFrames<
protocols::mongodb::stream_id_t, protocols::mongodb::Frame, protocols::mongodb::StateWrapper>(
message_type_t type, protocols::mongodb::StateWrapper* state);

template void DataStream::ProcessBytesToFrames<protocols::tls::stream_id_t, protocols::tls::Frame,
protocols::NoState>(message_type_t type,
protocols::NoState* state);
void DataStream::Reset() {
data_buffer_.Reset();
has_new_events_ = false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,5 +46,6 @@ pl_cc_library(
"//src/stirling/source_connectors/socket_tracer/protocols/nats:cc_library",
"//src/stirling/source_connectors/socket_tracer/protocols/pgsql:cc_library",
"//src/stirling/source_connectors/socket_tracer/protocols/redis:cc_library",
"//src/stirling/source_connectors/socket_tracer/protocols/tls:cc_library",
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,4 @@
#include "src/stirling/source_connectors/socket_tracer/protocols/nats/stitcher.h" // IWYU pragma: export
#include "src/stirling/source_connectors/socket_tracer/protocols/pgsql/stitcher.h" // IWYU pragma: export
#include "src/stirling/source_connectors/socket_tracer/protocols/redis/stitcher.h" // IWYU pragma: export
#include "src/stirling/source_connectors/socket_tracer/protocols/tls/stitcher.h" // IWYU pragma: export
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ namespace stirling {
namespace protocols {
namespace tls {

using px::utils::JSONObjectBuilder;

constexpr size_t kTLSRecordHeaderLength = 5;
constexpr size_t kExtensionMinimumLength = 4;
constexpr size_t kSNIExtensionMinimumLength = 3;
Expand All @@ -39,11 +41,9 @@ constexpr size_t kSNIExtensionMinimumLength = 3;
// In TLS 1.2 and earlier, gmt_unix_time is 4 bytes and Random is 28 bytes.
constexpr size_t kRandomStructLength = 32;

StatusOr<ParseState> ExtractSNIExtension(std::map<std::string, std::string>* exts,
BinaryDecoder* decoder) {
StatusOr<ParseState> ExtractSNIExtension(ReqExtensions* exts, BinaryDecoder* decoder) {
PX_ASSIGN_OR(auto server_name_list_length, decoder->ExtractBEInt<uint16_t>(),
return ParseState::kInvalid);
std::vector<std::string> server_names;
while (server_name_list_length > 0) {
PX_ASSIGN_OR(auto server_name_type, decoder->ExtractBEInt<uint8_t>(),
return error::Internal("Failed to extract server name type"));
Expand All @@ -56,10 +56,9 @@ StatusOr<ParseState> ExtractSNIExtension(std::map<std::string, std::string>* ext
PX_ASSIGN_OR(auto server_name, decoder->ExtractString(server_name_length),
return error::Internal("Failed to extract server name"));

server_names.push_back(std::string(server_name));
exts->server_names.push_back(std::string(server_name));
server_name_list_length -= kSNIExtensionMinimumLength + server_name_length;
}
exts->insert({"server_name", ToJSONString(server_names)});
return ParseState::kSuccess;
}

Expand Down Expand Up @@ -162,6 +161,8 @@ ParseState ParseFullFrame(BinaryDecoder* decoder, Frame* frame) {
return ParseState::kSuccess;
}

ReqExtensions req_extensions;
RespExtensions resp_extensions;
while (extensions_length > 0) {
PX_ASSIGN_OR(auto extension_type, decoder->ExtractBEInt<uint16_t>(),
return ParseState::kInvalid);
Expand All @@ -170,7 +171,7 @@ ParseState ParseFullFrame(BinaryDecoder* decoder, Frame* frame) {

if (extension_length > 0) {
if (extension_type == 0x00) {
if (!ExtractSNIExtension(&frame->extensions, decoder).ok()) {
if (!ExtractSNIExtension(&req_extensions, decoder).ok()) {
return ParseState::kInvalid;
}
} else {
Expand All @@ -182,6 +183,13 @@ ParseState ParseFullFrame(BinaryDecoder* decoder, Frame* frame) {

extensions_length -= kExtensionMinimumLength + extension_length;
}
JSONObjectBuilder req_body_builder;
req_body_builder.WriteKVRecursive("extensions", req_extensions);
frame->req_body = req_body_builder.GetString();

JSONObjectBuilder resp_body_builder;
resp_body_builder.WriteKVRecursive("extensions", resp_extensions);
frame->resp_body = resp_body_builder.GetString();

return ParseState::kSuccess;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -315,8 +315,7 @@ TEST_F(TLSParserTest, ParseValidClientHello) {
ASSERT_GT(frame.session_id.size(), 0);

// Validate the SNI extension was parsed properly
ASSERT_EQ(frame.extensions.size(), 1);
ASSERT_EQ(frame.extensions["server_name"], "[\"argocd-cluster-repo-server\"]");
ASSERT_EQ(frame.req_body, R"({"extensions":{"server_name":["argocd-cluster-repo-server"]}})");
ASSERT_EQ(state, ParseState::kSuccess);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@ namespace stirling {
namespace protocols {
namespace tls {

using ::px::utils::ToJSONString;

enum class ContentType : uint8_t {
kChangeCipherSpec = 0x14,
kAlert = 0x15,
Expand Down Expand Up @@ -186,6 +184,25 @@ enum class ExtensionType : uint16_t {
kRenegotiationInfo = 65281,
};

// Extensions that are common to both the client and server side
// of a TLS handshake
struct SharedExtensions {
void ToJSON(::px::utils::JSONObjectBuilder* /*builder*/) const {}
};

struct ReqExtensions : public SharedExtensions {
std::vector<std::string> server_names;

void ToJSON(::px::utils::JSONObjectBuilder* builder) const {
SharedExtensions::ToJSON(builder);
builder->WriteKV("server_name", server_names);
}
};

struct RespExtensions : public SharedExtensions {
void ToJSON(::px::utils::JSONObjectBuilder* builder) const { SharedExtensions::ToJSON(builder); }
};

struct Frame : public FrameBase {
ContentType content_type;

Expand All @@ -195,12 +212,13 @@ struct Frame : public FrameBase {

HandshakeType handshake_type;

uint24_t handshake_length;
uint24_t handshake_length = uint24_t(0);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GCC fails if this isn't initialized (buildbuddy failure)


LegacyVersion handshake_version;

std::string session_id;
std::map<std::string, std::string> extensions;
std::string req_body;
std::string resp_body;

bool consumed = false;

Expand All @@ -209,9 +227,9 @@ struct Frame : public FrameBase {
std::string ToString() const override {
return absl::Substitute(
"TLS Frame [len=$0 content_type=$1 legacy_version=$2 handshake_version=$3 "
"handshake_type=$4 extensions=$5]",
length, content_type, legacy_version, handshake_version, handshake_type,
ToJSONString(extensions));
"handshake_type=$4 req_body=$5 resp_body=$6]",
length, content_type, legacy_version, handshake_version, handshake_type, req_body,
resp_body);
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include "src/stirling/source_connectors/socket_tracer/protocols/nats/types.h"
#include "src/stirling/source_connectors/socket_tracer/protocols/pgsql/types.h"
#include "src/stirling/source_connectors/socket_tracer/protocols/redis/types.h"
#include "src/stirling/source_connectors/socket_tracer/protocols/tls/types.h"

namespace px {
namespace stirling {
Expand All @@ -53,7 +54,8 @@ using FrameDequeVariant = std::variant<std::monostate,
absl::flat_hash_map<kafka::correlation_id_t, std::deque<kafka::Packet>>,
absl::flat_hash_map<nats::stream_id_t, std::deque<nats::Message>>,
absl::flat_hash_map<amqp::channel_id, std::deque<amqp::Frame>>,
absl::flat_hash_map<mongodb::stream_id_t, std::deque<mongodb::Frame>>>;
absl::flat_hash_map<mongodb::stream_id_t, std::deque<mongodb::Frame>>,
absl::flat_hash_map<tls::stream_id_t, std::deque<tls::Frame>>>;
// clang-format off

} // namespace protocols
Expand Down
Loading
Loading