Skip to content

Commit e39a496

Browse files
Ensure that transient ThreadContext headers with propagators survive restore (#20854)
* Ensure that transient ThreadContext headers with propagators survive restore Signed-off-by: Deepti24 <chauhan.deepti24@gmail.com> Signed-off-by: Deepti Chauhan <dchauhan3@atlassian.com> Co-authored-by: Deepti Chauhan <dchauhan3@atlassian.com>
1 parent d0dddd9 commit e39a496

File tree

5 files changed

+100
-4
lines changed

5 files changed

+100
-4
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
5858
- Delegate getMin/getMax methods for ExitableTerms ([#20775](https://github.com/opensearch-project/OpenSearch/pull/20775))
5959
- Fix terms lookup subquery fetch limit reading from non-existent index setting instead of cluster `max_clause_count` ([#20823](https://github.com/opensearch-project/OpenSearch/pull/20823))
6060
- Fix array_index_out_of_bounds_exception with wildcard and aggregations ([#20842](https://github.com/opensearch-project/OpenSearch/pull/20842))
61+
- Ensure that transient ThreadContext headers with propagators survive restore ([#169373](https://github.com/opensearch-project/OpenSearch/pull/20854))
6162
- Handle dependencies between analyzers ([#19248](https://github.com/opensearch-project/OpenSearch/pull/19248))
6263
- Fix `_field_caps` returning empty results and corrupted field names for `disable_objects: true` mappings ([#20800](https://github.com/opensearch-project/OpenSearch/pull/20800))
6364

server/src/main/java/org/opensearch/common/util/concurrent/ThreadContext.java

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -252,14 +252,26 @@ public StoredContext newStoredContext(boolean preserveResponseHeaders) {
252252
return newStoredContext(preserveResponseHeaders, Collections.emptyList());
253253
}
254254

255+
public StoredContext newStoredContext(boolean preserveResponseHeaders, boolean preserveTransients) {
256+
return newStoredContext(preserveResponseHeaders, preserveTransients, Collections.emptyList());
257+
}
258+
259+
public StoredContext newStoredContext(boolean preserveResponseHeaders, Collection<String> transientHeadersToClear) {
260+
return newStoredContext(preserveResponseHeaders, false, transientHeadersToClear);
261+
}
262+
255263
/**
256264
* Just like {@link #stashContext()} but no default context is set. Instead, the {@code transientHeadersToClear} argument can be used
257265
* to clear specific transient headers in the new context. All headers (with the possible exception of {@code responseHeaders}) are
258266
* restored by closing the returned {@link StoredContext}.
259267
*
260268
* @param preserveResponseHeaders if set to <code>true</code> the response headers of the restore thread will be preserved.
261269
*/
262-
public StoredContext newStoredContext(boolean preserveResponseHeaders, Collection<String> transientHeadersToClear) {
270+
public StoredContext newStoredContext(
271+
boolean preserveResponseHeaders,
272+
boolean preserveTransients,
273+
Collection<String> transientHeadersToClear
274+
) {
263275
final ThreadContextStruct originalContext = threadLocal.get();
264276
final Map<String, Object> newTransientHeaders = new HashMap<>(originalContext.transientHeaders);
265277

@@ -293,10 +305,24 @@ public StoredContext newStoredContext(boolean preserveResponseHeaders, Collectio
293305
final ThreadContextStruct newContext = threadLocal.get();
294306

295307
return () -> {
308+
// Re-apply propagator-declared transients from the current context back into the
309+
// snapshot being restored. This ensures that transients written after the snapshot
310+
// was taken using newStoredContext (e.g. CURRENT_SPAN set by the tracing infrastructure) are not silently
311+
// dropped when the security plugin (or any other caller) calls storedContext.restore().
312+
// Without this, restore() would blindly overwrite the threadLocal with the original
313+
// snapshot, losing any propagated transients that were set after newStoredContext() was called.
314+
ThreadContextStruct current = threadLocal.get();
315+
ThreadContextStruct restoredContext = originalContext;
316+
if (preserveTransients) {
317+
final Map<String, Object> propagated = propagateTransients(current.transientHeaders, current.isSystemContext);
318+
if (!propagated.isEmpty()) {
319+
restoredContext = originalContext.putTransientIfAbsent(propagated);
320+
}
321+
}
296322
if (preserveResponseHeaders && threadLocal.get() != newContext) {
297-
threadLocal.set(originalContext.putResponseHeaders(threadLocal.get().responseHeaders));
323+
threadLocal.set(restoredContext.putResponseHeaders(threadLocal.get().responseHeaders));
298324
} else {
299-
threadLocal.set(originalContext);
325+
threadLocal.set(restoredContext);
300326
}
301327
};
302328
}
@@ -864,6 +890,14 @@ private ThreadContextStruct putTransient(Map<String, Object> values) {
864890
return new ThreadContextStruct(requestHeaders, responseHeaders, newTransient, persistentHeaders, isSystemContext);
865891
}
866892

893+
private ThreadContextStruct putTransientIfAbsent(Map<String, Object> values) {
894+
Map<String, Object> newTransient = new HashMap<>(this.transientHeaders);
895+
for (Map.Entry<String, Object> entry : values.entrySet()) {
896+
newTransient.putIfAbsent(entry.getKey(), entry.getValue());
897+
}
898+
return new ThreadContextStruct(requestHeaders, responseHeaders, newTransient, persistentHeaders, isSystemContext);
899+
}
900+
867901
private ThreadContextStruct putTransient(String key, Object value) {
868902
Map<String, Object> newTransient = new HashMap<>(this.transientHeaders);
869903
putSingleHeader(key, value, newTransient);

server/src/main/java/org/opensearch/telemetry/tracing/ThreadContextBasedTracerContextStorage.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ public Map<String, Object> transients(Map<String, Object> source) {
5656
final Map<String, Object> transients = new HashMap<>();
5757
if (source.containsKey(CURRENT_SPAN)) {
5858
final SpanReference current = (SpanReference) source.get(CURRENT_SPAN);
59-
if (current != null) {
59+
if (current != null && current.getSpan() != null) {
6060
transients.put(CURRENT_SPAN, new SpanReference(current.getSpan()));
6161
}
6262
}

server/src/test/java/org/opensearch/common/util/concurrent/ThreadContextTests.java

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -851,4 +851,46 @@ protected void doRun() throws Exception {
851851
}
852852
};
853853
}
854+
855+
// We are simulating behavior that happens in Netty4HttpRequestHeaderVerifier
856+
// It take a snapshot of state and stores in CONTEXT_TO_RESTORE and
857+
// later tries to restore the same in SecurityFilter. Any transients added in between are lost
858+
public void testPropagatedTransientsAreRestored() {
859+
ThreadContext threadContext = new ThreadContext(Settings.EMPTY);
860+
final String PROPAGATED_KEY = "test_propagated_transient";
861+
final Object PROPAGATED_VALUE = new Object();
862+
863+
// Register a propagator that declares PROPAGATED_KEY as a transient to carry across stashes.
864+
threadContext.registerThreadContextStatePropagator(new ThreadContextStatePropagator() {
865+
@Override
866+
@SuppressWarnings("removal")
867+
public Map<String, Object> transients(Map<String, Object> source) {
868+
if (source.containsKey(PROPAGATED_KEY)) {
869+
return Collections.singletonMap(PROPAGATED_KEY, source.get(PROPAGATED_KEY));
870+
}
871+
return Collections.emptyMap();
872+
}
873+
874+
@Override
875+
@SuppressWarnings("removal")
876+
public Map<String, String> headers(Map<String, Object> source) {
877+
return Collections.emptyMap();
878+
}
879+
});
880+
881+
ThreadContext.StoredContext storedContext = null;
882+
try (ThreadContext.StoredContext sc = threadContext.newStoredContext(false, true)) {
883+
// now we add something to original thread
884+
// Simulate the tracing infrastructure writing CURRENT_SPAN into the stashed context.
885+
storedContext = sc;
886+
threadContext.putTransient(PROPAGATED_KEY, PROPAGATED_VALUE);
887+
} catch (Exception e) {
888+
// unlikey to get exception, if we got one, test should fail
889+
throw e;
890+
}
891+
// storedContext would have closed. Now we restore and after that, our original thread should have it
892+
storedContext.restore();
893+
// we should be able to find the key now
894+
assertEquals(threadContext.getTransient(PROPAGATED_KEY), PROPAGATED_VALUE);
895+
}
854896
}

server/src/test/java/org/opensearch/telemetry/tracing/ThreadContextBasedTracerContextStorageTests.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
import org.junit.After;
2323
import org.junit.Before;
2424

25+
import java.util.HashMap;
26+
import java.util.Map;
2527
import java.util.Optional;
2628
import java.util.Set;
2729
import java.util.concurrent.ExecutionException;
@@ -269,4 +271,21 @@ public void testSpanNotPropagatedToChildSystemThreadContext() {
269271
assertThat(threadContext.getTransient(ThreadContextBasedTracerContextStorage.CURRENT_SPAN), is(not(nullValue())));
270272
assertThat(threadContextStorage.get(ThreadContextBasedTracerContextStorage.CURRENT_SPAN), is(nullValue()));
271273
}
274+
275+
public void testNullSpanWithinSpanReference() {
276+
// invalid span, should not be present in final transients
277+
SpanReference spanReference = new SpanReference(null);
278+
Map<String, Object> source = new HashMap<>();
279+
source.put(ThreadContextBasedTracerContextStorage.CURRENT_SPAN, spanReference);
280+
ThreadContextBasedTracerContextStorage context = (ThreadContextBasedTracerContextStorage) threadContextStorage;
281+
assertTrue(context.transients(source).isEmpty());
282+
283+
// valid span, present in final transients
284+
final Span span = tracer.startSpan(SpanCreationContext.internal().name("test"));
285+
spanReference = new SpanReference(span);
286+
source = new HashMap<>();
287+
source.put(ThreadContextBasedTracerContextStorage.CURRENT_SPAN, spanReference);
288+
assertFalse(context.transients(source).isEmpty());
289+
assertEquals(span, ((SpanReference) context.transients(source).get(ThreadContextBasedTracerContextStorage.CURRENT_SPAN)).getSpan());
290+
}
272291
}

0 commit comments

Comments
 (0)