Skip to content

Commit a33775d

Browse files
[ML] Implement JSONPath replacement for Inference API (#127036) (#127072)
* Adding initial extractor * Finishing tests * Addressing feedback
1 parent 16cd8a2 commit a33775d

File tree

2 files changed

+396
-0
lines changed

2 files changed

+396
-0
lines changed
Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.common;
9+
10+
import org.elasticsearch.common.Strings;
11+
12+
import java.util.ArrayList;
13+
import java.util.List;
14+
import java.util.Map;
15+
import java.util.regex.Pattern;
16+
17+
/**
18+
* Extracts fields from a {@link Map}.
19+
*
20+
* Uses a subset of the JSONPath schema to extract fields from a map.
21+
* For more information <a href="https://en.wikipedia.org/wiki/JSONPath">see here</a>.
22+
*
23+
* This implementation differs in how it handles lists in that JSONPath will flatten inner lists. This implementation
24+
* preserves inner lists.
25+
*
26+
* Examples of the schema:
27+
*
28+
* <pre>
29+
* {@code
30+
* $.field1.array[*].field2
31+
* $.field1.field2
32+
* }
33+
* </pre>
34+
*
35+
* Given the map
36+
* <pre>
37+
* {@code
38+
* {
39+
* "request_id": "B4AB89C8-B135-xxxx-A6F8-2BAB801A2CE4",
40+
* "latency": 38,
41+
* "usage": {
42+
* "token_count": 3072
43+
* },
44+
* "result": {
45+
* "embeddings": [
46+
* {
47+
* "index": 0,
48+
* "embedding": [
49+
* 2,
50+
* 4
51+
* ]
52+
* },
53+
* {
54+
* "index": 1,
55+
* "embedding": [
56+
* 1,
57+
* 2
58+
* ]
59+
* }
60+
* ]
61+
* }
62+
* }
63+
* }
64+
* </pre>
65+
*
66+
* <pre>
67+
* {@code
68+
* var embeddings = MapPathExtractor.extract(map, "$.result.embeddings[*].embedding");
69+
* }
70+
* </pre>
71+
*
72+
* Will result in:
73+
*
74+
* <pre>
75+
* {@code
76+
* [
77+
* [2, 4],
78+
* [1, 2]
79+
* ]
80+
* }
81+
* </pre>
82+
*
83+
* This implementation differs from JSONPath when handling a list of maps. JSONPath will flatten the result and return a single array.
84+
* this implementation will preserve each nested list while gathering the results.
85+
*
86+
* For example
87+
*
88+
* <pre>
89+
* {@code
90+
* {
91+
* "result": [
92+
* {
93+
* "key": [
94+
* {
95+
* "a": 1.1
96+
* },
97+
* {
98+
* "a": 2.2
99+
* }
100+
* ]
101+
* },
102+
* {
103+
* "key": [
104+
* {
105+
* "a": 3.3
106+
* },
107+
* {
108+
* "a": 4.4
109+
* }
110+
* ]
111+
* }
112+
* ]
113+
* }
114+
* }
115+
* {@code var embeddings = MapPathExtractor.extract(map, "$.result[*].key[*].a");}
116+
*
117+
* JSONPath: {@code [1.1, 2.2, 3.3, 4.4]}
118+
* This implementation: {@code [[1.1, 2.2], [3.3, 4.4]]}
119+
* </pre>
120+
*/
121+
public class MapPathExtractor {
122+
123+
private static final String DOLLAR = "$";
124+
125+
// default for testing
126+
static final Pattern dotFieldPattern = Pattern.compile("^\\.([^.\\[]+)(.*)");
127+
static final Pattern arrayWildcardPattern = Pattern.compile("^\\[\\*\\](.*)");
128+
129+
public static Object extract(Map<String, Object> data, String path) {
130+
if (data == null || data.isEmpty() || path == null || path.trim().isEmpty()) {
131+
return null;
132+
}
133+
134+
var cleanedPath = path.trim();
135+
136+
if (cleanedPath.startsWith(DOLLAR)) {
137+
cleanedPath = cleanedPath.substring(DOLLAR.length());
138+
} else {
139+
throw new IllegalArgumentException(Strings.format("Path [%s] must start with a dollar sign ($)", cleanedPath));
140+
}
141+
142+
return navigate(data, cleanedPath);
143+
}
144+
145+
private static Object navigate(Object current, String remainingPath) {
146+
if (current == null || remainingPath == null || remainingPath.isEmpty()) {
147+
return current;
148+
}
149+
150+
var dotFieldMatcher = dotFieldPattern.matcher(remainingPath);
151+
var arrayWildcardMatcher = arrayWildcardPattern.matcher(remainingPath);
152+
153+
if (dotFieldMatcher.matches()) {
154+
String field = dotFieldMatcher.group(1);
155+
if (field == null || field.isEmpty()) {
156+
throw new IllegalArgumentException(
157+
Strings.format(
158+
"Unable to extract field from remaining path [%s]. Fields must be delimited by a dot character.",
159+
remainingPath
160+
)
161+
);
162+
}
163+
164+
String nextPath = dotFieldMatcher.group(2);
165+
if (current instanceof Map<?, ?> currentMap) {
166+
var fieldFromMap = currentMap.get(field);
167+
if (fieldFromMap == null) {
168+
throw new IllegalArgumentException(Strings.format("Unable to find field [%s] in map", field));
169+
}
170+
171+
return navigate(currentMap.get(field), nextPath);
172+
} else {
173+
throw new IllegalArgumentException(
174+
Strings.format(
175+
"Current path [%s] matched the dot field pattern but the current object is not a map, "
176+
+ "found invalid type [%s] instead.",
177+
remainingPath,
178+
current.getClass().getSimpleName()
179+
)
180+
);
181+
}
182+
} else if (arrayWildcardMatcher.matches()) {
183+
String nextPath = arrayWildcardMatcher.group(1);
184+
if (current instanceof List<?> list) {
185+
List<Object> results = new ArrayList<>();
186+
187+
for (Object item : list) {
188+
Object result = navigate(item, nextPath);
189+
if (result != null) {
190+
results.add(result);
191+
}
192+
}
193+
194+
return results;
195+
} else {
196+
throw new IllegalArgumentException(
197+
Strings.format(
198+
"Current path [%s] matched the array field pattern but the current object is not a list, "
199+
+ "found invalid type [%s] instead.",
200+
remainingPath,
201+
current.getClass().getSimpleName()
202+
)
203+
);
204+
}
205+
}
206+
207+
throw new IllegalArgumentException(Strings.format("Invalid path received [%s], unable to extract a field name.", remainingPath));
208+
}
209+
}
Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.common;
9+
10+
import org.elasticsearch.test.ESTestCase;
11+
12+
import java.util.List;
13+
import java.util.Map;
14+
15+
import static org.hamcrest.Matchers.is;
16+
17+
public class MapPathExtractorTests extends ESTestCase {
18+
public void testExtract_RetrievesListOfLists() {
19+
Map<String, Object> input = Map.of(
20+
"result",
21+
Map.of("embeddings", List.of(Map.of("index", 0, "embedding", List.of(1, 2)), Map.of("index", 1, "embedding", List.of(3, 4))))
22+
);
23+
24+
assertThat(MapPathExtractor.extract(input, "$.result.embeddings[*].embedding"), is(List.of(List.of(1, 2), List.of(3, 4))));
25+
}
26+
27+
public void testExtract_IteratesListOfMapsToListOfStrings() {
28+
Map<String, Object> input = Map.of(
29+
"result",
30+
List.of(Map.of("key", List.of("value1", "value2")), Map.of("key", List.of("value3", "value4")))
31+
);
32+
33+
assertThat(
34+
MapPathExtractor.extract(input, "$.result[*].key[*]"),
35+
is(List.of(List.of("value1", "value2"), List.of("value3", "value4")))
36+
);
37+
}
38+
39+
public void testExtract_IteratesListOfMapsToListOfMapsOfStringToDoubles() {
40+
Map<String, Object> input = Map.of(
41+
"result",
42+
List.of(
43+
Map.of("key", List.of(Map.of("a", 1.1d), Map.of("a", 2.2d))),
44+
Map.of("key", List.of(Map.of("a", 3.3d), Map.of("a", 4.4d)))
45+
)
46+
);
47+
48+
assertThat(MapPathExtractor.extract(input, "$.result[*].key[*].a"), is(List.of(List.of(1.1d, 2.2d), List.of(3.3d, 4.4d))));
49+
}
50+
51+
public void testExtract_ReturnsNullForEmptyList() {
52+
Map<String, Object> input = Map.of();
53+
54+
assertNull(MapPathExtractor.extract(input, "$.awesome"));
55+
}
56+
57+
public void testExtract_ReturnsNull_WhenTheInputMapIsNull() {
58+
assertNull(MapPathExtractor.extract(null, "$.result"));
59+
}
60+
61+
public void testExtract_ReturnsNull_WhenPathIsNull() {
62+
assertNull(MapPathExtractor.extract(Map.of("key", "value"), null));
63+
}
64+
65+
public void testExtract_ReturnsNull_WhenPathIsWhiteSpace() {
66+
assertNull(MapPathExtractor.extract(Map.of("key", "value"), " "));
67+
}
68+
69+
public void testExtract_ThrowsException_WhenPathDoesNotStartWithDollarSign() {
70+
var exception = expectThrows(IllegalArgumentException.class, () -> MapPathExtractor.extract(Map.of("key", "value"), ".key"));
71+
assertThat(exception.getMessage(), is("Path [.key] must start with a dollar sign ($)"));
72+
}
73+
74+
public void testExtract_ThrowsException_WhenCannotFindField() {
75+
Map<String, Object> input = Map.of("result", "key");
76+
77+
var exception = expectThrows(IllegalArgumentException.class, () -> MapPathExtractor.extract(input, "$.awesome"));
78+
assertThat(exception.getMessage(), is("Unable to find field [awesome] in map"));
79+
}
80+
81+
public void testExtract_ThrowsAnException_WhenThePathIsInvalid() {
82+
Map<String, Object> input = Map.of("result", "key");
83+
84+
var exception = expectThrows(IllegalArgumentException.class, () -> MapPathExtractor.extract(input, "$awesome"));
85+
assertThat(exception.getMessage(), is("Invalid path received [awesome], unable to extract a field name."));
86+
}
87+
88+
public void testExtract_ThrowsException_WhenMissingArraySyntax() {
89+
Map<String, Object> input = Map.of(
90+
"result",
91+
Map.of("embeddings", List.of(Map.of("index", 0, "embedding", List.of(1, 2)), Map.of("index", 1, "embedding", List.of(3, 4))))
92+
);
93+
94+
var exception = expectThrows(
95+
IllegalArgumentException.class,
96+
// embeddings is missing [*] to indicate that it is an array
97+
() -> MapPathExtractor.extract(input, "$.result.embeddings.embedding")
98+
);
99+
assertThat(
100+
exception.getMessage(),
101+
is(
102+
"Current path [.embedding] matched the dot field pattern but the current object "
103+
+ "is not a map, found invalid type [List12] instead."
104+
)
105+
);
106+
}
107+
108+
public void testExtract_ThrowsException_WhenHasArraySyntaxButIsAMap() {
109+
Map<String, Object> input = Map.of(
110+
"result",
111+
Map.of("embeddings", List.of(Map.of("index", 0, "embedding", List.of(1, 2)), Map.of("index", 1, "embedding", List.of(3, 4))))
112+
);
113+
114+
var exception = expectThrows(
115+
IllegalArgumentException.class,
116+
// result is not an array
117+
() -> MapPathExtractor.extract(input, "$.result[*].embeddings[*].embedding")
118+
);
119+
assertThat(
120+
exception.getMessage(),
121+
is(
122+
"Current path [[*].embeddings[*].embedding] matched the array field pattern but the current "
123+
+ "object is not a list, found invalid type [Map1] instead."
124+
)
125+
);
126+
}
127+
128+
public void testExtract_ReturnsAnEmptyList_WhenItIsEmpty() {
129+
Map<String, Object> input = Map.of("result", List.of());
130+
131+
assertThat(MapPathExtractor.extract(input, "$.result"), is(List.of()));
132+
}
133+
134+
public void testExtract_ReturnsAnEmptyList_WhenItIsEmpty_PathIncludesArray() {
135+
Map<String, Object> input = Map.of("result", List.of());
136+
137+
assertThat(MapPathExtractor.extract(input, "$.result[*]"), is(List.of()));
138+
}
139+
140+
public void testDotFieldPattern() {
141+
{
142+
var matcher = MapPathExtractor.dotFieldPattern.matcher(".abc.123");
143+
assertTrue(matcher.matches());
144+
assertThat(matcher.group(1), is("abc"));
145+
assertThat(matcher.group(2), is(".123"));
146+
}
147+
{
148+
var matcher = MapPathExtractor.dotFieldPattern.matcher(".abc[*].123");
149+
assertTrue(matcher.matches());
150+
assertThat(matcher.group(1), is("abc"));
151+
assertThat(matcher.group(2), is("[*].123"));
152+
}
153+
{
154+
var matcher = MapPathExtractor.dotFieldPattern.matcher(".abc[.123");
155+
assertTrue(matcher.matches());
156+
assertThat(matcher.group(1), is("abc"));
157+
assertThat(matcher.group(2), is("[.123"));
158+
}
159+
{
160+
var matcher = MapPathExtractor.dotFieldPattern.matcher(".abc");
161+
assertTrue(matcher.matches());
162+
assertThat(matcher.group(1), is("abc"));
163+
assertThat(matcher.group(2), is(""));
164+
}
165+
}
166+
167+
public void testArrayWildcardPattern() {
168+
{
169+
var matcher = MapPathExtractor.arrayWildcardPattern.matcher("[*].abc.123");
170+
assertTrue(matcher.matches());
171+
assertThat(matcher.group(1), is(".abc.123"));
172+
}
173+
{
174+
var matcher = MapPathExtractor.arrayWildcardPattern.matcher("[*]");
175+
assertTrue(matcher.matches());
176+
assertThat(matcher.group(1), is(""));
177+
}
178+
{
179+
var matcher = MapPathExtractor.arrayWildcardPattern.matcher("[1].abc");
180+
assertFalse(matcher.matches());
181+
}
182+
{
183+
var matcher = MapPathExtractor.arrayWildcardPattern.matcher("[].abc");
184+
assertFalse(matcher.matches());
185+
}
186+
}
187+
}

0 commit comments

Comments
 (0)