|
| 1 | +/* |
| 2 | + * SPDX-License-Identifier: Apache-2.0 |
| 3 | + * |
| 4 | + * The OpenSearch Contributors require contributions made to |
| 5 | + * this file be licensed under the Apache-2.0 license or a |
| 6 | + * compatible open source license. |
| 7 | + */ |
| 8 | + |
| 9 | +package org.opensearch.http.reactor.netty4; |
| 10 | + |
| 11 | +import org.opensearch.OpenSearchReactorNetty4IntegTestCase; |
| 12 | +import org.opensearch.cluster.metadata.IndexNameExpressionResolver; |
| 13 | +import org.opensearch.cluster.node.DiscoveryNodes; |
| 14 | +import org.opensearch.common.collect.Tuple; |
| 15 | +import org.opensearch.common.lease.Releasable; |
| 16 | +import org.opensearch.common.settings.ClusterSettings; |
| 17 | +import org.opensearch.common.settings.IndexScopedSettings; |
| 18 | +import org.opensearch.common.settings.Settings; |
| 19 | +import org.opensearch.common.settings.SettingsFilter; |
| 20 | +import org.opensearch.common.util.concurrent.ThreadContext; |
| 21 | +import org.opensearch.core.common.bytes.BytesReference; |
| 22 | +import org.opensearch.core.common.transport.TransportAddress; |
| 23 | +import org.opensearch.core.rest.RestStatus; |
| 24 | +import org.opensearch.http.HttpChunk; |
| 25 | +import org.opensearch.http.HttpServerTransport; |
| 26 | +import org.opensearch.plugins.ActionPlugin; |
| 27 | +import org.opensearch.plugins.Plugin; |
| 28 | +import org.opensearch.rest.BaseRestHandler; |
| 29 | +import org.opensearch.rest.RestController; |
| 30 | +import org.opensearch.rest.RestHandler; |
| 31 | +import org.opensearch.rest.RestRequest; |
| 32 | +import org.opensearch.rest.StreamingRestChannel; |
| 33 | +import org.opensearch.tasks.Task; |
| 34 | +import org.opensearch.test.OpenSearchIntegTestCase; |
| 35 | +import org.opensearch.transport.Netty4ModulePlugin; |
| 36 | +import org.opensearch.transport.client.node.NodeClient; |
| 37 | +import org.opensearch.transport.reactor.ReactorNetty4Plugin; |
| 38 | +import org.junit.Assert; |
| 39 | + |
| 40 | +import java.io.IOException; |
| 41 | +import java.nio.ByteBuffer; |
| 42 | +import java.nio.charset.StandardCharsets; |
| 43 | +import java.time.Duration; |
| 44 | +import java.util.ArrayList; |
| 45 | +import java.util.Collection; |
| 46 | +import java.util.List; |
| 47 | +import java.util.Map; |
| 48 | +import java.util.concurrent.CompletableFuture; |
| 49 | +import java.util.function.Supplier; |
| 50 | + |
| 51 | +import io.netty.handler.codec.http.FullHttpResponse; |
| 52 | +import io.netty.util.CharsetUtil; |
| 53 | +import io.netty.util.ReferenceCounted; |
| 54 | +import reactor.core.publisher.Flux; |
| 55 | +import reactor.core.publisher.Mono; |
| 56 | + |
| 57 | +import static org.opensearch.rest.RestRequest.Method.POST; |
| 58 | + |
| 59 | +/** |
| 60 | + * Integration tests for streaming REST channels with tracing support. |
| 61 | + * Tests thread context restoration and proper handling of streaming responses |
| 62 | + * with the tracing infrastructure enabled. |
| 63 | + */ |
| 64 | +@OpenSearchIntegTestCase.ClusterScope(scope = OpenSearchIntegTestCase.Scope.TEST, supportsDedicatedMasters = false, numDataNodes = 1) |
| 65 | +public class ReactorNetty4StreamingTracingIT extends OpenSearchReactorNetty4IntegTestCase { |
| 66 | + |
| 67 | + @Override |
| 68 | + protected boolean addMockHttpTransport() { |
| 69 | + return false; // enable http |
| 70 | + } |
| 71 | + |
| 72 | + @Override |
| 73 | + protected Collection<Class<? extends Plugin>> nodePlugins() { |
| 74 | + return List.of(ReactorNetty4Plugin.class, Netty4ModulePlugin.class, MockStreamingPlugin.class); |
| 75 | + } |
| 76 | + |
| 77 | + public static class MockStreamingPlugin extends Plugin implements ActionPlugin { |
| 78 | + @Override |
| 79 | + public List<RestHandler> getRestHandlers( |
| 80 | + Settings settings, |
| 81 | + RestController restController, |
| 82 | + ClusterSettings clusterSettings, |
| 83 | + IndexScopedSettings indexScopedSettings, |
| 84 | + SettingsFilter settingsFilter, |
| 85 | + IndexNameExpressionResolver indexNameExpressionResolver, |
| 86 | + Supplier<DiscoveryNodes> nodesInCluster |
| 87 | + ) { |
| 88 | + return List.of(new MockStreamingRestHandler()); |
| 89 | + } |
| 90 | + } |
| 91 | + |
| 92 | + public static class MockStreamingRestHandler extends BaseRestHandler { |
| 93 | + @Override |
| 94 | + public String getName() { |
| 95 | + return "mock_streaming_tracing_handler"; |
| 96 | + } |
| 97 | + |
| 98 | + @Override |
| 99 | + public List<Route> routes() { |
| 100 | + return List.of(new Route(POST, "/test/_stream")); |
| 101 | + } |
| 102 | + |
| 103 | + @Override |
| 104 | + public boolean supportsStreaming() { |
| 105 | + return true; |
| 106 | + } |
| 107 | + |
| 108 | + @Override |
| 109 | + public boolean supportsContentStream() { |
| 110 | + return true; |
| 111 | + } |
| 112 | + |
| 113 | + @Override |
| 114 | + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { |
| 115 | + return channel -> { |
| 116 | + if (channel instanceof StreamingRestChannel) { |
| 117 | + StreamingRestChannel streamingChannel = (StreamingRestChannel) channel; |
| 118 | + |
| 119 | + Supplier<ThreadContext.StoredContext> supplier = client.threadPool().getThreadContext().newRestorableContext(true); |
| 120 | + |
| 121 | + Map<String, List<String>> headers = Map.of( |
| 122 | + "Content-Type", |
| 123 | + List.of("text/event-stream"), |
| 124 | + "Cache-Control", |
| 125 | + List.of("no-cache"), |
| 126 | + "Connection", |
| 127 | + List.of("keep-alive") |
| 128 | + ); |
| 129 | + streamingChannel.prepareResponse(RestStatus.OK, headers); |
| 130 | + |
| 131 | + Flux.from(streamingChannel).ofType(HttpChunk.class).collectList().flatMap(chunks -> { |
| 132 | + try (ThreadContext.StoredContext ignored = supplier.get()) { |
| 133 | + |
| 134 | + String opaqueId = request.header(Task.X_OPAQUE_ID); |
| 135 | + streamingChannel.sendChunk( |
| 136 | + createHttpChunk("data: {\"status\":\"streaming\",\"opaque_id\":\"" + opaqueId + "\"}\n\n", false) |
| 137 | + ); |
| 138 | + |
| 139 | + final CompletableFuture<HttpChunk> future = new CompletableFuture<>(); |
| 140 | + |
| 141 | + Flux.just( |
| 142 | + createHttpChunk("data: {\"content\":\"test chunk 1\"}\n\n", false), |
| 143 | + createHttpChunk("data: {\"content\":\"test chunk 2\"}\n\n", false), |
| 144 | + createHttpChunk("data: {\"content\":\"final chunk\",\"is_last\":true}\n\n", true) |
| 145 | + ) |
| 146 | + .delayElements(Duration.ofMillis(100)) |
| 147 | + .doOnNext(streamingChannel::sendChunk) |
| 148 | + .doOnComplete(() -> future.complete(createHttpChunk("", true))) |
| 149 | + .doOnError(future::completeExceptionally) |
| 150 | + .subscribe(); // Simulate streaming delay |
| 151 | + |
| 152 | + return Mono.fromCompletionStage(future); |
| 153 | + } catch (Exception e) { |
| 154 | + return Mono.error(e); |
| 155 | + } |
| 156 | + }).doOnNext(streamingChannel::sendChunk).onErrorResume(ex -> { |
| 157 | + try { |
| 158 | + HttpChunk errorChunk = createHttpChunk("data: {\"error\":\"" + ex.getMessage() + "\"}\n\n", true); |
| 159 | + streamingChannel.sendChunk(errorChunk); |
| 160 | + } catch (Exception e) { |
| 161 | + // Log error |
| 162 | + } |
| 163 | + return Mono.empty(); |
| 164 | + }).subscribe(); |
| 165 | + } |
| 166 | + }; |
| 167 | + } |
| 168 | + |
| 169 | + private HttpChunk createHttpChunk(String sseData, boolean isLast) { |
| 170 | + BytesReference bytesRef = BytesReference.fromByteBuffer(ByteBuffer.wrap(sseData.getBytes(StandardCharsets.UTF_8))); |
| 171 | + return new HttpChunk() { |
| 172 | + @Override |
| 173 | + public void close() { |
| 174 | + if (bytesRef instanceof Releasable) { |
| 175 | + ((Releasable) bytesRef).close(); |
| 176 | + } |
| 177 | + } |
| 178 | + |
| 179 | + @Override |
| 180 | + public boolean isLast() { |
| 181 | + return isLast; |
| 182 | + } |
| 183 | + |
| 184 | + @Override |
| 185 | + public BytesReference content() { |
| 186 | + return bytesRef; |
| 187 | + } |
| 188 | + }; |
| 189 | + } |
| 190 | + } |
| 191 | + |
| 192 | + public void testStreamingWithTraceEnabled() throws Exception { |
| 193 | + ensureGreen(); |
| 194 | + |
| 195 | + HttpServerTransport httpServerTransport = internalCluster().getInstance(HttpServerTransport.class); |
| 196 | + TransportAddress[] boundAddresses = httpServerTransport.boundAddress().boundAddresses(); |
| 197 | + TransportAddress transportAddress = randomFrom(boundAddresses); |
| 198 | + |
| 199 | + List<Tuple<String, CharSequence>> requests = new ArrayList<>(); |
| 200 | + requests.add(Tuple.tuple("/test/_stream", "dummy request body")); |
| 201 | + |
| 202 | + try (ReactorHttpClient nettyHttpClient = ReactorHttpClient.create(Settings.EMPTY)) { |
| 203 | + Collection<FullHttpResponse> singleResponse = nettyHttpClient.post(transportAddress.address(), requests); |
| 204 | + try { |
| 205 | + Assert.assertEquals(1, singleResponse.size()); |
| 206 | + FullHttpResponse response = singleResponse.iterator().next(); |
| 207 | + String responseBody = response.content().toString(CharsetUtil.UTF_8); |
| 208 | + Assert.assertEquals(200, response.status().code()); |
| 209 | + } finally { |
| 210 | + singleResponse.forEach(ReferenceCounted::release); |
| 211 | + } |
| 212 | + } |
| 213 | + } |
| 214 | +} |
0 commit comments