Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/shared/protocols/protocols.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ enum class Protocol {
kKafka = 10,
kMux = 11,
kAMQP = 12,
kTLS = 13,
};

} // namespace protocols
Expand Down
2 changes: 1 addition & 1 deletion src/stirling/binaries/stirling_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ DEFINE_string(trace, "",
"Dynamic trace to deploy. Either (1) the path to a file containing PxL or IR trace "
"spec, or (2) <path to object file>:<symbol_name> for full-function tracing.");
DEFINE_string(print_record_batches,
"http_events,mysql_events,pgsql_events,redis_events,cql_events,dns_events",
"http_events,mysql_events,pgsql_events,redis_events,cql_events,dns_events,tls_events",
"Comma-separated list of tables to print.");
DEFINE_bool(init_only, false, "If true, only runs the init phase and exits. For testing.");
DEFINE_int32(timeout_secs, -1,
Expand Down
21 changes: 21 additions & 0 deletions src/stirling/source_connectors/socket_tracer/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,27 @@ pl_cc_bpf_test(
],
)

pl_cc_bpf_test(
name = "tls_trace_bpf_test",
timeout = "long",
srcs = ["tls_trace_bpf_test.cc"],
flaky = True,
shard_count = 2,
tags = [
"cpu:16",
"no_asan",
"requires_bpf",
],
deps = [
":cc_library",
"//src/common/testing/test_utils:cc_library",
"//src/stirling/source_connectors/socket_tracer/testing:cc_library",
"//src/stirling/source_connectors/socket_tracer/testing/container_images:curl_container",
"//src/stirling/source_connectors/socket_tracer/testing/container_images:nginx_openssl_3_0_8_container",
"//src/stirling/testing:cc_library",
],
)

