Skip to content

Commit d470862

Browse files
author
lhh
committed
Fix span lifecycle with smart pointers to prevent use-after-free in async RPC callbacks (#3068)
1 parent 9934c39 commit d470862

32 files changed

+683
-282
lines changed

src/brpc/builtin/rpcz_service.cpp

Lines changed: 55 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -185,16 +185,35 @@ static void PrintElapse(std::ostream& os, int64_t cur_time,
185185

186186
static void PrintAnnotations(
187187
std::ostream& os, int64_t cur_time, int64_t* last_time,
188-
SpanInfoExtractor** extractors, int num_extr) {
188+
SpanInfoExtractor** extractors, int num_extr, const RpczSpan* span) {
189189
int64_t anno_time;
190190
std::string a;
191+
const char* span_type_str = "Span";
192+
if (span) {
193+
switch (span->type()) {
194+
case SPAN_TYPE_SERVER:
195+
span_type_str = "ServerSpan";
196+
break;
197+
case SPAN_TYPE_CLIENT:
198+
span_type_str = "ClientSpan";
199+
break;
200+
case SPAN_TYPE_BTHREAD:
201+
span_type_str = "BthreadSpan";
202+
break;
203+
}
204+
}
205+
191206
// TODO: Going through all extractors is not strictly correct because
192207
// later extractors may have earlier annotations.
193208
for (int i = 0; i < num_extr; ++i) {
194209
while (extractors[i]->PopAnnotation(cur_time, &anno_time, &a)) {
195210
PrintRealTime(os, anno_time);
196211
PrintElapse(os, anno_time, last_time);
197-
os << ' ' << WebEscape(a);
212+
os << ' ';
213+
if (span) {
214+
os << '[' << span_type_str << ' ' << SPAN_ID_STR << '=' << Hex(span->span_id()) << "] ";
215+
}
216+
os << WebEscape(a);
198217
if (a.empty() || butil::back_char(a) != '\n') {
199218
os << '\n';
200219
}
@@ -204,12 +223,12 @@ static void PrintAnnotations(
204223

205224
static bool PrintAnnotationsAndRealTimeSpan(
206225
std::ostream& os, int64_t cur_time, int64_t* last_time,
207-
SpanInfoExtractor** extr, int num_extr) {
226+
SpanInfoExtractor** extr, int num_extr, const RpczSpan* span) {
208227
if (cur_time == 0) {
209228
// the field was not set.
210229
return false;
211230
}
212-
PrintAnnotations(os, cur_time, last_time, extr, num_extr);
231+
PrintAnnotations(os, cur_time, last_time, extr, num_extr, span);
213232
PrintRealTime(os, cur_time);
214233
PrintElapse(os, cur_time, last_time);
215234
return true;
@@ -239,9 +258,10 @@ static void PrintClientSpan(
239258
extr[num_extr++] = server_extr;
240259
}
241260
extr[num_extr++] = &client_extr;
242-
// start_send_us is always set for client spans.
243-
CHECK(PrintAnnotationsAndRealTimeSpan(os, span.start_send_real_us(),
244-
last_time, extr, num_extr));
261+
if (!PrintAnnotationsAndRealTimeSpan(os, span.start_send_real_us(),
262+
last_time, extr, num_extr, &span)) {
263+
os << " start_send_real_us:not-set";
264+
}
245265
const Protocol* protocol = FindProtocol(span.protocol());
246266
const char* protocol_name = (protocol ? protocol->name : "Unknown");
247267
const butil::EndPoint remote_side(butil::int2ip(span.remote_ip()), span.remote_port());
@@ -271,12 +291,12 @@ static void PrintClientSpan(
271291
os << std::endl;
272292

273293
if (PrintAnnotationsAndRealTimeSpan(os, span.sent_real_us(),
274-
last_time, extr, num_extr)) {
275-
os << " Requested(" << span.request_size() << ") [1]" << std::endl;
294+
last_time, extr, num_extr, &span)) {
295+
os << " [ClientSpan " << SPAN_ID_STR << '=' << Hex(span.span_id()) << "] Requested(" << span.request_size() << ") [1]" << std::endl;
276296
}
277297
if (PrintAnnotationsAndRealTimeSpan(os, span.received_real_us(),
278-
last_time, extr, num_extr)) {
279-
os << " Received response(" << span.response_size() << ")";
298+
last_time, extr, num_extr, &span)) {
299+
os << " [ClientSpan " << SPAN_ID_STR << '=' << Hex(span.span_id()) << "] Received response(" << span.response_size() << ")";
280300
if (span.base_cid() != 0 && span.ending_cid() != 0) {
281301
int64_t ver = span.ending_cid() - span.base_cid();
282302
if (ver >= 1) {
@@ -289,18 +309,18 @@ static void PrintClientSpan(
289309
}
290310

291311
if (PrintAnnotationsAndRealTimeSpan(os, span.start_parse_real_us(),
292-
last_time, extr, num_extr)) {
293-
os << " Processing the response in a new bthread" << std::endl;
312+
last_time, extr, num_extr, &span)) {
313+
os << " [ClientSpan " << SPAN_ID_STR << '=' << Hex(span.span_id()) << "] Processing the response in a new bthread" << std::endl;
294314
}
295315

296316
if (PrintAnnotationsAndRealTimeSpan(
297317
os, span.start_callback_real_us(),
298-
last_time, extr, num_extr)) {
299-
os << (span.async() ? " Enter user's done" : " Back to user's callsite") << std::endl;
318+
last_time, extr, num_extr, &span)) {
319+
os << " [ClientSpan " << SPAN_ID_STR << '=' << Hex(span.span_id()) << "] " << (span.async() ? " Enter user's done" : " Back to user's callsite") << std::endl;
300320
}
301321

302322
PrintAnnotations(os, std::numeric_limits<int64_t>::max(),
303-
last_time, extr, num_extr);
323+
last_time, extr, num_extr, &span);
304324
}
305325

306326
static void PrintClientSpan(std::ostream& os,const RpczSpan& span,
@@ -318,7 +338,15 @@ static void PrintBthreadSpan(std::ostream& os, const RpczSpan& span, int64_t* la
318338
extr[num_extr++] = server_extr;
319339
}
320340
extr[num_extr++] = &client_extr;
321-
PrintAnnotations(os, std::numeric_limits<int64_t>::max(), last_time, extr, num_extr);
341+
342+
// Print span id for bthread span context identification
343+
os << " [BthreadSpan " << SPAN_ID_STR << '=' << Hex(span.span_id());
344+
if (span.parent_span_id() != 0) {
345+
os << " parent_span=" << Hex(span.parent_span_id());
346+
}
347+
os << "] ";
348+
349+
PrintAnnotations(os, std::numeric_limits<int64_t>::max(), last_time, extr, num_extr, &span);
322350
}
323351

324352
static void PrintServerSpan(std::ostream& os, const RpczSpan& span,
@@ -348,16 +376,16 @@ static void PrintServerSpan(std::ostream& os, const RpczSpan& span,
348376
os << std::endl;
349377
if (PrintAnnotationsAndRealTimeSpan(
350378
os, span.start_parse_real_us(),
351-
&last_time, extr, ARRAY_SIZE(extr))) {
352-
os << " Processing the request in a new bthread" << std::endl;
379+
&last_time, extr, ARRAY_SIZE(extr), &span)) {
380+
os << " [ServerSpan " << SPAN_ID_STR << '=' << Hex(span.span_id()) << "] Processing the request in a new bthread" << std::endl;
353381
}
354382

355383
bool entered_user_method = false;
356384
if (PrintAnnotationsAndRealTimeSpan(
357385
os, span.start_callback_real_us(),
358-
&last_time, extr, ARRAY_SIZE(extr))) {
386+
&last_time, extr, ARRAY_SIZE(extr), &span)) {
359387
entered_user_method = true;
360-
os << " Enter " << WebEscape(span.full_method_name()) << std::endl;
388+
os << " [ServerSpan " << SPAN_ID_STR << '=' << Hex(span.span_id()) << "] Enter " << WebEscape(span.full_method_name()) << std::endl;
361389
}
362390

363391
const int nclient = span.client_spans_size();
@@ -372,22 +400,22 @@ static void PrintServerSpan(std::ostream& os, const RpczSpan& span,
372400

373401
if (PrintAnnotationsAndRealTimeSpan(
374402
os, span.start_send_real_us(),
375-
&last_time, extr, ARRAY_SIZE(extr))) {
403+
&last_time, extr, ARRAY_SIZE(extr), &span)) {
376404
if (entered_user_method) {
377-
os << " Leave " << WebEscape(span.full_method_name()) << std::endl;
405+
os << " [ServerSpan " << SPAN_ID_STR << '=' << Hex(span.span_id()) << "] Leave " << WebEscape(span.full_method_name()) << std::endl;
378406
} else {
379-
os << " Responding" << std::endl;
407+
os << " [ServerSpan " << SPAN_ID_STR << '=' << Hex(span.span_id()) << "] Responding" << std::endl;
380408
}
381409
}
382410

383411
if (PrintAnnotationsAndRealTimeSpan(
384412
os, span.sent_real_us(),
385-
&last_time, extr, ARRAY_SIZE(extr))) {
386-
os << " Responded(" << span.response_size() << ')' << std::endl;
413+
&last_time, extr, ARRAY_SIZE(extr), &span)) {
414+
os << " [ServerSpan " << SPAN_ID_STR << '=' << Hex(span.span_id()) << "] Responded(" << span.response_size() << ')' << std::endl;
387415
}
388416

389417
PrintAnnotations(os, std::numeric_limits<int64_t>::max(),
390-
&last_time, extr, ARRAY_SIZE(extr));
418+
&last_time, extr, ARRAY_SIZE(extr), &span);
391419
}
392420

393421
class RpczSpanFilter : public SpanFilter {

src/brpc/channel.cpp

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
#include "brpc/details/usercode_backup_pool.h" // TooManyUserCode
3838
#include "brpc/rdma/rdma_helper.h"
3939
#include "brpc/policy/esp_authenticator.h"
40+
#include "brpc/details/controller_private_accessor.h"
4041

4142
namespace brpc {
4243

@@ -490,7 +491,7 @@ void Channel::CallMethod(const google::protobuf::MethodDescriptor* method,
490491
}
491492
cntl->set_used_by_rpc();
492493

493-
if (cntl->_sender == NULL && IsTraceable(Span::tls_parent())) {
494+
if (cntl->_sender == NULL && IsTraceable(Span::tls_parent().get())) {
494495
const int64_t start_send_us = butil::cpuwide_time_us();
495496
const std::string* method_name = NULL;
496497
if (_get_method_name) {
@@ -501,13 +502,16 @@ void Channel::CallMethod(const google::protobuf::MethodDescriptor* method,
501502
const static std::string NULL_METHOD_STR = "null-method";
502503
method_name = &NULL_METHOD_STR;
503504
}
504-
Span* span = Span::CreateClientSpan(
505+
std::shared_ptr<Span> span = Span::CreateClientSpan(
505506
*method_name, start_send_real_us - start_send_us);
506-
span->set_log_id(cntl->log_id());
507-
span->set_base_cid(correlation_id);
508-
span->set_protocol(_options.protocol);
509-
span->set_start_send_us(start_send_us);
510-
cntl->_span = span;
507+
if (span) {
508+
ControllerPrivateAccessor accessor(cntl);
509+
span->set_log_id(cntl->log_id());
510+
span->set_base_cid(correlation_id);
511+
span->set_protocol(_options.protocol);
512+
span->set_start_send_us(start_send_us);
513+
accessor.set_span(span);
514+
}
511515
}
512516
// Override some options if they haven't been set by Controller
513517
if (cntl->timeout_ms() == UNSET_MAGIC_NUM) {
@@ -608,9 +612,7 @@ void Channel::CallMethod(const google::protobuf::MethodDescriptor* method,
608612
// be woken up by callback when RPC finishes (succeeds or still
609613
// fails after retry)
610614
Join(correlation_id);
611-
if (cntl->_span) {
612-
cntl->SubmitSpan();
613-
}
615+
cntl->SubmitSpan();
614616
cntl->OnRPCEnd(butil::gettimeofday_us());
615617
}
616618
}

src/brpc/controller.cpp

Lines changed: 39 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -183,8 +183,8 @@ static void CreateIgnoreAllRead() { s_ignore_all_read = new IgnoreAllRead; }
183183
// you don't have to set the fields to initial state after deletion since
184184
// they'll be set uniformly after this method is called.
185185
void Controller::ResetNonPods() {
186-
if (_span) {
187-
Span::Submit(_span, butil::cpuwide_time_us());
186+
if (auto span = _span.lock()) {
187+
Span::Submit(span, butil::cpuwide_time_us());
188188
}
189189
_error_text.clear();
190190
_remote_side = butil::EndPoint();
@@ -240,7 +240,7 @@ void Controller::ResetNonPods() {
240240
void Controller::ResetPods() {
241241
// NOTE: Make the sequence of assignments same with the order that they're
242242
// defined in header. Better for cpu cache and faster for lookup.
243-
_span = NULL;
243+
_span.reset();
244244
_flags = 0;
245245
#ifndef BAIDU_INTERNAL
246246
set_pb_bytes_to_base64(true);
@@ -450,9 +450,9 @@ void Controller::SetFailed(const std::string& reason) {
450450
AppendServerIdentiy();
451451
}
452452
_error_text.append(reason);
453-
if (_span) {
454-
_span->set_error_code(_error_code);
455-
_span->Annotate(reason);
453+
if (auto span = _span.lock()) {
454+
span->set_error_code(_error_code);
455+
span->Annotate(reason);
456456
}
457457
UpdateResponseHeader(this);
458458
}
@@ -479,9 +479,9 @@ void Controller::SetFailed(int error_code, const char* reason_fmt, ...) {
479479
va_start(ap, reason_fmt);
480480
butil::string_vappendf(&_error_text, reason_fmt, ap);
481481
va_end(ap);
482-
if (_span) {
483-
_span->set_error_code(_error_code);
484-
_span->AnnotateCStr(_error_text.c_str() + old_size, 0);
482+
if (auto span = _span.lock()) {
483+
span->set_error_code(_error_code);
484+
span->AnnotateCStr(_error_text.c_str() + old_size, 0);
485485
}
486486
UpdateResponseHeader(this);
487487
}
@@ -507,9 +507,9 @@ void Controller::CloseConnection(const char* reason_fmt, ...) {
507507
va_start(ap, reason_fmt);
508508
butil::string_vappendf(&_error_text, reason_fmt, ap);
509509
va_end(ap);
510-
if (_span) {
511-
_span->set_error_code(_error_code);
512-
_span->AnnotateCStr(_error_text.c_str() + old_size, 0);
510+
if (auto span = _span.lock()) {
511+
span->set_error_code(_error_code);
512+
span->AnnotateCStr(_error_text.c_str() + old_size, 0);
513513
}
514514
UpdateResponseHeader(this);
515515
}
@@ -943,9 +943,9 @@ void Controller::EndRPC(const CompletionInfo& info) {
943943
}
944944
// RPC finished, now it's safe to release `LoadBalancerWithNaming'
945945
_lb.reset();
946-
if (_span) {
947-
_span->set_ending_cid(info.id);
948-
_span->set_async(_done);
946+
if (auto span = _span.lock()) {
947+
span->set_ending_cid(info.id);
948+
span->set_async(_done);
949949
// Submit the span if we're in async RPC. For sync RPC, the span
950950
// is submitted after Join() to get a more accurate resuming timestamp.
951951
if (_done) {
@@ -1019,12 +1019,16 @@ void Controller::DoneInBackupThread() {
10191019

10201020
void Controller::SubmitSpan() {
10211021
const int64_t now = butil::cpuwide_time_us();
1022-
_span->set_start_callback_us(now);
1023-
if (_span->local_parent()) {
1024-
_span->local_parent()->AsParent();
1022+
if (auto span = _span.lock()) {
1023+
span->set_start_callback_us(now);
1024+
if (auto parent_span = span->local_parent().lock()) {
1025+
if (parent_span->is_active()) {
1026+
parent_span->AsParent();
1027+
}
1028+
}
1029+
Span::Submit(span, now);
1030+
_span.reset();
10251031
}
1026-
Span::Submit(_span, now);
1027-
_span = NULL;
10281032
}
10291033

10301034
void Controller::HandleSendFailed() {
@@ -1122,8 +1126,7 @@ void Controller::IssueRPC(int64_t start_realtime_us) {
11221126
CHECK_EQ(_remote_side, tmp_sock->remote_side());
11231127
}
11241128

1125-
Span* span = _span;
1126-
if (span) {
1129+
if (auto span = _span.lock()) {
11271130
if (_current_call.nretry == 0) {
11281131
span->set_remote_side(_remote_side);
11291132
} else {
@@ -1235,15 +1238,15 @@ void Controller::IssueRPC(int64_t start_realtime_us) {
12351238
int rc;
12361239
size_t packet_size = 0;
12371240
if (user_packet_guard) {
1238-
if (span) {
1241+
if (auto span = _span.lock()) {
12391242
packet_size = user_packet_guard->EstimatedByteSize();
12401243
}
12411244
rc = _current_call.sending_sock->Write(user_packet_guard, &wopt);
12421245
} else {
12431246
packet_size = packet.size();
12441247
rc = _current_call.sending_sock->Write(&packet, &wopt);
12451248
}
1246-
if (span) {
1249+
if (auto span = _span.lock()) {
12471250
if (_current_call.nretry == 0) {
12481251
span->set_sent_us(butil::cpuwide_time_us());
12491252
span->set_request_size(packet_size);
@@ -1387,8 +1390,18 @@ const Controller* Controller::sub(int index) const {
13871390
return NULL;
13881391
}
13891392

1390-
uint64_t Controller::trace_id() const { return _span ? _span->trace_id() : 0; }
1391-
uint64_t Controller::span_id() const { return _span ? _span->span_id() : 0; }
1393+
uint64_t Controller::trace_id() const {
1394+
if (auto span = _span.lock()) {
1395+
return span->trace_id();
1396+
}
1397+
return 0;
1398+
}
1399+
uint64_t Controller::span_id() const {
1400+
if (auto span = _span.lock()) {
1401+
return span->span_id();
1402+
}
1403+
return 0;
1404+
}
13921405

13931406
void* Controller::session_local_data() {
13941407
if (_session_local_data) {

src/brpc/controller.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include <functional> // std::function
2626
#include <gflags/gflags.h> // Users often need gflags
2727
#include <string>
28+
#include <memory>
2829
#include "butil/intrusive_ptr.hpp" // butil::intrusive_ptr
2930
#include "bthread/errno.h" // Redefine errno
3031
#include "butil/endpoint.h" // butil::EndPoint
@@ -803,7 +804,7 @@ friend void policy::ProcessThriftRequest(InputMessageBase*);
803804
private:
804805
// NOTE: align and group fields to make Controller as compact as possible.
805806

806-
Span* _span;
807+
std::weak_ptr<Span> _span;
807808
uint32_t _flags; // all boolean fields inside Controller
808809
int32_t _error_code;
809810
std::string _error_text;

0 commit comments

Comments
 (0)