1919import org .elasticsearch .threadpool .ThreadPool ;
2020import org .elasticsearch .xcontent .XContentType ;
2121import org .elasticsearch .xpack .core .inference .action .InferenceAction ;
22- import org .elasticsearch .xpack .inference .external .action .cohere . CohereActionCreator ;
22+ import org .elasticsearch .xpack .inference .external .action .voyageai . VoyageAIActionCreator ;
2323import org .elasticsearch .xpack .inference .external .http .HttpClientManager ;
24- import org .elasticsearch .xpack .inference .external .http .sender .ChatCompletionInput ;
2524import org .elasticsearch .xpack .inference .external .http .sender .DocumentsOnlyInput ;
2625import org .elasticsearch .xpack .inference .external .http .sender .HttpRequestSenderTests ;
2726import org .elasticsearch .xpack .inference .logging .ThrottlerManager ;
28- import org .elasticsearch .xpack .inference .services .cohere .CohereTruncation ;
29- import org .elasticsearch .xpack .inference .services .cohere .completion .CohereCompletionModelTests ;
30- import org .elasticsearch .xpack .inference .services .cohere .embeddings .CohereEmbeddingType ;
31- import org .elasticsearch .xpack .inference .services .cohere .embeddings .CohereEmbeddingsModelTests ;
32- import org .elasticsearch .xpack .inference .services .cohere .embeddings .CohereEmbeddingsTaskSettings ;
33- import org .elasticsearch .xpack .inference .services .cohere .embeddings .CohereEmbeddingsTaskSettingsTests ;
27+ import org .elasticsearch .xpack .inference .services .voyageai .embeddings .VoyageAIEmbeddingType ;
28+ import org .elasticsearch .xpack .inference .services .voyageai .embeddings .VoyageAIEmbeddingsModelTests ;
29+ import org .elasticsearch .xpack .inference .services .voyageai .embeddings .VoyageAIEmbeddingsTaskSettings ;
30+ import org .elasticsearch .xpack .inference .services .voyageai .embeddings .VoyageAIEmbeddingsTaskSettingsTests ;
3431import org .hamcrest .MatcherAssert ;
3532import org .junit .After ;
3633import org .junit .Before ;
4542import static org .elasticsearch .xpack .inference .external .http .Utils .entityAsMap ;
4643import static org .elasticsearch .xpack .inference .external .http .Utils .getUrl ;
4744import static org .elasticsearch .xpack .inference .external .http .sender .HttpRequestSenderTests .createSender ;
48- import static org .elasticsearch .xpack .inference .results .ChatCompletionResultsTests .buildExpectationCompletion ;
4945import static org .elasticsearch .xpack .inference .results .TextEmbeddingResultsTests .buildExpectationFloat ;
5046import static org .elasticsearch .xpack .inference .services .ServiceComponentsTests .createWithEmptySettings ;
5147import static org .hamcrest .Matchers .is ;
@@ -73,50 +69,42 @@ public void shutdown() throws IOException {
7369 webServer .close ();
7470 }
7571
76- public void testCreate_CohereEmbeddingsModel () throws IOException {
72+ public void testCreate_VoyageAIEmbeddingsModel () throws IOException {
7773 var senderFactory = HttpRequestSenderTests .createSenderFactory (threadPool , clientManager );
7874
7975 try (var sender = createSender (senderFactory )) {
8076 sender .start ();
8177
8278 String responseJson = """
8379 {
84- "id": "de37399c-5df6-47cb-bc57-e3c5680c977b",
85- "texts": [
86- "hello"
87- ],
88- "embeddings": {
89- "float": [
90- [
80+ "object": "list",
81+ "data": [{
82+ "object": "embedding",
83+ "embedding": [
9184 0.123,
9285 -0.123
93- ]
94- ]
95- },
96- "meta": {
97- "api_version": {
98- "version": "1"
99- },
100- "billed_units": {
101- "input_tokens": 1
102- }
103- },
104- "response_type": "embeddings_by_type"
86+ ],
87+ "index": 0
88+ }],
89+ "model": "voyage-3-large",
90+ "usage": {
91+ "total_tokens": 123
92+ }
10593 }
10694 """ ;
10795 webServer .enqueue (new MockResponse ().setResponseCode (200 ).setBody (responseJson ));
10896
109- var model = CohereEmbeddingsModelTests .createModel (
97+ var model = VoyageAIEmbeddingsModelTests .createModel (
11098 getUrl (webServer ),
11199 "secret" ,
112- new CohereEmbeddingsTaskSettings (InputType .INGEST , CohereTruncation . START ),
100+ new VoyageAIEmbeddingsTaskSettings (InputType .INGEST , true ),
113101 1024 ,
114102 1024 ,
115103 "model" ,
116- CohereEmbeddingType .FLOAT
104+ VoyageAIEmbeddingType .FLOAT
117105 );
118- var actionCreator = new CohereActionCreator (sender , createWithEmptySettings (threadPool ));
119- var overriddenTaskSettings = CohereEmbeddingsTaskSettingsTests .getTaskSettingsMap (InputType .SEARCH , CohereTruncation . END );
106+ var actionCreator = new VoyageAIActionCreator (sender , createWithEmptySettings (threadPool ));
107+ var overriddenTaskSettings = VoyageAIEmbeddingsTaskSettingsTests .getTaskSettingsMap (InputType .SEARCH );
120108 var action = actionCreator .create (model , overriddenTaskSettings , InputType .UNSPECIFIED );
121109
122110 PlainActionFuture <InferenceServiceResults > listener = new PlainActionFuture <>();
@@ -138,139 +126,15 @@ public void testCreate_CohereEmbeddingsModel() throws IOException {
138126 requestMap ,
139127 is (
140128 Map .of (
141- "texts" ,
142- List .of ("abc" ),
143- "model" ,
144- "model" ,
145- "input_type" ,
146- "search_query" ,
147- "embedding_types" ,
148- List .of ("float" ),
149- "truncate" ,
150- "end"
129+ "output_dtype" ,"float" ,
130+ "truncation" , true ,
131+ "input_type" , "query" ,
132+ "output_dimension" ,1024 ,
133+ "input" , List .of ("abc" ),
134+ "model" , "model"
151135 )
152136 )
153137 );
154138 }
155139 }
156-
157- public void testCreate_CohereCompletionModel_WithModelSpecified () throws IOException {
158- var senderFactory = HttpRequestSenderTests .createSenderFactory (threadPool , clientManager );
159-
160- try (var sender = createSender (senderFactory )) {
161- sender .start ();
162-
163- String responseJson = """
164- {
165- "response_id": "some id",
166- "text": "result",
167- "generation_id": "some id",
168- "chat_history": [
169- {
170- "role": "USER",
171- "message": "input"
172- },
173- {
174- "role": "CHATBOT",
175- "message": "result"
176- }
177- ],
178- "finish_reason": "COMPLETE",
179- "meta": {
180- "api_version": {
181- "version": "1"
182- },
183- "billed_units": {
184- "input_tokens": 4,
185- "output_tokens": 191
186- },
187- "tokens": {
188- "input_tokens": 70,
189- "output_tokens": 191
190- }
191- }
192- }
193- """ ;
194-
195- webServer .enqueue (new MockResponse ().setResponseCode (200 ).setBody (responseJson ));
196-
197- var model = CohereCompletionModelTests .createModel (getUrl (webServer ), "secret" , "model" );
198- var actionCreator = new CohereActionCreator (sender , createWithEmptySettings (threadPool ));
199- var action = actionCreator .create (model , Map .of ());
200-
201- PlainActionFuture <InferenceServiceResults > listener = new PlainActionFuture <>();
202- action .execute (new ChatCompletionInput (List .of ("abc" )), InferenceAction .Request .DEFAULT_TIMEOUT , listener );
203-
204- var result = listener .actionGet (TIMEOUT );
205-
206- assertThat (result .asMap (), is (buildExpectationCompletion (List .of ("result" ))));
207- assertThat (webServer .requests (), hasSize (1 ));
208- assertNull (webServer .requests ().get (0 ).getUri ().getQuery ());
209- assertThat (webServer .requests ().get (0 ).getHeader (HttpHeaders .CONTENT_TYPE ), is (XContentType .JSON .mediaType ()));
210- assertThat (webServer .requests ().get (0 ).getHeader (HttpHeaders .AUTHORIZATION ), is ("Bearer secret" ));
211-
212- var requestMap = entityAsMap (webServer .requests ().get (0 ).getBody ());
213- assertThat (requestMap , is (Map .of ("message" , "abc" , "model" , "model" )));
214- }
215- }
216-
217- public void testCreate_CohereCompletionModel_WithoutModelSpecified () throws IOException {
218- var senderFactory = HttpRequestSenderTests .createSenderFactory (threadPool , clientManager );
219-
220- try (var sender = createSender (senderFactory )) {
221- sender .start ();
222-
223- String responseJson = """
224- {
225- "response_id": "some id",
226- "text": "result",
227- "generation_id": "some id",
228- "chat_history": [
229- {
230- "role": "USER",
231- "message": "input"
232- },
233- {
234- "role": "CHATBOT",
235- "message": "result"
236- }
237- ],
238- "finish_reason": "COMPLETE",
239- "meta": {
240- "api_version": {
241- "version": "1"
242- },
243- "billed_units": {
244- "input_tokens": 4,
245- "output_tokens": 191
246- },
247- "tokens": {
248- "input_tokens": 70,
249- "output_tokens": 191
250- }
251- }
252- }
253- """ ;
254-
255- webServer .enqueue (new MockResponse ().setResponseCode (200 ).setBody (responseJson ));
256-
257- var model = CohereCompletionModelTests .createModel (getUrl (webServer ), "secret" , null );
258- var actionCreator = new CohereActionCreator (sender , createWithEmptySettings (threadPool ));
259- var action = actionCreator .create (model , Map .of ());
260-
261- PlainActionFuture <InferenceServiceResults > listener = new PlainActionFuture <>();
262- action .execute (new ChatCompletionInput (List .of ("abc" )), InferenceAction .Request .DEFAULT_TIMEOUT , listener );
263-
264- var result = listener .actionGet (TIMEOUT );
265-
266- assertThat (result .asMap (), is (buildExpectationCompletion (List .of ("result" ))));
267- assertThat (webServer .requests (), hasSize (1 ));
268- assertNull (webServer .requests ().get (0 ).getUri ().getQuery ());
269- assertThat (webServer .requests ().get (0 ).getHeader (HttpHeaders .CONTENT_TYPE ), is (XContentType .JSON .mediaType ()));
270- assertThat (webServer .requests ().get (0 ).getHeader (HttpHeaders .AUTHORIZATION ), is ("Bearer secret" ));
271-
272- var requestMap = entityAsMap (webServer .requests ().get (0 ).getBody ());
273- assertThat (requestMap , is (Map .of ("message" , "abc" )));
274- }
275- }
276140}
0 commit comments