Skip to content

Commit c1345c4

Browse files
update tests
1 parent 906b1e9 commit c1345c4

File tree

23 files changed

+272
-162
lines changed

23 files changed

+272
-162
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import org.elasticsearch.TransportVersion;
1212
import org.elasticsearch.TransportVersions;
1313
import org.elasticsearch.action.ActionListener;
14+
import org.elasticsearch.cluster.service.ClusterService;
1415
import org.elasticsearch.common.Strings;
1516
import org.elasticsearch.common.ValidationException;
1617
import org.elasticsearch.common.util.LazyInitializable;
@@ -137,6 +138,31 @@ public ElasticInferenceService(
137138
);
138139
}
139140

141+
// for testing
142+
public ElasticInferenceService(
143+
HttpRequestSender.Factory factory,
144+
ServiceComponents serviceComponents,
145+
ElasticInferenceServiceSettings elasticInferenceServiceSettings,
146+
ModelRegistry modelRegistry,
147+
ElasticInferenceServiceAuthorizationRequestHandler authorizationRequestHandler,
148+
ClusterService service
149+
) {
150+
super(factory, serviceComponents, service);
151+
this.elasticInferenceServiceComponents = new ElasticInferenceServiceComponents(
152+
elasticInferenceServiceSettings.getElasticInferenceServiceUrl()
153+
);
154+
authorizationHandler = new ElasticInferenceServiceAuthorizationHandler(
155+
serviceComponents,
156+
modelRegistry,
157+
authorizationRequestHandler,
158+
initDefaultEndpoints(elasticInferenceServiceComponents),
159+
IMPLEMENTED_TASK_TYPES,
160+
this,
161+
getSender(),
162+
elasticInferenceServiceSettings
163+
);
164+
}
165+
140166
private static Map<String, DefaultModelConfig> initDefaultEndpoints(
141167
ElasticInferenceServiceComponents elasticInferenceServiceComponents
142168
) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerService.java

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,29 @@ public SageMakerService(
8484
this.clusterService = Objects.requireNonNull(context.clusterService());
8585
}
8686

87+
// for testing
88+
public SageMakerService(
89+
SageMakerModelBuilder modelBuilder,
90+
SageMakerClient client,
91+
SageMakerSchemas schemas,
92+
ThreadPool threadPool,
93+
CheckedSupplier<Map<String, SettingsConfiguration>, RuntimeException> configurationMap,
94+
ClusterService clusterService
95+
) {
96+
this.modelBuilder = modelBuilder;
97+
this.client = client;
98+
this.schemas = schemas;
99+
this.threadPool = threadPool;
100+
this.configuration = new LazyInitializable<>(
101+
() -> new InferenceServiceConfiguration.Builder().setService(NAME)
102+
.setName(DISPLAY_NAME)
103+
.setTaskTypes(supportedTaskTypes())
104+
.setConfigurations(configurationMap.get())
105+
.build()
106+
);
107+
this.clusterService = clusterService;
108+
}
109+
87110
@Override
88111
public String name() {
89112
return NAME;

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import org.elasticsearch.TransportVersion;
1111
import org.elasticsearch.action.ActionListener;
1212
import org.elasticsearch.action.support.PlainActionFuture;
13+
import org.elasticsearch.cluster.service.ClusterService;
1314
import org.elasticsearch.common.ValidationException;
1415
import org.elasticsearch.core.TimeValue;
1516
import org.elasticsearch.inference.ChunkedInference;
@@ -103,7 +104,7 @@ public void testStart_CallingStartTwiceKeepsSameSenderReference() throws IOExcep
103104

104105
private static final class TestSenderService extends SenderService {
105106
TestSenderService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
106-
super(factory, serviceComponents);
107+
super(factory, serviceComponents, mock(ClusterService.class));
107108
}
108109

109110
@Override

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import org.elasticsearch.ElasticsearchStatusException;
1111
import org.elasticsearch.action.ActionListener;
1212
import org.elasticsearch.action.support.PlainActionFuture;
13+
import org.elasticsearch.cluster.service.ClusterService;
1314
import org.elasticsearch.common.ValidationException;
1415
import org.elasticsearch.common.bytes.BytesArray;
1516
import org.elasticsearch.common.bytes.BytesReference;
@@ -77,11 +78,13 @@ public class AlibabaCloudSearchServiceTests extends ESTestCase {
7778
private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
7879
private ThreadPool threadPool;
7980
private HttpClientManager clientManager;
81+
private ClusterService clusterService;
8082

8183
@Before
8284
public void init() throws Exception {
8385
threadPool = createThreadPool(inferenceUtilityPool());
8486
clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));
87+
clusterService = mock(ClusterService.class);
8588
}
8689

8790
@After
@@ -91,7 +94,7 @@ public void shutdown() throws IOException {
9194
}
9295

