Skip to content

Commit 6342cc3

Browse files
Adding timed listener tests
1 parent 9f83e91 commit 6342cc3

File tree

2 files changed

+155
-12
lines changed

2 files changed

+155
-12
lines changed

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTaskTests.java

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,7 @@ public void testRequest_ReturnsTimeoutException() {
8787
);
8888

8989
var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
90-
assertThat(
91-
thrownException.getMessage(),
92-
is(format("Request timed out waiting to be sent after [%s]", TimeValue.timeValueMillis(1)))
93-
);
90+
assertThat(thrownException.getMessage(), is(format("Request timed out after [%s]", TimeValue.timeValueMillis(1))));
9491
assertTrue(requestTask.hasCompleted());
9592
assertTrue(requestTask.getRequestCompletedFunction().get());
9693
assertThat(thrownException.status().getStatus(), is(408));
@@ -117,10 +114,7 @@ public void testRequest_DoesNotCallOnFailureTwiceWhenTimingOut() throws Exceptio
117114

118115
ArgumentCaptor<Exception> argument = ArgumentCaptor.forClass(Exception.class);
119116
verify(listener, times(1)).onFailure(argument.capture());
120-
assertThat(
121-
argument.getValue().getMessage(),
122-
is(format("Request timed out waiting to be sent after [%s]", TimeValue.timeValueMillis(1)))
123-
);
117+
assertThat(argument.getValue().getMessage(), is(format("Request timed out after [%s]", TimeValue.timeValueMillis(1))));
124118
assertTrue(requestTask.hasCompleted());
125119
assertTrue(requestTask.getRequestCompletedFunction().get());
126120

