Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions server/src/main/java/org/elasticsearch/tasks/TaskManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -170,21 +170,33 @@ public Task register(String type, String action, TaskAwareRequest request, boole
Task previousTask = tasks.put(task.getId(), task);
assert previousTask == null;
if (traceRequest) {
startTrace(threadContext, task);
maybeStartTrace(threadContext, task);
}
}
return task;
}

// package private for testing
void startTrace(ThreadContext threadContext, Task task) {
/**
* Start a new trace span if a parent trace context already exists.
* For REST actions this will be the case, otherwise {@link Tracer#startTrace} can be used.
*/
void maybeStartTrace(ThreadContext threadContext, Task task) {
if (threadContext.hasParentTraceContext() == false) {
return;
}
TaskId parentTask = task.getParentTaskId();
Map<String, Object> attributes = parentTask.isSet()
? Map.of(Tracer.AttributeKeys.TASK_ID, task.getId(), Tracer.AttributeKeys.PARENT_TASK_ID, parentTask.toString())
: Map.of(Tracer.AttributeKeys.TASK_ID, task.getId());
tracer.startTrace(threadContext, task, task.getAction(), attributes);
}

void maybeStopTrace(ThreadContext threadContext, Task task) {
if (threadContext.hasTraceContext()) {
tracer.stopTrace(task);
}
}

public <Request extends ActionRequest, Response extends ActionResponse> Task registerAndExecute(
String type,
TransportAction<Request, Response> action,
Expand Down Expand Up @@ -247,7 +259,7 @@ private void registerCancellableTask(Task task, long requestId, boolean traceReq
CancellableTaskHolder holder = new CancellableTaskHolder(cancellableTask);
cancellableTasks.put(task, requestId, holder);
if (traceRequest) {
startTrace(threadPool.getThreadContext(), task);
maybeStartTrace(threadPool.getThreadContext(), task);
}
// Check if this task was banned before we start it.
if (task.getParentTaskId().isSet()) {
Expand Down Expand Up @@ -346,7 +358,7 @@ public Task unregister(Task task) {
return removedTask;
}
} finally {
tracer.stopTrace(task);
maybeStopTrace(threadPool.getThreadContext(), task);
for (RemovedTaskListener listener : removedTaskListeners) {
listener.onRemoved(task);
}
Expand Down
138 changes: 123 additions & 15 deletions server/src/test/java/org/elasticsearch/tasks/TaskManagerTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;

public class TaskManagerTests extends ESTestCase {
Expand Down Expand Up @@ -281,13 +282,72 @@ public void testTaskAccounting() {
/**
* Check that registering a task also causes tracing to be started on that task.
*/
public void testRegisterTaskStartsTracing() {
public void testRegisterTaskStartsTracingIfTraceParentExists() {
final Tracer mockTracer = mock(Tracer.class);
final TaskManager taskManager = new TaskManager(Settings.EMPTY, threadPool, Set.of(), mockTracer);

// fake a trace parent
threadPool.getThreadContext().putHeader(Task.TRACE_PARENT_HTTP_HEADER, "traceparent");
final boolean hasParentTask = randomBoolean();
final TaskId parentTask = hasParentTask ? new TaskId("parentNode", 1) : TaskId.EMPTY_TASK_ID;

try (var ignored = threadPool.getThreadContext().newTraceContext()) {

final Task task = taskManager.register("testType", "testAction", new TaskAwareRequest() {

@Override
public void setParentTask(TaskId taskId) {}

@Override
public void setRequestId(long requestId) {}

@Override
public TaskId getParentTask() {
return parentTask;
}
});

Map<String, Object> attributes = hasParentTask
? Map.of(Tracer.AttributeKeys.TASK_ID, task.getId(), Tracer.AttributeKeys.PARENT_TASK_ID, parentTask.toString())
: Map.of(Tracer.AttributeKeys.TASK_ID, task.getId());
verify(mockTracer).startTrace(any(), eq(task), eq("testAction"), eq(attributes));
}
}

/**
* Check that registering a task also causes tracing to be started on that task.
*/
public void testRegisterTaskSkipsTracingIfTraceParentMissing() {
final Tracer mockTracer = mock(Tracer.class);
final TaskManager taskManager = new TaskManager(Settings.EMPTY, threadPool, Set.of(), mockTracer);

// no trace parent
try (var ignored = threadPool.getThreadContext().newTraceContext()) {
final Task task = taskManager.register("testType", "testAction", new TaskAwareRequest() {

@Override
public void setParentTask(TaskId taskId) {}

@Override
public void setRequestId(long requestId) {}

@Override
public TaskId getParentTask() {
return TaskId.EMPTY_TASK_ID;
}
});
}

verifyNoInteractions(mockTracer);
}

/**
* Check that unregistering a task also causes tracing to be stopped on that task.
*/
public void testUnregisterTaskStopsTracingIfTraceContextExists() {
final Tracer mockTracer = mock(Tracer.class);
final TaskManager taskManager = new TaskManager(Settings.EMPTY, threadPool, Set.of(), mockTracer);

final Task task = taskManager.register("testType", "testAction", new TaskAwareRequest() {

@Override
Expand All @@ -298,20 +358,21 @@ public void setRequestId(long requestId) {}

@Override
public TaskId getParentTask() {
return parentTask;
return TaskId.EMPTY_TASK_ID;
}
});

Map<String, Object> attributes = hasParentTask
? Map.of(Tracer.AttributeKeys.TASK_ID, task.getId(), Tracer.AttributeKeys.PARENT_TASK_ID, parentTask.toString())
: Map.of(Tracer.AttributeKeys.TASK_ID, task.getId());
verify(mockTracer).startTrace(any(), eq(task), eq("testAction"), eq(attributes));
// fake a trace context (trace parent)
threadPool.getThreadContext().putHeader(Task.TRACE_PARENT_HTTP_HEADER, "traceparent");

taskManager.unregister(task);
verify(mockTracer).stopTrace(task);
}

/**
* Check that unregistering a task also causes tracing to be stopped on that task.
*/
public void testUnregisterTaskStopsTracing() {
public void testUnregisterTaskStopsTracingIfTraceContextMissing() {
final Tracer mockTracer = mock(Tracer.class);
final TaskManager taskManager = new TaskManager(Settings.EMPTY, threadPool, Set.of(), mockTracer);

Expand All @@ -329,18 +390,22 @@ public TaskId getParentTask() {
}
});

taskManager.unregister(task);
// no trace context

verify(mockTracer).stopTrace(task);
taskManager.unregister(task);
verifyNoInteractions(mockTracer);
}

/**
* Check that registering and executing a task also causes tracing to be started and stopped on that task.
* Check that registering and executing a task also causes tracing to be started if a trace parent exists.
*/
public void testRegisterAndExecuteStartsAndStopsTracing() {
public void testRegisterAndExecuteStartsTracingIfTraceParentExists() {
final Tracer mockTracer = mock(Tracer.class);
final TaskManager taskManager = new TaskManager(Settings.EMPTY, threadPool, Set.of(), mockTracer);

// fake a trace parent
threadPool.getThreadContext().putHeader(Task.TRACE_PARENT_HTTP_HEADER, "traceparent");

final Task task = taskManager.registerAndExecute(
"testType",
new TransportAction<ActionRequest, ActionResponse>(
Expand Down Expand Up @@ -375,25 +440,68 @@ public TaskId getParentTask() {
verify(mockTracer).startTrace(any(), eq(task), eq("actionName"), anyMap());
}

/**
* Check that registering and executing a task skips tracing if trace parent does not exists.
*/
public void testRegisterAndExecuteSkipsTracingIfTraceParentMissing() {
final Tracer mockTracer = mock(Tracer.class);
final TaskManager taskManager = new TaskManager(Settings.EMPTY, threadPool, Set.of(), mockTracer);

// clean thread context without trace parent

final Task task = taskManager.registerAndExecute(
"testType",
new TransportAction<ActionRequest, ActionResponse>(
"actionName",
new ActionFilters(Set.of()),
taskManager,
EsExecutors.DIRECT_EXECUTOR_SERVICE
) {
@Override
protected void doExecute(Task task, ActionRequest request, ActionListener<ActionResponse> listener) {
listener.onResponse(new ActionResponse() {
@Override
public void writeTo(StreamOutput out) {}
});
}
},
new ActionRequest() {
@Override
public ActionRequestValidationException validate() {
return null;
}

@Override
public TaskId getParentTask() {
return TaskId.EMPTY_TASK_ID;
}
},
null,
ActionTestUtils.assertNoFailureListener(r -> {})
);

verifyNoInteractions(mockTracer);
}

public void testRegisterWithEnabledDisabledTracing() {
final Tracer mockTracer = mock(Tracer.class);
final TaskManager taskManager = spy(new TaskManager(Settings.EMPTY, threadPool, Set.of(), mockTracer));

taskManager.register("type", "action", makeTaskRequest(true, 123), false);
verify(taskManager, times(0)).startTrace(any(), any());
verify(taskManager, times(0)).maybeStartTrace(any(), any());

taskManager.register("type", "action", makeTaskRequest(false, 234), false);
verify(taskManager, times(0)).startTrace(any(), any());
verify(taskManager, times(0)).maybeStartTrace(any(), any());

clearInvocations(taskManager);

taskManager.register("type", "action", makeTaskRequest(true, 345), true);
verify(taskManager, times(1)).startTrace(any(), any());
verify(taskManager, times(1)).maybeStartTrace(any(), any());

clearInvocations(taskManager);

taskManager.register("type", "action", makeTaskRequest(false, 456), true);
verify(taskManager, times(1)).startTrace(any(), any());
verify(taskManager, times(1)).maybeStartTrace(any(), any());
}

static class CancellableRequest extends AbstractTransportRequest {
Expand Down