| 
 | 1 | +/*  | 
 | 2 | + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one  | 
 | 3 | + * or more contributor license agreements. Licensed under the Elastic License  | 
 | 4 | + * 2.0; you may not use this file except in compliance with the Elastic License  | 
 | 5 | + * 2.0.  | 
 | 6 | + */  | 
 | 7 | + | 
 | 8 | +package org.elasticsearch.xpack.core.inference.action;  | 
 | 9 | + | 
 | 10 | +import org.elasticsearch.action.ActionRequest;  | 
 | 11 | +import org.elasticsearch.action.ActionRequestValidationException;  | 
 | 12 | +import org.elasticsearch.action.ActionType;  | 
 | 13 | +import org.elasticsearch.common.io.stream.StreamInput;  | 
 | 14 | +import org.elasticsearch.common.io.stream.StreamOutput;  | 
 | 15 | +import org.elasticsearch.inference.TaskType;  | 
 | 16 | +import org.elasticsearch.xcontent.XContentParser;  | 
 | 17 | + | 
 | 18 | +import java.io.IOException;  | 
 | 19 | +import java.util.Objects;  | 
 | 20 | + | 
 | 21 | +public class UnifiedCompletionAction extends ActionType<InferenceAction.Response> {  | 
 | 22 | +    public static final UnifiedCompletionAction INSTANCE = new UnifiedCompletionAction();  | 
 | 23 | +    public static final String NAME = "cluster:monitor/xpack/inference/unified";  | 
 | 24 | + | 
 | 25 | +    public UnifiedCompletionAction() {  | 
 | 26 | +        super(NAME);  | 
 | 27 | +    }  | 
 | 28 | + | 
 | 29 | +    public static class Request extends ActionRequest {  | 
 | 30 | +        public static Request parseRequest(String inferenceEntityId, TaskType taskType, XContentParser parser) throws IOException {  | 
 | 31 | +            var unifiedRequest = UnifiedCompletionRequest.PARSER.apply(parser, null);  | 
 | 32 | +            return new Request(inferenceEntityId, taskType, unifiedRequest);  | 
 | 33 | +        }  | 
 | 34 | + | 
 | 35 | +        private final String inferenceEntityId;  | 
 | 36 | +        private final TaskType taskType;  | 
 | 37 | +        private final UnifiedCompletionRequest unifiedCompletionRequest;  | 
 | 38 | + | 
 | 39 | +        public Request(String inferenceEntityId, TaskType taskType, UnifiedCompletionRequest unifiedCompletionRequest) {  | 
 | 40 | +            this.inferenceEntityId = Objects.requireNonNull(inferenceEntityId);  | 
 | 41 | +            this.taskType = Objects.requireNonNull(taskType);  | 
 | 42 | +            this.unifiedCompletionRequest = Objects.requireNonNull(unifiedCompletionRequest);  | 
 | 43 | +        }  | 
 | 44 | + | 
 | 45 | +        public Request(StreamInput in) throws IOException {  | 
 | 46 | +            super(in);  | 
 | 47 | +            this.inferenceEntityId = in.readString();  | 
 | 48 | +            this.taskType = TaskType.fromStream(in);  | 
 | 49 | +            this.unifiedCompletionRequest = new UnifiedCompletionRequest(in);  | 
 | 50 | +        }  | 
 | 51 | + | 
 | 52 | +        public TaskType getTaskType() {  | 
 | 53 | +            return taskType;  | 
 | 54 | +        }  | 
 | 55 | + | 
 | 56 | +        public String getInferenceEntityId() {  | 
 | 57 | +            return inferenceEntityId;  | 
 | 58 | +        }  | 
 | 59 | + | 
 | 60 | +        public UnifiedCompletionRequest getUnifiedCompletionRequest() {  | 
 | 61 | +            return unifiedCompletionRequest;  | 
 | 62 | +        }  | 
 | 63 | + | 
 | 64 | +        public boolean isStreaming() {  | 
 | 65 | +            return Objects.requireNonNullElse(unifiedCompletionRequest.stream(), false);  | 
 | 66 | +        }  | 
 | 67 | + | 
 | 68 | +        @Override  | 
 | 69 | +        public ActionRequestValidationException validate() {  | 
 | 70 | +            if (unifiedCompletionRequest == null || unifiedCompletionRequest.messages() == null) {  | 
 | 71 | +                var e = new ActionRequestValidationException();  | 
 | 72 | +                e.addValidationError("Field [messages] cannot be null");  | 
 | 73 | +                return e;  | 
 | 74 | +            }  | 
 | 75 | + | 
 | 76 | +            if (unifiedCompletionRequest.messages().isEmpty()) {  | 
 | 77 | +                var e = new ActionRequestValidationException();  | 
 | 78 | +                e.addValidationError("Field [messages] cannot be an empty array");  | 
 | 79 | +                return e;  | 
 | 80 | +            }  | 
 | 81 | + | 
 | 82 | +            if (taskType != TaskType.COMPLETION) {  | 
 | 83 | +                var e = new ActionRequestValidationException();  | 
 | 84 | +                e.addValidationError("Field [taskType] must be [completion]");  | 
 | 85 | +                return e;  | 
 | 86 | +            }  | 
 | 87 | + | 
 | 88 | +            return null;  | 
 | 89 | +        }  | 
 | 90 | + | 
 | 91 | +        @Override  | 
 | 92 | +        public void writeTo(StreamOutput out) throws IOException {  | 
 | 93 | +            super.writeTo(out);  | 
 | 94 | +            out.writeString(inferenceEntityId);  | 
 | 95 | +            taskType.writeTo(out);  | 
 | 96 | +            unifiedCompletionRequest.writeTo(out);  | 
 | 97 | +        }  | 
 | 98 | + | 
 | 99 | +        @Override  | 
 | 100 | +        public boolean equals(Object o) {  | 
 | 101 | +            if (o == null || getClass() != o.getClass()) return false;  | 
 | 102 | +            Request request = (Request) o;  | 
 | 103 | +            return Objects.equals(inferenceEntityId, request.inferenceEntityId)  | 
 | 104 | +                && taskType == request.taskType  | 
 | 105 | +                && Objects.equals(unifiedCompletionRequest, request.unifiedCompletionRequest);  | 
 | 106 | +        }  | 
 | 107 | + | 
 | 108 | +        @Override  | 
 | 109 | +        public int hashCode() {  | 
 | 110 | +            return Objects.hash(inferenceEntityId, taskType, unifiedCompletionRequest);  | 
 | 111 | +        }  | 
 | 112 | +    }  | 
 | 113 | + | 
 | 114 | +}  | 
0 commit comments