diff --git a/httpcore5-h2/src/main/java/org/apache/hc/core5/http2/frame/FrameType.java b/httpcore5-h2/src/main/java/org/apache/hc/core5/http2/frame/FrameType.java index 2253e24ca..1d53e5b2e 100644 --- a/httpcore5-h2/src/main/java/org/apache/hc/core5/http2/frame/FrameType.java +++ b/httpcore5-h2/src/main/java/org/apache/hc/core5/http2/frame/FrameType.java @@ -43,7 +43,8 @@ public enum FrameType { GOAWAY(0x07), WINDOW_UPDATE(0x08), CONTINUATION(0x09), - PRIORITY_UPDATE(0x10); // 16 + PRIORITY_UPDATE(0x10), // 16 + ORIGIN(0x0c); // RFC 8336 final int value; diff --git a/httpcore5-h2/src/main/java/org/apache/hc/core5/http2/impl/nio/AbstractH2StreamMultiplexer.java b/httpcore5-h2/src/main/java/org/apache/hc/core5/http2/impl/nio/AbstractH2StreamMultiplexer.java index 9b39d4b08..d133dc44c 100644 --- a/httpcore5-h2/src/main/java/org/apache/hc/core5/http2/impl/nio/AbstractH2StreamMultiplexer.java +++ b/httpcore5-h2/src/main/java/org/apache/hc/core5/http2/impl/nio/AbstractH2StreamMultiplexer.java @@ -27,16 +27,21 @@ package org.apache.hc.core5.http2.impl.nio; import java.io.IOException; +import java.net.InetSocketAddress; import java.net.SocketAddress; import java.nio.BufferOverflowException; import java.nio.ByteBuffer; import java.nio.channels.SelectionKey; import java.nio.charset.StandardCharsets; +import java.util.Collections; import java.util.Deque; +import java.util.HashSet; import java.util.Iterator; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Queue; +import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentLinkedDeque; import java.util.concurrent.ConcurrentLinkedQueue; @@ -53,11 +58,13 @@ import org.apache.hc.core5.http.HttpConnection; import org.apache.hc.core5.http.HttpException; import org.apache.hc.core5.http.HttpHeaders; +import org.apache.hc.core5.http.HttpHost; import org.apache.hc.core5.http.HttpStreamResetException; import org.apache.hc.core5.http.HttpVersion; import org.apache.hc.core5.http.ProtocolException; import org.apache.hc.core5.http.ProtocolVersion; import org.apache.hc.core5.http.RequestNotExecutedException; +import org.apache.hc.core5.http.URIScheme; import org.apache.hc.core5.http.config.CharCodingConfig; import org.apache.hc.core5.http.impl.BasicEndpointDetails; import org.apache.hc.core5.http.impl.BasicHttpConnectionMetrics; @@ -144,6 +151,12 @@ enum SettingsHandshake { READY, TRANSMITTED, ACKED } private final Map priorities = new ConcurrentHashMap<>(); private volatile boolean peerNoRfc7540Priorities; + /** + * RFC 8336 Origin Set (client-side). + */ + private final Set originSet = Collections.newSetFromMap(new ConcurrentHashMap<>()); + private volatile boolean originInit; + AbstractH2StreamMultiplexer( final ProtocolIOSession ioSession, final FrameFactory frameFactory, @@ -729,6 +742,19 @@ private void consumeFrame(final RawFrame frame) throws HttpException, IOExceptio throw new H2ConnectionException(H2Error.PROTOCOL_ERROR, "CONTINUATION frame expected"); } switch (frameType) { + case ORIGIN: { // RFC 8336 + // Only valid on stream 0; ignore on h2c; ignore reserved incompatible flags (0x1|0x2|0x4|0x8) + if (streamId == 0 && getSSLSession() != null) { + final int flags = frame.getFlags(); + if ((flags & 0x0F) == 0) { + final ByteBuffer pl = frame.getPayloadContent(); + if (pl != null) { + processOriginPayload(pl); + } + } + } + } + break; case DATA: { if (streamId == 0) { throw new H2ConnectionException(H2Error.PROTOCOL_ERROR, "Illegal stream id: " + streamId); @@ -1584,4 +1610,117 @@ public String toString() { } + /** + * Initialize the Origin Set once per connection (RFC 8336 ยง2.3). + */ + private void ensureOriginInit() { + if (originInit) { + return; + } + // Initial origin: scheme "https", host = SNI (lowercased) if available, else remote IP; port = remote port. + final SSLSession ssl = getSSLSession(); + if (ssl == null) { + return; // ORIGIN is ignored on h2c; keep uninitialized + } + String host = null; + try { + // Best-effort SNI via session value used by our TLS strategy (if present) + final Object sni = ssl.getValue("HOSTNAME"); + if (sni instanceof String) { + host = ((String) sni).toLowerCase(Locale.ROOT); + } + } catch (final Exception ignore) { + } + if (host == null) { + final SocketAddress ra = getRemoteAddress(); + if (ra instanceof InetSocketAddress) { + host = ((InetSocketAddress) ra).getHostString().toLowerCase(Locale.ROOT); + } + } + int port = 0; + final SocketAddress ra = getRemoteAddress(); + if (ra instanceof java.net.InetSocketAddress) { + port = ((java.net.InetSocketAddress) ra).getPort(); + } + if (host != null && port > 0) { + originSet.add(new HttpHost("https", host, port)); + originInit = true; + } + } + + /** + * Parse and merge ORIGIN payload (list of Origin-Entry). + */ + private void processOriginPayload(final ByteBuffer pl) { + ensureOriginInit(); + while (pl.remaining() >= 2) { + final int len = Short.toUnsignedInt(pl.getShort()); + if (len == 0) { + // Empty Origin-Entry is allowed (server can signal "SNI-only"); no-op here. + continue; + } + if (pl.remaining() < len) { + break; // malformed; stop processing silently per robustness principle + } + final byte[] b = new byte[len]; + pl.get(b); + final String ascii = new String(b, java.nio.charset.StandardCharsets.US_ASCII); + final org.apache.hc.core5.http.HttpHost parsed = parseAsciiOrigin(ascii); + if (parsed != null) { + originSet.add(parsed); + } + } + } + + /** + * RFC 6454 ASCII origin parser (scheme/host/port only). Returns null if invalid. + */ + private HttpHost parseAsciiOrigin(final String s) { + try { + final java.net.URI u = java.net.URI.create(s); + if (u.getFragment() != null) return null; + final String scheme = u.getScheme(); + final String host = u.getHost(); + if (scheme == null || host == null) return null; + int port = u.getPort(); + if (port < 0) { + if (URIScheme.HTTPS.same(scheme)) { + port = 443; + } else if (URIScheme.HTTP.same(scheme)) { + port = 80; + } else { + return null; + } + } + return new HttpHost(scheme.toLowerCase(Locale.ROOT), host.toLowerCase(Locale.ROOT), port); + } catch (final IllegalArgumentException ex) { + return null; + } + } + + + protected final void commitConnFrame(final RawFrame frame) throws IOException { + Args.notNull(frame, "Frame"); + ioSession.getLock().lock(); + try { + commitFrameInternal(frame); + } finally { + ioSession.getLock().unlock(); + } + } + + Set getOriginSetSnapshot() { + return Collections.unmodifiableSet(new HashSet<>(originSet)); + } + + public void removeOrigin(final org.apache.hc.core5.http.HttpHost origin) { + if (origin != null) { + originSet.remove(origin); + } + } + + boolean isOriginAllowed(final HttpHost origin) { + return origin != null && originSet.contains(origin); + } + } \ No newline at end of file diff --git a/httpcore5-h2/src/main/java/org/apache/hc/core5/http2/impl/nio/ClientH2StreamHandler.java b/httpcore5-h2/src/main/java/org/apache/hc/core5/http2/impl/nio/ClientH2StreamHandler.java index 80e6db1d7..2d556d4da 100644 --- a/httpcore5-h2/src/main/java/org/apache/hc/core5/http2/impl/nio/ClientH2StreamHandler.java +++ b/httpcore5-h2/src/main/java/org/apache/hc/core5/http2/impl/nio/ClientH2StreamHandler.java @@ -29,6 +29,7 @@ import java.io.IOException; import java.nio.ByteBuffer; import java.util.List; +import java.util.Locale; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; @@ -37,11 +38,13 @@ import org.apache.hc.core5.http.HeaderElements; import org.apache.hc.core5.http.HttpException; import org.apache.hc.core5.http.HttpHeaders; +import org.apache.hc.core5.http.HttpHost; import org.apache.hc.core5.http.HttpRequest; import org.apache.hc.core5.http.HttpResponse; import org.apache.hc.core5.http.HttpStatus; import org.apache.hc.core5.http.HttpVersion; import org.apache.hc.core5.http.ProtocolException; +import org.apache.hc.core5.http.URIScheme; import org.apache.hc.core5.http.impl.BasicHttpConnectionMetrics; import org.apache.hc.core5.http.impl.IncomingEntityDetails; import org.apache.hc.core5.http.impl.nio.MessageState; @@ -54,6 +57,7 @@ import org.apache.hc.core5.http.protocol.HttpProcessor; import org.apache.hc.core5.http2.H2ConnectionException; import org.apache.hc.core5.http2.H2Error; +import org.apache.hc.core5.http2.H2PseudoRequestHeaders; import org.apache.hc.core5.http2.impl.DefaultH2RequestConverter; import org.apache.hc.core5.http2.impl.DefaultH2ResponseConverter; @@ -72,13 +76,18 @@ class ClientH2StreamHandler implements H2StreamHandler { private final AtomicBoolean failed; private final AtomicBoolean done; + private final ClientH2StreamMultiplexer parent; + private volatile HttpHost lastRequestOrigin; + ClientH2StreamHandler( + final ClientH2StreamMultiplexer parent, final H2StreamChannel outputChannel, final HttpProcessor httpProcessor, final BasicHttpConnectionMetrics connMetrics, final AsyncClientExchangeHandler exchangeHandler, final HandlerFactory pushHandlerFactory, final HttpCoreContext context) { + this.parent = parent; this.outputChannel = outputChannel; this.dataChannel = new DataStreamChannel() { @@ -142,6 +151,35 @@ private void commitRequest(final HttpRequest request, final EntityDetails entity httpProcessor.process(request, entityDetails, context); final List
headers = DefaultH2RequestConverter.INSTANCE.convert(request); + String scheme = null; + String authority = null; + for (final Header h : headers) { + final String n = h.getName(); + if (H2PseudoRequestHeaders.SCHEME.equalsIgnoreCase(n)) { + scheme = h.getValue(); + } else if (H2PseudoRequestHeaders.AUTHORITY.equalsIgnoreCase(n)) { + authority = h.getValue(); + } + } + if (scheme != null && authority != null) { + String host = authority; + int port = -1; + final int colon = authority.lastIndexOf(':'); + if (colon > 0 && authority.indexOf(']') < 0) { + host = authority.substring(0, colon); + try { + port = Integer.parseInt(authority.substring(colon + 1)); + } catch (final NumberFormatException ignore) { + + } + } + if (port < 0) { + port = URIScheme.HTTPS.same(scheme) ? 443 : (URIScheme.HTTP.same(scheme) ? 80 : -1); + } + if (port > 0) { + lastRequestOrigin = new HttpHost(scheme.toLowerCase(Locale.ROOT), host.toLowerCase(Locale.ROOT),port); + } + } if (entityDetails == null) { requestState.set(MessageState.COMPLETE); outputChannel.submit(headers, true); @@ -197,6 +235,16 @@ public void consumeHeader(final List
headers, final boolean endStream) t if (status > HttpStatus.SC_CONTINUE && status < HttpStatus.SC_SUCCESS) { exchangeHandler.consumeInformation(response, context); } + if (status == HttpStatus.SC_MISDIRECTED_REQUEST /* 421 */ && lastRequestOrigin != null && parent != null) { + parent.removeOrigin(lastRequestOrigin); + } + if (lastRequestOrigin != null) { + // Only enforce after ORIGIN has initialized the set (i.e., it's not empty) + if (!parent.getOriginSetSnapshot().isEmpty() && !parent.isOriginAllowed(lastRequestOrigin)) { + throw new ProtocolException("Origin not allowed on this HTTP/2 connection: " + lastRequestOrigin); + } + } + if (requestState.get() == MessageState.ACK) { if (status == HttpStatus.SC_CONTINUE || status >= HttpStatus.SC_SUCCESS) { requestState.set(MessageState.BODY); diff --git a/httpcore5-h2/src/main/java/org/apache/hc/core5/http2/impl/nio/ClientH2StreamMultiplexer.java b/httpcore5-h2/src/main/java/org/apache/hc/core5/http2/impl/nio/ClientH2StreamMultiplexer.java index ab8ecc48a..8cda98e8d 100644 --- a/httpcore5-h2/src/main/java/org/apache/hc/core5/http2/impl/nio/ClientH2StreamMultiplexer.java +++ b/httpcore5-h2/src/main/java/org/apache/hc/core5/http2/impl/nio/ClientH2StreamMultiplexer.java @@ -131,7 +131,7 @@ H2StreamHandler outgoingRequest( final HttpCoreContext coreContext = HttpCoreContext.castOrCreate(context); coreContext.setSSLSession(getSSLSession()); coreContext.setEndpointDetails(getEndpointDetails()); - return new ClientH2StreamHandler(channel, getHttpProcessor(), getConnMetrics(), exchangeHandler, + return new ClientH2StreamHandler(this, channel, getHttpProcessor(), getConnMetrics(), exchangeHandler, pushHandlerFactory != null ? pushHandlerFactory : this.pushHandlerFactory, coreContext); } @@ -170,6 +170,5 @@ public String toString() { buf.append("]"); return buf.toString(); } - } diff --git a/httpcore5-h2/src/main/java/org/apache/hc/core5/http2/impl/nio/ServerH2StreamMultiplexer.java b/httpcore5-h2/src/main/java/org/apache/hc/core5/http2/impl/nio/ServerH2StreamMultiplexer.java index 046047398..ff4ac3617 100644 --- a/httpcore5-h2/src/main/java/org/apache/hc/core5/http2/impl/nio/ServerH2StreamMultiplexer.java +++ b/httpcore5-h2/src/main/java/org/apache/hc/core5/http2/impl/nio/ServerH2StreamMultiplexer.java @@ -28,6 +28,9 @@ import java.io.IOException; import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Collection; import java.util.List; import org.apache.hc.core5.annotation.Internal; @@ -50,6 +53,8 @@ import org.apache.hc.core5.http2.config.H2Setting; import org.apache.hc.core5.http2.frame.DefaultFrameFactory; import org.apache.hc.core5.http2.frame.FrameFactory; +import org.apache.hc.core5.http2.frame.FrameType; +import org.apache.hc.core5.http2.frame.RawFrame; import org.apache.hc.core5.http2.frame.StreamIdGenerator; import org.apache.hc.core5.http2.hpack.HeaderListConstraintException; import org.apache.hc.core5.reactor.ProtocolIOSession; @@ -171,4 +176,38 @@ public String toString() { return buf.toString(); } + + public void sendOrigin(final Collection asciiOrigins) throws IOException { + if (asciiOrigins == null || asciiOrigins.isEmpty()) { + final ByteBuffer empty = ByteBuffer.allocate(0); + final RawFrame origin = new RawFrame(FrameType.ORIGIN.getValue(), 0, 0, empty); + commitConnFrame(origin); + return; + } + final ArrayList parts = new ArrayList<>(); + int total = 0; + for (final String s : asciiOrigins) { + if (s == null) { + continue; + } + final byte[] b = s.getBytes(StandardCharsets.US_ASCII); + if (b.length > 0xFFFF) { + continue; + } + parts.add(b); + total += 2 + b.length; + } + if (total == 0) { + return; + } + final ByteBuffer pl = ByteBuffer.allocate(total); + for (final byte[] b : parts) { + pl.putShort((short)(b.length & 0xFFFF)); + pl.put(b); + } + pl.flip(); + final RawFrame origin = new RawFrame(FrameType.ORIGIN.getValue(), 0, 0, pl); + commitConnFrame(origin); + } + } diff --git a/httpcore5-h2/src/test/java/org/apache/hc/core5/http2/impl/nio/ClientH2StreamHandlerOriginTest.java b/httpcore5-h2/src/test/java/org/apache/hc/core5/http2/impl/nio/ClientH2StreamHandlerOriginTest.java new file mode 100644 index 000000000..3909abe59 --- /dev/null +++ b/httpcore5-h2/src/test/java/org/apache/hc/core5/http2/impl/nio/ClientH2StreamHandlerOriginTest.java @@ -0,0 +1,213 @@ +/* + * ==================================================================== + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * ==================================================================== + * + * This software consists of voluntary contributions made by many + * individuals on behalf of the Apache Software Foundation. For more + * information on the Apache Software Foundation, please see + * . + * + */ +package org.apache.hc.core5.http2.impl.nio; + + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.anyList; +import static org.mockito.Mockito.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.net.URI; +import java.nio.ByteBuffer; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import org.apache.hc.core5.http.EntityDetails; +import org.apache.hc.core5.http.Header; +import org.apache.hc.core5.http.HttpException; +import org.apache.hc.core5.http.HttpHost; +import org.apache.hc.core5.http.HttpRequest; +import org.apache.hc.core5.http.HttpResponse; +import org.apache.hc.core5.http.ProtocolException; +import org.apache.hc.core5.http.impl.BasicHttpConnectionMetrics; +import org.apache.hc.core5.http.message.BasicHeader; +import org.apache.hc.core5.http.message.BasicHttpRequest; +import org.apache.hc.core5.http.nio.AsyncClientExchangeHandler; +import org.apache.hc.core5.http.nio.CapacityChannel; +import org.apache.hc.core5.http.nio.DataStreamChannel; +import org.apache.hc.core5.http.nio.RequestChannel; +import org.apache.hc.core5.http.protocol.HttpContext; +import org.apache.hc.core5.http.protocol.HttpCoreContext; +import org.apache.hc.core5.http.protocol.HttpProcessor; +import org.apache.hc.core5.http2.impl.BasicH2TransportMetrics; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; + +class ClientH2StreamHandlerOriginTest { + + @Test + void removesOriginOn421() throws Exception { + final ClientH2StreamMultiplexer parent = mock(ClientH2StreamMultiplexer.class); + final H2StreamChannel channel = mock(H2StreamChannel.class); + final HttpProcessor httpProcessor = mock(HttpProcessor.class); + final BasicHttpConnectionMetrics metrics = new BasicHttpConnectionMetrics(new BasicH2TransportMetrics(), new BasicH2TransportMetrics()); + final HttpCoreContext ctx = HttpCoreContext.create(); + + final HttpRequest req = new BasicHttpRequest("GET", URI.create("https://Example.com/")); + final AsyncClientExchangeHandler eh = new OneShot(req); + + final ClientH2StreamHandler h = + new ClientH2StreamHandler(parent, channel, httpProcessor, metrics, eh, null, ctx); + + h.produceOutput(); + verify(channel).submit(anyList(), eq(true)); + + when(parent.getOriginSetSnapshot()).thenReturn(Collections.emptySet()); + + h.consumeHeader(Collections.
singletonList(new BasicHeader(":status", "421")), true); + + final ArgumentCaptor cap = ArgumentCaptor.forClass(HttpHost.class); + verify(parent).removeOrigin(cap.capture()); + final HttpHost o = cap.getValue(); + assertEquals("https", o.getSchemeName()); + assertEquals("example.com", o.getHostName()); + assertEquals(443, o.getPort()); + } + + @Test + void throwsWhenOriginNotAllowed() throws Exception { + final ClientH2StreamMultiplexer parent = mock(ClientH2StreamMultiplexer.class); + final H2StreamChannel channel = mock(H2StreamChannel.class); + final HttpProcessor httpProcessor = mock(HttpProcessor.class); + final BasicHttpConnectionMetrics metrics = new BasicHttpConnectionMetrics(new BasicH2TransportMetrics(), new BasicH2TransportMetrics()); + final HttpCoreContext ctx = HttpCoreContext.create(); + + final HttpRequest req = new BasicHttpRequest("GET", URI.create("https://blocked.example/")); + final AsyncClientExchangeHandler eh = new OneShot(req); + + final ClientH2StreamHandler h = + new ClientH2StreamHandler(parent, channel, httpProcessor, metrics, eh, null, ctx); + + h.produceOutput(); + + final Set initialized = new HashSet(); + initialized.add(new HttpHost("https", "allowed.example", 443)); + when(parent.getOriginSetSnapshot()).thenReturn(initialized); + when(parent.isOriginAllowed(any(HttpHost.class))).thenReturn(false); + + final ProtocolException ex = assertThrows(ProtocolException.class, + new org.junit.jupiter.api.function.Executable() { + @Override + public void execute() throws Throwable { + h.consumeHeader(Collections.
singletonList(new BasicHeader(":status", "200")), true); + } + }); + assertTrue(ex.getMessage().contains("Origin not allowed")); + } + + @Test + void preservesExplicitPortInAuthority() throws Exception { + final ClientH2StreamMultiplexer parent = mock(ClientH2StreamMultiplexer.class); + final H2StreamChannel channel = mock(H2StreamChannel.class); + final HttpProcessor httpProcessor = mock(HttpProcessor.class); + final BasicHttpConnectionMetrics metrics = new BasicHttpConnectionMetrics(new BasicH2TransportMetrics(), new BasicH2TransportMetrics()); + final HttpCoreContext ctx = HttpCoreContext.create(); + + final HttpRequest req = new BasicHttpRequest("GET", URI.create("http://example.com:8080/")); + final AsyncClientExchangeHandler eh = new OneShot(req); + + final ClientH2StreamHandler h = + new ClientH2StreamHandler(parent, channel, httpProcessor, metrics, eh, null, ctx); + + h.produceOutput(); + verify(channel).submit(anyList(), eq(true)); + + when(parent.getOriginSetSnapshot()).thenReturn(Collections.emptySet()); + h.consumeHeader(Collections.
singletonList(new BasicHeader(":status", "421")), true); + + final ArgumentCaptor cap = ArgumentCaptor.forClass(HttpHost.class); + verify(parent).removeOrigin(cap.capture()); + final HttpHost o = cap.getValue(); + assertEquals("http", o.getSchemeName()); + assertEquals("example.com", o.getHostName()); + assertEquals(8080, o.getPort()); + } + + private static final class OneShot implements AsyncClientExchangeHandler { + private final HttpRequest request; + + OneShot(final HttpRequest request) { + this.request = request; + } + + @Override + public void produceRequest(final RequestChannel channel, final HttpContext context) throws HttpException, IOException { + channel.sendRequest(request, null, context); + } + + @Override + public int available() { + return 0; + } + + @Override + public void produce(final DataStreamChannel channel) { + } + + @Override + public void consumeInformation(final HttpResponse response, final HttpContext context) { + } + + @Override + public void cancel() { + + } + + @Override + public void consumeResponse(final HttpResponse response, final EntityDetails entityDetails, final HttpContext context) { + } + + @Override + public void updateCapacity(final CapacityChannel capacityChannel) { + } + + @Override + public void consume(final ByteBuffer src) { + } + + @Override + public void streamEnd(final List trailers) { + } + + @Override + public void failed(final Exception cause) { + } + + @Override + public void releaseResources() { + } + } +} diff --git a/httpcore5-h2/src/test/java/org/apache/hc/core5/http2/impl/nio/ServerH2OriginRFC8336Test.java b/httpcore5-h2/src/test/java/org/apache/hc/core5/http2/impl/nio/ServerH2OriginRFC8336Test.java new file mode 100644 index 000000000..2bf04c0ea --- /dev/null +++ b/httpcore5-h2/src/test/java/org/apache/hc/core5/http2/impl/nio/ServerH2OriginRFC8336Test.java @@ -0,0 +1,204 @@ +/* + * ==================================================================== + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * ==================================================================== + * + * This software consists of voluntary contributions made by many + * individuals on behalf of the Apache Software Foundation. For more + * information on the Apache Software Foundation, please see + * . + * + */ +package org.apache.hc.core5.http2.impl.nio; + + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.RETURNS_DEEP_STUBS; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.net.InetSocketAddress; +import java.nio.ByteBuffer; +import java.nio.channels.SelectionKey; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Set; +import java.util.concurrent.locks.ReentrantLock; + +import javax.net.ssl.SSLSession; + +import org.apache.hc.core5.http.HttpHost; +import org.apache.hc.core5.http.config.CharCodingConfig; +import org.apache.hc.core5.http.nio.AsyncServerExchangeHandler; +import org.apache.hc.core5.http.nio.HandlerFactory; +import org.apache.hc.core5.http.protocol.HttpProcessor; +import org.apache.hc.core5.http2.config.H2Config; +import org.apache.hc.core5.http2.frame.DefaultFrameFactory; +import org.apache.hc.core5.http2.frame.FrameType; +import org.apache.hc.core5.reactor.ProtocolIOSession; +import org.apache.hc.core5.reactor.ssl.TlsDetails; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class ServerH2OriginRFC8336Test { + + private ProtocolIOSession ioSession; + private HttpProcessor httpProcessor; + private HandlerFactory exchangeHandlerFactory; + + @BeforeEach + void setUp() { + ioSession = mock(ProtocolIOSession.class, RETURNS_DEEP_STUBS); + when(ioSession.getId()).thenReturn("test-conn"); + when(ioSession.getLock()).thenReturn(new ReentrantLock()); + when(ioSession.isOpen()).thenReturn(true); + when(ioSession.getRemoteAddress()).thenReturn(new InetSocketAddress("127.0.0.1", 443)); + httpProcessor = mock(HttpProcessor.class); + @SuppressWarnings("unchecked") final HandlerFactory f = (HandlerFactory) mock(HandlerFactory.class); + exchangeHandlerFactory = f; + } + + private ServerH2StreamMultiplexer newMuxWithTLS() { + final SSLSession sslSession = mock(SSLSession.class); + final TlsDetails tlsDetails = new TlsDetails(sslSession, "h2"); + when(ioSession.getTlsDetails()).thenReturn(tlsDetails); + + return new ServerH2StreamMultiplexer( + ioSession, + DefaultFrameFactory.INSTANCE, + httpProcessor, + exchangeHandlerFactory, + CharCodingConfig.DEFAULT, + H2Config.DEFAULT, + null + ); + } + + private ServerH2StreamMultiplexer newMuxWithoutTLS() { + when(ioSession.getTlsDetails()).thenReturn(null); + return new ServerH2StreamMultiplexer( + ioSession, + DefaultFrameFactory.INSTANCE, + httpProcessor, + exchangeHandlerFactory, + CharCodingConfig.DEFAULT, + H2Config.DEFAULT, + null + ); + } + + private static ByteBuffer makeOriginPayload(final List origins) { + int total = 0; + for (final String s : origins) { + final byte[] b = s.getBytes(java.nio.charset.StandardCharsets.US_ASCII); + total += 2 + b.length; + } + final ByteBuffer pl = ByteBuffer.allocate(total); + for (final String s : origins) { + final byte[] b = s.getBytes(java.nio.charset.StandardCharsets.US_ASCII); + pl.putShort((short) (b.length & 0xFFFF)); + pl.put(b); + } + pl.flip(); + return pl; + } + + private static ByteBuffer h2Frame(final byte type, final byte flags, final int streamId, final ByteBuffer payload) { + final int len = payload == null ? 0 : payload.remaining(); + final ByteBuffer buf = ByteBuffer.allocate(9 + len); + buf.put((byte) ((len >>> 16) & 0xFF)); + buf.put((byte) ((len >>> 8) & 0xFF)); + buf.put((byte) (len & 0xFF)); + buf.put(type); + buf.put(flags); + buf.putInt(streamId & 0x7FFFFFFF); + if (payload != null) { + buf.put(payload.slice()); + } + buf.flip(); + return buf; + } + + @Test + void originFrame_overTLS_parsesAndStoresHosts() throws Exception { + final ServerH2StreamMultiplexer mux = newMuxWithTLS(); + + final ByteBuffer payload = makeOriginPayload(Arrays.asList( + "https://a.example:443", + "https://b.example:8443", + "https://c.example" // no port -> default 443 + )); + final ByteBuffer frame = h2Frame((byte) FrameType.ORIGIN.getValue(), (byte) 0x00, 0, payload); + + mux.onInput(frame); + + final Set set = ((AbstractH2StreamMultiplexer) mux).getOriginSetSnapshot(); + assertTrue(set.contains(new HttpHost("https", "a.example", 443))); + assertTrue(set.contains(new HttpHost("https", "b.example", 8443))); + assertTrue(set.contains(new HttpHost("https", "c.example", 443))); + } + + @Test + void originFrame_withoutTLS_isIgnored() throws Exception { + final ServerH2StreamMultiplexer mux = newMuxWithoutTLS(); + + final ByteBuffer payload = makeOriginPayload(Arrays.asList("https://ignored.example:443")); + final ByteBuffer frame = h2Frame((byte) FrameType.ORIGIN.getValue(), (byte) 0x00, 0, payload); + + mux.onInput(frame); + + final Set set = ((AbstractH2StreamMultiplexer) mux).getOriginSetSnapshot(); + assertFalse(set.contains(new HttpHost("https", "ignored.example", 443))); + } + + @Test + void originFrame_withNonZeroLowerFlags_isIgnored() throws Exception { + final ServerH2StreamMultiplexer mux = newMuxWithTLS(); + + final ByteBuffer payload = makeOriginPayload(Collections.singletonList("https://flags.example:443")); + // lower 4 bits non-zero -> ignore per RFC 8336 + final ByteBuffer frame = h2Frame((byte) FrameType.ORIGIN.getValue(), (byte) 0x01, 0, payload); + + mux.onInput(frame); + + final Set set = ((AbstractH2StreamMultiplexer) mux).getOriginSetSnapshot(); + assertFalse(set.contains(new HttpHost("https", "flags.example", 443))); + } + + @Test + void sendOrigin_enqueuesAndSignalsWrite() throws Exception { + final ServerH2StreamMultiplexer mux = newMuxWithTLS(); + + mux.sendOrigin(Collections.singletonList("https://emit.example:443")); + + verify(ioSession, atLeastOnce()).setEvent(SelectionKey.OP_WRITE); + } + + @Test + void sendOrigin_emptyPayload_okAndSignalsWrite() throws Exception { + final ServerH2StreamMultiplexer mux = newMuxWithTLS(); + + mux.sendOrigin(java.util.Collections.emptyList()); + + verify(ioSession, atLeastOnce()).setEvent(SelectionKey.OP_WRITE); + } +}