Skip to content

Commit ed684c0

Browse files
committed
Avoid stack overflow in SubscriptionFlowTree
1 parent 1f79cdb commit ed684c0

File tree

2 files changed

+127
-93
lines changed

2 files changed

+127
-93
lines changed

src/main/java/com/hivemq/client/internal/mqtt/datatypes/MqttTopicLevel.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,8 @@ public class MqttTopicLevel extends ByteArray.Range {
4242

4343
public static @NotNull MqttTopicLevel root(final @NotNull MqttTopicFilterImpl topicFilter) {
4444
final byte[] binary = topicFilter.toBinary();
45-
final int start = topicFilter.getFilterByteStart();
46-
final int end = nextEnd(binary, start);
47-
return new MqttTopicLevel(binary, start, end);
45+
final int start = topicFilter.getFilterByteStart() - 1;
46+
return new MqttTopicLevel(binary, start, start);
4847
}
4948

5049
private static int nextEnd(final @NotNull byte[] array, final int start) {

src/main/java/com/hivemq/client/internal/mqtt/handler/publish/incoming/MqttSubscriptionFlowTree.java

Lines changed: 125 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,6 @@
3737
@NotThreadSafe
3838
public class MqttSubscriptionFlowTree implements MqttSubscriptionFlows {
3939

40-
private static final @NotNull ByteArray ROOT_LEVEL = new ByteArray(new byte[0]);
41-
4240
private @Nullable TopicTreeNode rootNode;
4341

4442
@Inject
@@ -48,37 +46,63 @@ public class MqttSubscriptionFlowTree implements MqttSubscriptionFlows {
4846
public void subscribe(
4947
final @NotNull MqttTopicFilterImpl topicFilter, final @Nullable MqttSubscribedPublishFlow flow) {
5048

51-
final MqttTopicLevel level = MqttTopicLevel.root(topicFilter);
5249
final TopicTreeEntry entry = (flow == null) ? null : new TopicTreeEntry(flow, topicFilter);
5350
if (rootNode == null) {
54-
rootNode = new TopicTreeNode(ROOT_LEVEL, level, entry);
55-
} else {
56-
rootNode.subscribe(level, entry);
51+
rootNode = new TopicTreeNode(null, null);
52+
}
53+
final MqttTopicLevel level = MqttTopicLevel.root(topicFilter);
54+
TopicTreeNode node = rootNode;
55+
while (node != null) {
56+
node = node.subscribe(level.next(), entry);
5757
}
5858
}
5959

6060
@Override
6161
public void remove(final @NotNull MqttTopicFilterImpl topicFilter, final @Nullable MqttSubscribedPublishFlow flow) {
62-
if ((rootNode != null) && rootNode.remove(MqttTopicLevel.root(topicFilter), flow)) {
63-
rootNode = null;
62+
final MqttTopicLevel level = MqttTopicLevel.root(topicFilter);
63+
TopicTreeNode node = rootNode;
64+
TopicTreeNode lastNode = null;
65+
while (node != null) {
66+
lastNode = node;
67+
node = node.remove(level.next(), flow);
6468
}
69+
compact(lastNode);
6570
}
6671

6772
@Override
6873
public void unsubscribe(
6974
final @NotNull MqttTopicFilterImpl topicFilter,
7075
final @Nullable Consumer<MqttSubscribedPublishFlow> unsubscribedCallback) {
7176

72-
if ((rootNode != null) && rootNode.unsubscribe(MqttTopicLevel.root(topicFilter), unsubscribedCallback)) {
73-
rootNode = null;
77+
final MqttTopicLevel level = MqttTopicLevel.root(topicFilter);
78+
TopicTreeNode node = rootNode;
79+
TopicTreeNode lastNode = null;
80+
while (node != null) {
81+
lastNode = node;
82+
node = node.unsubscribe(level.next(), unsubscribedCallback);
83+
}
84+
compact(lastNode);
85+
}
86+
87+
private void compact(@Nullable TopicTreeNode lastNode) {
88+
while ((lastNode != null) && lastNode.isEmpty()) {
89+
final TopicTreeNode parentNode = lastNode.parentNode;
90+
if (parentNode == null) {
91+
rootNode = null;
92+
} else {
93+
parentNode.removeNext(lastNode);
94+
}
95+
lastNode = parentNode;
7496
}
7597
}
7698

7799
@Override
78100
public void cancel(final @NotNull MqttSubscribedPublishFlow flow) {
79-
if (rootNode != null) {
80-
for (final MqttTopicFilterImpl topicFilter : flow.getTopicFilters()) {
81-
rootNode.cancel(MqttTopicLevel.root(topicFilter), flow);
101+
for (final MqttTopicFilterImpl topicFilter : flow.getTopicFilters()) {
102+
final MqttTopicLevel level = MqttTopicLevel.root(topicFilter);
103+
TopicTreeNode node = rootNode;
104+
while (node != null) {
105+
node = node.cancel(level.next(), flow);
82106
}
83107
}
84108
}
@@ -92,10 +116,11 @@ public boolean findMatching(
92116

93117
@Override
94118
public void clear(final @NotNull Throwable cause) {
95-
if (rootNode != null) {
96-
rootNode.clear(cause);
97-
rootNode = null;
119+
TopicTreeNode node = rootNode;
120+
while (node != null) {
121+
node = node.clear(cause);
98122
}
123+
rootNode = null;
99124
}
100125

101126
private static class TopicTreeEntry {
@@ -111,23 +136,21 @@ private static class TopicTreeEntry {
111136

112137
private static class TopicTreeNode {
113138

114-
private final @NotNull ByteArray parentLevel;
139+
private final @Nullable TopicTreeNode parentNode;
140+
private final @Nullable ByteArray parentLevel;
115141
private @Nullable HashMap<ByteArray, TopicTreeNode> next;
116-
private @Nullable HandleList<TopicTreeEntry> entries;
142+
private @Nullable TopicTreeNode singleLevel;
117143
private @Nullable HandleList<TopicTreeEntry> multiLevelEntries;
144+
private @Nullable HandleList<TopicTreeEntry> entries;
118145
private int subscriptions;
119146
private int multiLevelSubscriptions;
120-
private boolean hasSingleLevelSubscription;
121-
122-
TopicTreeNode(
123-
final @NotNull ByteArray parentLevel, final @Nullable MqttTopicLevel level,
124-
final @Nullable TopicTreeEntry entry) {
125147

148+
TopicTreeNode(final @Nullable TopicTreeNode parentNode, final @Nullable ByteArray parentLevel) {
149+
this.parentNode = parentNode;
126150
this.parentLevel = parentLevel;
127-
subscribe(level, entry);
128151
}
129152

130-
void subscribe(final @Nullable MqttTopicLevel level, final @Nullable TopicTreeEntry entry) {
153+
@Nullable TopicTreeNode subscribe(final @Nullable MqttTopicLevel level, final @Nullable TopicTreeEntry entry) {
131154
if (level == null) {
132155
if (entry != null) {
133156
if (entries == null) {
@@ -136,56 +159,57 @@ void subscribe(final @Nullable MqttTopicLevel level, final @Nullable TopicTreeEn
136159
entries.add(entry);
137160
}
138161
subscriptions++;
139-
} else if (level.isMultiLevelWildcard()) {
162+
return null;
163+
}
164+
if (level.isMultiLevelWildcard()) {
140165
if (entry != null) {
141166
if (multiLevelEntries == null) {
142167
multiLevelEntries = new HandleList<>();
143168
}
144169
multiLevelEntries.add(entry);
145170
}
146171
multiLevelSubscriptions++;
147-
} else {
148-
final TopicTreeNode node;
149-
if (next == null) {
150-
next = new HashMap<>();
151-
node = null;
152-
} else {
153-
node = next.get(level);
154-
}
155-
if (node == null) {
156-
if (level.isSingleLevelWildcard()) {
157-
hasSingleLevelSubscription = true;
158-
}
159-
final ByteArray levelCopy = level.copy();
160-
next.put(levelCopy, new TopicTreeNode(levelCopy, level.next(), entry));
161-
} else {
162-
node.subscribe(level.next(), entry);
172+
return null;
173+
}
174+
if (level.isSingleLevelWildcard()) {
175+
if (singleLevel == null) {
176+
singleLevel = new TopicTreeNode(this, MqttTopicLevel.SINGLE_LEVEL_WILDCARD);
163177
}
178+
return singleLevel;
164179
}
180+
TopicTreeNode node;
181+
if (next == null) {
182+
next = new HashMap<>();
183+
node = null;
184+
} else {
185+
node = next.get(level);
186+
}
187+
if (node == null) {
188+
final ByteArray levelCopy = level.copy();
189+
node = new TopicTreeNode(this, levelCopy);
190+
next.put(levelCopy, node);
191+
}
192+
return node;
165193
}
166194

167-
boolean remove(final @Nullable MqttTopicLevel level, final @Nullable MqttSubscribedPublishFlow flow) {
195+
@Nullable TopicTreeNode remove(
196+
final @Nullable MqttTopicLevel level, final @Nullable MqttSubscribedPublishFlow flow) {
197+
168198
if (level == null) {
169199
if (remove(entries, flow)) {
170200
entries = null;
171201
}
172202
subscriptions--;
173-
return (subscriptions == 0) && (multiLevelSubscriptions == 0) && (next == null);
203+
return null;
174204
}
175205
if (level.isMultiLevelWildcard()) {
176206
if (remove(multiLevelEntries, flow)) {
177207
multiLevelEntries = null;
178208
}
179209
multiLevelSubscriptions--;
180-
return (subscriptions == 0) && (multiLevelSubscriptions == 0) && (next == null);
210+
return null;
181211
}
182-
if (next != null) {
183-
final TopicTreeNode node = next.get(level);
184-
if ((node != null) && node.remove(level.next(), flow)) {
185-
return removeNext(node);
186-
}
187-
}
188-
return false;
212+
return traverseNext(level);
189213
}
190214

191215
private static boolean remove(
@@ -205,29 +229,23 @@ private static boolean remove(
205229
return false;
206230
}
207231

208-
boolean unsubscribe(
232+
@Nullable TopicTreeNode unsubscribe(
209233
final @Nullable MqttTopicLevel level,
210234
final @Nullable Consumer<MqttSubscribedPublishFlow> unsubscribedCallback) {
211235

212236
if (level == null) {
213237
unsubscribe(entries, unsubscribedCallback);
214238
entries = null;
215239
subscriptions = 0;
216-
return (multiLevelSubscriptions == 0) && (next == null);
240+
return null;
217241
}
218242
if (level.isMultiLevelWildcard()) {
219243
unsubscribe(multiLevelEntries, unsubscribedCallback);
220244
multiLevelEntries = null;
221245
multiLevelSubscriptions = 0;
222-
return (subscriptions == 0) && (next == null);
246+
return null;
223247
}
224-
if (next != null) {
225-
final TopicTreeNode node = next.get(level);
226-
if ((node != null) && node.unsubscribe(level.next(), unsubscribedCallback)) {
227-
return removeNext(node);
228-
}
229-
}
230-
return false;
248+
return traverseNext(level);
231249
}
232250

233251
private static void unsubscribe(
@@ -248,34 +266,22 @@ private static void unsubscribe(
248266
}
249267
}
250268

251-
private boolean removeNext(final @NotNull TopicTreeNode node) {
252-
assert next != null;
253-
if (node.parentLevel == MqttTopicLevel.SINGLE_LEVEL_WILDCARD) {
254-
hasSingleLevelSubscription = false;
255-
}
256-
next.remove(node.parentLevel);
257-
if (next.isEmpty()) {
258-
next = null;
259-
return (subscriptions == 0) && (multiLevelSubscriptions == 0);
260-
}
261-
return false;
262-
}
269+
@Nullable TopicTreeNode cancel(
270+
final @Nullable MqttTopicLevel level, final @NotNull MqttSubscribedPublishFlow flow) {
263271

264-
void cancel(final @Nullable MqttTopicLevel level, final @NotNull MqttSubscribedPublishFlow flow) {
265272
if (level == null) {
266273
if (cancel(entries, flow)) {
267274
entries = null;
268275
}
269-
} else if (level.isMultiLevelWildcard()) {
276+
return null;
277+
}
278+
if (level.isMultiLevelWildcard()) {
270279
if (cancel(multiLevelEntries, flow)) {
271280
multiLevelEntries = null;
272281
}
273-
} else if (next != null) {
274-
final TopicTreeNode node = next.get(level);
275-
if (node != null) {
276-
node.cancel(level.next(), flow);
277-
}
282+
return null;
278283
}
284+
return traverseNext(level);
279285
}
280286

281287
private static boolean cancel(
@@ -306,9 +312,8 @@ boolean findMatching(
306312
add(matchingFlows, multiLevelEntries);
307313
boolean subscriptionFound = (multiLevelSubscriptions != 0);
308314
if (next != null) {
309-
if (hasSingleLevelSubscription) {
310-
final TopicTreeNode singleLevelNode = next.get(MqttTopicLevel.SINGLE_LEVEL_WILDCARD);
311-
subscriptionFound |= singleLevelNode.findMatching(level.fork().next(), matchingFlows);
315+
if (singleLevel != null) {
316+
subscriptionFound |= singleLevel.findMatching(level.fork().next(), matchingFlows);
312317
}
313318
final TopicTreeNode node = next.get(level);
314319
if (node != null) {
@@ -329,7 +334,13 @@ private static void add(
329334
}
330335
}
331336

332-
void clear(final @NotNull Throwable cause) {
337+
@Nullable TopicTreeNode clear(final @NotNull Throwable cause) {
338+
if (next != null) {
339+
return next.values().iterator().next();
340+
}
341+
if (singleLevel != null) {
342+
return singleLevel;
343+
}
333344
if (entries != null) {
334345
for (final TopicTreeEntry entry : entries) {
335346
entry.flow.onError(cause);
@@ -342,13 +353,37 @@ void clear(final @NotNull Throwable cause) {
342353
}
343354
multiLevelEntries = null;
344355
}
356+
if (parentNode != null) {
357+
parentNode.removeNext(this);
358+
}
359+
return parentNode;
360+
}
361+
362+
private @Nullable TopicTreeNode traverseNext(final @NotNull MqttTopicLevel level) {
363+
if (level.isSingleLevelWildcard()) {
364+
return singleLevel;
365+
}
345366
if (next != null) {
346-
next.values().forEach(node -> node.clear(cause));
347-
next = null;
367+
return next.get(level);
348368
}
349-
subscriptions = 0;
350-
multiLevelSubscriptions = 0;
351-
hasSingleLevelSubscription = false;
369+
return null;
370+
}
371+
372+
private void removeNext(final @NotNull TopicTreeNode node) {
373+
assert next != null;
374+
assert node.parentLevel != null;
375+
if (node.parentLevel == MqttTopicLevel.SINGLE_LEVEL_WILDCARD) {
376+
singleLevel = null;
377+
} else {
378+
next.remove(node.parentLevel);
379+
if (next.isEmpty()) {
380+
next = null;
381+
}
382+
}
383+
}
384+
385+
boolean isEmpty() {
386+
return (subscriptions == 0) && (multiLevelSubscriptions == 0) && (singleLevel == null) && (next == null);
352387
}
353388
}
354389
}

0 commit comments

Comments
 (0)