Skip to content

Commit 124d91f

Browse files
committed
Wait for remote close on disconnect
1 parent d60727f commit 124d91f

File tree

2 files changed

+70
-30
lines changed

2 files changed

+70
-30
lines changed

src/main/java/com/hivemq/client/internal/mqtt/handler/disconnect/MqttDisconnectHandler.java

Lines changed: 70 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,16 @@
3636
import com.hivemq.client.mqtt.lifecycle.MqttDisconnectSource;
3737
import com.hivemq.client.mqtt.mqtt5.auth.Mqtt5EnhancedAuthMechanism;
3838
import com.hivemq.client.mqtt.mqtt5.exceptions.Mqtt5DisconnectException;
39+
import io.netty.channel.Channel;
3940
import io.netty.channel.ChannelHandlerContext;
4041
import io.netty.channel.EventLoop;
42+
import io.netty.channel.socket.DuplexChannel;
43+
import io.netty.util.concurrent.ScheduledFuture;
4144
import org.jetbrains.annotations.NotNull;
45+
import org.jetbrains.annotations.Nullable;
4246

4347
import javax.inject.Inject;
48+
import java.util.concurrent.TimeUnit;
4449

4550
import static com.hivemq.client.internal.mqtt.handler.disconnect.MqttDisconnectUtil.fireDisconnectEvent;
4651

