diff --git a/server/src/main/java/org/elasticsearch/tasks/TaskManager.java b/server/src/main/java/org/elasticsearch/tasks/TaskManager.java index 4636edbfab1db..2ed347c226870 100644 --- a/server/src/main/java/org/elasticsearch/tasks/TaskManager.java +++ b/server/src/main/java/org/elasticsearch/tasks/TaskManager.java @@ -170,14 +170,20 @@ public Task register(String type, String action, TaskAwareRequest request, boole Task previousTask = tasks.put(task.getId(), task); assert previousTask == null; if (traceRequest) { - startTrace(threadContext, task); + maybeStartTrace(threadContext, task); } } return task; } - // package private for testing - void startTrace(ThreadContext threadContext, Task task) { + /** + * Start a new trace span if a parent trace context already exists. + * For REST actions this will be the case, otherwise {@link Tracer#startTrace} can be used. + */ + void maybeStartTrace(ThreadContext threadContext, Task task) { + if (threadContext.hasParentTraceContext() == false) { + return; + } TaskId parentTask = task.getParentTaskId(); Map attributes = parentTask.isSet() ? Map.of(Tracer.AttributeKeys.TASK_ID, task.getId(), Tracer.AttributeKeys.PARENT_TASK_ID, parentTask.toString()) @@ -185,6 +191,12 @@ void startTrace(ThreadContext threadContext, Task task) { tracer.startTrace(threadContext, task, task.getAction(), attributes); } + void maybeStopTrace(ThreadContext threadContext, Task task) { + if (threadContext.hasTraceContext()) { + tracer.stopTrace(task); + } + } + public Task registerAndExecute( String type, TransportAction action, @@ -247,7 +259,7 @@ private void registerCancellableTask(Task task, long requestId, boolean traceReq CancellableTaskHolder holder = new CancellableTaskHolder(cancellableTask); cancellableTasks.put(task, requestId, holder); if (traceRequest) { - startTrace(threadPool.getThreadContext(), task); + maybeStartTrace(threadPool.getThreadContext(), task); } // Check if this task was banned before we start it. if (task.getParentTaskId().isSet()) { @@ -346,7 +358,7 @@ public Task unregister(Task task) { return removedTask; } } finally { - tracer.stopTrace(task); + maybeStopTrace(threadPool.getThreadContext(), task); for (RemovedTaskListener listener : removedTaskListeners) { listener.onRemoved(task); } diff --git a/server/src/test/java/org/elasticsearch/tasks/TaskManagerTests.java b/server/src/test/java/org/elasticsearch/tasks/TaskManagerTests.java index 32999d87534dd..86436a0852f58 100644 --- a/server/src/test/java/org/elasticsearch/tasks/TaskManagerTests.java +++ b/server/src/test/java/org/elasticsearch/tasks/TaskManagerTests.java @@ -68,6 +68,7 @@ import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; public class TaskManagerTests extends ESTestCase { @@ -281,13 +282,72 @@ public void testTaskAccounting() { /** * Check that registering a task also causes tracing to be started on that task. */ - public void testRegisterTaskStartsTracing() { + public void testRegisterTaskStartsTracingIfTraceParentExists() { final Tracer mockTracer = mock(Tracer.class); final TaskManager taskManager = new TaskManager(Settings.EMPTY, threadPool, Set.of(), mockTracer); + // fake a trace parent + threadPool.getThreadContext().putHeader(Task.TRACE_PARENT_HTTP_HEADER, "traceparent"); final boolean hasParentTask = randomBoolean(); final TaskId parentTask = hasParentTask ? new TaskId("parentNode", 1) : TaskId.EMPTY_TASK_ID; + try (var ignored = threadPool.getThreadContext().newTraceContext()) { + + final Task task = taskManager.register("testType", "testAction", new TaskAwareRequest() { + + @Override + public void setParentTask(TaskId taskId) {} + + @Override + public void setRequestId(long requestId) {} + + @Override + public TaskId getParentTask() { + return parentTask; + } + }); + + Map attributes = hasParentTask + ? Map.of(Tracer.AttributeKeys.TASK_ID, task.getId(), Tracer.AttributeKeys.PARENT_TASK_ID, parentTask.toString()) + : Map.of(Tracer.AttributeKeys.TASK_ID, task.getId()); + verify(mockTracer).startTrace(any(), eq(task), eq("testAction"), eq(attributes)); + } + } + + /** + * Check that registering a task also causes tracing to be started on that task. + */ + public void testRegisterTaskSkipsTracingIfTraceParentMissing() { + final Tracer mockTracer = mock(Tracer.class); + final TaskManager taskManager = new TaskManager(Settings.EMPTY, threadPool, Set.of(), mockTracer); + + // no trace parent + try (var ignored = threadPool.getThreadContext().newTraceContext()) { + final Task task = taskManager.register("testType", "testAction", new TaskAwareRequest() { + + @Override + public void setParentTask(TaskId taskId) {} + + @Override + public void setRequestId(long requestId) {} + + @Override + public TaskId getParentTask() { + return TaskId.EMPTY_TASK_ID; + } + }); + } + + verifyNoInteractions(mockTracer); + } + + /** + * Check that unregistering a task also causes tracing to be stopped on that task. + */ + public void testUnregisterTaskStopsTracingIfTraceContextExists() { + final Tracer mockTracer = mock(Tracer.class); + final TaskManager taskManager = new TaskManager(Settings.EMPTY, threadPool, Set.of(), mockTracer); + final Task task = taskManager.register("testType", "testAction", new TaskAwareRequest() { @Override @@ -298,20 +358,21 @@ public void setRequestId(long requestId) {} @Override public TaskId getParentTask() { - return parentTask; + return TaskId.EMPTY_TASK_ID; } }); - Map attributes = hasParentTask - ? Map.of(Tracer.AttributeKeys.TASK_ID, task.getId(), Tracer.AttributeKeys.PARENT_TASK_ID, parentTask.toString()) - : Map.of(Tracer.AttributeKeys.TASK_ID, task.getId()); - verify(mockTracer).startTrace(any(), eq(task), eq("testAction"), eq(attributes)); + // fake a trace context (trace parent) + threadPool.getThreadContext().putHeader(Task.TRACE_PARENT_HTTP_HEADER, "traceparent"); + + taskManager.unregister(task); + verify(mockTracer).stopTrace(task); } /** * Check that unregistering a task also causes tracing to be stopped on that task. */ - public void testUnregisterTaskStopsTracing() { + public void testUnregisterTaskStopsTracingIfTraceContextMissing() { final Tracer mockTracer = mock(Tracer.class); final TaskManager taskManager = new TaskManager(Settings.EMPTY, threadPool, Set.of(), mockTracer); @@ -329,18 +390,22 @@ public TaskId getParentTask() { } }); - taskManager.unregister(task); + // no trace context - verify(mockTracer).stopTrace(task); + taskManager.unregister(task); + verifyNoInteractions(mockTracer); } /** - * Check that registering and executing a task also causes tracing to be started and stopped on that task. + * Check that registering and executing a task also causes tracing to be started if a trace parent exists. */ - public void testRegisterAndExecuteStartsAndStopsTracing() { + public void testRegisterAndExecuteStartsTracingIfTraceParentExists() { final Tracer mockTracer = mock(Tracer.class); final TaskManager taskManager = new TaskManager(Settings.EMPTY, threadPool, Set.of(), mockTracer); + // fake a trace parent + threadPool.getThreadContext().putHeader(Task.TRACE_PARENT_HTTP_HEADER, "traceparent"); + final Task task = taskManager.registerAndExecute( "testType", new TransportAction( @@ -375,25 +440,68 @@ public TaskId getParentTask() { verify(mockTracer).startTrace(any(), eq(task), eq("actionName"), anyMap()); } + /** + * Check that registering and executing a task skips tracing if trace parent does not exists. + */ + public void testRegisterAndExecuteSkipsTracingIfTraceParentMissing() { + final Tracer mockTracer = mock(Tracer.class); + final TaskManager taskManager = new TaskManager(Settings.EMPTY, threadPool, Set.of(), mockTracer); + + // clean thread context without trace parent + + final Task task = taskManager.registerAndExecute( + "testType", + new TransportAction( + "actionName", + new ActionFilters(Set.of()), + taskManager, + EsExecutors.DIRECT_EXECUTOR_SERVICE + ) { + @Override + protected void doExecute(Task task, ActionRequest request, ActionListener listener) { + listener.onResponse(new ActionResponse() { + @Override + public void writeTo(StreamOutput out) {} + }); + } + }, + new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public TaskId getParentTask() { + return TaskId.EMPTY_TASK_ID; + } + }, + null, + ActionTestUtils.assertNoFailureListener(r -> {}) + ); + + verifyNoInteractions(mockTracer); + } + public void testRegisterWithEnabledDisabledTracing() { final Tracer mockTracer = mock(Tracer.class); final TaskManager taskManager = spy(new TaskManager(Settings.EMPTY, threadPool, Set.of(), mockTracer)); taskManager.register("type", "action", makeTaskRequest(true, 123), false); - verify(taskManager, times(0)).startTrace(any(), any()); + verify(taskManager, times(0)).maybeStartTrace(any(), any()); taskManager.register("type", "action", makeTaskRequest(false, 234), false); - verify(taskManager, times(0)).startTrace(any(), any()); + verify(taskManager, times(0)).maybeStartTrace(any(), any()); clearInvocations(taskManager); taskManager.register("type", "action", makeTaskRequest(true, 345), true); - verify(taskManager, times(1)).startTrace(any(), any()); + verify(taskManager, times(1)).maybeStartTrace(any(), any()); clearInvocations(taskManager); taskManager.register("type", "action", makeTaskRequest(false, 456), true); - verify(taskManager, times(1)).startTrace(any(), any()); + verify(taskManager, times(1)).maybeStartTrace(any(), any()); } static class CancellableRequest extends AbstractTransportRequest {