pl_cc_bpf_test(
name = "dyn_lib_trace_bpf_test",
timeout = "moderate",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ pl_cc_test(
"ENABLE_NATS_TRACING=true",
"ENABLE_MONGO_TRACING=true",
"ENABLE_AMQP_TRACING=true",
"ENABLE_TLS_TRACING=true",
],
deps = [
"//src/stirling/bpf_tools/bcc_bpf:headers",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,46 @@ static __inline enum message_type_t infer_http_message(const char* buf, size_t c
return kUnknown;
}

static __inline enum message_type_t infer_tls_message(const char* buf, size_t count) {
if (count < 6) {
return kUnknown;
}

uint8_t content_type = buf[0];
// TLS content types correspond to the following:
// 0x14: ChangeCipherSpec
// 0x15: Alert
// 0x16: Handshake
// 0x17: ApplicationData
// 0x18: Heartbeat
if (content_type != 0x16) {
return kUnknown;
}

uint16_t legacy_version = buf[1] << 8 | buf[2];
// TLS versions correspond to the following:
// 0x0300: SSL 3.0
// 0x0301: TLS 1.0
// 0x0302: TLS 1.1
// 0x0303: TLS 1.2
// 0x0304: TLS 1.3
if (legacy_version < 0x0300 || legacy_version > 0x0304) {
return kUnknown;
}

uint8_t handshake_type = buf[5];
// Check for ServerHello
if (handshake_type == 2) {
return kResponse;
}
// Check for ClientHello
if (handshake_type == 1) {
return kRequest;
}

return kUnknown;
}

// Cassandra frame:
// 0 8 16 24 32 40
// +---------+---------+---------+---------+---------+
Expand Down Expand Up @@ -699,7 +739,16 @@ static __inline struct protocol_message_t infer_protocol(const char* buf, size_t
// role by considering which side called accept() vs connect(). Once the clean-up
// above is done, the code below can be turned into a chained ternary.
// PROTOCOL_LIST: Requires update on new protocols.
if (ENABLE_HTTP_TRACING && (inferred_message.type = infer_http_message(buf, count)) != kUnknown) {
//
// TODO(ddelnano): TLS tracing should be handled differently in the future as we want to be able
// to trace the handshake and the application data separately (gh#2095). The current connection
// tracker model only works with one or the other, meaning if TLS tracing is enabled, tracing the
// plaintext within an encrypted conn will not work. ENABLE_TLS_TRACING will default to false
// until this is revisted.
if (ENABLE_TLS_TRACING && (inferred_message.type = infer_tls_message(buf, count)) != kUnknown) {
inferred_message.protocol = kProtocolTLS;
} else if (ENABLE_HTTP_TRACING &&
(inferred_message.type = infer_http_message(buf, count)) != kUnknown) {
inferred_message.protocol = kProtocolHTTP;
} else if (ENABLE_CQL_TRACING &&
(inferred_message.type = infer_cql_message(buf, count)) != kUnknown) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -482,3 +482,27 @@ TEST(ProtocolInferenceTest, AMQPResponse) {
EXPECT_EQ(protocol_message.protocol, kProtocolAMQP);
EXPECT_EQ(protocol_message.type, kResponse);
}

TEST(ProtocolInferenceTest, TLSRequest) {
struct conn_info_t conn_info = {};
// TLS Client Hello
constexpr uint8_t kReqFrame[] = {
0x16, 0x03, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0xfc, 0x03, 0x03, 0x7b, 0x7b, 0x7b,
};
auto protocol_message =
infer_protocol(reinterpret_cast<const char*>(kReqFrame), sizeof(kReqFrame), &conn_info);
EXPECT_EQ(protocol_message.protocol, kProtocolTLS);
EXPECT_EQ(protocol_message.type, kRequest);
}

TEST(ProtocolInferenceTest, TLSResponse) {
struct conn_info_t conn_info = {};
// TLS Server Hello
constexpr uint8_t kRespFrame[] = {
0x16, 0x03, 0x01, 0x00, 0x01, 0x02, 0x00, 0x00, 0xfc, 0x03, 0x03, 0x7b, 0x7b, 0x7b,
};
auto protocol_message =
infer_protocol(reinterpret_cast<const char*>(kRespFrame), sizeof(kRespFrame), &conn_info);
EXPECT_EQ(protocol_message.protocol, kProtocolTLS);
EXPECT_EQ(protocol_message.type, kResponse);
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ enum traffic_protocol_t {
kProtocolKafka = 10,
kProtocolMux = 11,
kProtocolAMQP = 12,
kProtocolTLS = 13,
// We use magic enum to iterate through protocols in C++ land,
// and don't want the C-enum-size trick to show up there.
#ifndef __cplusplus
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,7 @@ auto CreateTraceRoles() {
res.Set(kProtocolKafka, {kRoleServer});
res.Set(kProtocolMux, {kRoleServer});
res.Set(kProtocolAMQP, {kRoleServer});
res.Set(kProtocolTLS, {kRoleServer});

DCHECK(res.AreAllKeysSet());
return res;
Expand Down
4 changes: 4 additions & 0 deletions src/stirling/source_connectors/socket_tracer/data_stream.cc
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,10 @@ template void DataStream::ProcessBytesToFrames<protocols::amqp::channel_id, prot
template void DataStream::ProcessBytesToFrames<
protocols::mongodb::stream_id_t, protocols::mongodb::Frame, protocols::mongodb::StateWrapper>(
message_type_t type, protocols::mongodb::StateWrapper* state);

template void DataStream::ProcessBytesToFrames<protocols::tls::stream_id_t, protocols::tls::Frame,
protocols::NoState>(message_type_t type,
protocols::NoState* state);
void DataStream::Reset() {
data_buffer_.Reset();
has_new_events_ = false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,5 +46,6 @@ pl_cc_library(
"//src/stirling/source_connectors/socket_tracer/protocols/nats:cc_library",
"//src/stirling/source_connectors/socket_tracer/protocols/pgsql:cc_library",
"//src/stirling/source_connectors/socket_tracer/protocols/redis:cc_library",
"//src/stirling/source_connectors/socket_tracer/protocols/tls:cc_library",
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,4 @@
#include "src/stirling/source_connectors/socket_tracer/protocols/nats/stitcher.h" // IWYU pragma: export
#include "src/stirling/source_connectors/socket_tracer/protocols/pgsql/stitcher.h" // IWYU pragma: export
#include "src/stirling/source_connectors/socket_tracer/protocols/redis/stitcher.h" // IWYU pragma: export
#include "src/stirling/source_connectors/socket_tracer/protocols/tls/stitcher.h" // IWYU pragma: export
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "src/stirling/source_connectors/socket_tracer/protocols/tls/parse.h"

#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
Expand All @@ -31,6 +32,8 @@ namespace stirling {
namespace protocols {
namespace tls {

using px::utils::JSONObjectBuilder;

constexpr size_t kTLSRecordHeaderLength = 5;
constexpr size_t kExtensionMinimumLength = 4;
constexpr size_t kSNIExtensionMinimumLength = 3;
Expand All @@ -39,11 +42,9 @@ constexpr size_t kSNIExtensionMinimumLength = 3;
// In TLS 1.2 and earlier, gmt_unix_time is 4 bytes and Random is 28 bytes.
constexpr size_t kRandomStructLength = 32;

StatusOr<ParseState> ExtractSNIExtension(std::map<std::string, std::string>* exts,
BinaryDecoder* decoder) {
StatusOr<ParseState> ExtractSNIExtension(SharedExtensions* exts, BinaryDecoder* decoder) {
PX_ASSIGN_OR(auto server_name_list_length, decoder->ExtractBEInt<uint16_t>(),
return ParseState::kInvalid);
std::vector<std::string> server_names;
while (server_name_list_length > 0) {
PX_ASSIGN_OR(auto server_name_type, decoder->ExtractBEInt<uint8_t>(),
return error::Internal("Failed to extract server name type"));
Expand All @@ -56,10 +57,9 @@ StatusOr<ParseState> ExtractSNIExtension(std::map<std::string, std::string>* ext
PX_ASSIGN_OR(auto server_name, decoder->ExtractString(server_name_length),
return error::Internal("Failed to extract server name"));

server_names.push_back(std::string(server_name));
exts->server_names.push_back(std::string(server_name));
server_name_list_length -= kSNIExtensionMinimumLength + server_name_length;
}
exts->insert({"server_name", ToJSONString(server_names)});
return ParseState::kSuccess;
}

Expand All @@ -76,7 +76,7 @@ StatusOr<ParseState> ExtractSNIExtension(std::map<std::string, std::string>* ext
* diagram: https://en.wikipedia.org/wiki/Transport_Layer_Security#TLS_record
*/

ParseState ParseFullFrame(BinaryDecoder* decoder, Frame* frame) {
ParseState ParseFullFrame(SharedExtensions* extensions, BinaryDecoder* decoder, Frame* frame) {
PX_ASSIGN_OR(auto raw_content_type, decoder->ExtractBEInt<uint8_t>(),
return ParseState::kInvalid);
auto content_type = magic_enum::enum_cast<tls::ContentType>(raw_content_type);
Expand Down Expand Up @@ -170,7 +170,7 @@ ParseState ParseFullFrame(BinaryDecoder* decoder, Frame* frame) {

if (extension_length > 0) {
if (extension_type == 0x00) {
if (!ExtractSNIExtension(&frame->extensions, decoder).ok()) {
if (!ExtractSNIExtension(extensions, decoder).ok()) {
return ParseState::kInvalid;
}
} else {
Expand All @@ -182,14 +182,17 @@ ParseState ParseFullFrame(BinaryDecoder* decoder, Frame* frame) {

extensions_length -= kExtensionMinimumLength + extension_length;
}
JSONObjectBuilder body_builder;
body_builder.WriteKVRecursive("extensions", *extensions);
frame->body = body_builder.GetString();

return ParseState::kSuccess;
}

} // namespace tls

template <>
ParseState ParseFrame(message_type_t, std::string_view* buf, tls::Frame* frame, NoState*) {
ParseState ParseFrame(message_type_t type, std::string_view* buf, tls::Frame* frame, NoState*) {
// TLS record header is 5 bytes. The size of the record is in bytes 4 and 5.
if (buf->length() < tls::kTLSRecordHeaderLength) {
return ParseState::kNeedsMoreData;
Expand All @@ -200,7 +203,13 @@ ParseState ParseFrame(message_type_t, std::string_view* buf, tls::Frame* frame,
}

BinaryDecoder decoder(*buf);
auto parse_result = tls::ParseFullFrame(&decoder, frame);
std::unique_ptr<tls::SharedExtensions> extensions;
if (type == kRequest) {
extensions = std::make_unique<tls::ReqExtensions>();
} else {
extensions = std::make_unique<tls::RespExtensions>();
}
auto parse_result = tls::ParseFullFrame(extensions.get(), &decoder, frame);
if (parse_result == ParseState::kSuccess) {
buf->remove_prefix(length + tls::kTLSRecordHeaderLength);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ namespace stirling {
namespace protocols {
namespace tls {

ParseState ParseFullFrame(BinaryDecoder* decoder, Frame* frame);
ParseState ParseFullFrame(SharedExtensions* extensions, BinaryDecoder* decoder, Frame* frame);

}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -315,8 +315,7 @@ TEST_F(TLSParserTest, ParseValidClientHello) {
ASSERT_GT(frame.session_id.size(), 0);

// Validate the SNI extension was parsed properly
ASSERT_EQ(frame.extensions.size(), 1);
ASSERT_EQ(frame.extensions["server_name"], "[\"argocd-cluster-repo-server\"]");
ASSERT_EQ(frame.body, R"({"extensions":{"server_name":["argocd-cluster-repo-server"]}})");
ASSERT_EQ(state, ParseState::kSuccess);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@ namespace stirling {
namespace protocols {
namespace tls {

using ::px::utils::ToJSONString;

enum class ContentType : uint8_t {
kChangeCipherSpec = 0x14,
kAlert = 0x15,
Expand Down Expand Up @@ -186,6 +184,28 @@ enum class ExtensionType : uint16_t {
kRenegotiationInfo = 65281,
};

// Extensions that are common to both the client and server side
// of a TLS handshake
struct SharedExtensions {
std::vector<std::string> server_names;

virtual void ToJSON(::px::utils::JSONObjectBuilder* /*builder*/) const {}
virtual ~SharedExtensions() = default;
};

struct ReqExtensions : public SharedExtensions {
void ToJSON(::px::utils::JSONObjectBuilder* builder) const override {
SharedExtensions::ToJSON(builder);
builder->WriteKV("server_name", server_names);
}
};

struct RespExtensions : public SharedExtensions {
void ToJSON(::px::utils::JSONObjectBuilder* builder) const override {
SharedExtensions::ToJSON(builder);
}
};

struct Frame : public FrameBase {
ContentType content_type;

Expand All @@ -195,12 +215,12 @@ struct Frame : public FrameBase {

HandshakeType handshake_type;

uint24_t handshake_length;
uint24_t handshake_length = uint24_t(0);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GCC fails if this isn't initialized (buildbuddy failure)


LegacyVersion handshake_version;

std::string session_id;
std::map<std::string, std::string> extensions;
std::string body;

bool consumed = false;

Expand All @@ -209,9 +229,8 @@ struct Frame : public FrameBase {
std::string ToString() const override {
return absl::Substitute(
"TLS Frame [len=$0 content_type=$1 legacy_version=$2 handshake_version=$3 "
"handshake_type=$4 extensions=$5]",
length, content_type, legacy_version, handshake_version, handshake_type,
ToJSONString(extensions));
"handshake_type=$4 body=$5]",
length, content_type, legacy_version, handshake_version, handshake_type, body);
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include "src/stirling/source_connectors/socket_tracer/protocols/nats/types.h"
#include "src/stirling/source_connectors/socket_tracer/protocols/pgsql/types.h"
#include "src/stirling/source_connectors/socket_tracer/protocols/redis/types.h"
#include "src/stirling/source_connectors/socket_tracer/protocols/tls/types.h"

namespace px {
namespace stirling {
Expand All @@ -53,7 +54,8 @@ using FrameDequeVariant = std::variant<std::monostate,
absl::flat_hash_map<kafka::correlation_id_t, std::deque<kafka::Packet>>,
absl::flat_hash_map<nats::stream_id_t, std::deque<nats::Message>>,
absl::flat_hash_map<amqp::channel_id, std::deque<amqp::Frame>>,
absl::flat_hash_map<mongodb::stream_id_t, std::deque<mongodb::Frame>>>;
absl::flat_hash_map<mongodb::stream_id_t, std::deque<mongodb::Frame>>,
absl::flat_hash_map<tls::stream_id_t, std::deque<tls::Frame>>>;
// clang-format off

} // namespace protocols
Expand Down
Loading
Loading