@@ -56,12 +61,13 @@
5661
@ConnectionScope
5762
public class MqttDisconnectHandler extends MqttConnectionAwareHandler {
5863

59-
private static final @NotNull InternalLogger LOGGER = InternalLoggerFactory.getLogger(MqttDisconnectHandler.class);
6064
public static final @NotNull String NAME = "disconnect";
65+
private static final @NotNull InternalLogger LOGGER = InternalLoggerFactory.getLogger(MqttDisconnectHandler.class);
66+
private static final int DISCONNECT_TIMEOUT = 10; // TODO configurable
6167

6268
private final @NotNull MqttClientConfig clientConfig;
6369
private final @NotNull MqttSession session;
64-
private boolean once = true;
70+
private @Nullable State state = null;
6571

6672
@Inject
6773
MqttDisconnectHandler(final @NotNull MqttClientConfig clientConfig, final @NotNull MqttSession session) {
@@ -79,8 +85,8 @@ public void channelRead(final @NotNull ChannelHandlerContext ctx, final @NotNull
7985
}
8086

8187
private void readDisconnect(final @NotNull ChannelHandlerContext ctx, final @NotNull MqttDisconnect disconnect) {
82-
if (once) {
83-
once = false;
88+
if (state == null) {
89+
state = State.CLOSED;
8490
fireDisconnectEvent(ctx.channel(), new Mqtt5DisconnectException(disconnect, "Server sent DISCONNECT."),
8591
MqttDisconnectSource.SERVER);
8692
}
@@ -89,18 +95,23 @@ private void readDisconnect(final @NotNull ChannelHandlerContext ctx, final @Not
8995
@Override
9096
public void channelInactive(final @NotNull ChannelHandlerContext ctx) {
9197
ctx.fireChannelInactive();
92-
if (once) {
93-
once = false;
94-
fireDisconnectEvent(ctx.channel(),
95-
new ConnectionClosedException("Server closed connection without DISCONNECT."),
98+
if (state == null) {
99+
state = State.CLOSED;
100+
fireDisconnectEvent(ctx.channel(), new ConnectionClosedException("Server closed connection without DISCONNECT."),
96101
MqttDisconnectSource.SERVER);
102+
} else if (state instanceof DisconnectingState) {
103+
final DisconnectingState disconnectingState = (DisconnectingState) state;
104+
state = State.CLOSED;
105+
disconnectingState.timeoutFuture.cancel(false);
106+
disconnected(disconnectingState.channel, disconnectingState.disconnectEvent);
107+
disconnectingState.disconnectEvent.getFlow().onComplete();
97108
}
98109
}
99110

100111
@Override
101112
public void exceptionCaught(final @NotNull ChannelHandlerContext ctx, final @NotNull Throwable cause) {
102-
if (once) {
103-
once = false;
113+
if (state == null) {
114+
state = State.CLOSED;
104115
fireDisconnectEvent(ctx.channel(), new ConnectionClosedException(cause), MqttDisconnectSource.CLIENT);
105116
} else {
106117
LOGGER.error("Exception while disconnecting.", cause);
@@ -115,8 +126,8 @@ public void disconnect(final @NotNull MqttDisconnect disconnect, final @NotNull
115126

116127
private void writeDisconnect(final @NotNull MqttDisconnect disconnect, final @NotNull CompletableFlow flow) {
117128
final ChannelHandlerContext ctx = this.ctx;
118-
if ((ctx != null) && once) {
119-
once = false;
129+
if ((ctx != null) && (state == null)) {
130+
state = State.CLOSED;
120131
fireDisconnectEvent(ctx.channel(), new MqttDisconnectEvent.ByUser(disconnect, flow));
121132
} else {
122133
flow.onError(MqttClientStateExceptions.notConnected());
@@ -130,11 +141,13 @@ protected void onDisconnectEvent(final @NotNull MqttDisconnectEvent disconnectEv
130141
return;
131142
}
132143
super.onDisconnectEvent(disconnectEvent);
133-
once = false;
144+
state = State.CLOSED;
145+
146+
final Channel channel = ctx.channel();
134147

135148
if (disconnectEvent.getSource() == MqttDisconnectSource.SERVER) {
136-
disconnected(ctx, disconnectEvent);
137-
ctx.channel().close();
149+
disconnected(channel, disconnectEvent);
150+
channel.close();
138151
return;
139152
}
140153

@@ -156,36 +169,41 @@ protected void onDisconnectEvent(final @NotNull MqttDisconnectEvent disconnectEv
156169
}
157170

158171
if (disconnectEvent instanceof MqttDisconnectEvent.ByUser) {
159-
final CompletableFlow flow = ((MqttDisconnectEvent.ByUser) disconnectEvent).getFlow();
160-
ctx.writeAndFlush(disconnect).addListener(f -> ctx.channel().close().addListener(cf -> {
161-
disconnected(ctx, disconnectEvent);
172+
final MqttDisconnectEvent.ByUser disconnectEventByUser = (MqttDisconnectEvent.ByUser) disconnectEvent;
173+
ctx.writeAndFlush(disconnect).addListener(f -> {
162174
if (f.isSuccess()) {
163-
flow.onComplete();
175+
((DuplexChannel) channel).shutdownOutput().addListener(cf -> {
176+
if (cf.isSuccess()) {
177+
state = new DisconnectingState(channel, disconnectEventByUser);
178+
} else {
179+
disconnected(channel, disconnectEvent);
180+
disconnectEventByUser.getFlow().onError(new ConnectionClosedException(cf.cause()));
181+
}
182+
});
164183
} else {
165-
flow.onError(new ConnectionClosedException(f.cause()));
184+
disconnected(channel, disconnectEvent);
185+
disconnectEventByUser.getFlow().onError(new ConnectionClosedException(f.cause()));
166186
}
167-
}));
187+
});
168188

169189
} else if (clientConfig.getMqttVersion() == MqttVersion.MQTT_5_0) {
170190
ctx.writeAndFlush(disconnect)
171-
.addListener(f -> ctx.channel().close().addListener(cf -> disconnected(ctx, disconnectEvent)));
191+
.addListener(f -> channel.close().addListener(cf -> disconnected(channel, disconnectEvent)));
172192

173193
} else {
174-
ctx.channel().close().addListener(cf -> disconnected(ctx, disconnectEvent));
194+
channel.close().addListener(cf -> disconnected(channel, disconnectEvent));
175195
}
176196
} else {
177-
ctx.channel().close().addListener(cf -> disconnected(ctx, disconnectEvent));
197+
channel.close().addListener(cf -> disconnected(channel, disconnectEvent));
178198
}
179199
}
180200

181-
private void disconnected(
182-
final @NotNull ChannelHandlerContext ctx, final @NotNull MqttDisconnectEvent disconnectEvent) {
183-
201+
private void disconnected(final @NotNull Channel channel, final @NotNull MqttDisconnectEvent disconnectEvent) {
184202
final MqttClientConnectionConfig connectionConfig = clientConfig.getRawConnectionConfig();
185203
if (connectionConfig != null) {
186-
session.expire(disconnectEvent.getCause(), connectionConfig, ctx.channel().eventLoop());
204+
session.expire(disconnectEvent.getCause(), connectionConfig, channel.eventLoop());
187205

188-
reconnect(disconnectEvent, connectionConfig, ctx.channel().eventLoop());
206+
reconnect(disconnectEvent, connectionConfig, channel.eventLoop());
189207

190208
clientConfig.setConnectionConfig(null);
191209
}
@@ -231,4 +249,27 @@ public void channelUnregistered(final @NotNull ChannelHandlerContext ctx) {
231249
public boolean isSharable() {
232250
return false;
233251
}
252+
253+
private static class State {
254+
255+
static final @NotNull State CLOSED = new State();
256+
}
257+
258+
private static class DisconnectingState extends State implements Runnable {
259+
260+
private final @NotNull Channel channel;
261+
private final @NotNull MqttDisconnectEvent.ByUser disconnectEvent;
262+
private final @NotNull ScheduledFuture<?> timeoutFuture;
263+
264+
DisconnectingState(final @NotNull Channel channel, final @NotNull MqttDisconnectEvent.ByUser disconnectEvent) {
265+
this.channel = channel;
266+
this.disconnectEvent = disconnectEvent;
267+
timeoutFuture = channel.eventLoop().schedule(this, DISCONNECT_TIMEOUT, TimeUnit.SECONDS);
268+
}
269+
270+
@Override
271+
public void run() {
272+
channel.close();
273+
}
274+
}
234275
}

src/main/java/com/hivemq/client/internal/mqtt/handler/disconnect/MqttDisconnectUtil.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,6 @@ static void fireDisconnectEvent(
9797
static void fireDisconnectEvent(
9898
final @NotNull Channel channel, final @NotNull MqttDisconnectEvent disconnectEvent) {
9999

100-
channel.config().setAutoRead(false);
101100
channel.pipeline().fireUserEventTriggered(disconnectEvent);
102101
}
103102

0 commit comments

Comments
 (0)