Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/changelog/125520.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 125520
summary: Add `FailedShardEntry` info to shard-failed task source string
area: Allocation
type: enhancement
issues:
- 102606
Original file line number Diff line number Diff line change
Expand Up @@ -276,23 +276,21 @@ public void onTimeout(TimeValue timeout) {
}

// TODO: Make this a TransportMasterNodeAction and remove duplication of master failover retrying from upstream code
private static class ShardFailedTransportHandler implements TransportRequestHandler<FailedShardEntry> {
static class ShardFailedTransportHandler implements TransportRequestHandler<FailedShardEntry> {
private final MasterServiceTaskQueue<FailedShardUpdateTask> taskQueue;

ShardFailedTransportHandler(
ClusterService clusterService,
ShardFailedClusterStateTaskExecutor shardFailedClusterStateTaskExecutor
) {
taskQueue = clusterService.createTaskQueue(TASK_SOURCE, Priority.HIGH, shardFailedClusterStateTaskExecutor);
taskQueue = clusterService.createTaskQueue("shard-failed", Priority.HIGH, shardFailedClusterStateTaskExecutor);
}

private static final String TASK_SOURCE = "shard-failed";

@Override
public void messageReceived(FailedShardEntry request, TransportChannel channel, Task task) {
logger.debug(() -> format("%s received shard failed for [%s]", request.getShardId(), request), request.failure);
taskQueue.submitTask(
TASK_SOURCE,
"shard-failed " + request.toStringNoFailureStackTrace(),
new FailedShardUpdateTask(request, new ChannelActionListener<>(channel).map(ignored -> TransportResponse.Empty.INSTANCE)),
null
);
Expand Down Expand Up @@ -501,14 +499,22 @@ public void writeTo(StreamOutput out) throws IOException {

@Override
public String toString() {
return toString(true);
}

public String toStringNoFailureStackTrace() {
return toString(false);
}

private String toString(boolean includeStackTrace) {
return Strings.format(
"FailedShardEntry{shardId [%s], allocationId [%s], primary term [%d], message [%s], markAsStale [%b], failure [%s]}",
shardId,
allocationId,
primaryTerm,
message,
markAsStale,
failure != null ? ExceptionsHelper.stackTrace(failure) : null
failure == null ? null : (includeStackTrace ? ExceptionsHelper.stackTrace(failure) : failure.getMessage())
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@

import org.apache.lucene.index.CorruptIndexException;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.replication.ClusterStateCreationUtils;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.ClusterStateObserver;
import org.elasticsearch.cluster.ClusterStateTaskExecutor;
import org.elasticsearch.cluster.ClusterStateTaskListener;
import org.elasticsearch.cluster.NotMasterException;
import org.elasticsearch.cluster.action.shard.ShardStateAction.FailedShardEntry;
import org.elasticsearch.cluster.action.shard.ShardStateAction.StartedShardEntry;
Expand All @@ -28,19 +31,27 @@
import org.elasticsearch.cluster.routing.ShardsIterator;
import org.elasticsearch.cluster.routing.allocation.AllocationService;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.cluster.service.MasterServiceTaskQueue;
import org.elasticsearch.common.Priority;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.settings.ClusterSettings;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.index.Index;
import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.index.shard.ShardLongFieldRange;
import org.elasticsearch.index.shard.ShardLongFieldRangeWireTests;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.transport.CapturingTransport;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.NodeDisconnectedException;
import org.elasticsearch.transport.NodeNotConnectedException;
import org.elasticsearch.transport.TestTransportChannel;
import org.elasticsearch.transport.TransportException;
import org.elasticsearch.transport.TransportRequest;
import org.elasticsearch.transport.TransportResponse;
Expand All @@ -51,7 +62,11 @@
import org.junit.BeforeClass;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Phaser;
Expand Down Expand Up @@ -628,6 +643,77 @@ public void testStartedShardEntrySerializationWithOlderTransportVersion() throws
}
}

public void testShardFailedTransportHandlerSubmitTaskSourceStringIncludesRequestInfo() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer an integration test for this change which will more directly get the information from MasterService instead of mocking the queue and bypassing MasterService. Technically, what an user sees is the source field of PendingClusterTask and that is what we want to fix. Currently it is indeed copied all the way from the source argument of submitTask method. But I'd rather to not make that implementation assumption in the test.

Concretely, I think we can add a test to ShardStateIT that does the following:

  1. Create an index and find its associated node and IndicesService similar to this. For simplicity, the index can have just 1 shard and no replica.
  2. Create a blocking task queue on the masterService and submit a task to ensure it is blocked similar to this
  3. Fail the shard similar to this
  4. While the MasterService is blocked, assert that it receive a new pending task for shard failure and check its source string, e.g. something like assertThat(clusterService.getMasterService().pendingTasks().stream().anyMatch(t -> t.getSource()...) wrapped in an assertBusy.
  5. Unblock MasterService
  6. Wait for the index to recover and finish the test

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @ywangd, I've switched to an integration test in 0f7b047 per your outline.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I simplified the test per our call earlier, commit c3698fa. I'll use a separate branch to investigate the possible race condition in the version that attempts to block and wait for the shard-started task.

// Create a modified ClusterService that returns task capturing task queues.
final var taskQueueMap = new HashMap<String, TaskCollectingQueue<? extends ClusterStateTaskListener>>();
final var modifiedClusterService = new ClusterService(
Settings.EMPTY,
new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS),
THREAD_POOL,
null
) {
@SuppressWarnings("unchecked")
@Override
public <T extends ClusterStateTaskListener> MasterServiceTaskQueue<T> createTaskQueue(
String name,
Priority priority,
ClusterStateTaskExecutor<T> executor
) {
return (MasterServiceTaskQueue<T>) taskQueueMap.computeIfAbsent(
name,
k -> new TaskCollectingQueue<T>(super.createTaskQueue(name, priority, executor))
);
}
};

final var simulatedException = new RuntimeException("fake exception");
final var failedShardEntry = new FailedShardEntry(
new ShardId(new Index("foo-idx", "foo-idx-id"), 0),
"alloc-id",
0L,
"FAILURE MSG",
simulatedException,
false
);
final var now = new TimeValue(System.currentTimeMillis());
final var shardFailedTask = new Task(
82L,
"transport",
ShardStateAction.SHARD_FAILED_ACTION_NAME,
"",
null,
now.millis(),
now.nanos(),
Map.of()
);

final var handler = new ShardStateAction.ShardFailedTransportHandler(
modifiedClusterService,
new ShardStateAction.ShardFailedClusterStateTaskExecutor(null, null)
);

// Check that the submitted task's 'source' string doesn't include the exception stack trace.
handler.messageReceived(failedShardEntry, new TestTransportChannel(ActionListener.noop()), shardFailedTask);
final var taskQueue = taskQueueMap.get("shard-failed");
assertNotNull(taskQueue);
final var tasks = taskQueue.getTasks();
assertEquals(1, tasks.size());
final var task = tasks.getFirst();
final var stackTraceInfo = ExceptionsHelper.stackTrace(simulatedException);
assertEquals("shard-failed " + failedShardEntry.toStringNoFailureStackTrace(), task.source());
assertNotNull(failedShardEntry.failure);
assertFalse("Shard failed task's source string included the exception stack trace", task.source.contains(stackTraceInfo));
assertTrue(
"Shard failed task's source string didn't include the exception message",
task.source.contains(simulatedException.getMessage())
);
assertTrue(
"FailedShardEntry.toString() didn't include the exception stack trace",
failedShardEntry.toString().contains(stackTraceInfo)
);
assertTrue(task.task instanceof ShardStateAction.FailedShardUpdateTask);
}

BytesReference serialize(Writeable writeable, TransportVersion version) throws IOException {
try (BytesStreamOutput out = new BytesStreamOutput()) {
out.setTransportVersion(version);
Expand Down Expand Up @@ -663,4 +749,27 @@ void await() throws InterruptedException {
latch.await();
}
}

private static class TaskCollectingQueue<T extends ClusterStateTaskListener> implements MasterServiceTaskQueue<T> {

record Entry<T>(String source, T task, TimeValue timeout) {}

private final List<Entry<T>> tasks;
private final MasterServiceTaskQueue<T> taskQueue;

TaskCollectingQueue(MasterServiceTaskQueue<T> taskQueue) {
this.taskQueue = taskQueue;
tasks = new ArrayList<>();
}

@Override
public void submitTask(String source, T task, TimeValue timeout) {
tasks.add(new Entry<T>(source, task, timeout));
taskQueue.submitTask(source, task, timeout);
}

List<Entry<T>> getTasks() {
return tasks;
}
}
}
Loading