Skip to content

Commit e806fe6

Browse files
Merge branch 'main' into fix-remote-cluster-profile-comment
2 parents 1278643 + e149db0 commit e806fe6

File tree

43 files changed

+1452
-201
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+1452
-201
lines changed

docs/changelog/129146.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 129146
2+
summary: "[ML] Add IBM watsonx Completion and Chat Completion support to the Inference Plugin"
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

docs/changelog/130421.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 130421
2+
summary: Support avg on aggregate metric double
3+
area: ES|QL
4+
type: bug
5+
issues: []

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,7 @@ static TransportVersion def(int id) {
328328
public static final TransportVersion MAPPINGS_IN_DATA_STREAMS = def(9_112_0_00);
329329
public static final TransportVersion PROJECT_STATE_REGISTRY_RECORDS_DELETIONS = def(9_113_0_00);
330330
public static final TransportVersion ESQL_SERIALIZE_TIMESERIES_FIELD_TYPE = def(9_114_0_00);
331-
331+
public static final TransportVersion ML_INFERENCE_IBM_WATSONX_COMPLETION_ADDED = def(9_115_0_00);
332332
/*
333333
* STOP! READ THIS FIRST! No, really,
334334
* ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1220,7 +1220,12 @@ public enum Cap {
12201220
/**
12211221
* FUSE command
12221222
*/
1223-
FUSE(Build.current().isSnapshot());
1223+
FUSE(Build.current().isSnapshot()),
1224+
1225+
/**
1226+
* Support avg with aggregate metric doubles
1227+
*/
1228+
AGGREGATE_METRIC_DOUBLE_AVG(AGGREGATE_METRIC_DOUBLE_FEATURE_FLAG);
12241229

12251230
private final boolean enabled;
12261231

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Avg.java

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import static java.util.Collections.emptyList;
2929
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.DEFAULT;
3030
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType;
31+
import static org.elasticsearch.xpack.esql.core.type.DataType.AGGREGATE_METRIC_DOUBLE;
3132

3233
public class Avg extends AggregateFunction implements SurrogateExpression {
3334
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Avg", Avg::new);
@@ -50,7 +51,7 @@ public Avg(
5051
Source source,
5152
@Param(
5253
name = "number",
53-
type = { "double", "integer", "long" },
54+
type = { "aggregate_metric_double", "double", "integer", "long" },
5455
description = "Expression that outputs values to average."
5556
) Expression field
5657
) {
@@ -65,10 +66,10 @@ public Avg(Source source, Expression field, Expression filter) {
6566
protected Expression.TypeResolution resolveType() {
6667
return isType(
6768
field(),
68-
dt -> dt.isNumeric() && dt != DataType.UNSIGNED_LONG,
69+
dt -> dt.isNumeric() && dt != DataType.UNSIGNED_LONG || dt == AGGREGATE_METRIC_DOUBLE,
6970
sourceText(),
7071
DEFAULT,
71-
"numeric except unsigned_long or counter types"
72+
"aggregate_metric_double or numeric except unsigned_long or counter types"
7273
);
7374
}
7475

@@ -105,9 +106,12 @@ public Avg withFilter(Expression filter) {
105106
public Expression surrogate() {
106107
var s = source();
107108
var field = field();
108-
109-
return field().foldable()
110-
? new MvAvg(s, field)
111-
: new Div(s, new Sum(s, field, filter()), new Count(s, field, filter()), dataType());
109+
if (field.foldable()) {
110+
return new MvAvg(s, field);
111+
}
112+
if (field.dataType() == AGGREGATE_METRIC_DOUBLE) {
113+
return new Div(s, new Sum(s, field, filter()).surrogate(), new Count(s, field, filter()).surrogate());
114+
}
115+
return new Div(s, new Sum(s, field, filter()), new Count(s, field, filter()), dataType());
112116
}
113117
}

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2006,7 +2006,7 @@ public void testUnsupportedTypesInStats() {
20062006
| stats avg(x), count_distinct(x), max(x), median(x), median_absolute_deviation(x), min(x), percentile(x, 10), sum(x)
20072007
""", """
20082008
Found 8 problems
2009-
line 2:12: argument of [avg(x)] must be [numeric except unsigned_long or counter types],\
2009+
line 2:12: argument of [avg(x)] must be [aggregate_metric_double or numeric except unsigned_long or counter types],\
20102010
found value [x] type [unsigned_long]
20112011
line 2:20: argument of [count_distinct(x)] must be [any exact type except unsigned_long, _source, or counter types],\
20122012
found value [x] type [unsigned_long]
@@ -2028,7 +2028,7 @@ public void testUnsupportedTypesInStats() {
20282028
| stats avg(x), median(x), median_absolute_deviation(x), percentile(x, 10), sum(x)
20292029
""", """
20302030
Found 5 problems
2031-
line 2:10: argument of [avg(x)] must be [numeric except unsigned_long or counter types],\
2031+
line 2:10: argument of [avg(x)] must be [aggregate_metric_double or numeric except unsigned_long or counter types],\
20322032
found value [x] type [version]
20332033
line 2:18: argument of [median(x)] must be [numeric except unsigned_long or counter types],\
20342034
found value [x] type [version]

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,7 @@ public void testAggsExpressionsInStatsAggs() {
359359
error("from test | stats max(max(salary)) by first_name")
360360
);
361361
assertEquals(
362-
"1:25: argument of [avg(first_name)] must be [numeric except unsigned_long or counter types],"
362+
"1:25: argument of [avg(first_name)] must be [aggregate_metric_double or numeric except unsigned_long or counter types],"
363363
+ " found value [first_name] type [keyword]",
364364
error("from test | stats count(avg(first_name)) by first_name")
365365
);

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AvgErrorTests.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,13 @@ protected Expression build(Source source, List<Expression> args) {
3232

3333
@Override
3434
protected Matcher<String> expectedTypeErrorMatcher(List<Set<DataType>> validPerPosition, List<DataType> signature) {
35-
return equalTo(typeErrorMessage(false, validPerPosition, signature, (v, p) -> "numeric except unsigned_long or counter types"));
35+
return equalTo(
36+
typeErrorMessage(
37+
false,
38+
validPerPosition,
39+
signature,
40+
(v, p) -> "aggregate_metric_double or numeric except unsigned_long or counter types"
41+
)
42+
);
3643
}
3744
}

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,8 @@ public void testGetServicesWithCompletionTaskType() throws IOException {
151151
"completion_test_service",
152152
"hugging_face",
153153
"amazon_sagemaker",
154-
"mistral"
154+
"mistral",
155+
"watsonxai"
155156
).toArray()
156157
)
157158
);
@@ -169,7 +170,8 @@ public void testGetServicesWithChatCompletionTaskType() throws IOException {
169170
"hugging_face",
170171
"amazon_sagemaker",
171172
"googlevertexai",
172-
"mistral"
173+
"mistral",
174+
"watsonxai"
173175
).toArray()
174176
)
175177
);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@
9595
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserServiceSettings;
9696
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankServiceSettings;
9797
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankTaskSettings;
98+
import org.elasticsearch.xpack.inference.services.ibmwatsonx.completion.IbmWatsonxChatCompletionServiceSettings;
9899
import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsServiceSettings;
99100
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankServiceSettings;
100101
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankTaskSettings;
@@ -469,6 +470,13 @@ private static void addIbmWatsonxNamedWritables(List<NamedWriteableRegistry.Entr
469470
namedWriteables.add(
470471
new NamedWriteableRegistry.Entry(TaskSettings.class, IbmWatsonxRerankTaskSettings.NAME, IbmWatsonxRerankTaskSettings::new)
471472
);
473+
namedWriteables.add(
474+
new NamedWriteableRegistry.Entry(
475+
ServiceSettings.class,
476+
IbmWatsonxChatCompletionServiceSettings.NAME,
477+
IbmWatsonxChatCompletionServiceSettings::new
478+
)
479+
);
472480
}
473481

474482
private static void addGoogleVertexAiNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {

0 commit comments

Comments
 (0)