22
22
import org .elasticsearch .common .util .concurrent .ThreadContext ;
23
23
import org .elasticsearch .core .TimeValue ;
24
24
import org .elasticsearch .index .query .MatchAllQueryBuilder ;
25
+ import org .elasticsearch .tasks .Task ;
26
+ import org .elasticsearch .tasks .TaskCancellationService ;
27
+ import org .elasticsearch .tasks .TaskId ;
25
28
import org .elasticsearch .test .ESTestCase ;
26
29
import org .elasticsearch .test .transport .MockTransportService ;
27
30
import org .elasticsearch .threadpool .ScalingExecutorBuilder ;
31
34
import java .util .Collections ;
32
35
import java .util .List ;
33
36
import java .util .concurrent .CopyOnWriteArrayList ;
37
+ import java .util .concurrent .CountDownLatch ;
38
+ import java .util .concurrent .ExecutionException ;
34
39
import java .util .concurrent .ExecutorService ;
35
40
import java .util .concurrent .Executors ;
36
41
import java .util .concurrent .TimeUnit ;
42
+ import java .util .concurrent .atomic .AtomicLong ;
37
43
44
+ import static org .elasticsearch .test .tasks .MockTaskManager .SPY_TASK_MANAGER_SETTING ;
38
45
import static org .hamcrest .Matchers .equalTo ;
46
+ import static org .hamcrest .Matchers .instanceOf ;
47
+ import static org .mockito .ArgumentMatchers .anyString ;
48
+ import static org .mockito .ArgumentMatchers .eq ;
49
+ import static org .mockito .Mockito .verify ;
39
50
40
51
public class RemoteClusterAwareClientTests extends ESTestCase {
41
52
@@ -62,6 +73,89 @@ private MockTransportService startTransport(String id, List<DiscoveryNode> known
62
73
);
63
74
}
64
75
76
+ public void testRemoteTaskCancellationOnFailedResponse () throws Exception {
77
+ Settings .Builder remoteTransportSettingsBuilder = Settings .builder ();
78
+ remoteTransportSettingsBuilder .put (SPY_TASK_MANAGER_SETTING .getKey (), true );
79
+ try (
80
+ MockTransportService remoteTransport = RemoteClusterConnectionTests .startTransport (
81
+ "seed_node" ,
82
+ new CopyOnWriteArrayList <>(),
83
+ VersionInformation .CURRENT ,
84
+ TransportVersion .current (),
85
+ threadPool ,
86
+ remoteTransportSettingsBuilder .build ()
87
+ )
88
+ ) {
89
+ remoteTransport .getTaskManager ().setTaskCancellationService (new TaskCancellationService (remoteTransport ));
90
+ Settings .Builder builder = Settings .builder ();
91
+ builder .putList ("cluster.remote.cluster1.seeds" , remoteTransport .getLocalDiscoNode ().getAddress ().toString ());
92
+ try (
93
+ MockTransportService localService = MockTransportService .createNewService (
94
+ builder .build (),
95
+ VersionInformation .CURRENT ,
96
+ TransportVersion .current (),
97
+ threadPool ,
98
+ null
99
+ )
100
+ ) {
101
+ // the TaskCancellationService references the same TransportService instance
102
+ // this is identically to how it works in the Node constructor
103
+ localService .getTaskManager ().setTaskCancellationService (new TaskCancellationService (localService ));
104
+ localService .start ();
105
+ localService .acceptIncomingRequests ();
106
+
107
+ SearchShardsRequest searchShardsRequest = new SearchShardsRequest (
108
+ new String [] { "test-index" },
109
+ IndicesOptions .strictExpandOpen (),
110
+ new MatchAllQueryBuilder (),
111
+ null ,
112
+ "index_not_found" , // this request must fail
113
+ randomBoolean (),
114
+ null
115
+ );
116
+ Task parentTask = localService .getTaskManager ().register ("test_type" , "test_action" , searchShardsRequest );
117
+ TaskId parentTaskId = new TaskId ("test-mock-node-id" , parentTask .getId ());
118
+ searchShardsRequest .setParentTask (parentTaskId );
119
+ var client = new RemoteClusterAwareClient (
120
+ localService ,
121
+ "cluster1" ,
122
+ threadPool .executor (TEST_THREAD_POOL_NAME ),
123
+ randomBoolean ()
124
+ );
125
+
126
+ CountDownLatch cancelChildReceived = new CountDownLatch (1 );
127
+ remoteTransport .addRequestHandlingBehavior (
128
+ TaskCancellationService .CANCEL_CHILD_ACTION_NAME ,
129
+ (handler , request , channel , task ) -> {
130
+ handler .messageReceived (request , channel , task );
131
+ cancelChildReceived .countDown ();
132
+ }
133
+ );
134
+ AtomicLong searchShardsRequestId = new AtomicLong (-1 );
135
+ CountDownLatch cancelChildSent = new CountDownLatch (1 );
136
+ localService .addSendBehavior (remoteTransport , (connection , requestId , action , request , options ) -> {
137
+ connection .sendRequest (requestId , action , request , options );
138
+ if (action .equals ("indices:admin/search/search_shards" )) {
139
+ searchShardsRequestId .set (requestId );
140
+ } else if (action .equals (TaskCancellationService .CANCEL_CHILD_ACTION_NAME )) {
141
+ cancelChildSent .countDown ();
142
+ }
143
+ });
144
+
145
+ // assert original request failed
146
+ var future = new PlainActionFuture <SearchShardsResponse >();
147
+ client .execute (TransportSearchShardsAction .REMOTE_TYPE , searchShardsRequest , future );
148
+ ExecutionException e = expectThrows (ExecutionException .class , future ::get );
149
+ assertThat (e .getCause (), instanceOf (RemoteTransportException .class ));
150
+
151
+ // assert remote task is cancelled
152
+ safeAwait (cancelChildSent );
153
+ safeAwait (cancelChildReceived );
154
+ verify (remoteTransport .getTaskManager ()).cancelChildLocal (eq (parentTaskId ), eq (searchShardsRequestId .get ()), anyString ());
155
+ }
156
+ }
157
+ }
158
+
65
159
public void testSearchShards () throws Exception {
66
160
List <DiscoveryNode > knownNodes = new CopyOnWriteArrayList <>();
67
161
try (
0 commit comments