Skip to content

Commit 52a51bb

Browse files
author
Max Hniebergall
committed
Change to action and tests
1 parent b5d6fa0 commit 52a51bb

File tree

2 files changed

+73
-0
lines changed

2 files changed

+73
-0
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,15 @@ public static class Request extends ActionRequest {
5959
public static final TimeValue DEFAULT_TIMEOUT = TimeValue.timeValueSeconds(30);
6060
public static final ParseField INPUT = new ParseField("input");
6161
public static final ParseField TASK_SETTINGS = new ParseField("task_settings");
62+
public static final ParseField PARAMETERS = new ParseField("parameters");
6263
public static final ParseField QUERY = new ParseField("query");
6364
public static final ParseField TIMEOUT = new ParseField("timeout");
6465

6566
static final ObjectParser<Request.Builder, Void> PARSER = new ObjectParser<>(NAME, Request.Builder::new);
6667
static {
6768
PARSER.declareStringArray(Request.Builder::setInput, INPUT);
6869
PARSER.declareObject(Request.Builder::setTaskSettings, (p, c) -> p.mapOrdered(), TASK_SETTINGS);
70+
PARSER.declareObject(Request.Builder::setTaskSettings, (p, c) -> p.mapOrdered(), PARAMETERS);
6971
PARSER.declareString(Request.Builder::setQuery, QUERY);
7072
PARSER.declareString(Builder::setInferenceTimeout, TIMEOUT);
7173
}

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,77 @@ public void testParsing() throws IOException {
7373
}
7474
}
7575

76+
public void testParsingWithTaskSettings() throws IOException {
77+
String requestText = """
78+
{
79+
"input": "single text input",
80+
"task_settings": {
81+
"foo": "bar"
82+
}
83+
}
84+
""";
85+
try (var parser = createParser(JsonXContent.jsonXContent, requestText)) {
86+
var request = InferenceAction.Request.parseRequest("model_id", TaskType.SPARSE_EMBEDDING, parser).build();
87+
assertThat(request.getInput(), contains("single text input"));
88+
assertThat(request.getTaskSettings(), is(Map.of("foo", "bar")));
89+
}
90+
}
91+
92+
public void testParsingWithParameters() throws IOException {
93+
String requestText = """
94+
{
95+
"input": "single text input",
96+
"parameters": {
97+
"foo": "bar"
98+
}
99+
}
100+
""";
101+
try (var parser = createParser(JsonXContent.jsonXContent, requestText)) {
102+
var request = InferenceAction.Request.parseRequest("model_id", TaskType.SPARSE_EMBEDDING, parser).build();
103+
assertThat(request.getInput(), contains("single text input"));
104+
assertThat(request.getTaskSettings(), is(Map.of("foo", "bar")));
105+
}
106+
}
107+
108+
public void testParsingWithTaskSettingsAndParameters() throws IOException {
109+
{
110+
String singleInputRequest = """
111+
{
112+
"input": "single text input",
113+
"parameters": {
114+
"foo": "bar"
115+
},
116+
"task_settings": {
117+
"food": "bard"
118+
}
119+
}
120+
""";
121+
try (var parser = createParser(JsonXContent.jsonXContent, singleInputRequest)) {
122+
var request = InferenceAction.Request.parseRequest("model_id", TaskType.SPARSE_EMBEDDING, parser).build();
123+
assertThat(request.getInput(), contains("single text input"));
124+
assertThat(request.getTaskSettings(), is(Map.of("food", "bard")));
125+
}
126+
}
127+
{
128+
String singleInputRequest = """
129+
{
130+
"input": "single text input",
131+
"task_settings": {
132+
"food": "bard"
133+
},
134+
"parameters": {
135+
"foo": "bar"
136+
}
137+
}
138+
""";
139+
try (var parser = createParser(JsonXContent.jsonXContent, singleInputRequest)) {
140+
var request = InferenceAction.Request.parseRequest("model_id", TaskType.SPARSE_EMBEDDING, parser).build();
141+
assertThat(request.getInput(), contains("single text input"));
142+
assertThat(request.getTaskSettings(), is(Map.of("foo", "bar")));
143+
}
144+
}
145+
}
146+
76147
public void testValidation_TextEmbedding() {
77148
InferenceAction.Request request = new InferenceAction.Request(
78149
TaskType.TEXT_EMBEDDING,

0 commit comments

Comments
 (0)