Skip to content

Commit c08fdb5

Browse files
Adding rate limit for auth task creation
1 parent c2286e1 commit c08fdb5

File tree

5 files changed

+151
-16
lines changed

5 files changed

+151
-16
lines changed

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorIT.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,9 @@ protected Settings nodeSettings() {
112112
// Ensure that the polling logic only occurs once so we can deterministically control when an authorization response is
113113
// received
114114
.put(ElasticInferenceServiceSettings.PERIODIC_AUTHORIZATION_ENABLED.getKey(), false)
115+
// Use very short intervals for testing purposes so that waiting for the task to be recreated is fast
116+
.put(ElasticInferenceServiceSettings.AUTHORIZATION_REQUEST_INTERVAL.getKey(), TimeValue.timeValueMillis(1))
117+
.put(ElasticInferenceServiceSettings.MAX_AUTHORIZATION_REQUEST_JITTER.getKey(), TimeValue.timeValueMillis(1))
115118
.build();
116119
}
117120

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorMultipleNodesIT.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
package org.elasticsearch.xpack.inference.integration;
99

1010
import org.elasticsearch.common.settings.Settings;
11+
import org.elasticsearch.core.TimeValue;
1112
import org.elasticsearch.inference.TaskType;
1213
import org.elasticsearch.license.LicenseSettings;
1314
import org.elasticsearch.plugins.Plugin;
@@ -91,6 +92,9 @@ protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) {
9192
.put(LicenseSettings.SELF_GENERATED_LICENSE_TYPE.getKey(), "trial")
9293
.put(ElasticInferenceServiceSettings.ELASTIC_INFERENCE_SERVICE_URL.getKey(), gatewayUrl)
9394
.put(ElasticInferenceServiceSettings.PERIODIC_AUTHORIZATION_ENABLED.getKey(), false)
95+
// Use very short intervals for testing purposes so that waiting for the task to be recreated is fast
96+
.put(ElasticInferenceServiceSettings.AUTHORIZATION_REQUEST_INTERVAL.getKey(), TimeValue.timeValueMillis(1))
97+
.put(ElasticInferenceServiceSettings.MAX_AUTHORIZATION_REQUEST_JITTER.getKey(), TimeValue.timeValueMillis(1))
9498
.build();
9599
}
96100

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSettings.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,15 +44,15 @@ public class ElasticInferenceServiceSettings {
4444
);
4545

