@@ -77,15 +77,51 @@ public void shutdown() throws IOException {
7777 webServer .close ();
7878 }
7979
80- public void testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction () throws IOException {
81- testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction ("overridden_user" );
80+ public void testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction_WithUser_WithDimensions_DimensionsSetByUserFalse ()
81+ throws IOException {
82+ testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction ("overridden_user" , 384 , false , null );
8283 }
8384
84- public void testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction_WithoutUser () throws IOException {
85- testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction (null );
85+ public void testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction_NoUser_WithDimensions_DimensionsSetByUserFalse ()
86+ throws IOException {
87+ testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction (null , 384 , false , null );
8688 }
8789
88- private void testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction (String user ) throws IOException {
90+ public void testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction_WithUser_NoDimensions_DimensionsSetByUserFalse ()
91+ throws IOException {
92+ testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction ("overridden_user" , null , false , null );
93+ }
94+
95+ public void testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction_NoUser_NoDimensions_DimensionsSetByUserFalse ()
96+ throws IOException {
97+ testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction (null , null , false , null );
98+ }
99+
100+ public void testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction_WithUser_WithDimensions_DimensionsSetByUserTrue ()
101+ throws IOException {
102+ testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction ("overridden_user" , 384 , true , 384 );
103+ }
104+
105+ public void testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction_NoUser_WithDimensions_DimensionsSetByUserTrue ()
106+ throws IOException {
107+ testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction (null , 384 , true , 384 );
108+ }
109+
110+ public void testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction_WithUser_NoDimensions_DimensionsSetByUserTrue ()
111+ throws IOException {
112+ testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction ("overridden_user" , null , true , null );
113+ }
114+
115+ public void testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction_NoUser_NoDimensions_DimensionsSetByUserTrue () throws IOException {
116+ testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction (null , null , true , null );
117+ }
118+
119+ private void testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction (
120+ String user ,
121+ Integer dimensions ,
122+ boolean dimensionsSetByUser ,
123+ Integer expectedDimensions
124+ ) throws IOException {
89125 var senderFactory = HttpRequestSenderTests .createSenderFactory (threadPool , clientManager );
90126
91127 try (var sender = createSender (senderFactory )) {
@@ -113,25 +149,68 @@ private void testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction(String us
113149 """ ;
114150 webServer .enqueue (new MockResponse ().setResponseCode (200 ).setBody (responseJson ));
115151
116- PlainActionFuture <InferenceServiceResults > listener = createEmbeddingsFuture (sender , createWithEmptySettings (threadPool ), user );
152+ PlainActionFuture <InferenceServiceResults > listener = createEmbeddingsFuture (
153+ sender ,
154+ createWithEmptySettings (threadPool ),
155+ user ,
156+ dimensions ,
157+ dimensionsSetByUser
158+ );
117159
118160 var result = listener .actionGet (TIMEOUT );
119161
120162 assertThat (result .asMap (), is (TextEmbeddingFloatResultsTests .buildExpectationFloat (List .of (new float [] { -0.123F , 0.123F }))));
121163
122- assertEmbeddingsRequest (user );
164+ assertEmbeddingsRequest (user , expectedDimensions );
123165 }
124166 }
125167
126- public void testExecute_FailsFromInvalidResponseFormat_ForEmbeddingsAction () throws IOException {
127- testExecute_FailsFromInvalidResponseFormat_ForEmbeddingsAction ("overridden_user" );
168+ public void testExecute_FailsFromInvalidResponseFormat_ForEmbeddingsAction_WithUser_WithDimensions_DimensionsSetByUserFalse ()
169+ throws IOException {
170+ testExecute_FailsFromInvalidResponseFormat_ForEmbeddingsAction ("overridden_user" , 384 , false , null );
171+ }
172+
173+ public void testExecute_FailsFromInvalidResponseFormat_ForEmbeddingsAction_NoUser_WithDimensions_DimensionsSetByUserFalse ()
174+ throws IOException {
175+ testExecute_FailsFromInvalidResponseFormat_ForEmbeddingsAction (null , 384 , false , null );
128176 }
129177
130- public void testExecute_FailsFromInvalidResponseFormat_ForEmbeddingsAction_WithoutUser () throws IOException {
131- testExecute_FailsFromInvalidResponseFormat_ForEmbeddingsAction (null );
178+ public void testExecute_FailsFromInvalidResponseFormat_ForEmbeddingsAction_WithUser_NoDimensions_DimensionsSetByUserFalse ()
179+ throws IOException {
180+ testExecute_FailsFromInvalidResponseFormat_ForEmbeddingsAction ("overridden_user" , null , false , null );
132181 }
133182
134- private void testExecute_FailsFromInvalidResponseFormat_ForEmbeddingsAction (String user ) throws IOException {
183+ public void testExecute_FailsFromInvalidResponseFormat_ForEmbeddingsAction_NoUser_NoDimensions_DimensionsSetByUserFalse ()
184+ throws IOException {
185+ testExecute_FailsFromInvalidResponseFormat_ForEmbeddingsAction (null , null , false , null );
186+ }
187+
188+ public void testExecute_FailsFromInvalidResponseFormat_ForEmbeddingsAction_WithUser_WithDimensions_DimensionsSetByUserTrue ()
189+ throws IOException {
190+ testExecute_FailsFromInvalidResponseFormat_ForEmbeddingsAction ("overridden_user" , 384 , true , 384 );
191+ }
192+
193+ public void testExecute_FailsFromInvalidResponseFormat_ForEmbeddingsAction_NoUser_WithDimensions_DimensionsSetByUserTrue ()
194+ throws IOException {
195+ testExecute_FailsFromInvalidResponseFormat_ForEmbeddingsAction (null , 384 , true , 384 );
196+ }
197+
198+ public void testExecute_FailsFromInvalidResponseFormat_ForEmbeddingsAction_WithUser_NoDimensions_DimensionsSetByUserTrue ()
199+ throws IOException {
200+ testExecute_FailsFromInvalidResponseFormat_ForEmbeddingsAction ("overridden_user" , null , true , null );
201+ }
202+
203+ public void testExecute_FailsFromInvalidResponseFormat_ForEmbeddingsAction_NoUser_NoDimensions_DimensionsSetByUserTrue ()
204+ throws IOException {
205+ testExecute_FailsFromInvalidResponseFormat_ForEmbeddingsAction (null , null , true , null );
206+ }
207+
208+ private void testExecute_FailsFromInvalidResponseFormat_ForEmbeddingsAction (
209+ String user ,
210+ Integer dimensions ,
211+ boolean dimensionsSetByUser ,
212+ Integer expectedDimensions
213+ ) throws IOException {
135214 var settings = buildSettingsWithRetryFields (
136215 TimeValue .timeValueMillis (1 ),
137216 TimeValue .timeValueMinutes (1 ),
@@ -149,15 +228,21 @@ private void testExecute_FailsFromInvalidResponseFormat_ForEmbeddingsAction(Stri
149228 """ ;
150229 webServer .enqueue (new MockResponse ().setResponseCode (200 ).setBody (responseJson ));
151230
152- PlainActionFuture <InferenceServiceResults > listener = createEmbeddingsFuture (sender , createWithEmptySettings (threadPool ), user );
231+ PlainActionFuture <InferenceServiceResults > listener = createEmbeddingsFuture (
232+ sender ,
233+ createWithEmptySettings (threadPool ),
234+ user ,
235+ dimensions ,
236+ dimensionsSetByUser
237+ );
153238
154239 var thrownException = expectThrows (ElasticsearchException .class , () -> listener .actionGet (TIMEOUT ));
155240 assertThat (
156241 thrownException .getMessage (),
157242 is ("Failed to send Llama text_embedding request from inference entity id [id]. Cause: Required [data]" )
158243 );
159244
160- assertEmbeddingsRequest (user );
245+ assertEmbeddingsRequest (user , expectedDimensions );
161246 }
162247 }
163248
@@ -262,8 +347,21 @@ private void testExecute_FailsFromInvalidResponseFormat_ForCompletionAction(Stri
262347 }
263348 }
264349
265- private PlainActionFuture <InferenceServiceResults > createEmbeddingsFuture (Sender sender , ServiceComponents threadPool , String user ) {
266- var model = LlamaEmbeddingsModelTests .createEmbeddingsModel ("model" , getUrl (webServer ), "secret" , user );
350+ private PlainActionFuture <InferenceServiceResults > createEmbeddingsFuture (
351+ Sender sender ,
352+ ServiceComponents threadPool ,
353+ String user ,
354+ Integer dimensions ,
355+ boolean dimensionsSetByUser
356+ ) {
357+ var model = LlamaEmbeddingsModelTests .createEmbeddingsModel (
358+ "model" ,
359+ getUrl (webServer ),
360+ "secret" ,
361+ user ,
362+ dimensions ,
363+ dimensionsSetByUser
364+ );
267365 var actionCreator = new LlamaActionCreator (sender , threadPool );
268366 var overriddenTaskSettings = createRequestTaskSettingsMap (user );
269367 var action = actionCreator .create (model , overriddenTaskSettings );
@@ -305,19 +403,15 @@ private void assertCompletionRequest(String user) throws IOException {
305403 }
306404
307405 @ SuppressWarnings ("unchecked" )
308- private void assertEmbeddingsRequest (String user ) throws IOException {
406+ private void assertEmbeddingsRequest (String user , Integer dimensions ) throws IOException {
309407 assertCommonRequestProperties ();
310408
311409 var requestMap = entityAsMap (webServer .requests ().get (0 ).getBody ());
312- if (user == null ) {
313- assertThat (requestMap .size (), is (2 ));
314- } else {
315- assertThat (requestMap .size (), is (3 ));
316- }
317410 assertThat (requestMap .get ("input" ), instanceOf (List .class ));
318411 var inputList = (List <String >) requestMap .get ("input" );
319412 assertThat (inputList , contains ("abc" ));
320413 assertThat (requestMap .get ("user" ), is (user ));
414+ assertThat (requestMap .get ("dimensions" ), is (dimensions ));
321415 }
322416
323417 private void assertCommonRequestProperties () {
0 commit comments