Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Delegate getMin/getMax methods for ExitableTerms ([#20775](https://github.com/opensearch-project/OpenSearch/pull/20775))
- Fix terms lookup subquery fetch limit reading from non-existent index setting instead of cluster `max_clause_count` ([#20823](https://github.com/opensearch-project/OpenSearch/pull/20823))
- Fix array_index_out_of_bounds_exception with wildcard and aggregations ([#20842](https://github.com/opensearch-project/OpenSearch/pull/20842))
- Ensure that transient ThreadContext headers with propagators survive restore ([#169373](https://github.com/opensearch-project/OpenSearch/pull/20854))

### Dependencies
- Bump shadow-gradle-plugin from 8.3.9 to 9.3.1 ([#20569](https://github.com/opensearch-project/OpenSearch/pull/20569))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -293,10 +293,31 @@ public StoredContext newStoredContext(boolean preserveResponseHeaders, Collectio
final ThreadContextStruct newContext = threadLocal.get();

return () -> {
// Re-apply propagator-declared transients from the current context back into the
// snapshot being restored. This ensures that transients written after the snapshot
// was taken using newStoredContext (e.g. CURRENT_SPAN set by the tracing infrastructure) are not silently
// dropped when the security plugin (or any other caller) calls storedContext.restore().
// Without this, restore() would blindly overwrite the threadLocal with the original
// snapshot, losing any propagated transients that were set after newStoredContext() was called.
ThreadContextStruct current = threadLocal.get();
ThreadContextStruct restoredContext = originalContext;
final Map<String, Object> propagated = propagateTransients(current.transientHeaders, current.isSystemContext);
if (!propagated.isEmpty()) {
Map<String, Object> merged = new HashMap<>(originalContext.transientHeaders);
propagated.forEach(merged::putIfAbsent);
restoredContext = new ThreadContextStruct(
restoredContext.requestHeaders,
restoredContext.responseHeaders,
merged,
restoredContext.persistentHeaders,
restoredContext.isSystemContext,
restoredContext.warningHeadersSize
);
}
if (preserveResponseHeaders && threadLocal.get() != newContext) {
threadLocal.set(originalContext.putResponseHeaders(threadLocal.get().responseHeaders));
threadLocal.set(restoredContext.putResponseHeaders(threadLocal.get().responseHeaders));
} else {
threadLocal.set(originalContext);
threadLocal.set(restoredContext);
}
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ public Map<String, Object> transients(Map<String, Object> source) {
final Map<String, Object> transients = new HashMap<>();
if (source.containsKey(CURRENT_SPAN)) {
final SpanReference current = (SpanReference) source.get(CURRENT_SPAN);
if (current != null) {
if (current != null && current.getSpan() != null) {
transients.put(CURRENT_SPAN, new SpanReference(current.getSpan()));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -851,4 +851,46 @@ protected void doRun() throws Exception {
}
};
}

// We are simulating behavior that happens in Netty4HttpRequestHeaderVerifier
// It take a snapshot of state and stores in CONTEXT_TO_RESTORE and
// later tries to restore the same in SecurityFilter. Any transients added in between are lost
public void testPropagatedTransientsAreRestored() {
ThreadContext threadContext = new ThreadContext(Settings.EMPTY);
final String PROPAGATED_KEY = "test_propagated_transient";
final Object PROPAGATED_VALUE = new Object();

// Register a propagator that declares PROPAGATED_KEY as a transient to carry across stashes.
threadContext.registerThreadContextStatePropagator(new ThreadContextStatePropagator() {
@Override
@SuppressWarnings("removal")
public Map<String, Object> transients(Map<String, Object> source) {
if (source.containsKey(PROPAGATED_KEY)) {
return Collections.singletonMap(PROPAGATED_KEY, source.get(PROPAGATED_KEY));
}
return Collections.emptyMap();
}

@Override
@SuppressWarnings("removal")
public Map<String, String> headers(Map<String, Object> source) {
return Collections.emptyMap();
}
});

ThreadContext.StoredContext storedContext = null;
try (ThreadContext.StoredContext sc = threadContext.newStoredContext(false, Collections.emptyList())) {
// now we add something to original thread
// Simulate the tracing infrastructure writing CURRENT_SPAN into the stashed context.
storedContext = sc;
threadContext.putTransient(PROPAGATED_KEY, PROPAGATED_VALUE);
} catch (Exception e) {
// unlikey to get exception, if we got one, test should fail
throw e;
}
// storedContext would have closed. Now we restore and after that, our original thread should have it
storedContext.restore();
// we should be able to find the key now
assertEquals(threadContext.getTransient(PROPAGATED_KEY), PROPAGATED_VALUE);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import org.junit.After;
import org.junit.Before;

import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ExecutionException;
Expand Down Expand Up @@ -269,4 +271,12 @@ public void testSpanNotPropagatedToChildSystemThreadContext() {
assertThat(threadContext.getTransient(ThreadContextBasedTracerContextStorage.CURRENT_SPAN), is(not(nullValue())));
assertThat(threadContextStorage.get(ThreadContextBasedTracerContextStorage.CURRENT_SPAN), is(nullValue()));
}

public void testNullSpanWithinSpanReference() {
SpanReference spanReference = new SpanReference(null);
Map<String, Object> source = new HashMap<>();
source.put(ThreadContextBasedTracerContextStorage.CURRENT_SPAN, spanReference);
ThreadContextBasedTracerContextStorage context = (ThreadContextBasedTracerContextStorage) threadContextStorage;
assertTrue(context.transients(source).isEmpty());
}
}
Loading