Skip to content

Commit 51b9016

Browse files
committed
Fix a null ACK grouping tracker can be accessed after consumer is closed
Fixes #516
1 parent 3be5267 commit 51b9016

File tree

9 files changed

+175
-94
lines changed

9 files changed

+175
-94
lines changed

lib/AckGroupingTracker.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ DECLARE_LOG_OBJECT();
3636

3737
void AckGroupingTracker::doImmediateAck(const MessageId& msgId, const ResultCallback& callback,
3838
CommandAck_AckType ackType) const {
39-
const auto cnx = connectionSupplier_();
39+
const auto cnx = getConnection();
4040
if (!cnx) {
4141
LOG_DEBUG("Connection is not ready, ACK failed for " << msgId);
4242
if (callback) {
@@ -89,7 +89,7 @@ static std::ostream& operator<<(std::ostream& os, const std::set<MessageId>& msg
8989

9090
void AckGroupingTracker::doImmediateAck(const std::set<MessageId>& msgIds,
9191
const ResultCallback& callback) const {
92-
const auto cnx = connectionSupplier_();
92+
const auto cnx = getConnection();
9393
if (!cnx) {
9494
LOG_DEBUG("Connection is not ready, ACK failed for " << msgIds);
9595
if (callback) {

lib/AckGroupingTracker.h

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,15 @@
2727
#include <set>
2828

2929
#include "ProtoApiEnums.h"
30+
#include "lib/HandlerBase.h"
3031

3132
namespace pulsar {
3233

3334
class ClientConnection;
3435
using ClientConnectionPtr = std::shared_ptr<ClientConnection>;
3536
using ClientConnectionWeakPtr = std::weak_ptr<ClientConnection>;
3637
using ResultCallback = std::function<void(Result)>;
38+
using HandlerBaseWeakPtr = std::weak_ptr<HandlerBase>;
3739

3840
/**
3941
* @class AckGroupingTracker
@@ -42,19 +44,19 @@ using ResultCallback = std::function<void(Result)>;
4244
*/
4345
class AckGroupingTracker : public std::enable_shared_from_this<AckGroupingTracker> {
4446
public:
45-
AckGroupingTracker(std::function<ClientConnectionPtr()> connectionSupplier,
46-
std::function<uint64_t()> requestIdSupplier, uint64_t consumerId, bool waitResponse)
47-
: connectionSupplier_(std::move(connectionSupplier)),
48-
requestIdSupplier_(std::move(requestIdSupplier)),
47+
AckGroupingTracker(std::function<uint64_t()> requestIdSupplier, uint64_t consumerId, bool waitResponse)
48+
: requestIdSupplier_(std::move(requestIdSupplier)),
4949
consumerId_(consumerId),
5050
waitResponse_(waitResponse) {}
5151

5252
virtual ~AckGroupingTracker() = default;
5353

5454
/**
5555
* Start tracking the ACK requests.
56+
*
57+
* @param[in] handler the handler to get a ClientConnection for sending the ACK requests.
5658
*/
57-
virtual void start() {}
59+
virtual void start(const HandlerBaseWeakPtr& handler) { handler_ = handler; }
5860

5961
/**
6062
* Since ACK requests are grouped and delayed, we need to do some best-effort duplicate check to
@@ -99,15 +101,39 @@ class AckGroupingTracker : public std::enable_shared_from_this<AckGroupingTracke
99101
*/
100102
virtual void flushAndClean() {}
101103

104+
/**
105+
* Close the ACK grouping tracker, which will prevent further ACK requests being sent.
106+
*/
107+
virtual void close() { isClosed_.store(true, std::memory_order_relaxed); }
108+
102109
protected:
103110
void doImmediateAck(const MessageId& msgId, const ResultCallback& callback,
104111
CommandAck_AckType ackType) const;
105112
void doImmediateAck(const std::set<MessageId>& msgIds, const ResultCallback& callback) const;
113+
bool isClosed() const noexcept { return isClosed_.load(std::memory_order_relaxed); }
114+
bool validateClosed(const ResultCallback& callback) const {
115+
if (isClosed()) {
116+
if (callback) {
117+
callback(ResultAlreadyClosed);
118+
}
119+
return true;
120+
}
121+
return false;
122+
}
106123

107124
private:
108-
const std::function<ClientConnectionPtr()> connectionSupplier_;
125+
std::weak_ptr<HandlerBase> handler_;
109126
const std::function<uint64_t()> requestIdSupplier_;
110127
const uint64_t consumerId_;
128+
std::atomic_bool isClosed_{false};
129+
130+
ClientConnectionPtr getConnection() const {
131+
auto handler = handler_.lock();
132+
if (!handler) {
133+
return nullptr;
134+
}
135+
return handler->getCnx().lock();
136+
}
111137

112138
protected:
113139
const bool waitResponse_;

lib/AckGroupingTrackerEnabled.cc

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,10 @@ static int compare(const MessageId& lhs, const MessageId& rhs) {
4545
}
4646
}
4747

48-
void AckGroupingTrackerEnabled::start() { this->scheduleTimer(); }
48+
void AckGroupingTrackerEnabled::start(const HandlerBaseWeakPtr& handler) {
49+
AckGroupingTracker::start(handler);
50+
this->scheduleTimer();
51+
}
4952

5053
bool AckGroupingTrackerEnabled::isDuplicate(const MessageId& msgId) {
5154
{
@@ -62,6 +65,9 @@ bool AckGroupingTrackerEnabled::isDuplicate(const MessageId& msgId) {
6265
}
6366

6467
void AckGroupingTrackerEnabled::addAcknowledge(const MessageId& msgId, const ResultCallback& callback) {
68+
if (validateClosed(callback)) {
69+
return;
70+
}
6571
std::lock_guard<std::recursive_mutex> lock(this->rmutexPendingIndAcks_);
6672
this->pendingIndividualAcks_.insert(msgId);
6773
if (waitResponse_) {
@@ -76,6 +82,9 @@ void AckGroupingTrackerEnabled::addAcknowledge(const MessageId& msgId, const Res
7682

7783
void AckGroupingTrackerEnabled::addAcknowledgeList(const MessageIdList& msgIds,
7884
const ResultCallback& callback) {
85+
if (validateClosed(callback)) {
86+
return;
87+
}
7988
std::lock_guard<std::recursive_mutex> lock(this->rmutexPendingIndAcks_);
8089
for (const auto& msgId : msgIds) {
8190
this->pendingIndividualAcks_.emplace(msgId);
@@ -92,6 +101,9 @@ void AckGroupingTrackerEnabled::addAcknowledgeList(const MessageIdList& msgIds,
92101

93102
void AckGroupingTrackerEnabled::addAcknowledgeCumulative(const MessageId& msgId,
94103
const ResultCallback& callback) {
104+
if (validateClosed(callback)) {
105+
return;
106+
}
95107
std::unique_lock<std::mutex> lock(this->mutexCumulativeAckMsgId_);
96108
bool completeCallback = true;
97109
if (compare(msgId, this->nextCumulativeAckMsgId_) > 0) {
@@ -115,10 +127,15 @@ void AckGroupingTrackerEnabled::addAcknowledgeCumulative(const MessageId& msgId,
115127
callback(ResultOk);
116128
}
117129
}
118-
119130
AckGroupingTrackerEnabled::~AckGroupingTrackerEnabled() {
120-
isClosed_ = true;
121-
this->flush();
131+
std::lock_guard<std::mutex> lock(this->mutexTimer_);
132+
if (this->timer_) {
133+
cancelTimer(*this->timer_);
134+
}
135+
}
136+
137+
void AckGroupingTrackerEnabled::close() {
138+
AckGroupingTracker::close();
122139
std::lock_guard<std::mutex> lock(this->mutexTimer_);
123140
if (this->timer_) {
124141
cancelTimer(*this->timer_);
@@ -165,7 +182,7 @@ void AckGroupingTrackerEnabled::flushAndClean() {
165182
}
166183

167184
void AckGroupingTrackerEnabled::scheduleTimer() {
168-
if (isClosed_) {
185+
if (isClosed()) {
169186
return;
170187
}
171188

lib/AckGroupingTrackerEnabled.h

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,10 @@ using HandlerBaseWeakPtr = std::weak_ptr<HandlerBase>;
4545
*/
4646
class AckGroupingTrackerEnabled : public AckGroupingTracker {
4747
public:
48-
AckGroupingTrackerEnabled(const std::function<ClientConnectionPtr()>& connectionSupplier,
49-
const std::function<uint64_t()>& requestIdSupplier, uint64_t consumerId,
48+
AckGroupingTrackerEnabled(const std::function<uint64_t()>& requestIdSupplier, uint64_t consumerId,
5049
bool waitResponse, long ackGroupingTimeMs, long ackGroupingMaxSize,
5150
const ExecutorServicePtr& executor)
52-
: AckGroupingTracker(connectionSupplier, requestIdSupplier, consumerId, waitResponse),
51+
: AckGroupingTracker(requestIdSupplier, consumerId, waitResponse),
5352
ackGroupingTimeMs_(ackGroupingTimeMs),
5453
ackGroupingMaxSize_(ackGroupingMaxSize),
5554
executor_(executor) {
@@ -58,21 +57,20 @@ class AckGroupingTrackerEnabled : public AckGroupingTracker {
5857

5958
~AckGroupingTrackerEnabled();
6059

61-
void start() override;
60+
void start(const HandlerBaseWeakPtr& handler) override;
6261
bool isDuplicate(const MessageId& msgId) override;
6362
void addAcknowledge(const MessageId& msgId, const ResultCallback& callback) override;
6463
void addAcknowledgeList(const MessageIdList& msgIds, const ResultCallback& callback) override;
6564
void addAcknowledgeCumulative(const MessageId& msgId, const ResultCallback& callback) override;
66-
void flush();
6765
void flushAndClean() override;
66+
void close() override;
67+
68+
private:
69+
void flush();
6870

6971
protected:
70-
//! Method for scheduling grouping timer.
7172
void scheduleTimer();
7273

73-
//! State
74-
std::atomic_bool isClosed_{false};
75-
7674
//! Next message ID to be cumulatively cumulatively.
7775
MessageId nextCumulativeAckMsgId_{MessageId::earliest()};
7876
bool requireCumulativeAck_{false};

lib/ConsumerImpl.cc

Lines changed: 28 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,25 @@ static boost::optional<MessageId> getStartMessageId(const boost::optional<Messag
7474
return startMessageId;
7575
}
7676

77+
static AckGroupingTracker* newAckGroupingTracker(const ClientImplPtr& client, const std::string& topic,
78+
uint64_t consumerId, const ConsumerConfiguration& config) {
79+
const auto requestIdGenerator = client->getRequestIdGenerator();
80+
const auto requestIdSupplier = [requestIdGenerator] { return (*requestIdGenerator)++; };
81+
82+
if (TopicName::get(topic)->isPersistent()) {
83+
if (config.getAckGroupingTimeMs() > 0) {
84+
return new AckGroupingTrackerEnabled(
85+
requestIdSupplier, consumerId, config.isAckReceiptEnabled(), config.getAckGroupingTimeMs(),
86+
config.getAckGroupingMaxSize(), client->getIOExecutorProvider()->get());
87+
} else {
88+
return new AckGroupingTrackerDisabled(requestIdSupplier, consumerId,
89+
config.isAckReceiptEnabled());
90+
}
91+
} else {
92+
return new AckGroupingTracker(requestIdSupplier, consumerId, config.isAckReceiptEnabled());
93+
}
94+
}
95+
7796
ConsumerImpl::ConsumerImpl(const ClientImplPtr& client, const std::string& topic,
7897
const std::string& subscriptionName, const ConsumerConfiguration& conf,
7998
bool isPersistent, const ConsumerInterceptorsPtr& interceptors,
@@ -105,6 +124,7 @@ ConsumerImpl::ConsumerImpl(const ClientImplPtr& client, const std::string& topic
105124
consumerStr_("[" + topic + ", " + subscriptionName + ", " + std::to_string(consumerId_) + "] "),
106125
messageListenerRunning_(!conf.isStartPaused()),
107126
negativeAcksTracker_(std::make_shared<NegativeAcksTracker>(client, *this, conf)),
127+
ackGroupingTrackerPtr_(newAckGroupingTracker(client, topic, consumerId_, conf)),
108128
readCompacted_(conf.isReadCompacted()),
109129
startMessageId_(getStartMessageId(startMessageId, conf.isStartMessageIdInclusive())),
110130
maxPendingChunkedMessage_(conf.getMaxPendingChunkedMessage()),
@@ -198,38 +218,7 @@ const std::string& ConsumerImpl::getTopic() const { return topic(); }
198218

199219
void ConsumerImpl::start() {
200220
HandlerBase::start();
201-
202-
std::weak_ptr<ConsumerImpl> weakSelf{get_shared_this_ptr()};
203-
auto connectionSupplier = [weakSelf]() -> ClientConnectionPtr {
204-
auto self = weakSelf.lock();
205-
if (!self) {
206-
return nullptr;
207-
}
208-
return self->getCnx().lock();
209-
};
210-
211-
// NOTE: start() is always called in `ClientImpl`'s method, so lock() returns not null
212-
const auto requestIdGenerator = client_.lock()->getRequestIdGenerator();
213-
const auto requestIdSupplier = [requestIdGenerator] { return (*requestIdGenerator)++; };
214-
215-
// Initialize ackGroupingTrackerPtr_ here because the get_shared_this_ptr() was not initialized until the
216-
// constructor completed.
217-
if (TopicName::get(topic())->isPersistent()) {
218-
if (config_.getAckGroupingTimeMs() > 0) {
219-
ackGroupingTrackerPtr_.reset(new AckGroupingTrackerEnabled(
220-
connectionSupplier, requestIdSupplier, consumerId_, config_.isAckReceiptEnabled(),
221-
config_.getAckGroupingTimeMs(), config_.getAckGroupingMaxSize(),
222-
client_.lock()->getIOExecutorProvider()->get()));
223-
} else {
224-
ackGroupingTrackerPtr_.reset(new AckGroupingTrackerDisabled(
225-
connectionSupplier, requestIdSupplier, consumerId_, config_.isAckReceiptEnabled()));
226-
}
227-
} else {
228-
LOG_INFO(getName() << "ACK will NOT be sent to broker for this non-persistent topic.");
229-
ackGroupingTrackerPtr_.reset(new AckGroupingTracker(connectionSupplier, requestIdSupplier,
230-
consumerId_, config_.isAckReceiptEnabled()));
231-
}
232-
ackGroupingTrackerPtr_->start();
221+
ackGroupingTrackerPtr_->start(std::static_pointer_cast<HandlerBase>(shared_from_this()));
233222
}
234223

235224
void ConsumerImpl::beforeConnectionChange(ClientConnection& cnx) { cnx.removeConsumer(consumerId_); }
@@ -591,17 +580,16 @@ void ConsumerImpl::messageReceived(const ClientConnectionPtr& cnx, const proto::
591580
LOG_DEBUG(getName() << " metadata.has_num_messages_in_batch() = "
592581
<< metadata.has_num_messages_in_batch());
593582

594-
uint32_t numOfMessageReceived = m.impl_->metadata.num_messages_in_batch();
595-
auto ackGroupingTrackerPtr = ackGroupingTrackerPtr_;
596-
if (ackGroupingTrackerPtr == nullptr) { // The consumer is closing
583+
const auto state = state_.load(std::memory_order_relaxed);
584+
if (state == Closing || state == Closed) {
597585
return;
598586
}
599-
if (ackGroupingTrackerPtr->isDuplicate(m.getMessageId())) {
587+
uint32_t numOfMessageReceived = m.impl_->metadata.num_messages_in_batch();
588+
if (ackGroupingTrackerPtr_->isDuplicate(m.getMessageId())) {
600589
LOG_DEBUG(getName() << " Ignoring message as it was ACKed earlier by same consumer.");
601590
increaseAvailablePermits(cnx, numOfMessageReceived);
602591
return;
603592
}
604-
ackGroupingTrackerPtr.reset();
605593

606594
if (metadata.has_num_messages_in_batch()) {
607595
BitSet::Data words(msg.ack_set_size());
@@ -1340,12 +1328,8 @@ void ConsumerImpl::closeAsync(const ResultCallback& originalCallback) {
13401328
incomingMessages_.close();
13411329

13421330
// Flush pending grouped ACK requests.
1343-
if (ackGroupingTrackerPtr_.use_count() != 1) {
1344-
LOG_ERROR("AckGroupingTracker is shared by other "
1345-
<< (ackGroupingTrackerPtr_.use_count() - 1)
1346-
<< " threads, which will prevent flushing the ACKs");
1347-
}
1348-
ackGroupingTrackerPtr_.reset();
1331+
ackGroupingTrackerPtr_->flushAndClean();
1332+
ackGroupingTrackerPtr_->close();
13491333
negativeAcksTracker_->close();
13501334

13511335
ClientConnectionPtr cnx = getCnx().lock();
@@ -1369,13 +1353,12 @@ void ConsumerImpl::closeAsync(const ResultCallback& originalCallback) {
13691353
cnx->sendRequestWithId(Commands::newCloseConsumer(consumerId_, requestId), requestId)
13701354
.addListener([self, callback](Result result, const ResponseData&) { callback(result); });
13711355
}
1372-
13731356
const std::string& ConsumerImpl::getName() const { return consumerStr_; }
13741357

13751358
void ConsumerImpl::shutdown() { internalShutdown(); }
13761359

13771360
void ConsumerImpl::internalShutdown() {
1378-
ackGroupingTrackerPtr_.reset();
1361+
ackGroupingTrackerPtr_->close();
13791362
incomingMessages_.clear();
13801363
possibleSendToDeadLetterTopicMessages_.clear();
13811364
resetCnx();

lib/ConsumerImpl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ class ConsumerImpl : public ConsumerImplBase {
246246
UnAckedMessageTrackerPtr unAckedMessageTrackerPtr_;
247247
BrokerConsumerStatsImpl brokerConsumerStats_;
248248
std::shared_ptr<NegativeAcksTracker> negativeAcksTracker_;
249-
AckGroupingTrackerPtr ackGroupingTrackerPtr_;
249+
const AckGroupingTrackerPtr ackGroupingTrackerPtr_;
250250

251251
MessageCryptoPtr msgCrypto_;
252252
const bool readCompacted_;

lib/OnceUniquePtr.h

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
/**
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
#pragma once
20+
21+
#include <memory>
22+
23+
namespace pulsar {
24+
25+
// A wrapper around std::unique_ptr that allows initialization only once.
26+
// This wrapper is used when the initialization needs to be deferred.
27+
template <typename T>
28+
class OnceUniquePtr {
29+
public:
30+
void init(T* ptr) {
31+
if (ptr && !ptr_) {
32+
ptr_.reset(ptr);
33+
}
34+
}
35+
36+
const T& operator*() const noexcept { return *ptr_; }
37+
T& operator*() noexcept { return *ptr_; }
38+
39+
const T* operator->() const noexcept { return ptr_.get(); }
40+
T* operator->() noexcept { return ptr_.get(); }
41+
42+
private:
43+
std::unique_ptr<T> ptr_;
44+
};
45+
46+
} // namespace pulsar

0 commit comments

Comments
 (0)