Skip to content

Commit 9aaddfd

Browse files
Refactor OpenShift AI service tests to use constants for URL, model ID, and API key
1 parent 8b5c407 commit 9aaddfd

File tree

1 file changed

+57
-69
lines changed

1 file changed

+57
-69
lines changed

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/OpenShiftAiServiceTests.java

Lines changed: 57 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@
6060
import org.elasticsearch.xpack.inference.services.openshiftai.embeddings.OpenShiftAiEmbeddingsServiceSettings;
6161
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
6262
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
63-
import org.hamcrest.CoreMatchers;
6463
import org.hamcrest.Matchers;
6564
import org.junit.After;
6665
import org.junit.Before;
@@ -94,14 +93,19 @@
9493
import static org.elasticsearch.xpack.inference.services.openshiftai.completion.OpenShiftAiChatCompletionModelTests.createChatCompletionModel;
9594
import static org.elasticsearch.xpack.inference.services.openshiftai.completion.OpenShiftAiChatCompletionServiceSettingsTests.getServiceSettingsMap;
9695
import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap;
97-
import static org.hamcrest.CoreMatchers.is;
9896
import static org.hamcrest.Matchers.equalTo;
9997
import static org.hamcrest.Matchers.hasSize;
10098
import static org.hamcrest.Matchers.instanceOf;
99+
import static org.hamcrest.Matchers.is;
101100
import static org.hamcrest.Matchers.isA;
102101
import static org.mockito.Mockito.mock;
103102

104103
public class OpenShiftAiServiceTests extends AbstractInferenceServiceTests {
104+
private static final String URL = "http://www.abc.com";
105+
private static final String MODEL_ID = "model_id";
106+
private static final String USER_ROLE = "user";
107+
private static final String API_KEY = "secret";
108+
private static final String INFERENCE_ID = "id";
105109
private final MockWebServer webServer = new MockWebServer();
106110
private ThreadPool threadPool;
107111
private HttpClientManager clientManager;
@@ -168,32 +172,32 @@ private static void assertModel(Model model, TaskType taskType, boolean modelInc
168172
private static void assertTextEmbeddingModel(Model model, boolean modelIncludesSecrets) {
169173
var openShiftAiModel = assertCommonModelFields(model, modelIncludesSecrets);
170174

171-
assertThat(openShiftAiModel.getTaskType(), Matchers.is(TaskType.TEXT_EMBEDDING));
175+
assertThat(openShiftAiModel.getTaskType(), is(TaskType.TEXT_EMBEDDING));
172176
}
173177

174178
private static OpenShiftAiModel assertCommonModelFields(Model model, boolean modelIncludesSecrets) {
175179
assertThat(model, instanceOf(OpenShiftAiModel.class));
176180

177181
var openShiftAiModel = (OpenShiftAiModel) model;
178-
assertThat(openShiftAiModel.getServiceSettings().modelId(), is("model_id"));
179-
assertThat(openShiftAiModel.getServiceSettings().uri.toString(), Matchers.is("http://www.abc.com"));
180-
assertThat(openShiftAiModel.getTaskSettings(), Matchers.is(EmptyTaskSettings.INSTANCE));
182+
assertThat(openShiftAiModel.getServiceSettings().modelId(), is(MODEL_ID));
183+
assertThat(openShiftAiModel.getServiceSettings().uri.toString(), is(URL));
184+
assertThat(openShiftAiModel.getTaskSettings(), is(EmptyTaskSettings.INSTANCE));
181185

182186
if (modelIncludesSecrets) {
183-
assertThat(openShiftAiModel.getSecretSettings().apiKey(), Matchers.is(new SecureString("secret".toCharArray())));
187+
assertThat(openShiftAiModel.getSecretSettings().apiKey(), is(new SecureString(API_KEY.toCharArray())));
184188
}
185189

186190
return openShiftAiModel;
187191
}
188192

189193
private static void assertCompletionModel(Model model, boolean modelIncludesSecrets) {
190194
var openShiftAiModel = assertCommonModelFields(model, modelIncludesSecrets);
191-
assertThat(openShiftAiModel.getTaskType(), Matchers.is(TaskType.COMPLETION));
195+
assertThat(openShiftAiModel.getTaskType(), is(TaskType.COMPLETION));
192196
}
193197

194198
private static void assertChatCompletionModel(Model model, boolean modelIncludesSecrets) {
195199
var openShiftAiModel = assertCommonModelFields(model, modelIncludesSecrets);
196-
assertThat(openShiftAiModel.getTaskType(), Matchers.is(TaskType.CHAT_COMPLETION));
200+
assertThat(openShiftAiModel.getTaskType(), is(TaskType.CHAT_COMPLETION));
197201
}
198202

199203
public static SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) {
@@ -202,9 +206,7 @@ public static SenderService createService(ThreadPool threadPool, HttpClientManag
202206
}
203207

204208
private static Map<String, Object> createServiceSettingsMap(TaskType taskType) {
205-
Map<String, Object> settingsMap = new HashMap<>(
206-
Map.of(ServiceFields.URL, "http://www.abc.com", ServiceFields.MODEL_ID, "model_id")
207-
);
209+
Map<String, Object> settingsMap = new HashMap<>(Map.of(ServiceFields.URL, URL, ServiceFields.MODEL_ID, MODEL_ID));
208210

209211
if (taskType == TaskType.TEXT_EMBEDDING) {
210212
settingsMap.putAll(
@@ -223,27 +225,17 @@ private static Map<String, Object> createServiceSettingsMap(TaskType taskType) {
223225
}
224226

225227
private static Map<String, Object> createSecretSettingsMap() {
226-
return new HashMap<>(Map.of("api_key", "secret"));
228+
return new HashMap<>(Map.of("api_key", API_KEY));
227229
}
228230

229231
private static OpenShiftAiEmbeddingsModel createInternalEmbeddingModel(@Nullable SimilarityMeasure similarityMeasure) {
230-
var inferenceId = "inference_id";
231-
232232
return new OpenShiftAiEmbeddingsModel(
233-
inferenceId,
233+
INFERENCE_ID,
234234
TaskType.TEXT_EMBEDDING,
235235
OpenShiftAiService.NAME,
236-
new OpenShiftAiEmbeddingsServiceSettings(
237-
"model_id",
238-
"http://www.abc.com",
239-
1536,
240-
similarityMeasure,
241-
512,
242-
new RateLimitSettings(10_000),
243-
true
244-
),
236+
new OpenShiftAiEmbeddingsServiceSettings(MODEL_ID, URL, 1536, similarityMeasure, 512, new RateLimitSettings(10_000), true),
245237
createRandomChunkingSettings(),
246-
new DefaultSecretSettings(new SecureString("secret".toCharArray()))
238+
new DefaultSecretSettings(new SecureString(API_KEY.toCharArray()))
247239
);
248240
}
249241

@@ -267,19 +259,15 @@ public void testParseRequestConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsP
267259
assertThat(model, instanceOf(OpenShiftAiEmbeddingsModel.class));
268260

269261
var embeddingsModel = (OpenShiftAiEmbeddingsModel) model;
270-
assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url"));
262+
assertThat(embeddingsModel.getServiceSettings().uri().toString(), is(URL));
271263
assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class));
272-
assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
264+
assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is(API_KEY));
273265
}, e -> fail("parse request should not fail " + e.getMessage()));
274266

