Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -235,10 +235,10 @@ private Context getParentContext(TraceContext traceContext) {
// You can just pass the Context object directly to another thread (it is immutable and thus thread-safe).

// Attempt to fetch a local parent context first, otherwise look for a remote parent
Context parentContext = traceContext.getTransient("parent_" + Task.APM_TRACE_CONTEXT);
Context parentContext = traceContext.getTransient(Task.PARENT_APM_TRACE_CONTEXT);
if (parentContext == null) {
final String traceParentHeader = traceContext.getTransient("parent_" + Task.TRACE_PARENT_HTTP_HEADER);
final String traceStateHeader = traceContext.getTransient("parent_" + Task.TRACE_STATE);
final String traceParentHeader = traceContext.getTransient(Task.PARENT_TRACE_PARENT_HEADER);
final String traceStateHeader = traceContext.getTransient(Task.PARENT_TRACE_STATE);

if (traceParentHeader != null) {
final Map<String, String> traceContextMap = Maps.newMapWithExpectedSize(2);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -595,10 +595,10 @@ public void copyRequestHeadersToThreadContext(HttpPreRequest request, ThreadCont
Optional<String> traceId = RestUtils.extractTraceId(traceparent);
if (traceId.isPresent()) {
threadContext.putHeader(Task.TRACE_ID, traceId.get());
threadContext.putTransient("parent_" + Task.TRACE_PARENT_HTTP_HEADER, traceparent);
threadContext.putTransient(Task.PARENT_TRACE_PARENT_HEADER, traceparent);
}
} else if (name.equals(Task.TRACE_STATE)) {
threadContext.putTransient("parent_" + Task.TRACE_STATE, distinctHeaderValues.get(0));
threadContext.putTransient(Task.PARENT_TRACE_STATE, distinctHeaderValues.get(0));
} else {
threadContext.putHeader(name, String.join(",", distinctHeaderValues));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,33 +193,38 @@ public StoredContext newEmptySystemContext() {
*/
public StoredContext newTraceContext() {
final ThreadContextStruct originalContext = threadLocal.get();
final Map<String, String> newRequestHeaders = new HashMap<>(originalContext.requestHeaders);
final Map<String, Object> newTransientHeaders = new HashMap<>(originalContext.transientHeaders);
// this is the context when this method returns
final ThreadContextStruct newContext;
if (originalContext.hasTraceContext() == false) {
newContext = originalContext;
} else {
final Map<String, String> newRequestHeaders = new HashMap<>(originalContext.requestHeaders);
final Map<String, Object> newTransientHeaders = new HashMap<>(originalContext.transientHeaders);

final String previousTraceParent = newRequestHeaders.remove(Task.TRACE_PARENT_HTTP_HEADER);
if (previousTraceParent != null) {
newTransientHeaders.put("parent_" + Task.TRACE_PARENT_HTTP_HEADER, previousTraceParent);
}
final String previousTraceParent = newRequestHeaders.remove(Task.TRACE_PARENT_HTTP_HEADER);
if (previousTraceParent != null) {
newTransientHeaders.put(Task.PARENT_TRACE_PARENT_HEADER, previousTraceParent);
}

final String previousTraceState = newRequestHeaders.remove(Task.TRACE_STATE);
if (previousTraceState != null) {
newTransientHeaders.put("parent_" + Task.TRACE_STATE, previousTraceState);
}
final String previousTraceState = newRequestHeaders.remove(Task.TRACE_STATE);
if (previousTraceState != null) {
newTransientHeaders.put(Task.PARENT_TRACE_STATE, previousTraceState);
}

final Object previousTraceContext = newTransientHeaders.remove(Task.APM_TRACE_CONTEXT);
if (previousTraceContext != null) {
newTransientHeaders.put("parent_" + Task.APM_TRACE_CONTEXT, previousTraceContext);
}
final Object previousTraceContext = newTransientHeaders.remove(Task.APM_TRACE_CONTEXT);
if (previousTraceContext != null) {
newTransientHeaders.put(Task.PARENT_APM_TRACE_CONTEXT, previousTraceContext);
}

// this is the context when this method returns
final ThreadContextStruct newContext = new ThreadContextStruct(
newRequestHeaders,
originalContext.responseHeaders,
newTransientHeaders,
originalContext.isSystemContext,
originalContext.warningHeadersSize
);
threadLocal.set(newContext);
newContext = new ThreadContextStruct(
newRequestHeaders,
originalContext.responseHeaders,
newTransientHeaders,
originalContext.isSystemContext,
originalContext.warningHeadersSize
);
threadLocal.set(newContext);
}
// Tracing shouldn't interrupt the propagation of response headers, so in the same as
// #newStoredContextPreservingResponseHeaders(), pass on any potential changes to the response headers.
return () -> {
Expand All @@ -233,10 +238,11 @@ public StoredContext newTraceContext() {
}

public boolean hasTraceContext() {
final ThreadContextStruct context = threadLocal.get();
return context.requestHeaders.containsKey(Task.TRACE_PARENT_HTTP_HEADER)
|| context.requestHeaders.containsKey(Task.TRACE_STATE)
|| context.transientHeaders.containsKey(Task.APM_TRACE_CONTEXT);
return threadLocal.get().hasTraceContext();
}

public boolean hasParentTraceContext() {
return threadLocal.get().hasParentTraceContext();
}

/**
Expand All @@ -254,10 +260,10 @@ public StoredContext clearTraceContext() {
newRequestHeaders.remove(Task.TRACE_PARENT_HTTP_HEADER);
newRequestHeaders.remove(Task.TRACE_STATE);

newTransientHeaders.remove("parent_" + Task.TRACE_PARENT_HTTP_HEADER);
newTransientHeaders.remove("parent_" + Task.TRACE_STATE);
newTransientHeaders.remove(Task.PARENT_TRACE_PARENT_HEADER);
newTransientHeaders.remove(Task.PARENT_TRACE_STATE);
newTransientHeaders.remove(Task.APM_TRACE_CONTEXT);
newTransientHeaders.remove("parent_" + Task.APM_TRACE_CONTEXT);
newTransientHeaders.remove(Task.PARENT_APM_TRACE_CONTEXT);

threadLocal.set(
new ThreadContextStruct(
Expand Down Expand Up @@ -853,6 +859,18 @@ private ThreadContextStruct putResponseHeaders(Map<String, Set<String>> headers)
return new ThreadContextStruct(requestHeaders, newResponseHeaders, transientHeaders, isSystemContext);
}

private boolean hasTraceContext() {
return requestHeaders.containsKey(Task.TRACE_PARENT_HTTP_HEADER)
|| requestHeaders.containsKey(Task.TRACE_STATE)
|| transientHeaders.containsKey(Task.APM_TRACE_CONTEXT);
}

private boolean hasParentTraceContext() {
return transientHeaders.containsKey(Task.PARENT_TRACE_PARENT_HEADER)
|| transientHeaders.containsKey(Task.PARENT_TRACE_STATE)
|| transientHeaders.containsKey(Task.PARENT_APM_TRACE_CONTEXT);
}

private void logWarningHeaderThresholdExceeded(long threshold, Setting<?> thresholdSetting) {
// If available, log some selected headers to help identifying the source of the request.
// Note: Only Task.HEADERS_TO_COPY are guaranteed to be preserved at this point.
Expand Down
30 changes: 18 additions & 12 deletions server/src/main/java/org/elasticsearch/tasks/Task.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@ public class Task implements Traceable {
*/
public static final String X_OPAQUE_ID_HTTP_HEADER = "X-Opaque-Id";

/**
* A request header that indicates the origin of the request from Elastic stack. The value will stored in ThreadContext
* and emitted to ES logs
*/
public static final String X_ELASTIC_PRODUCT_ORIGIN_HTTP_HEADER = "X-elastic-product-origin";

public static final String X_ELASTIC_PROJECT_ID_HTTP_HEADER = "X-Elastic-Project-Id";

/**
* The request header which is contained in HTTP request. We parse trace.id from it and store it in thread context.
* TRACE_PARENT once parsed in RestController.tryAllHandler is not preserved
Expand All @@ -39,28 +47,26 @@ public class Task implements Traceable {
*/
public static final String TRACE_PARENT_HTTP_HEADER = "traceparent";

public static final String TRACE_STATE = "tracestate";

/**
* A request header that indicates the origin of the request from Elastic stack. The value will stored in ThreadContext
* and emitted to ES logs
* Parsed part of traceparent. It is stored in thread context and emitted in logs.
* Has to be declared as a header copied over for tasks.
*/
public static final String X_ELASTIC_PRODUCT_ORIGIN_HTTP_HEADER = "X-elastic-product-origin";
public static final String TRACE_ID = "trace.id";

public static final String TRACE_STATE = "tracestate";
public static final String TRACE_START_TIME = "trace.starttime";

/**
* Used internally to pass the apm trace context between the nodes
*/
public static final String APM_TRACE_CONTEXT = "apm.local.context";

/**
* Parsed part of traceparent. It is stored in thread context and emitted in logs.
* Has to be declared as a header copied over for tasks.
*/
public static final String TRACE_ID = "trace.id";
public static final String PARENT_TRACE_PARENT_HEADER = "parent_" + Task.TRACE_PARENT_HTTP_HEADER;

public static final String TRACE_START_TIME = "trace.starttime";
public static final String TRACE_PARENT = "traceparent";
public static final String X_ELASTIC_PROJECT_ID_HTTP_HEADER = "X-Elastic-Project-Id";
public static final String PARENT_TRACE_STATE = "parent_" + Task.TRACE_STATE;

public static final String PARENT_APM_TRACE_CONTEXT = "parent_" + Task.APM_TRACE_CONTEXT;

public static final Set<String> HEADERS_TO_COPY = Set.of(
X_OPAQUE_ID_HTTP_HEADER,
Expand Down
28 changes: 18 additions & 10 deletions server/src/main/java/org/elasticsearch/tasks/TaskManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -161,24 +161,32 @@ public Task register(String type, String action, TaskAwareRequest request, boole
Task previousTask = tasks.put(task.getId(), task);
assert previousTask == null;
if (traceRequest) {
startTrace(threadContext, task);
maybeStartTrace(threadContext, task);
}
}
return task;
}

// Start a new trace span if a parent trace context already exists.
// For REST actions this will be the case, otherwise Tracer#startTrace can be used.
// package private for testing
void startTrace(ThreadContext threadContext, Task task) {
void maybeStartTrace(ThreadContext threadContext, Task task) {
if (threadContext.hasParentTraceContext() == false) {
return;
}
TaskId parentTask = task.getParentTaskId();
Map<String, Object> attributes = Map.of(
Tracer.AttributeKeys.TASK_ID,
task.getId(),
Tracer.AttributeKeys.PARENT_TASK_ID,
parentTask.toString()
);
Map<String, Object> attributes = parentTask.isSet()
? Map.of(Tracer.AttributeKeys.TASK_ID, task.getId(), Tracer.AttributeKeys.PARENT_TASK_ID, parentTask.toString())
: Map.of(Tracer.AttributeKeys.TASK_ID, task.getId());
tracer.startTrace(threadContext, task, task.getAction(), attributes);
}

void maybeStopTrace(ThreadContext threadContext, Task task) {
if (threadContext.hasTraceContext()) {
tracer.stopTrace(task);
}
}

public <Request extends ActionRequest, Response extends ActionResponse> Task registerAndExecute(
String type,
TransportAction<Request, Response> action,
Expand Down Expand Up @@ -241,7 +249,7 @@ private void registerCancellableTask(Task task, long requestId, boolean traceReq
CancellableTaskHolder holder = new CancellableTaskHolder(cancellableTask);
cancellableTasks.put(task, requestId, holder);
if (traceRequest) {
startTrace(threadPool.getThreadContext(), task);
maybeStartTrace(threadPool.getThreadContext(), task);
}
// Check if this task was banned before we start it.
if (task.getParentTaskId().isSet()) {
Expand Down Expand Up @@ -340,7 +348,7 @@ public Task unregister(Task task) {
return removedTask;
}
} finally {
tracer.stopTrace(task);
maybeStopTrace(threadPool.getThreadContext(), task);
for (RemovedTaskListener listener : removedTaskListeners) {
listener.onRemoved(task);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ public enum Transports {
;
private static final Set<String> REQUEST_HEADERS_ALLOWED_ON_DEFAULT_THREAD_CONTEXT = Set.of(
Task.TRACE_ID,
Task.TRACE_PARENT,
Task.TRACE_PARENT_HTTP_HEADER,
Task.X_OPAQUE_ID_HTTP_HEADER,
Task.X_ELASTIC_PRODUCT_ORIGIN_HTTP_HEADER,
Task.X_ELASTIC_PROJECT_ID_HTTP_HEADER
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,14 @@

import static com.carrotsearch.randomizedtesting.RandomizedTest.randomAsciiLettersOfLengthBetween;
import static org.elasticsearch.tasks.Task.HEADERS_TO_COPY;
import static org.hamcrest.Matchers.anEmptyMap;
import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasItem;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.nullValue;
import static org.hamcrest.Matchers.sameInstance;
Expand Down Expand Up @@ -1161,6 +1163,75 @@ public void testNewEmptySystemContext() {
assertNotNull(threadContext.getHeader(header));
}

public void testNewTraceContext() {
final var threadContext = new ThreadContext(Settings.EMPTY);

var rootTraceContext = Map.of(Task.TRACE_PARENT_HTTP_HEADER, randomIdentifier(), Task.TRACE_STATE, randomIdentifier());
var apmTraceContext = new Object();
var responseKey = randomIdentifier();
var responseValue = randomAlphaOfLength(10);

threadContext.putHeader(rootTraceContext);
threadContext.putTransient(Task.APM_TRACE_CONTEXT, apmTraceContext);

assertThat(threadContext.hasTraceContext(), equalTo(true));
assertThat(threadContext.hasParentTraceContext(), equalTo(false));

try (var ignored = threadContext.newTraceContext()) {
assertThat(threadContext.hasTraceContext(), equalTo(false)); // no trace started yet
assertThat(threadContext.hasParentTraceContext(), equalTo(true));

assertThat(threadContext.getHeaders(), is(anEmptyMap()));
assertThat(
threadContext.getTransientHeaders(),
equalTo(
Map.of(
Task.PARENT_TRACE_PARENT_HEADER,
rootTraceContext.get(Task.TRACE_PARENT_HTTP_HEADER),
Task.PARENT_TRACE_STATE,
rootTraceContext.get(Task.TRACE_STATE),
Task.PARENT_APM_TRACE_CONTEXT,
apmTraceContext
)
)
);
// response headers shall be propagated
threadContext.addResponseHeader(responseKey, responseValue);
}

assertThat(threadContext.hasTraceContext(), equalTo(true));
assertThat(threadContext.hasParentTraceContext(), equalTo(false));

assertThat(threadContext.getHeaders(), equalTo(rootTraceContext));
assertThat(threadContext.getTransientHeaders(), equalTo(Map.of(Task.APM_TRACE_CONTEXT, apmTraceContext)));
assertThat(threadContext.getResponseHeaders(), equalTo(Map.of(responseKey, List.of(responseValue))));
}

public void testNewTraceContextWithoutParentTrace() {
final var threadContext = new ThreadContext(Settings.EMPTY);

var responseKey = randomIdentifier();
var responseValue = randomAlphaOfLength(10);

assertThat(threadContext.hasTraceContext(), equalTo(false));
assertThat(threadContext.hasParentTraceContext(), equalTo(false));

try (var ignored = threadContext.newTraceContext()) {
assertTrue(threadContext.isDefaultContext());
assertThat(threadContext.hasTraceContext(), equalTo(false));
assertThat(threadContext.hasParentTraceContext(), equalTo(false));

// discared, just making sure the context is isolated
threadContext.putTransient(randomIdentifier(), randomAlphaOfLength(10));
// response headers shall be propagated
threadContext.addResponseHeader(responseKey, responseValue);
}

assertThat(threadContext.getHeaders(), is(anEmptyMap()));
assertThat(threadContext.getTransientHeaders(), is(anEmptyMap()));
assertThat(threadContext.getResponseHeaders(), equalTo(Map.of(responseKey, List.of(responseValue))));
}

public void testRestoreExistingContext() {
final var threadContext = new ThreadContext(Settings.EMPTY);
final var header = randomIdentifier();
Expand Down
Loading