9396
public void testParseRequestConfig_CreatesAnEmbeddingsModel() throws IOException {
94-
try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool))) {
97+
try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), clusterService)) {
9598
ActionListener<Model> modelVerificationListener = ActionListener.wrap(model -> {
9699
assertThat(model, instanceOf(AlibabaCloudSearchEmbeddingsModel.class));
97100

@@ -116,7 +119,7 @@ public void testParseRequestConfig_CreatesAnEmbeddingsModel() throws IOException
116119
}
117120

118121
public void testParseRequestConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsProvided() throws IOException {
119-
try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool))) {
122+
try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), clusterService)) {
120123
ActionListener<Model> modelVerificationListener = ActionListener.wrap(model -> {
121124
assertThat(model, instanceOf(AlibabaCloudSearchEmbeddingsModel.class));
122125

@@ -143,7 +146,7 @@ public void testParseRequestConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsP
143146
}
144147

145148
public void testParseRequestConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException {
146-
try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool))) {
149+
try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), clusterService)) {
147150
ActionListener<Model> modelVerificationListener = ActionListener.wrap(model -> {
148151
assertThat(model, instanceOf(AlibabaCloudSearchEmbeddingsModel.class));
149152

@@ -169,7 +172,7 @@ public void testParseRequestConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsN
169172
}
170173

171174
public void testParsePersistedConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsProvided() throws IOException {
172-
try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool))) {
175+
try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), clusterService)) {
173176
var model = service.parsePersistedConfig(
174177
"id",
175178
TaskType.TEXT_EMBEDDING,
@@ -190,7 +193,7 @@ public void testParsePersistedConfig_CreatesAnEmbeddingsModelWhenChunkingSetting
190193
}
191194

192195
public void testParsePersistedConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException {
193-
try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool))) {
196+
try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), clusterService)) {
194197
var model = service.parsePersistedConfig(
195198
"id",
196199
TaskType.TEXT_EMBEDDING,
@@ -210,7 +213,7 @@ public void testParsePersistedConfig_CreatesAnEmbeddingsModelWhenChunkingSetting
210213
}
211214

212215
public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChunkingSettingsProvided() throws IOException {
213-
try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool))) {
216+
try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), clusterService)) {
214217
var persistedConfig = getPersistedConfigMap(
215218
AlibabaCloudSearchEmbeddingsServiceSettingsTests.getServiceSettingsMap("service_id", "host", "default"),
216219
AlibabaCloudSearchEmbeddingsTaskSettingsTests.getTaskSettingsMap(null),
@@ -235,7 +238,7 @@ public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChun
235238
}
236239

237240
public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException {
238-
try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool))) {
241+
try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), clusterService)) {
239242
var persistedConfig = getPersistedConfigMap(
240243
AlibabaCloudSearchEmbeddingsServiceSettingsTests.getServiceSettingsMap("service_id", "host", "default"),
241244
AlibabaCloudSearchEmbeddingsTaskSettingsTests.getTaskSettingsMap(null),
@@ -262,7 +265,7 @@ public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChun
262265
public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException {
263266
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
264267

265-
try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool))) {
268+
try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool), clusterService)) {
266269
var model = OpenAiChatCompletionModelTests.createCompletionModel(
267270
randomAlphaOfLength(10),
268271
randomAlphaOfLength(10),
@@ -279,7 +282,7 @@ public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IO
279282

280283
public void testUpdateModelWithEmbeddingDetails_UpdatesEmbeddingSizeAndSimilarity() throws IOException {
281284
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
282-
try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool))) {
285+
try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool), clusterService)) {
283286
var embeddingSize = randomNonNegativeInt();
284287
var model = AlibabaCloudSearchEmbeddingsModelTests.createModel(
285288
randomAlphaOfLength(10),
@@ -316,7 +319,7 @@ public void testInfer_ThrowsValidationErrorForInvalidInputType_TextEmbedding() t
316319
taskSettingsMap,
317320
secretSettingsMap
318321
);
319-
try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool))) {
322+
try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool), clusterService)) {
320323
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
321324
var thrownException = expectThrows(
322325
ValidationException.class,
@@ -360,7 +363,7 @@ public void testInfer_ThrowsValidationExceptionForInvalidInputType_SparseEmbeddi
360363
taskSettingsMap,
361364
secretSettingsMap
362365
);
363-
try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool))) {
366+
try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool), clusterService)) {
364367
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
365368
var thrownException = expectThrows(
366369
ValidationException.class,
@@ -404,7 +407,7 @@ public void testInfer_ThrowsValidationErrorForInvalidRerankParams() throws IOExc
404407
taskSettingsMap,
405408
secretSettingsMap
406409
);
407-
try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool))) {
410+
try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool), clusterService)) {
408411
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
409412
var thrownException = expectThrows(
410413
ValidationException.class,
@@ -452,7 +455,7 @@ private void testChunkedInfer(TaskType taskType, ChunkingSettings chunkingSettin
452455

453456
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
454457

455-
try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool))) {
458+
try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool), clusterService)) {
456459
var model = createModelForTaskType(taskType, chunkingSettings);
457460

458461
PlainActionFuture<List<ChunkedInference>> listener = new PlainActionFuture<>();
@@ -482,7 +485,7 @@ private void testChunkedInfer(TaskType taskType, ChunkingSettings chunkingSettin
482485

483486
@SuppressWarnings("checkstyle:LineLength")
484487
public void testGetConfiguration() throws Exception {
485-
try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool))) {
488+
try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool), clusterService)) {
486489
String content = XContentHelper.stripWhitespace(
487490
"""
488491
{

0 commit comments

Comments
 (0)