275267
service.parseRequestConfig(
276-
"id",
268+
INFERENCE_ID,
277269
TaskType.TEXT_EMBEDDING,
278-
getRequestConfigMap(
279-
getServiceSettingsMap("model", "url"),
280-
createRandomChunkingSettingsMap(),
281-
getSecretSettingsMap("secret")
282-
),
270+
getRequestConfigMap(getServiceSettingsMap(MODEL_ID, URL), createRandomChunkingSettingsMap(), getSecretSettingsMap(API_KEY)),
283271
modelVerificationActionListener
284272
);
285273
}
@@ -291,49 +279,43 @@ public void testParseRequestConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsN
291279
assertThat(model, instanceOf(OpenShiftAiEmbeddingsModel.class));
292280

293281
var embeddingsModel = (OpenShiftAiEmbeddingsModel) model;
294-
assertThat(embeddingsModel.getServiceSettings().uri().toString(), is("url"));
282+
assertThat(embeddingsModel.getServiceSettings().uri().toString(), is(URL));
295283
assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class));
296-
assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
284+
assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is(API_KEY));
297285
}, e -> fail("parse request should not fail " + e.getMessage()));
298286

299287
service.parseRequestConfig(
300-
"id",
288+
INFERENCE_ID,
301289
TaskType.TEXT_EMBEDDING,
302-
getRequestConfigMap(getServiceSettingsMap("model", "url"), getSecretSettingsMap("secret")),
290+
getRequestConfigMap(getServiceSettingsMap(MODEL_ID, URL), getSecretSettingsMap(API_KEY)),
303291
modelVerificationActionListener
304292
);
305293
}
306294
}
307295

