Skip to content

Commit b2d6dd2

Browse files
committed
Add internal get reranker size action
1 parent 63258f4 commit b2d6dd2

File tree

6 files changed

+243
-1
lines changed

6 files changed

+243
-1
lines changed
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.core.inference.action;
9+
10+
import org.elasticsearch.action.ActionRequest;
11+
import org.elasticsearch.action.ActionRequestValidationException;
12+
import org.elasticsearch.action.ActionResponse;
13+
import org.elasticsearch.action.ActionType;
14+
import org.elasticsearch.common.io.stream.StreamInput;
15+
import org.elasticsearch.common.io.stream.StreamOutput;
16+
17+
import java.io.IOException;
18+
import java.util.Objects;
19+
20+
public class GetRerankerAction extends ActionType<GetRerankerAction.Response> {
21+
22+
public static final GetRerankerAction INSTANCE = new GetRerankerAction();
23+
public static final String NAME = "cluster:internal/xpack/inference/rerank/get";
24+
25+
26+
public GetRerankerAction() {
27+
super(NAME);
28+
}
29+
30+
public static class Request extends ActionRequest {
31+
32+
private final String inferenceEntityId;
33+
34+
public Request(String inferenceEntityId) {
35+
this.inferenceEntityId = inferenceEntityId;
36+
}
37+
38+
public Request(StreamInput in) throws IOException {
39+
super(in);
40+
this.inferenceEntityId = in.readString();
41+
}
42+
43+
public String getInferenceEntityId() {
44+
return inferenceEntityId;
45+
}
46+
47+
@Override
48+
public void writeTo(StreamOutput out) throws IOException {
49+
super.writeTo(out);
50+
out.writeString(inferenceEntityId);
51+
}
52+
53+
@Override
54+
public ActionRequestValidationException validate() {
55+
return null;
56+
}
57+
58+
@Override
59+
public boolean equals(Object o) {
60+
if (o == null || getClass() != o.getClass()) return false;
61+
Request request = (Request) o;
62+
return Objects.equals(inferenceEntityId, request.inferenceEntityId);
63+
}
64+
65+
@Override
66+
public int hashCode() {
67+
return Objects.hashCode(inferenceEntityId);
68+
}
69+
}
70+
71+
public static class Response extends ActionResponse {
72+
73+
private final int windowSize;
74+
75+
public Response(int windowSize) {
76+
this.windowSize = windowSize;
77+
}
78+
79+
public Response(StreamInput in) throws IOException {
80+
this.windowSize = in.readVInt();
81+
}
82+
83+
@Override
84+
public void writeTo(StreamOutput out) throws IOException {
85+
out.writeVInt(windowSize);
86+
}
87+
88+
@Override
89+
public boolean equals(Object o) {
90+
if (o == null || getClass() != o.getClass()) return false;
91+
Response response = (Response) o;
92+
return windowSize == response.windowSize;
93+
}
94+
95+
@Override
96+
public int hashCode() {
97+
return Objects.hashCode(windowSize);
98+
}
99+
}
100+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
import org.elasticsearch.xpack.core.inference.action.GetInferenceDiagnosticsAction;
6363
import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction;
6464
import org.elasticsearch.xpack.core.inference.action.GetInferenceServicesAction;
65+
import org.elasticsearch.xpack.core.inference.action.GetRerankerAction;
6566
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
6667
import org.elasticsearch.xpack.core.inference.action.InferenceActionProxy;
6768
import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction;
@@ -72,6 +73,7 @@
7273
import org.elasticsearch.xpack.inference.action.TransportGetInferenceDiagnosticsAction;
7374
import org.elasticsearch.xpack.inference.action.TransportGetInferenceModelAction;
7475
import org.elasticsearch.xpack.inference.action.TransportGetInferenceServicesAction;
76+
import org.elasticsearch.xpack.inference.action.TransportGetRerankerAction;
7577
import org.elasticsearch.xpack.inference.action.TransportInferenceAction;
7678
import org.elasticsearch.xpack.inference.action.TransportInferenceActionProxy;
7779
import org.elasticsearch.xpack.inference.action.TransportInferenceUsageAction;
@@ -233,7 +235,8 @@ public List<ActionHandler> getActions() {
233235
new ActionHandler(XPackUsageFeatureAction.INFERENCE, TransportInferenceUsageAction.class),
234236
new ActionHandler(GetInferenceDiagnosticsAction.INSTANCE, TransportGetInferenceDiagnosticsAction.class),
235237
new ActionHandler(GetInferenceServicesAction.INSTANCE, TransportGetInferenceServicesAction.class),
236-
new ActionHandler(UnifiedCompletionAction.INSTANCE, TransportUnifiedCompletionInferenceAction.class)
238+
new ActionHandler(UnifiedCompletionAction.INSTANCE, TransportUnifiedCompletionInferenceAction.class),
239+
new ActionHandler(GetRerankerAction.INSTANCE, TransportGetRerankerAction.class)
237240
);
238241
}
239242

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.action;
9+
10+
import org.elasticsearch.ElasticsearchStatusException;
11+
import org.elasticsearch.action.ActionListener;
12+
import org.elasticsearch.action.support.ActionFilters;
13+
import org.elasticsearch.action.support.HandledTransportAction;
14+
import org.elasticsearch.action.support.SubscribableListener;
15+
import org.elasticsearch.common.util.concurrent.EsExecutors;
16+
import org.elasticsearch.inference.InferenceService;
17+
import org.elasticsearch.inference.InferenceServiceRegistry;
18+
import org.elasticsearch.inference.TaskType;
19+
import org.elasticsearch.inference.UnparsedModel;
20+
import org.elasticsearch.injection.guice.Inject;
21+
import org.elasticsearch.rest.RestStatus;
22+
import org.elasticsearch.tasks.Task;
23+
import org.elasticsearch.threadpool.ThreadPool;
24+
import org.elasticsearch.transport.TransportService;
25+
import org.elasticsearch.xpack.core.inference.action.GetRerankerAction;
26+
import org.elasticsearch.xpack.inference.InferencePlugin;
27+
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
28+
29+
import java.util.concurrent.Executor;
30+
31+
public class TransportGetRerankerAction extends HandledTransportAction<GetRerankerAction.Request, GetRerankerAction.Response> {
32+
33+
private final ModelRegistry modelRegistry;
34+
private final InferenceServiceRegistry serviceRegistry;
35+
private final Executor executor;
36+
37+
@Inject
38+
public TransportGetRerankerAction(
39+
TransportService transportService,
40+
ActionFilters actionFilters,
41+
ThreadPool threadPool,
42+
ModelRegistry modelRegistry,
43+
InferenceServiceRegistry serviceRegistry
44+
) {
45+
super(GetRerankerAction.NAME, transportService, actionFilters, GetRerankerAction.Request::new, EsExecutors.DIRECT_EXECUTOR_SERVICE);
46+
this.modelRegistry = modelRegistry;
47+
this.serviceRegistry = serviceRegistry;
48+
this.executor = threadPool.executor(InferencePlugin.UTILITY_THREAD_POOL_NAME);
49+
}
50+
51+
@Override
52+
protected void doExecute(Task task, GetRerankerAction.Request request, ActionListener<GetRerankerAction.Response> listener) {
53+
54+
SubscribableListener.<UnparsedModel>newForked(l -> modelRegistry.getModel(request.getInferenceEntityId(), l))
55+
.andThen((l2, model) -> {
56+
if (model.taskType() != TaskType.RERANK) {
57+
l2.onFailure(
58+
new ElasticsearchStatusException(
59+
"Inference endpoint [{}] is not a reranker",
60+
RestStatus.BAD_REQUEST,
61+
request.getInferenceEntityId()
62+
)
63+
);
64+
return;
65+
}
66+
67+
var service = serviceRegistry.getService(model.service());
68+
l2.onResponse(new GetRerankerAction.Response(rerankWindowSize(service.get())));
69+
});
70+
}
71+
72+
public int rerankWindowSize(InferenceService service) {
73+
return 0;
74+
}
75+
76+
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.action;
9+
10+
import org.elasticsearch.common.io.stream.Writeable;
11+
import org.elasticsearch.test.AbstractWireSerializingTestCase;
12+
import org.elasticsearch.xpack.core.inference.action.GetRerankerAction;
13+
14+
import java.io.IOException;
15+
16+
public class GetRerankerActionRequestTests extends AbstractWireSerializingTestCase<GetRerankerAction.Request> {
17+
@Override
18+
protected Writeable.Reader<GetRerankerAction.Request> instanceReader() {
19+
return GetRerankerAction.Request::new;
20+
}
21+
22+
@Override
23+
protected GetRerankerAction.Request createTestInstance() {
24+
return new GetRerankerAction.Request(randomAlphaOfLength(8));
25+
}
26+
27+
@Override
28+
protected GetRerankerAction.Request mutateInstance(GetRerankerAction.Request instance) throws IOException {
29+
return randomValueOtherThan(instance, this::createTestInstance);
30+
}
31+
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.action;
9+
10+
import org.elasticsearch.common.io.stream.Writeable;
11+
import org.elasticsearch.test.AbstractWireSerializingTestCase;
12+
import org.elasticsearch.xpack.core.inference.action.GetRerankerAction;
13+
14+
import java.io.IOException;
15+
16+
public class GetRerankerActionResponseTests extends AbstractWireSerializingTestCase<GetRerankerAction.Response> {
17+
@Override
18+
protected Writeable.Reader<GetRerankerAction.Response> instanceReader() {
19+
return GetRerankerAction.Response::new;
20+
}
21+
22+
@Override
23+
protected GetRerankerAction.Response createTestInstance() {
24+
return new GetRerankerAction.Response(randomNonNegativeInt());
25+
}
26+
27+
@Override
28+
protected GetRerankerAction.Response mutateInstance(GetRerankerAction.Response instance) throws IOException {
29+
return randomValueOtherThan(instance, this::createTestInstance);
30+
}
31+
}

x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,7 @@ public class Constants {
326326
"cluster:admin/xpack/watcher/watch/put",
327327
"cluster:internal/remote_cluster/nodes",
328328
"cluster:internal/xpack/inference",
329+
"cluster:internal/xpack/inference/rerank/get",
329330
"cluster:internal/xpack/inference/unified",
330331
"cluster:internal/xpack/ml/coordinatedinference",
331332
"cluster:internal/xpack/ml/datafeed/isolate",

0 commit comments

Comments
 (0)