Skip to content

Commit ecaa0f2

Browse files
authored
[IOTDB-4343] Fix session manager in MQTT module. (#7247)
1 parent 326fc81 commit ecaa0f2

File tree

3 files changed

+51
-31
lines changed

3 files changed

+51
-31
lines changed

server/src/main/java/org/apache/iotdb/db/protocol/mqtt/MPPPublishHandler.java

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949

5050
import java.time.ZoneId;
5151
import java.util.List;
52+
import java.util.concurrent.ConcurrentHashMap;
5253

5354
/** PublishHandler handle the messages from MQTT clients. */
5455
public class MPPPublishHandler extends AbstractInterceptHandler {
@@ -57,7 +58,7 @@ public class MPPPublishHandler extends AbstractInterceptHandler {
5758

5859
private static final IoTDBConfig config = IoTDBDescriptor.getInstance().getConfig();
5960
private final SessionManager SESSION_MANAGER = SessionManager.getInstance();
60-
private long sessionId;
61+
private final ConcurrentHashMap<String, Long> clientIdToSessionIdMap = new ConcurrentHashMap<>();
6162
private final PayloadFormatter payloadFormat;
6263
private final IPartitionFetcher partitionFetcher;
6364
private final ISchemaFetcher schemaFetcher;
@@ -75,32 +76,41 @@ public MPPPublishHandler(IoTDBConfig config) {
7576

7677
@Override
7778
public String getID() {
78-
return "iotdb-mqtt-broker-listener-" + sessionId;
79+
return "iotdb-mqtt-broker-listener";
7980
}
8081

8182
@Override
8283
public void onConnect(InterceptConnectMessage msg) {
83-
try {
84-
BasicOpenSessionResp basicOpenSessionResp =
85-
SESSION_MANAGER.openSession(
86-
msg.getUsername(),
87-
new String(msg.getPassword()),
88-
ZoneId.systemDefault().toString(),
89-
TSProtocolVersion.IOTDB_SERVICE_PROTOCOL_V3);
90-
sessionId = basicOpenSessionResp.getSessionId();
91-
} catch (TException e) {
92-
throw new RuntimeException(e);
84+
if (!clientIdToSessionIdMap.containsKey(msg.getClientID())) {
85+
try {
86+
BasicOpenSessionResp basicOpenSessionResp =
87+
SESSION_MANAGER.openSession(
88+
msg.getUsername(),
89+
new String(msg.getPassword()),
90+
ZoneId.systemDefault().toString(),
91+
TSProtocolVersion.IOTDB_SERVICE_PROTOCOL_V3);
92+
clientIdToSessionIdMap.put(msg.getClientID(), basicOpenSessionResp.getSessionId());
93+
} catch (TException e) {
94+
throw new RuntimeException(e);
95+
}
9396
}
9497
}
9598

9699
@Override
97100
public void onDisconnect(InterceptDisconnectMessage msg) {
98-
SESSION_MANAGER.closeSession(sessionId);
101+
Long sessionId = clientIdToSessionIdMap.remove(msg.getClientID());
102+
if (null != sessionId) {
103+
SESSION_MANAGER.closeSession(sessionId);
104+
}
99105
}
100106

101107
@Override
102108
public void onPublish(InterceptPublishMessage msg) {
103109
String clientId = msg.getClientID();
110+
if (!clientIdToSessionIdMap.containsKey(clientId)) {
111+
return;
112+
}
113+
long sessionId = clientIdToSessionIdMap.get(msg.getClientID());
104114
ByteBuf payload = msg.getPayload();
105115
String topic = msg.getTopicName();
106116
String username = msg.getUsername();

server/src/main/java/org/apache/iotdb/db/protocol/mqtt/PublishHandler.java

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,14 @@
3838

3939
import java.time.ZoneId;
4040
import java.util.List;
41+
import java.util.concurrent.ConcurrentHashMap;
4142

4243
/** PublishHandler handle the messages from MQTT clients. */
4344
public class PublishHandler extends AbstractInterceptHandler {
44-
45+
private static final Logger LOG = LoggerFactory.getLogger(PublishHandler.class);
4546
private final SessionManager SESSION_MANAGER = SessionManager.getInstance();
46-
private long sessionId;
4747

48-
private static final Logger LOG = LoggerFactory.getLogger(PublishHandler.class);
48+
private final ConcurrentHashMap<String, Long> clientIdToSessionIdMap = new ConcurrentHashMap<>();
4949

5050
private final PayloadFormatter payloadFormat;
5151

@@ -59,32 +59,41 @@ protected PublishHandler(PayloadFormatter payloadFormat) {
5959

6060
@Override
6161
public String getID() {
62-
return "iotdb-mqtt-broker-listener-" + sessionId;
62+
return "iotdb-mqtt-broker-listener";
6363
}
6464

6565
@Override
6666
public void onConnect(InterceptConnectMessage msg) {
67-
try {
68-
BasicOpenSessionResp basicOpenSessionResp =
69-
SESSION_MANAGER.openSession(
70-
msg.getUsername(),
71-
new String(msg.getPassword()),
72-
ZoneId.systemDefault().toString(),
73-
TSProtocolVersion.IOTDB_SERVICE_PROTOCOL_V3);
74-
sessionId = basicOpenSessionResp.getSessionId();
75-
} catch (TException e) {
76-
throw new RuntimeException(e);
67+
if (!clientIdToSessionIdMap.containsKey(msg.getClientID())) {
68+
try {
69+
BasicOpenSessionResp basicOpenSessionResp =
70+
SESSION_MANAGER.openSession(
71+
msg.getUsername(),
72+
new String(msg.getPassword()),
73+
ZoneId.systemDefault().toString(),
74+
TSProtocolVersion.IOTDB_SERVICE_PROTOCOL_V3);
75+
clientIdToSessionIdMap.put(msg.getClientID(), basicOpenSessionResp.getSessionId());
76+
} catch (TException e) {
77+
throw new RuntimeException(e);
78+
}
7779
}
7880
}
7981

8082
@Override
8183
public void onDisconnect(InterceptDisconnectMessage msg) {
82-
SESSION_MANAGER.closeSession(sessionId);
84+
Long sessionId = clientIdToSessionIdMap.remove(msg.getClientID());
85+
if (null != sessionId) {
86+
SESSION_MANAGER.closeSession(sessionId);
87+
}
8388
}
8489

8590
@Override
8691
public void onPublish(InterceptPublishMessage msg) {
8792
String clientId = msg.getClientID();
93+
if (!clientIdToSessionIdMap.containsKey(clientId)) {
94+
return;
95+
}
96+
long sessionId = clientIdToSessionIdMap.get(msg.getClientID());
8897
ByteBuf payload = msg.getPayload();
8998
String topic = msg.getTopicName();
9099
String username = msg.getUsername();

server/src/test/java/org/apache/iotdb/db/protocol/mqtt/PublishHandlerTest.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ public static void tearDown() throws Exception {
6363
public void onPublish() throws ClassNotFoundException {
6464
PayloadFormatter payloadFormat = PayloadFormatManager.getPayloadFormat("json");
6565
PublishHandler handler = new PublishHandler(payloadFormat);
66+
String clientId = "clientId";
6667

6768
String payload =
6869
"{\n"
@@ -77,7 +78,7 @@ public void onPublish() throws ClassNotFoundException {
7778
// connect
7879
MqttConnectPayload mqttConnectPayload =
7980
new MqttConnectPayload(
80-
null,
81+
clientId,
8182
null,
8283
"test".getBytes(StandardCharsets.UTF_8),
8384
"root",
@@ -92,12 +93,12 @@ public void onPublish() throws ClassNotFoundException {
9293
MqttFixedHeader fixedHeader =
9394
new MqttFixedHeader(MqttMessageType.PUBLISH, false, MqttQoS.AT_LEAST_ONCE, false, 1);
9495
MqttPublishMessage publishMessage = new MqttPublishMessage(fixedHeader, variableHeader, buf);
95-
InterceptPublishMessage message = new InterceptPublishMessage(publishMessage, null, null);
96+
InterceptPublishMessage message = new InterceptPublishMessage(publishMessage, clientId, null);
9697
handler.onPublish(message);
9798

9899
// disconnect
99100
InterceptDisconnectMessage interceptDisconnectMessage =
100-
new InterceptDisconnectMessage(null, null);
101+
new InterceptDisconnectMessage(clientId, null);
101102
handler.onDisconnect(interceptDisconnectMessage);
102103

103104
String[] retArray = new String[] {"1586076045524,0.530635,"};

0 commit comments

Comments
 (0)