Skip to content

Commit 30316f9

Browse files
committed
Adds UT for proving models depend on xContentRegistry for prediction
Signed-off-by: Brian Flores <[email protected]>
1 parent 67c562a commit 30316f9

File tree

1 file changed

+52
-0
lines changed

1 file changed

+52
-0
lines changed

plugin/src/test/java/org/opensearch/ml/processor/MLInferenceIngestProcessorTests.java

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import org.opensearch.core.action.ActionListener;
3232
import org.opensearch.core.xcontent.NamedXContentRegistry;
3333
import org.opensearch.ingest.IngestDocument;
34+
import org.opensearch.ml.common.FunctionName;
3435
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
3536
import org.opensearch.ml.common.input.MLInput;
3637
import org.opensearch.ml.common.output.model.MLResultDataType;
@@ -138,6 +139,57 @@ public void testExecute_Exception() throws Exception {
138139

139140
}
140141

142+
/**
143+
* Models that use the parameters field need to have a valid NamedXContentRegistry object to create valid MLInputs. For example
144+
* <pre>
145+
* PUT /_plugins/_ml/_predict/text_embedding/model_id
146+
* {
147+
* "parameters": {
148+
* "content_type" : "query"
149+
* },
150+
* "text_docs" : ["what day is it today?"],
151+
* "target_response" : ["sentence_embedding"]
152+
* }
153+
* </pre>
154+
* These types of models like Local Asymmetric embedding models use the parameters field.
155+
* And as such we need to test that having the contentRegistry throws an exception as it can not
156+
* properly create a valid MLInput to perform prediction
157+
*
158+
* @implNote If you check the stack trace of the test you will see it tells you that it's a direct consequence of xContentRegistry being null
159+
*/
160+
public void testExecute_xContentRegistryNullWithLocalModel_throwsException() throws Exception {
161+
// Set the registry to null and reset after exiting the test
162+
xContentRegistry = null;
163+
164+
String localModelInput =
165+
"{ \"text_docs\": [\"What day is it today?\"],\"target_response\": [\"sentence_embedding\"], \"parameters\": { \"contentType\" : \"query\"} }";
166+
167+
MLInferenceIngestProcessor processor = createMLInferenceProcessor(
168+
"local_model_id",
169+
null,
170+
null,
171+
null,
172+
false,
173+
FunctionName.TEXT_EMBEDDING.toString(),
174+
false,
175+
false,
176+
false,
177+
localModelInput
178+
);
179+
try {
180+
String npeMessage =
181+
"Cannot invoke \"org.opensearch.ml.common.input.MLInput.setAlgorithm(org.opensearch.ml.common.FunctionName)\" because \"mlInput\" is null";
182+
183+
processor.execute(ingestDocument, handler);
184+
verify(handler)
185+
.accept(isNull(), argThat(exception -> exception instanceof NullPointerException && exception.getMessage().equals(npeMessage)));
186+
} catch (Exception e) {
187+
assertEquals("this catch block should not get executed.", e.getMessage());
188+
}
189+
// reset to mocked object
190+
xContentRegistry = mock(NamedXContentRegistry.class);
191+
}
192+
141193
/**
142194
* test nested object document with array of Map<String,String>
143195
*/

0 commit comments

Comments
 (0)