308296
public void testParseRequestConfig_WithoutModelId_Success() throws IOException {
309-
var url = "url";
310-
var secret = "secret";
311-
312297
try (var service = createService()) {
313298
ActionListener<Model> modelVerificationListener = ActionListener.wrap(m -> {
314299
assertThat(m, instanceOf(OpenShiftAiChatCompletionModel.class));
315300

316301
var chatCompletionModel = (OpenShiftAiChatCompletionModel) m;
317302

318-
assertThat(chatCompletionModel.getServiceSettings().uri().toString(), is(url));
303+
assertThat(chatCompletionModel.getServiceSettings().uri().toString(), is(URL));
319304
assertNull(chatCompletionModel.getServiceSettings().modelId());
320-
assertThat(chatCompletionModel.getSecretSettings().apiKey().toString(), is("secret"));
305+
assertThat(chatCompletionModel.getSecretSettings().apiKey().toString(), is(API_KEY));
321306

322307
}, e -> fail("parse request should not fail " + e.getMessage()));
323308

324309
service.parseRequestConfig(
325-
"id",
310+
INFERENCE_ID,
326311
TaskType.CHAT_COMPLETION,
327-
getRequestConfigMap(getServiceSettingsMap(null, url), getSecretSettingsMap(secret)),
312+
getRequestConfigMap(getServiceSettingsMap(null, URL), getSecretSettingsMap(API_KEY)),
328313
modelVerificationListener
329314
);
330315
}
331316
}
332317

333318
public void testParseRequestConfig_WithoutUrl_ThrowsException() throws IOException {
334-
var model = "model";
335-
var secret = "secret";
336-
337319
try (var service = createService()) {
338320
ActionListener<Model> modelVerificationListener = ActionListener.wrap(
339321
m -> fail("Expected exception, but got model: " + m),
@@ -347,9 +329,9 @@ public void testParseRequestConfig_WithoutUrl_ThrowsException() throws IOExcepti
347329
);
348330

349331
service.parseRequestConfig(
350-
"id",
332+
INFERENCE_ID,
351333
TaskType.CHAT_COMPLETION,
352-
getRequestConfigMap(getServiceSettingsMap(model, null), getSecretSettingsMap(secret)),
334+
getRequestConfigMap(getServiceSettingsMap(MODEL_ID, null), getSecretSettingsMap(API_KEY)),
353335
modelVerificationListener
354336
);
355337
}
@@ -386,12 +368,14 @@ public void testUnifiedCompletionInfer() throws Exception {
386368

387369
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
388370
try (var service = new OpenShiftAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) {
389-
var model = createChatCompletionModel(getUrl(webServer), "secret", "model");
371+
var model = createChatCompletionModel(getUrl(webServer), API_KEY, MODEL_ID);
390372
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
391373
service.unifiedCompletionInfer(
392374
model,
393375
UnifiedCompletionRequest.of(
394-
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null))
376+
List.of(
377+
new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), USER_ROLE, null, null)
378+
)
395379
),
396380
InferenceAction.Request.DEFAULT_TIMEOUT,
397381
listener
@@ -426,12 +410,14 @@ public void testUnifiedCompletionNonStreamingNotFoundError() throws Exception {
426410

427411
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
428412
try (var service = new OpenShiftAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) {
429-
var model = OpenShiftAiChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "secret", "model");
413+
var model = OpenShiftAiChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), API_KEY, MODEL_ID);
430414
var latch = new CountDownLatch(1);
431415
service.unifiedCompletionInfer(
432416
model,
433417
UnifiedCompletionRequest.of(
434-
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null))
418+
List.of(
419+
new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), USER_ROLE, null, null)
420+
)
435421
),
436422
InferenceAction.Request.DEFAULT_TIMEOUT,
437423
ActionListener.runAfter(ActionTestUtils.assertNoSuccessListener(e -> {
@@ -510,12 +496,14 @@ public void testInfer_StreamRequest() throws Exception {
510496
private void testStreamError(String expectedResponse) throws Exception {
511497
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
512498
try (var service = new OpenShiftAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) {
513-
var model = OpenShiftAiChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "secret", "model");
499+
var model = OpenShiftAiChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), API_KEY, MODEL_ID);
514500
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
515501
service.unifiedCompletionInfer(
516502
model,
517503
UnifiedCompletionRequest.of(
518-
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null))
504+
List.of(
505+
new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), USER_ROLE, null, null)
506+
)
519507
),
520508
InferenceAction.Request.DEFAULT_TIMEOUT,
521509
listener
@@ -597,7 +585,7 @@ public void testSupportsStreaming() throws IOException {
597585

598586
public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInEmbeddingSecretSettingsMap() throws IOException {
599587
try (var service = createService()) {
600-
var secretSettings = getSecretSettingsMap("secret");
588+
var secretSettings = getSecretSettingsMap(API_KEY);
601589
secretSettings.put("extra_key", "value");
602590

603591
var config = getRequestConfigMap(getEmbeddingsServiceSettingsMap(), secretSettings);
@@ -613,21 +601,21 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInEmbeddingSecretSe
613601
}
614602
);
615603

616-
service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, modelVerificationListener);
604+
service.parseRequestConfig(INFERENCE_ID, TaskType.TEXT_EMBEDDING, config, modelVerificationListener);
617605
}
618606
}
619607

