Skip to content

Commit b638556

Browse files
committed
Close currentChunkedWrite on client cancel (#105258)
If the client closes the channel while we're in the middle of a chunked write then today we don't complete the corresponding listener. This commit fixes the problem.
1 parent b20e2fc commit b638556

File tree

3 files changed

+283
-0
lines changed

3 files changed

+283
-0
lines changed

docs/changelog/105258.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 105258
2+
summary: Close `currentChunkedWrite` on client cancel
3+
area: Network
4+
type: bug
5+
issues: []
Lines changed: 274 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,274 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0 and the Server Side Public License, v 1; you may not use this file except
5+
* in compliance with, at your election, the Elastic License 2.0 or the Server
6+
* Side Public License, v 1.
7+
*/
8+
9+
package org.elasticsearch.http.netty4;
10+
11+
import org.apache.lucene.util.BytesRef;
12+
import org.elasticsearch.ESNetty4IntegTestCase;
13+
import org.elasticsearch.client.Request;
14+
import org.elasticsearch.client.Response;
15+
import org.elasticsearch.client.ResponseListener;
16+
import org.elasticsearch.client.internal.node.NodeClient;
17+
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
18+
import org.elasticsearch.cluster.node.DiscoveryNodes;
19+
import org.elasticsearch.common.Strings;
20+
import org.elasticsearch.common.bytes.BytesArray;
21+
import org.elasticsearch.common.bytes.BytesReference;
22+
import org.elasticsearch.common.bytes.ReleasableBytesReference;
23+
import org.elasticsearch.common.collect.Iterators;
24+
import org.elasticsearch.common.io.Streams;
25+
import org.elasticsearch.common.recycler.Recycler;
26+
import org.elasticsearch.common.settings.ClusterSettings;
27+
import org.elasticsearch.common.settings.IndexScopedSettings;
28+
import org.elasticsearch.common.settings.Settings;
29+
import org.elasticsearch.common.settings.SettingsFilter;
30+
import org.elasticsearch.common.util.CollectionUtils;
31+
import org.elasticsearch.core.AbstractRefCounted;
32+
import org.elasticsearch.core.RefCounted;
33+
import org.elasticsearch.core.Releasable;
34+
import org.elasticsearch.plugins.ActionPlugin;
35+
import org.elasticsearch.plugins.Plugin;
36+
import org.elasticsearch.rest.BaseRestHandler;
37+
import org.elasticsearch.rest.ChunkedRestResponseBody;
38+
import org.elasticsearch.rest.RestChannel;
39+
import org.elasticsearch.rest.RestController;
40+
import org.elasticsearch.rest.RestHandler;
41+
import org.elasticsearch.rest.RestRequest;
42+
import org.elasticsearch.rest.RestResponse;
43+
import org.elasticsearch.rest.RestStatus;
44+
import org.elasticsearch.tasks.TaskCancelledException;
45+
46+
import java.io.IOException;
47+
import java.io.InputStreamReader;
48+
import java.nio.charset.StandardCharsets;
49+
import java.util.Collection;
50+
import java.util.Collections;
51+
import java.util.Iterator;
52+
import java.util.List;
53+
import java.util.concurrent.CancellationException;
54+
import java.util.concurrent.CountDownLatch;
55+
import java.util.function.Supplier;
56+
57+
import static org.elasticsearch.rest.RestRequest.Method.GET;
58+
import static org.elasticsearch.rest.RestResponse.TEXT_CONTENT_TYPE;
59+
import static org.hamcrest.Matchers.containsString;
60+
import static org.hamcrest.Matchers.instanceOf;
61+
62+
public class Netty4ChunkedEncodingIT extends ESNetty4IntegTestCase {
63+
64+
@Override
65+
protected Collection<Class<? extends Plugin>> nodePlugins() {
66+
return CollectionUtils.concatLists(List.of(YieldsChunksPlugin.class), super.nodePlugins());
67+
}
68+
69+
@Override
70+
protected boolean addMockHttpTransport() {
71+
return false; // enable http
72+
}
73+
74+
private static final String EXPECTED_NONEMPTY_BODY = """
75+
chunk-0
76+
chunk-1
77+
chunk-2
78+
""";
79+
80+
public void testNonemptyResponse() throws IOException {
81+
getAndCheckBodyContents(YieldsChunksPlugin.CHUNKS_ROUTE, EXPECTED_NONEMPTY_BODY);
82+
}
83+
84+
public void testEmptyResponse() throws IOException {
85+
getAndCheckBodyContents(YieldsChunksPlugin.EMPTY_ROUTE, "");
86+
}
87+
88+
private static void getAndCheckBodyContents(String route, String expectedBody) throws IOException {
89+
try (var ignored = withResourceTracker()) {
90+
final var response = getRestClient().performRequest(new Request("GET", route));
91+
assertEquals(200, response.getStatusLine().getStatusCode());
92+
assertThat(response.getEntity().getContentType().toString(), containsString(TEXT_CONTENT_TYPE));
93+
if (Strings.hasLength(expectedBody)) {
94+
assertTrue(response.getEntity().isChunked());
95+
} // else we might have no chunks to send which doesn't need chunked-encoding
96+
final String body;
97+
try (var reader = new InputStreamReader(response.getEntity().getContent(), StandardCharsets.UTF_8)) {
98+
body = Streams.copyToString(reader);
99+
}
100+
assertEquals(expectedBody, body);
101+
}
102+
}
103+
104+
public void testClientCancellation() {
105+
try (var ignored = withResourceTracker()) {
106+
final var cancellable = getRestClient().performRequestAsync(
107+
new Request("GET", YieldsChunksPlugin.INFINITE_ROUTE),
108+
new ResponseListener() {
109+
@Override
110+
public void onSuccess(Response response) {
111+
fail("should not complete");
112+
}
113+
114+
@Override
115+
public void onFailure(Exception exception) {
116+
assertThat(exception, instanceOf(CancellationException.class));
117+
}
118+
}
119+
);
120+
if (randomBoolean()) {
121+
safeSleep(scaledRandomIntBetween(10, 500));
122+
}
123+
cancellable.cancel();
124+
}
125+
}
126+
127+
private static Releasable withResourceTracker() {
128+
assertNull(refs);
129+
final var latch = new CountDownLatch(1);
130+
refs = AbstractRefCounted.of(latch::countDown);
131+
return () -> {
132+
refs.decRef();
133+
try {
134+
safeAwait(latch);
135+
} finally {
136+
refs = null;
137+
}
138+
};
139+
}
140+
141+
private static volatile RefCounted refs = null;
142+
143+
public static class YieldsChunksPlugin extends Plugin implements ActionPlugin {
144+
static final String CHUNKS_ROUTE = "/_test/yields_chunks";
145+
static final String EMPTY_ROUTE = "/_test/yields_only_empty_chunks";
146+
static final String INFINITE_ROUTE = "/_test/yields_infinite_chunks";
147+
148+
private static Iterator<BytesReference> emptyChunks() {
149+
return Collections.emptyIterator(); // support for empty chunks added in #104837; this test suite backported without that
150+
}
151+
152+
@Override
153+
public List<RestHandler> getRestHandlers(
154+
Settings settings,
155+
RestController restController,
156+
ClusterSettings clusterSettings,
157+
IndexScopedSettings indexScopedSettings,
158+
SettingsFilter settingsFilter,
159+
IndexNameExpressionResolver indexNameExpressionResolver,
160+
Supplier<DiscoveryNodes> nodesInCluster
161+
) {
162+
return List.of(
163+
// 3 nonempty chunks, with some random empty chunks in between
164+
new BaseRestHandler() {
165+
@Override
166+
public String getName() {
167+
return CHUNKS_ROUTE;
168+
}
169+
170+
@Override
171+
public List<Route> routes() {
172+
return List.of(new Route(GET, CHUNKS_ROUTE));
173+
}
174+
175+
@Override
176+
protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) {
177+
return channel -> sendChunksResponse(
178+
channel,
179+
Iterators.concat(
180+
emptyChunks(),
181+
Iterators.flatMap(
182+
Iterators.forRange(0, 3, i -> "chunk-" + i + '\n'),
183+
chunk -> Iterators.concat(Iterators.single(new BytesArray(chunk)), emptyChunks())
184+
)
185+
)
186+
);
187+
}
188+
},
189+
190+
// only a few random empty chunks
191+
new BaseRestHandler() {
192+
@Override
193+
public String getName() {
194+
return EMPTY_ROUTE;
195+
}
196+
197+
@Override
198+
public List<Route> routes() {
199+
return List.of(new Route(GET, EMPTY_ROUTE));
200+
}
201+
202+
@Override
203+
protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) {
204+
return channel -> sendChunksResponse(channel, emptyChunks());
205+
}
206+
},
207+
208+
// keeps on emitting chunks until cancelled
209+
new BaseRestHandler() {
210+
@Override
211+
public String getName() {
212+
return INFINITE_ROUTE;
213+
}
214+
215+
@Override
216+
public List<Route> routes() {
217+
return List.of(new Route(GET, INFINITE_ROUTE));
218+
}
219+
220+
@Override
221+
protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) {
222+
return channel -> sendChunksResponse(channel, new Iterator<>() {
223+
private static final BytesReference CHUNK = new BytesArray("CHUNK\n");
224+
225+
@Override
226+
public boolean hasNext() {
227+
return true;
228+
}
229+
230+
@Override
231+
public BytesReference next() {
232+
return CHUNK;
233+
}
234+
});
235+
}
236+
}
237+
);
238+
}
239+
240+
private static void sendChunksResponse(RestChannel channel, Iterator<BytesReference> chunkIterator) {
241+
final var localRefs = refs; // single volatile read
242+
if (localRefs != null && localRefs.tryIncRef()) {
243+
channel.sendResponse(RestResponse.chunked(RestStatus.OK, new ChunkedRestResponseBody() {
244+
@Override
245+
public boolean isDone() {
246+
return chunkIterator.hasNext() == false;
247+
}
248+
249+
@Override
250+
public ReleasableBytesReference encodeChunk(int sizeHint, Recycler<BytesRef> recycler) {
251+
localRefs.mustIncRef();
252+
return new ReleasableBytesReference(chunkIterator.next(), localRefs::decRef);
253+
}
254+
255+
@Override
256+
public String getResponseContentTypeString() {
257+
return TEXT_CONTENT_TYPE;
258+
}
259+
260+
@Override
261+
public void close() {
262+
localRefs.decRef();
263+
}
264+
}));
265+
} else {
266+
try {
267+
channel.sendResponse(new RestResponse(channel, new TaskCancelledException("task cancelled")));
268+
} catch (IOException e) {
269+
fail(e);
270+
}
271+
}
272+
}
273+
}
274+
}

modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpPipeliningHandler.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,10 @@ private void failQueuedWrites() {
329329
while ((queuedWrite = queuedWrites.poll()) != null) {
330330
queuedWrite.failAsClosedChannel();
331331
}
332+
if (currentChunkedWrite != null) {
333+
safeFailPromise(currentChunkedWrite.onDone, new ClosedChannelException());
334+
currentChunkedWrite = null;
335+
}
332336
}
333337

334338
@Override

0 commit comments

Comments
 (0)