@@ -149,10 +143,7 @@ public void testRequest_DoesNotCallOnResponseAfterTimingOut() throws Exception {
149143

150144
ArgumentCaptor<Exception> argument = ArgumentCaptor.forClass(Exception.class);
151145
verify(listener, times(1)).onFailure(argument.capture());
152-
assertThat(
153-
argument.getValue().getMessage(),
154-
is(format("Request timed out waiting to be sent after [%s]", TimeValue.timeValueMillis(1)))
155-
);
146+
assertThat(argument.getValue().getMessage(), is(format("Request timed out after [%s]", TimeValue.timeValueMillis(1))));
156147
assertTrue(requestTask.hasCompleted());
157148
assertTrue(requestTask.getRequestCompletedFunction().get());
158149

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.http.sender;
9+
10+
import org.elasticsearch.ElasticsearchStatusException;
11+
import org.elasticsearch.action.ActionListener;
12+
import org.elasticsearch.action.support.PlainActionFuture;
13+
import org.elasticsearch.core.TimeValue;
14+
import org.elasticsearch.inference.InferenceServiceResults;
15+
import org.elasticsearch.test.ESTestCase;
16+
import org.elasticsearch.threadpool.Scheduler;
17+
import org.elasticsearch.threadpool.ThreadPool;
18+
import org.junit.After;
19+
import org.junit.Before;
20+
import org.mockito.ArgumentCaptor;
21+
22+
import java.util.concurrent.CountDownLatch;
23+
import java.util.concurrent.ExecutorService;
24+
import java.util.concurrent.TimeUnit;
25+
import java.util.concurrent.atomic.AtomicReference;
26+
27+
import static org.elasticsearch.core.Strings.format;
28+
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
29+
import static org.hamcrest.Matchers.is;
30+
import static org.mockito.ArgumentMatchers.any;
31+
import static org.mockito.Mockito.doAnswer;
32+
import static org.mockito.Mockito.mock;
33+
import static org.mockito.Mockito.times;
34+
import static org.mockito.Mockito.verify;
35+
import static org.mockito.Mockito.verifyNoMoreInteractions;
36+
import static org.mockito.Mockito.when;
37+
38+
public class TimedListenerTests extends ESTestCase {
39+
40+
private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
41+
private ThreadPool threadPool;
42+
43+
@Before
44+
public void init() throws Exception {
45+
threadPool = createThreadPool(inferenceUtilityPool());
46+
}
47+
48+
@After
49+
public void shutdown() {
50+
terminate(threadPool);
51+
}
52+
53+
public void testExecuting_DoesNotCallOnFailureForTimeout_AfterIllegalArgumentException() {
54+
AtomicReference<Runnable> onTimeout = new AtomicReference<>();
55+
var mockThreadPool = mockThreadPoolForTimeout(onTimeout);
56+
57+
@SuppressWarnings("unchecked")
58+
ActionListener<InferenceServiceResults> listener = mock(ActionListener.class);
59+
var timedListener = new TimedListener<>(TimeValue.timeValueMillis(1), listener, mockThreadPool);
60+
61+
timedListener.getListener().onFailure(new IllegalArgumentException("failed"));
62+
verify(listener, times(1)).onFailure(any());
63+
assertTrue(timedListener.hasCompleted());
64+
65+
onTimeout.get().run();
66+
verifyNoMoreInteractions(listener);
67+
}
68+
69+
public void testRequest_ReturnsTimeoutException() {
70+
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
71+
var timedListener = new TimedListener<>(TimeValue.timeValueMillis(1), listener, threadPool);
72+
73+
var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
74+
assertThat(thrownException.getMessage(), is(format("Request timed out after [%s]", TimeValue.timeValueMillis(1))));
75+
assertTrue(timedListener.hasCompleted());
76+
assertThat(thrownException.status().getStatus(), is(408));
77+
}
78+
79+
public void testRequest_DoesNotCallOnFailureTwiceWhenTimingOut() throws Exception {
80+
@SuppressWarnings("unchecked")
81+
ActionListener<InferenceServiceResults> listener = mock(ActionListener.class);
82+
var calledOnFailureLatch = new CountDownLatch(1);
83+
doAnswer(invocation -> {
84+
calledOnFailureLatch.countDown();
85+
return Void.TYPE;
86+
}).when(listener).onFailure(any());
87+
88+
var timedListener = new TimedListener<>(TimeValue.timeValueMillis(1), listener, threadPool);
89+
90+
calledOnFailureLatch.await(TIMEOUT.millis(), TimeUnit.MILLISECONDS);
91+
92+
ArgumentCaptor<Exception> argument = ArgumentCaptor.forClass(Exception.class);
93+
verify(listener, times(1)).onFailure(argument.capture());
94+
assertThat(argument.getValue().getMessage(), is(format("Request timed out after [%s]", TimeValue.timeValueMillis(1))));
95+
assertTrue(timedListener.hasCompleted());
96+
97+
timedListener.getListener().onFailure(new IllegalArgumentException("failed"));
98+
verifyNoMoreInteractions(listener);
99+
}
100+
101+
public void testRequest_DoesNotCallOnResponseAfterTimingOut() throws Exception {
102+
@SuppressWarnings("unchecked")
103+
ActionListener<InferenceServiceResults> listener = mock(ActionListener.class);
104+
var calledOnFailureLatch = new CountDownLatch(1);
105+
doAnswer(invocation -> {
106+
calledOnFailureLatch.countDown();
107+
return Void.TYPE;
108+
}).when(listener).onFailure(any());
109+
110+
var timedListener = new TimedListener<>(TimeValue.timeValueMillis(1), listener, threadPool);
111+
112+
calledOnFailureLatch.await(TIMEOUT.millis(), TimeUnit.MILLISECONDS);
113+
114+
ArgumentCaptor<Exception> argument = ArgumentCaptor.forClass(Exception.class);
115+
verify(listener, times(1)).onFailure(argument.capture());
116+
assertThat(argument.getValue().getMessage(), is(format("Request timed out after [%s]", TimeValue.timeValueMillis(1))));
117+
assertTrue(timedListener.hasCompleted());
118+
119+
timedListener.getListener().onResponse(mock(InferenceServiceResults.class));
120+
verifyNoMoreInteractions(listener);
121+
}
122+
123+
public void testRequest_DoesNotCallOnFailureForTimeout_AfterAlreadyCallingOnResponse() throws Exception {
124+
AtomicReference<Runnable> onTimeout = new AtomicReference<>();
125+
var mockThreadPool = mockThreadPoolForTimeout(onTimeout);
126+
127+
@SuppressWarnings("unchecked")
128+
ActionListener<InferenceServiceResults> listener = mock(ActionListener.class);
129+
var timedListener = new TimedListener<>(TimeValue.timeValueMillis(1), listener, mockThreadPool);
130+
131+
timedListener.getListener().onResponse(mock(InferenceServiceResults.class));
132+
verify(listener, times(1)).onResponse(any());
133+
assertTrue(timedListener.hasCompleted());
134+
135+
onTimeout.get().run();
136+
verifyNoMoreInteractions(listener);
137+
}
138+
139+
private ThreadPool mockThreadPoolForTimeout(AtomicReference<Runnable> onTimeoutRunnable) {
140+
var mockThreadPool = mock(ThreadPool.class);
141+
when(mockThreadPool.executor(any())).thenReturn(mock(ExecutorService.class));
142+
when(mockThreadPool.getThreadContext()).thenReturn(threadPool.getThreadContext());
143+
144+
doAnswer(invocation -> {
145+
Runnable runnable = (Runnable) invocation.getArguments()[0];
146+
onTimeoutRunnable.set(runnable);
147+
return mock(Scheduler.ScheduledCancellable.class);
148+
}).when(mockThreadPool).schedule(any(Runnable.class), any(), any());
149+
150+
return mockThreadPool;
151+
}
152+
}

0 commit comments

Comments
 (0)