4646
private static final TimeValue DEFAULT_AUTH_REQUEST_INTERVAL = TimeValue.timeValueMinutes(10);
47-
static final Setting<TimeValue> AUTHORIZATION_REQUEST_INTERVAL = Setting.timeSetting(
47+
public static final Setting<TimeValue> AUTHORIZATION_REQUEST_INTERVAL = Setting.timeSetting(
4848
"xpack.inference.elastic.authorization_request_interval",
4949
DEFAULT_AUTH_REQUEST_INTERVAL,
5050
Setting.Property.NodeScope,
5151
Setting.Property.Dynamic
5252
);
5353

5454
private static final TimeValue DEFAULT_AUTH_REQUEST_JITTER = TimeValue.timeValueMinutes(5);
55-
static final Setting<TimeValue> MAX_AUTHORIZATION_REQUEST_JITTER = Setting.timeSetting(
55+
public static final Setting<TimeValue> MAX_AUTHORIZATION_REQUEST_JITTER = Setting.timeSetting(
5656
"xpack.inference.elastic.max_authorization_request_jitter",
5757
DEFAULT_AUTH_REQUEST_JITTER,
5858
Setting.Property.NodeScope,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskExecutor.java

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import org.elasticsearch.cluster.ClusterState;
1919
import org.elasticsearch.cluster.ClusterStateListener;
2020
import org.elasticsearch.cluster.service.ClusterService;
21+
import org.elasticsearch.common.Randomness;
2122
import org.elasticsearch.common.Strings;
2223
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
2324
import org.elasticsearch.common.io.stream.StreamInput;
@@ -45,6 +46,8 @@
4546
import org.elasticsearch.xpack.inference.common.BroadcastMessageAction;
4647

4748
import java.io.IOException;
49+
import java.time.Clock;
50+
import java.time.Instant;
4851
import java.util.List;
4952
import java.util.Map;
5053
import java.util.Objects;
@@ -71,6 +74,8 @@ public class AuthorizationTaskExecutor extends PersistentTasksExecutor<Authoriza
7174
private final AtomicReference<AuthorizationPoller> currentTask = new AtomicReference<>();
7275
private final AtomicBoolean running = new AtomicBoolean(false);
7376
private final FeatureService featureService;
77+
private Instant nextCreateTaskAttemptTime;
78+
private final Clock clock;
7479

7580
public static AuthorizationTaskExecutor create(
7681
ClusterService clusterService,
@@ -84,7 +89,8 @@ public static AuthorizationTaskExecutor create(
8489
clusterService,
8590
new PersistentTasksService(clusterService, parameters.serviceComponents().threadPool(), parameters.client()),
8691
featureService,
87-
parameters
92+
parameters,
93+
Clock.systemUTC()
8894
);
8995
}
9096

@@ -93,13 +99,16 @@ public static AuthorizationTaskExecutor create(
9399
ClusterService clusterService,
94100
PersistentTasksService persistentTasksService,
95101
FeatureService featureService,
96-
AuthorizationPoller.Parameters pollerParameters
102+
AuthorizationPoller.Parameters pollerParameters,
103+
Clock clock
97104
) {
98105
super(TASK_NAME, pollerParameters.serviceComponents().threadPool().executor(UTILITY_THREAD_POOL_NAME));
99106
this.clusterService = Objects.requireNonNull(clusterService);
100107
this.featureService = Objects.requireNonNull(featureService);
101108
this.persistentTasksService = Objects.requireNonNull(persistentTasksService);
102109
this.pollerParameters = Objects.requireNonNull(pollerParameters);
110+
this.clock = Objects.requireNonNull(clock);
111+
this.nextCreateTaskAttemptTime = Instant.MIN;
103112
}
104113

105114
/**
@@ -144,6 +153,8 @@ private void sendStartRequest(@Nullable ClusterState state) {
144153
return;
145154
}
146155

156+
updateNextCreateTaskAttemptTime();
157+
147158
persistentTasksService.sendClusterStartRequest(
148159
TASK_NAME,
149160
TASK_NAME,
@@ -166,7 +177,10 @@ private boolean shouldSkipCreatingTask(@Nullable ClusterState state) {
166177
return true;
167178
}
168179

169-
return clusterCanSupportFeature(state) == false || running.get() == false || authorizationTaskExists(state);
180+
return clusterCanSupportFeature(state) == false
181+
|| running.get() == false
182+
|| authorizationTaskExists(state)
183+
|| hasAttemptedToCreateTaskRecently();
170184
}
171185

172186
private boolean clusterCanSupportFeature(@Nullable ClusterState state) {
@@ -185,6 +199,25 @@ private static boolean authorizationTaskExists(@Nullable ClusterState state) {
185199
return ClusterPersistentTasksCustomMetadata.getTaskWithId(state, TASK_NAME) != null;
186200
}
187201

202+
private boolean hasAttemptedToCreateTaskRecently() {
203+
return Instant.now(clock).isBefore(nextCreateTaskAttemptTime);
204+
}
205+
206+
private void updateNextCreateTaskAttemptTime() {
207+
var random = Randomness.get();
208+
var jitter = (long) (pollerParameters.elasticInferenceServiceSettings().getMaxAuthorizationRequestJitter().millis() * random
209+
.nextDouble());
210+
var waitTimeMillis = pollerParameters.elasticInferenceServiceSettings().getAuthRequestInterval().millis() + jitter;
211+
212+
nextCreateTaskAttemptTime = Instant.now(clock).plusMillis(waitTimeMillis);
213+
logger.debug(
214+
"Create task rate limit info interval: [{}] ms, jitter: [{}] ms",
215+
pollerParameters.elasticInferenceServiceSettings().getAuthRequestInterval().millis(),
216+
jitter
217+
);
218+
logger.debug("Next create task attempt time [{}]", nextCreateTaskAttemptTime);
219+
}
220+
188221
public synchronized void stop() {
189222
if (running.compareAndSet(true, false)) {
190223
logger.info("Shutting down authorization task executor");

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskExecutorTests.java

Lines changed: 106 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@
2929
import org.junit.Before;
3030
import org.mockito.Mockito;
3131

32+
import java.time.Clock;
33+
import java.time.Duration;
34+
3235
import static org.elasticsearch.cluster.metadata.Metadata.EMPTY_METADATA;
3336
import static org.elasticsearch.persistent.PersistentTasksExecutor.NO_NODE_FOUND;
3437
import static org.elasticsearch.test.ClusterServiceUtils.createClusterService;
@@ -86,7 +89,8 @@ public void testMultipleCallsToStart_OnlyRegistersOnce() {
8689
mock(Client.class),
8790
createMockCCMFeature(false),
8891
createMockCCMService(false)
89-
)
92+
),
93+
Clock.systemUTC()
9094
);
9195
executor.startAndImmediatelyCreateTask();
9296
executor.startAndImmediatelyCreateTask();
@@ -117,7 +121,8 @@ public void testStartLazy_OnlyRegistersOnce_NeverCallsPersistentTaskService() {
117121
mock(Client.class),
118122
createMockCCMFeature(false),
119123
createMockCCMService(false)
120-
)
124+
),
125+
Clock.systemUTC()
121126
);
122127
executor.startAndLazyCreateTask();
123128
executor.startAndLazyCreateTask();
@@ -154,7 +159,8 @@ public void testDoesNotRegisterListener_IfUrlIsEmpty() {
154159
mock(Client.class),
155160
createMockCCMFeature(false),
156161
createMockCCMService(false)
157-
)
162+
),
163+
Clock.systemUTC()
158164
);
159165
executor.startAndImmediatelyCreateTask();
160166
executor.startAndImmediatelyCreateTask();
@@ -170,6 +176,17 @@ public void testDoesNotRegisterListener_IfUrlIsEmpty() {
170176
}
171177

172178
public void testMultipleCallsToStart_AndStop() {
179+
var now = Clock.systemUTC().instant();
180+
var oneDayInFuture = now.plus(Duration.ofDays(1));
181+
var clock = mock(Clock.class);
182+
// The AuthorizationTaskExecutor does these calls:
183+
// 1. Check if the last create task time is expired (first call to instant()),
184+
// this will pass so a call to create the task will occur
185+
// 2. Then it will update the last create task time (second call to instant())
186+
// 3. On the next cluster state change, it will check if the last create task time is expired (third call to instant()),
187+
// we'll return now + 1 day to ensure that it is expired and allows another call to create the task
188+
when(clock.instant()).thenReturn(now).thenReturn(now).thenReturn(oneDayInFuture);
189+
173190
var eisUrl = "abc";
174191
var mockClusterService = createMockEmptyClusterService();
175192
var executor = new AuthorizationTaskExecutor(
@@ -185,7 +202,8 @@ public void testMultipleCallsToStart_AndStop() {
185202
mock(Client.class),
186203
createMockCCMFeature(false),
187204
createMockCCMService(false)
188-
)
205+
),
206+
clock
189207
);
190208
executor.startAndImmediatelyCreateTask();
191209
executor.startAndImmediatelyCreateTask();
@@ -216,6 +234,62 @@ public void testMultipleCallsToStart_AndStop() {
216234
verify(persistentTasksService, times(2)).sendClusterRemoveRequest(eq(AuthorizationPoller.TASK_NAME), any(), any());
217235
}
218236

237+
public void testMultipleCallsToStart_OnlyCallsSendClusterStartRequestOnce_WhenRateLimited() {
238+
var now = Clock.systemUTC().instant();
239+
var clock = mock(Clock.class);
240+
when(clock.instant()).thenReturn(now);
241+
242+
var eisUrl = "abc";
243+
var mockClusterService = createMockEmptyClusterService();
244+
var executor = new AuthorizationTaskExecutor(
245+
mockClusterService,
246+
persistentTasksService,
247+
enabledFeatureServiceMock,
248+
new AuthorizationPoller.Parameters(
249+
createWithEmptySettings(threadPool),
250+
mock(ElasticInferenceServiceAuthorizationRequestHandler.class),
251+
mock(Sender.class),
252+
ElasticInferenceServiceSettingsTests.create(eisUrl, TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true),
253+
mock(ModelRegistry.class),
254+
mock(Client.class),
255+
createMockCCMFeature(false),
256+
createMockCCMService(false)
257+
),
258+
clock
259+
);
260+
executor.startAndImmediatelyCreateTask();
261+
executor.startAndImmediatelyCreateTask();
262+
executor.stop();
263+
executor.stop();
264+
verify(mockClusterService, times(1)).addListener(executor);
265+
verify(persistentTasksService, times(1)).sendClusterStartRequest(
266+
eq(AuthorizationPoller.TASK_NAME),
267+
eq(AuthorizationPoller.TASK_NAME),
268+
eq(AuthorizationTaskParams.INSTANCE),
269+
any(),
270+
any()
271+
);
272+
verify(mockClusterService, times(1)).removeListener(executor);
273+
verify(persistentTasksService, times(1)).sendClusterRemoveRequest(eq(AuthorizationPoller.TASK_NAME), any(), any());
274+
275+
Mockito.clearInvocations(persistentTasksService);
276+
Mockito.clearInvocations(mockClusterService);
277+
278+
executor.startAndImmediatelyCreateTask();
279+
executor.stop();
280+
verify(mockClusterService, times(1)).addListener(executor);
281+
// No additional calls because time hasn't advanced to allow another task creation call
282+
verify(persistentTasksService, never()).sendClusterStartRequest(
283+
eq(AuthorizationPoller.TASK_NAME),
284+
eq(AuthorizationPoller.TASK_NAME),
285+
eq(AuthorizationTaskParams.INSTANCE),
286+
any(),
287+
any()
288+
);
289+
verify(mockClusterService, times(1)).removeListener(executor);
290+
verify(persistentTasksService, times(1)).sendClusterRemoveRequest(eq(AuthorizationPoller.TASK_NAME), any(), any());
291+
}
292+
219293
public void testCallsSendClusterStartRequest_WhenStartIsCalled() {
220294
var eisUrl = "abc";
221295
var mockClusterService = createMockEmptyClusterService();
@@ -232,7 +306,8 @@ public void testCallsSendClusterStartRequest_WhenStartIsCalled() {
232306
mock(Client.class),
233307
createMockCCMFeature(false),
234308
createMockCCMService(false)
235-
)
309+
),
310+
Clock.systemUTC()
236311
);
237312
executor.startAndImmediatelyCreateTask();
238313

@@ -282,7 +357,8 @@ public void testDoesNotCallSendClusterStartRequest_WhenStartIsCalled_WhenItIsAlr
282357
mock(Client.class),
283358
createMockCCMFeature(false),
284359
createMockCCMService(false)
285-
)
360+
),
361+
Clock.systemUTC()
286362
);
287363
executor.startAndImmediatelyCreateTask();
288364

@@ -297,6 +373,17 @@ public void testDoesNotCallSendClusterStartRequest_WhenStartIsCalled_WhenItIsAlr
297373
}
298374

299375
public void testCreatesTask_WhenItDoesNotExistOnClusterStateChange() {
376+
var now = Clock.systemUTC().instant();
377+
var oneDayInFuture = now.plus(Duration.ofDays(1));
378+
var clock = mock(Clock.class);
379+
// The AuthorizationTaskExecutor does these calls:
380+
// 1. Check if the last create task time is expired (first call to instant()),
381+
// this will pass so a call to create the task will occur
382+
// 2. Then it will update the last create task time (second call to instant())
383+
// 3. On the next cluster state change, it will check if the last create task time is expired (third call to instant()),
384+
// we'll return now + 1 day to ensure that it is expired and allows another call to create the task
385+
when(clock.instant()).thenReturn(now).thenReturn(now).thenReturn(oneDayInFuture);
386+
300387
var eisUrl = "abc";
301388

302389
var executor = new AuthorizationTaskExecutor(
@@ -312,7 +399,8 @@ public void testCreatesTask_WhenItDoesNotExistOnClusterStateChange() {
312399
mock(Client.class),
313400
createMockCCMFeature(false),
314401
createMockCCMService(false)
315-
)
402+
),
403+
clock
316404
);
317405
executor.startAndImmediatelyCreateTask();
318406

@@ -329,6 +417,9 @@ public void testCreatesTask_WhenItDoesNotExistOnClusterStateChange() {
329417
);
330418

331419
Mockito.clearInvocations(persistentTasksService);
420+
Mockito.clearInvocations(clock);
421+
when(clock.instant()).thenReturn(oneDayInFuture.plus(Duration.ofDays(1)));
422+
332423
// Ensure that if the task is gone, it will be recreated.
333424
var listener2 = new PlainActionFuture<Void>();
334425
clusterService.getClusterApplierService().onNewClusterState("initialization", this::initialState, listener2);
@@ -369,7 +460,8 @@ public void testDoesNotCreateTask_WhenFeatureIsNotSupported() {
369460
mock(Client.class),
370461
createMockCCMFeature(false),
371462
createMockCCMService(false)
372-
)
463+
),
464+
Clock.systemUTC()
373465
);
374466
executor.startAndImmediatelyCreateTask();
375467

@@ -400,7 +492,8 @@ public void testDoesNotRegisterClusterStateListener_DoesNotCreateTask_WhenTheEis
400492
mock(Client.class),
401493
createMockCCMFeature(false),
402494
createMockCCMService(false)
403-
)
495+
),
496+
Clock.systemUTC()
404497
);
405498
executor.startAndImmediatelyCreateTask();
406499

@@ -430,7 +523,8 @@ public void testDoesNotRegisterClusterStateListener_DoesNotCreateTask_WhenTheEis
430523
mock(Client.class),
431524
createMockCCMFeature(false),
432525
createMockCCMService(false)
433-
)
526+
),
527+
Clock.systemUTC()
434528
);
435529
executor.startAndImmediatelyCreateTask();
436530

@@ -483,7 +577,8 @@ public void testDoesNotCreateTask_OnClusterStateChange_WhenItAlreadyExists() {
483577
mock(Client.class),
484578
createMockCCMFeature(false),
485579
createMockCCMService(false)
486-
)
580+
),
581+
Clock.systemUTC()
487582
);
488583
executor.startAndImmediatelyCreateTask();
489584

0 commit comments

Comments
 (0)