Skip to content

Commit 2ac20c0

Browse files
Add mechanism for early http header validation (#92220)
This introduces a way to validate HTTP headers prior to reading the request body. Co-authored-by: Albert Zaharovits <[email protected]>
1 parent d5e93a8 commit 2ac20c0

File tree

9 files changed

+890
-25
lines changed

9 files changed

+890
-25
lines changed
Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0 and the Server Side Public License, v 1; you may not use this file except
5+
* in compliance with, at your election, the Elastic License 2.0 or the Server
6+
* Side Public License, v 1.
7+
*/
8+
9+
package org.elasticsearch.http.netty4;
10+
11+
import io.netty.buffer.Unpooled;
12+
import io.netty.channel.Channel;
13+
import io.netty.channel.ChannelHandlerContext;
14+
import io.netty.channel.ChannelInboundHandlerAdapter;
15+
import io.netty.handler.codec.DecoderResult;
16+
import io.netty.handler.codec.http.HttpContent;
17+
import io.netty.handler.codec.http.HttpObject;
18+
import io.netty.handler.codec.http.HttpRequest;
19+
import io.netty.handler.codec.http.LastHttpContent;
20+
import io.netty.util.ReferenceCountUtil;
21+
22+
import org.elasticsearch.action.ActionListener;
23+
import org.elasticsearch.common.TriConsumer;
24+
25+
import java.util.ArrayDeque;
26+
27+
import static org.elasticsearch.http.netty4.Netty4HttpHeaderValidator.State.DROPPING_DATA_PERMANENTLY;
28+
import static org.elasticsearch.http.netty4.Netty4HttpHeaderValidator.State.DROPPING_DATA_UNTIL_NEXT_REQUEST;
29+
import static org.elasticsearch.http.netty4.Netty4HttpHeaderValidator.State.FORWARDING_DATA_UNTIL_NEXT_REQUEST;
30+
import static org.elasticsearch.http.netty4.Netty4HttpHeaderValidator.State.QUEUEING_DATA;
31+
import static org.elasticsearch.http.netty4.Netty4HttpHeaderValidator.State.WAITING_TO_START;
32+
33+
public class Netty4HttpHeaderValidator extends ChannelInboundHandlerAdapter {
34+
35+
public static final TriConsumer<HttpRequest, Channel, ActionListener<Void>> NOOP_VALIDATOR = ((
36+
httpRequest,
37+
channel,
38+
listener) -> listener.onResponse(null));
39+
40+
private final TriConsumer<HttpRequest, Channel, ActionListener<Void>> validator;
41+
private ArrayDeque<HttpObject> pending = new ArrayDeque<>(4);
42+
private State state = WAITING_TO_START;
43+
44+
public Netty4HttpHeaderValidator(TriConsumer<HttpRequest, Channel, ActionListener<Void>> validator) {
45+
this.validator = validator;
46+
}
47+
48+
State getState() {
49+
return state;
50+
}
51+
52+
@SuppressWarnings("fallthrough")
53+
@Override
54+
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
55+
assert msg instanceof HttpObject;
56+
final HttpObject httpObject = (HttpObject) msg;
57+
58+
switch (state) {
59+
case WAITING_TO_START:
60+
assert pending.isEmpty();
61+
pending.add(ReferenceCountUtil.retain(httpObject));
62+
requestStart(ctx);
63+
assert state == QUEUEING_DATA;
64+
break;
65+
case QUEUEING_DATA:
66+
pending.add(ReferenceCountUtil.retain(httpObject));
67+
break;
68+
case FORWARDING_DATA_UNTIL_NEXT_REQUEST:
69+
assert pending.isEmpty();
70+
if (httpObject instanceof LastHttpContent) {
71+
state = WAITING_TO_START;
72+
}
73+
ctx.fireChannelRead(httpObject);
74+
break;
75+
case DROPPING_DATA_UNTIL_NEXT_REQUEST:
76+
assert pending.isEmpty();
77+
if (httpObject instanceof LastHttpContent) {
78+
state = WAITING_TO_START;
79+
}
80+
// fall-through
81+
case DROPPING_DATA_PERMANENTLY:
82+
assert pending.isEmpty();
83+
ReferenceCountUtil.release(httpObject); // consume without enqueuing
84+
break;
85+
}
86+
87+
setAutoReadForState(ctx, state);
88+
}
89+
90+
private void requestStart(ChannelHandlerContext ctx) {
91+
assert state == WAITING_TO_START;
92+
93+
if (pending.isEmpty()) {
94+
return;
95+
}
96+
97+
final HttpObject httpObject = pending.getFirst();
98+
final HttpRequest httpRequest;
99+
if (httpObject instanceof HttpRequest && httpObject.decoderResult().isSuccess()) {
100+
// a properly decoded HTTP start message is expected to begin validation
101+
// anything else is probably an error that the downstream HTTP message aggregator will have to handle
102+
httpRequest = (HttpRequest) httpObject;
103+
} else {
104+
httpRequest = null;
105+
}
106+
107+
state = QUEUEING_DATA;
108+
109+
if (httpRequest == null) {
110+
// this looks like a malformed request and will forward without validation
111+
ctx.channel().eventLoop().submit(() -> forwardFullRequest(ctx));
112+
} else {
113+
validator.apply(httpRequest, ctx.channel(), new ActionListener<>() {
114+
@Override
115+
public void onResponse(Void unused) {
116+
// Always use "Submit" to prevent reentrancy concerns if we are still on event loop
117+
ctx.channel().eventLoop().submit(() -> forwardFullRequest(ctx));
118+
}
119+
120+
@Override
121+
public void onFailure(Exception e) {
122+
// Always use "Submit" to prevent reentrancy concerns if we are still on event loop
123+
ctx.channel().eventLoop().submit(() -> forwardRequestWithDecoderExceptionAndNoContent(ctx, e));
124+
}
125+
});
126+
}
127+
}
128+
129+
private void forwardFullRequest(ChannelHandlerContext ctx) {
130+
assert ctx.channel().eventLoop().inEventLoop();
131+
assert ctx.channel().config().isAutoRead() == false;
132+
assert state == QUEUEING_DATA;
133+
134+
boolean fullRequestForwarded = forwardData(ctx, pending);
135+
136+
assert fullRequestForwarded || pending.isEmpty();
137+
if (fullRequestForwarded) {
138+
state = WAITING_TO_START;
139+
requestStart(ctx);
140+
} else {
141+
state = FORWARDING_DATA_UNTIL_NEXT_REQUEST;
142+
}
143+
144+
assert state == WAITING_TO_START || state == QUEUEING_DATA || state == FORWARDING_DATA_UNTIL_NEXT_REQUEST;
145+
setAutoReadForState(ctx, state);
146+
}
147+
148+
private void forwardRequestWithDecoderExceptionAndNoContent(ChannelHandlerContext ctx, Exception e) {
149+
assert ctx.channel().eventLoop().inEventLoop();
150+
assert ctx.channel().config().isAutoRead() == false;
151+
assert state == QUEUEING_DATA;
152+
153+
HttpObject messageToForward = pending.getFirst();
154+
boolean fullRequestDropped = dropData(pending);
155+
if (messageToForward instanceof HttpContent toReplace) {
156+
// if the request to forward contained data (which got dropped), replace with empty data
157+
messageToForward = toReplace.replace(Unpooled.EMPTY_BUFFER);
158+
}
159+
messageToForward.setDecoderResult(DecoderResult.failure(e));
160+
ctx.fireChannelRead(messageToForward);
161+
162+
assert fullRequestDropped || pending.isEmpty();
163+
if (fullRequestDropped) {
164+
state = WAITING_TO_START;
165+
requestStart(ctx);
166+
} else {
167+
state = DROPPING_DATA_UNTIL_NEXT_REQUEST;
168+
}
169+
170+
assert state == WAITING_TO_START || state == QUEUEING_DATA || state == DROPPING_DATA_UNTIL_NEXT_REQUEST;
171+
setAutoReadForState(ctx, state);
172+
}
173+
174+
@Override
175+
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
176+
state = DROPPING_DATA_PERMANENTLY;
177+
while (true) {
178+
if (dropData(pending) == false) {
179+
break;
180+
}
181+
}
182+
super.channelInactive(ctx);
183+
}
184+
185+
private static boolean forwardData(ChannelHandlerContext ctx, ArrayDeque<HttpObject> pending) {
186+
final int pendingMessages = pending.size();
187+
try {
188+
HttpObject toForward;
189+
while ((toForward = pending.poll()) != null) {
190+
ctx.fireChannelRead(toForward);
191+
ReferenceCountUtil.release(toForward); // reference cnt incremented when enqueued
192+
if (toForward instanceof LastHttpContent) {
193+
return true;
194+
}
195+
}
196+
return false;
197+
} finally {
198+
maybeResizePendingDown(pendingMessages, pending);
199+
}
200+
}
201+
202+
private static boolean dropData(ArrayDeque<HttpObject> pending) {
203+
final int pendingMessages = pending.size();
204+
try {
205+
HttpObject toDrop;
206+
while ((toDrop = pending.poll()) != null) {
207+
ReferenceCountUtil.release(toDrop, 2); // 1 for enqueuing, 1 for consuming
208+
if (toDrop instanceof LastHttpContent) {
209+
return true;
210+
}
211+
}
212+
return false;
213+
} finally {
214+
maybeResizePendingDown(pendingMessages, pending);
215+
}
216+
}
217+
218+
private static void maybeResizePendingDown(int largeSize, ArrayDeque<HttpObject> pending) {
219+
if (pending.size() <= 4 && largeSize > 32) {
220+
// Prevent the ArrayDeque from becoming forever large due to a single large message.
221+
ArrayDeque<HttpObject> old = pending;
222+
pending = new ArrayDeque<>(4);
223+
pending.addAll(old);
224+
}
225+
}
226+
227+
private static void setAutoReadForState(ChannelHandlerContext ctx, State state) {
228+
ctx.channel().config().setAutoRead((state == QUEUEING_DATA || state == DROPPING_DATA_PERMANENTLY) == false);
229+
}
230+
231+
enum State {
232+
WAITING_TO_START,
233+
QUEUEING_DATA,
234+
FORWARDING_DATA_UNTIL_NEXT_REQUEST,
235+
DROPPING_DATA_UNTIL_NEXT_REQUEST,
236+
DROPPING_DATA_PERMANENTLY
237+
}
238+
}

modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpServerTransport.java

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import io.netty.handler.codec.http.HttpContentCompressor;
2424
import io.netty.handler.codec.http.HttpContentDecompressor;
2525
import io.netty.handler.codec.http.HttpObjectAggregator;
26+
import io.netty.handler.codec.http.HttpRequest;
2627
import io.netty.handler.codec.http.HttpRequestDecoder;
2728
import io.netty.handler.codec.http.HttpResponse;
2829
import io.netty.handler.codec.http.HttpResponseEncoder;
@@ -35,6 +36,8 @@
3536
import org.apache.logging.log4j.LogManager;
3637
import org.apache.logging.log4j.Logger;
3738
import org.elasticsearch.ExceptionsHelper;
39+
import org.elasticsearch.action.ActionListener;
40+
import org.elasticsearch.common.TriConsumer;
3841
import org.elasticsearch.common.network.CloseableChannel;
3942
import org.elasticsearch.common.network.NetworkService;
4043
import org.elasticsearch.common.settings.ClusterSettings;
@@ -143,6 +146,7 @@ public class Netty4HttpServerTransport extends AbstractHttpServerTransport {
143146
private final RecvByteBufAllocator recvByteBufAllocator;
144147
private final TLSConfig tlsConfig;
145148
private final AcceptChannelHandler.AcceptPredicate acceptChannelPredicate;
149+
private final TriConsumer<HttpRequest, Channel, ActionListener<Void>> headerValidator;
146150
private final int readTimeoutMillis;
147151

148152
private final int maxCompositeBufferComponents;
@@ -160,8 +164,8 @@ public Netty4HttpServerTransport(
160164
SharedGroupFactory sharedGroupFactory,
161165
Tracer tracer,
162166
TLSConfig tlsConfig,
163-
@Nullable AcceptChannelHandler.AcceptPredicate acceptChannelPredicate
164-
167+
@Nullable AcceptChannelHandler.AcceptPredicate acceptChannelPredicate,
168+
@Nullable TriConsumer<HttpRequest, Channel, ActionListener<Void>> headerValidator
165169
) {
166170
super(
167171
settings,
@@ -178,6 +182,7 @@ public Netty4HttpServerTransport(
178182
this.sharedGroupFactory = sharedGroupFactory;
179183
this.tlsConfig = tlsConfig;
180184
this.acceptChannelPredicate = acceptChannelPredicate;
185+
this.headerValidator = headerValidator;
181186

182187
this.pipeliningMaxEvents = SETTING_PIPELINING_MAX_EVENTS.get(settings);
183188

@@ -323,7 +328,7 @@ public void onException(HttpChannel channel, Exception cause) {
323328
}
324329

325330
public ChannelHandler configureServerChannelHandler() {
326-
return new HttpChannelHandler(this, handlingSettings, tlsConfig, acceptChannelPredicate);
331+
return new HttpChannelHandler(this, handlingSettings, tlsConfig, acceptChannelPredicate, headerValidator);
327332
}
328333

329334
static final AttributeKey<Netty4HttpChannel> HTTP_CHANNEL_KEY = AttributeKey.newInstance("es-http-channel");
@@ -335,17 +340,20 @@ protected static class HttpChannelHandler extends ChannelInitializer<Channel> {
335340
private final HttpHandlingSettings handlingSettings;
336341
private final TLSConfig tlsConfig;
337342
private final BiPredicate<String, InetSocketAddress> acceptChannelPredicate;
343+
private final TriConsumer<HttpRequest, Channel, ActionListener<Void>> headerValidator;
338344

339345
protected HttpChannelHandler(
340346
final Netty4HttpServerTransport transport,
341347
final HttpHandlingSettings handlingSettings,
342348
final TLSConfig tlsConfig,
343-
@Nullable final BiPredicate<String, InetSocketAddress> acceptChannelPredicate
349+
@Nullable final BiPredicate<String, InetSocketAddress> acceptChannelPredicate,
350+
@Nullable final TriConsumer<HttpRequest, Channel, ActionListener<Void>> headerValidator
344351
) {
345352
this.transport = transport;
346353
this.handlingSettings = handlingSettings;
347354
this.tlsConfig = tlsConfig;
348355
this.acceptChannelPredicate = acceptChannelPredicate;
356+
this.headerValidator = headerValidator;
349357
}
350358

351359
@Override
@@ -374,11 +382,17 @@ protected void initChannel(Channel ch) throws Exception {
374382
handlingSettings.maxChunkSize()
375383
);
376384
decoder.setCumulator(ByteToMessageDecoder.COMPOSITE_CUMULATOR);
385+
ch.pipeline().addLast("decoder", decoder); // parses the HTTP bytes request into HTTP message pieces
386+
if (headerValidator != null) {
387+
// runs a validation function on the first HTTP message piece which contains all the headers
388+
// if validation passes, the pieces of that particular request are forwarded, otherwise they are discarded
389+
ch.pipeline().addLast("header_validator", new Netty4HttpHeaderValidator(headerValidator));
390+
}
391+
// combines the HTTP message pieces into a single full HTTP request (with headers and body)
377392
final HttpObjectAggregator aggregator = new HttpObjectAggregator(handlingSettings.maxContentLength());
378393
aggregator.setMaxCumulationBufferComponents(transport.maxCompositeBufferComponents);
379394
ch.pipeline()
380-
.addLast("decoder", decoder)
381-
.addLast("decoder_compress", new HttpContentDecompressor())
395+
.addLast("decoder_compress", new HttpContentDecompressor()) // this handles request body decompression
382396
.addLast("encoder", new HttpResponseEncoder() {
383397
@Override
384398
protected boolean isContentAlwaysEmpty(HttpResponse msg) {

modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4Plugin.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ public Map<String, Supplier<HttpServerTransport>> getHttpTransports(
114114
getSharedGroupFactory(settings),
115115
tracer,
116116
TLSConfig.noTLS(),
117+
null,
117118
null
118119
)
119120
);

modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4BadRequestTests.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,8 @@ public void dispatchBadRequest(RestChannel channel, ThreadContext threadContext,
8787
new SharedGroupFactory(Settings.EMPTY),
8888
Tracer.NOOP,
8989
TLSConfig.noTLS(),
90-
null
90+
null,
91+
randomFrom(Netty4HttpHeaderValidator.NOOP_VALIDATOR, null)
9192
)
9293
) {
9394
httpServerTransport.start();

0 commit comments

Comments
 (0)