1919import java .io .IOException ;
2020import java .util .List ;
2121import java .util .Locale ;
22+ import java .util .Objects ;
2223import java .util .Optional ;
2324
2425import org .opensearch .common .util .concurrent .ThreadContext ;
@@ -83,27 +84,30 @@ public List<Route> routes() {
8384
8485 @ Override
8586 public RestChannelConsumer prepareRequest (RestRequest request , NodeClient client ) throws IOException {
86- String algorithm = request .param (PARAMETER_ALGORITHM );
87+ String userAlgorithm = request .param (PARAMETER_ALGORITHM );
8788 String modelId = getParameterId (request , PARAMETER_MODEL_ID );
8889 Optional <FunctionName > functionName = modelManager .getOptionalModelFunctionName (modelId );
8990
90- if (algorithm == null && functionName .isPresent ()) {
91- algorithm = functionName .get ().name ();
92- }
93-
94- if (algorithm != null ) {
95- MLPredictionTaskRequest mlPredictionTaskRequest = getRequest (modelId , algorithm , request );
96- return channel -> client
97- .execute (MLPredictionTaskAction .INSTANCE , mlPredictionTaskRequest , new RestToXContentListener <>(channel ));
91+ // check if the model is in cache
92+ if (functionName .isPresent ()) {
93+ MLPredictionTaskRequest predictionRequest = getRequest (
94+ modelId ,
95+ functionName .get ().name (),
96+ Objects .requireNonNullElse (userAlgorithm , functionName .get ().name ()),
97+ request
98+ );
99+ return channel -> client .execute (MLPredictionTaskAction .INSTANCE , predictionRequest , new RestToXContentListener <>(channel ));
98100 }
99101
102+ // If the model isn't in cache
100103 return channel -> {
101104 ActionListener <MLModel > listener = ActionListener .wrap (mlModel -> {
102- String algoName = mlModel .getAlgorithm ().name ();
105+ String modelType = mlModel .getAlgorithm ().name ();
106+ String modelAlgorithm = Objects .requireNonNullElse (userAlgorithm , mlModel .getAlgorithm ().name ());
103107 client
104108 .execute (
105109 MLPredictionTaskAction .INSTANCE ,
106- getRequest (modelId , algoName , request ),
110+ getRequest (modelId , modelType , modelAlgorithm , request ),
107111 new RestToXContentListener <>(channel )
108112 );
109113 }, e -> {
@@ -126,18 +130,22 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client
126130 }
127131
128132 /**
129- * Creates a MLPredictionTaskRequest from a RestRequest
133+ * Creates a MLPredictionTaskRequest from a RestRequest. This method validates the request based on
134+ * enabled features and model types, and parses the input data for prediction.
130135 *
131- * @param request RestRequest
132- * @return MLPredictionTaskRequest
136+ * @param modelId The ID of the ML model to use for prediction
137+ * @param modelType The type of the ML model, extracted from model cache to specify if its a remote model or a local model
138+ * @param userAlgorithm The algorithm specified by the user for prediction, this is used todetermine the interface of the model
139+ * @param request The REST request containing prediction input data
140+ * @return MLPredictionTaskRequest configured with the model and input parameters
133141 */
134142 @ VisibleForTesting
135- MLPredictionTaskRequest getRequest (String modelId , String algorithm , RestRequest request ) throws IOException {
143+ MLPredictionTaskRequest getRequest (String modelId , String modelType , String userAlgorithm , RestRequest request ) throws IOException {
136144 String tenantId = getTenantID (mlFeatureEnabledSetting .isMultiTenancyEnabled (), request );
137145 ActionType actionType = ActionType .from (getActionTypeFromRestRequest (request ));
138- if (FunctionName .REMOTE .name ().equals (algorithm ) && !mlFeatureEnabledSetting .isRemoteInferenceEnabled ()) {
146+ if (FunctionName .REMOTE .name ().equals (modelType ) && !mlFeatureEnabledSetting .isRemoteInferenceEnabled ()) {
139147 throw new IllegalStateException (REMOTE_INFERENCE_DISABLED_ERR_MSG );
140- } else if (FunctionName .isDLModel (FunctionName .from (algorithm .toUpperCase (Locale .ROOT )))
148+ } else if (FunctionName .isDLModel (FunctionName .from (modelType .toUpperCase (Locale .ROOT )))
141149 && !mlFeatureEnabledSetting .isLocalModelEnabled ()) {
142150 throw new IllegalStateException (LOCAL_MODEL_DISABLED_ERR_MSG );
143151 } else if (ActionType .BATCH_PREDICT == actionType && !mlFeatureEnabledSetting .isOfflineBatchInferenceEnabled ()) {
@@ -148,7 +156,7 @@ MLPredictionTaskRequest getRequest(String modelId, String algorithm, RestRequest
148156
149157 XContentParser parser = request .contentParser ();
150158 ensureExpectedToken (XContentParser .Token .START_OBJECT , parser .nextToken (), parser );
151- MLInput mlInput = MLInput .parse (parser , algorithm , actionType );
159+ MLInput mlInput = MLInput .parse (parser , userAlgorithm , actionType );
152160 return new MLPredictionTaskRequest (modelId , mlInput , null , tenantId );
153161 }
154162
0 commit comments