1515import io .netty .handler .codec .http .HttpContent ;
1616import io .netty .handler .codec .http .HttpObject ;
1717import io .netty .handler .codec .http .HttpRequest ;
18+ import io .netty .util .ReferenceCounted ;
1819
1920import org .elasticsearch .action .ActionListener ;
2021import org .elasticsearch .action .support .ContextPreservingActionListener ;
22+ import org .elasticsearch .action .support .SubscribableListener ;
2123import org .elasticsearch .common .util .concurrent .ThreadContext ;
22- import org .elasticsearch .core .Nullable ;
2324import org .elasticsearch .http .netty4 .internal .HttpValidator ;
2425import org .elasticsearch .transport .Transports ;
2526
27+ import java .util .ArrayDeque ;
28+
2629public class Netty4HttpHeaderValidator extends ChannelDuplexHandler {
2730
2831 private final HttpValidator validator ;
2932 private final ThreadContext threadContext ;
30- private State state ;
33+ private State state = State .PASSING ;
34+ private final ArrayDeque <Object > buffer = new ArrayDeque <>();
3135
3236 public Netty4HttpHeaderValidator (HttpValidator validator , ThreadContext threadContext ) {
3337 this .validator = validator ;
@@ -36,80 +40,125 @@ public Netty4HttpHeaderValidator(HttpValidator validator, ThreadContext threadCo
3640
3741 @ Override
3842 public void channelRead (ChannelHandlerContext ctx , Object msg ) throws Exception {
43+ if (state == State .VALIDATING || buffer .size () > 0 ) {
44+ // there's already some buffered messages that need to be processed before this one, so queue this one up behind them
45+ buffer .offerLast (msg );
46+ return ;
47+ }
48+
3949 assert msg instanceof HttpObject ;
40- var httpObject = (HttpObject ) msg ;
50+ final var httpObject = (HttpObject ) msg ;
4151 if (httpObject .decoderResult ().isFailure ()) {
4252 ctx .fireChannelRead (httpObject ); // pass-through for decoding failures
53+ } else if (msg instanceof HttpRequest httpRequest ) {
54+ validate (ctx , httpRequest );
55+ } else if (state == State .PASSING ) {
56+ assert msg instanceof HttpContent ;
57+ ctx .fireChannelRead (msg );
4358 } else {
44- if (msg instanceof HttpRequest request ) {
45- validate (ctx , request );
46- } else {
47- assert msg instanceof HttpContent ;
48- var content = (HttpContent ) msg ;
49- if (state == State .DROPPING ) {
50- content .release ();
51- ctx .read ();
52- } else {
53- assert state == State .PASSING : "unexpected content before validation completed" ;
54- ctx .fireChannelRead (content );
55- }
56- }
59+ assert state == State .DROPPING : state ;
60+ assert msg instanceof HttpContent ;
61+ final var httpContent = (HttpContent ) msg ;
62+ httpContent .release ();
63+ ctx .read ();
5764 }
5865 }
5966
6067 @ Override
61- public void read (ChannelHandlerContext ctx ) throws Exception {
62- // until validation is completed we can ignore read calls,
63- // once validation is finished HttpRequest will be fired and downstream can read from there
64- if (state != State .VALIDATING ) {
65- ctx .read ();
66- }
68+ public void channelReadComplete (ChannelHandlerContext ctx ) {
69+ if (buffer .size () == 0 ) {
70+ ctx .fireChannelReadComplete ();
71+ } // else we're buffering messages so will manage the read-complete messages ourselves
6772 }
6873
69- void validate (ChannelHandlerContext ctx , HttpRequest request ) {
70- assert Transports .assertDefaultThreadContext (threadContext );
71- state = State .VALIDATING ;
72- ActionListener .run (
73- // this prevents thread-context changes to propagate to the validation listener
74- // atm, the validation listener submits to the event loop executor, which doesn't know about the ES thread-context,
75- // so this is just a defensive play, in case the code inside the listener changes to not use the event loop executor
76- ActionListener .assertOnce (
77- new ContextPreservingActionListener <Void >(
78- threadContext .wrapRestorable (threadContext .newStoredContext ()),
79- new ActionListener <>() {
80- @ Override
81- public void onResponse (Void unused ) {
82- handleValidationResult (ctx , request , null );
83- }
84-
85- @ Override
86- public void onFailure (Exception e ) {
87- handleValidationResult (ctx , request , e );
88- }
74+ @ Override
75+ public void read (ChannelHandlerContext ctx ) throws Exception {
76+ assert ctx .channel ().eventLoop ().inEventLoop ();
77+ if (state != State .VALIDATING ) {
78+ if (buffer .size () > 0 ) {
79+ final var message = buffer .pollFirst ();
80+ if (message instanceof HttpRequest httpRequest ) {
81+ if (httpRequest .decoderResult ().isFailure ()) {
82+ ctx .fireChannelRead (message ); // pass-through for decoding failures
83+ ctx .fireChannelReadComplete (); // downstream will have to call read() again when it's ready
84+ } else {
85+ validate (ctx , httpRequest );
8986 }
90- )
91- ),
92- listener -> {
93- // this prevents thread-context changes to propagate beyond the validation, as netty worker threads are reused
94- try (ThreadContext .StoredContext ignore = threadContext .newStoredContext ()) {
95- validator .validate (request , ctx .channel (), listener );
87+ } else {
88+ assert message instanceof HttpContent ;
89+ assert state == State .PASSING : state ; // DROPPING releases any buffered chunks up-front
90+ ctx .fireChannelRead (message );
91+ ctx .fireChannelReadComplete (); // downstream will have to call read() again when it's ready
9692 }
93+ } else {
94+ ctx .read ();
9795 }
98- );
96+ }
9997 }
10098
101- void handleValidationResult (ChannelHandlerContext ctx , HttpRequest request , @ Nullable Exception validationError ) {
102- assert Transports .assertDefaultThreadContext (threadContext );
103- // Always explicitly dispatch back to the event loop to prevent reentrancy concerns if we are still on event loop
104- ctx .channel ().eventLoop ().execute (() -> {
105- if (validationError != null ) {
106- request .setDecoderResult (DecoderResult .failure (validationError ));
107- state = State .DROPPING ;
108- } else {
109- state = State .PASSING ;
99+ void validate (ChannelHandlerContext ctx , HttpRequest httpRequest ) {
100+ final var validationResultListener = new ValidationResultListener (ctx , httpRequest );
101+ SubscribableListener .newForked (validationResultListener ::doValidate )
102+ .addListener (
103+ validationResultListener ,
104+ // dispatch back to event loop unless validation completed already in which case we can just continue on this thread
105+ // straight away, avoiding the need to buffer any subsequent messages
106+ ctx .channel ().eventLoop (),
107+ null
108+ );
109+ }
110+
111+ private class ValidationResultListener implements ActionListener <Void > {
112+
113+ private final ChannelHandlerContext ctx ;
114+ private final HttpRequest httpRequest ;
115+
116+ ValidationResultListener (ChannelHandlerContext ctx , HttpRequest httpRequest ) {
117+ this .ctx = ctx ;
118+ this .httpRequest = httpRequest ;
119+ }
120+
121+ void doValidate (ActionListener <Void > listener ) {
122+ assert Transports .assertDefaultThreadContext (threadContext );
123+ assert ctx .channel ().eventLoop ().inEventLoop ();
124+ assert state == State .PASSING || state == State .DROPPING : state ;
125+ state = State .VALIDATING ;
126+ try (var ignore = threadContext .newEmptyContext ()) {
127+ validator .validate (
128+ httpRequest ,
129+ ctx .channel (),
130+ new ContextPreservingActionListener <>(threadContext ::newEmptyContext , listener )
131+ );
110132 }
111- ctx .fireChannelRead (request );
112- });
133+ }
134+
135+ @ Override
136+ public void onResponse (Void unused ) {
137+ assert Transports .assertDefaultThreadContext (threadContext );
138+ assert ctx .channel ().eventLoop ().inEventLoop ();
139+ assert state == State .VALIDATING : state ;
140+ state = State .PASSING ;
141+ fireChannelRead ();
142+ }
143+
144+ @ Override
145+ public void onFailure (Exception e ) {
146+ assert Transports .assertDefaultThreadContext (threadContext );
147+ assert ctx .channel ().eventLoop ().inEventLoop ();
148+ assert state == State .VALIDATING : state ;
149+ httpRequest .setDecoderResult (DecoderResult .failure (e ));
150+ state = State .DROPPING ;
151+ while (buffer .isEmpty () == false && buffer .peekFirst () instanceof HttpRequest == false ) {
152+ assert buffer .peekFirst () instanceof HttpContent ;
153+ ((ReferenceCounted ) buffer .pollFirst ()).release ();
154+ }
155+ fireChannelRead ();
156+ }
157+
158+ private void fireChannelRead () {
159+ ctx .fireChannelRead (httpRequest );
160+ ctx .fireChannelReadComplete (); // downstream needs to read() again
161+ }
113162 }
114163
115164 private enum State {
0 commit comments