620608
public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException {
621-
var model = OpenShiftAiEmbeddingsModelTests.createModel(getUrl(webServer), "api_key", "model", 1234, false, 1536, null);
609+
var model = OpenShiftAiEmbeddingsModelTests.createModel(getUrl(webServer), API_KEY, MODEL_ID, 1234, false, 1536, null);
622610

623611
testChunkedInfer(model);
624612
}
625613

626614
public void testChunkedInfer_ChunkingSettingsSet() throws IOException {
627615
var model = OpenShiftAiEmbeddingsModelTests.createModel(
628616
getUrl(webServer),
629-
"api_key",
630-
"model",
617+
API_KEY,
618+
MODEL_ID,
631619
1234,
632620
false,
633621
1536,
@@ -691,7 +679,7 @@ public void testChunkedInfer(OpenShiftAiEmbeddingsModel model) throws IOExceptio
691679

692680
assertThat(results, hasSize(2));
693681
{
694-
assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class));
682+
assertThat(results.get(0), Matchers.instanceOf(ChunkedInferenceEmbedding.class));
695683
var floatResult = (ChunkedInferenceEmbedding) results.get(0);
696684
assertThat(floatResult.chunks(), hasSize(1));
697685
assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class));
@@ -703,7 +691,7 @@ public void testChunkedInfer(OpenShiftAiEmbeddingsModel model) throws IOExceptio
703691
);
704692
}
705693
{
706-
assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class));
694+
assertThat(results.get(1), Matchers.instanceOf(ChunkedInferenceEmbedding.class));
707695
var floatResult = (ChunkedInferenceEmbedding) results.get(1);
708696
assertThat(floatResult.chunks(), hasSize(1));
709697
assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(DenseEmbeddingFloatResults.Embedding.class));
@@ -721,12 +709,12 @@ public void testChunkedInfer(OpenShiftAiEmbeddingsModel model) throws IOExceptio
721709
webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE),
722710
equalTo(XContentType.JSON.mediaTypeWithoutParameters())
723711
);
724-
assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer api_key"));
712+
assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret"));
725713

726714
var requestMap = entityAsMap(webServer.requests().get(0).getBody());
727-
assertThat(requestMap.size(), Matchers.is(2));
728-
assertThat(requestMap.get("input"), Matchers.is(List.of("abc", "def")));
729-
assertThat(requestMap.get("model"), Matchers.is("model"));
715+
assertThat(requestMap.size(), is(2));
716+
assertThat(requestMap.get("input"), is(List.of("abc", "def")));
717+
assertThat(requestMap.get("model"), is(MODEL_ID));
730718
}
731719
}
732720

@@ -795,7 +783,7 @@ public void testGetConfiguration() throws Exception {
795783
private InferenceEventsAssertion streamCompletion() throws Exception {
796784
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
797785
try (var service = new OpenShiftAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) {
798-
var model = OpenShiftAiChatCompletionModelTests.createCompletionModel(getUrl(webServer), "secret", "model");
786+
var model = OpenShiftAiChatCompletionModelTests.createCompletionModel(getUrl(webServer), API_KEY, MODEL_ID);
799787
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
800788
service.infer(
801789
model,
@@ -842,7 +830,7 @@ private Map<String, Object> getRequestConfigMap(Map<String, Object> serviceSetti
842830
}
843831

844832
private static Map<String, Object> getEmbeddingsServiceSettingsMap() {
845-
return buildServiceSettingsMap("id", "url", SimilarityMeasure.COSINE.toString(), null, null, null);
833+
return buildServiceSettingsMap(INFERENCE_ID, URL, SimilarityMeasure.COSINE.toString(), null, null, null);
846834
}
847835

848836
@Override

0 commit comments

Comments
 (0)