Skip to content

Commit e9fbce7

Browse files
Refactor OpenShift AI Rerank handler to use JinaAIResponseHandler, use overridden model
1 parent dbc1c56 commit e9fbce7

File tree

2 files changed

+8
-10
lines changed

2 files changed

+8
-10
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreator.java

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
2020
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
2121
import org.elasticsearch.xpack.inference.services.ServiceComponents;
22-
import org.elasticsearch.xpack.inference.services.cohere.CohereResponseHandler;
23-
import org.elasticsearch.xpack.inference.services.cohere.response.CohereRankedResponseEntity;
22+
import org.elasticsearch.xpack.inference.services.jinaai.JinaAIResponseHandler;
23+
import org.elasticsearch.xpack.inference.services.jinaai.response.JinaAIRerankResponseEntity;
2424
import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity;
2525
import org.elasticsearch.xpack.inference.services.openai.response.OpenAiEmbeddingsResponseEntity;
2626
import org.elasticsearch.xpack.inference.services.openshiftai.completion.OpenShiftAiChatCompletionModel;
@@ -57,11 +57,10 @@ public class OpenShiftAiActionCreator implements OpenShiftAiActionVisitor {
5757
"OpenShift AI completion",
5858
OpenAiChatCompletionResponseEntity::fromResponse
5959
);
60-
// OpenShift AI Rerank task uses the same response format as Cohere, therefore we can reuse the CohereResponseHandler
61-
private static final ResponseHandler RERANK_HANDLER = new CohereResponseHandler(
60+
// OpenShift AI Rerank task uses the same response format as JinaAI, therefore we can reuse the JinaAIResponseHandler
61+
private static final ResponseHandler RERANK_HANDLER = new JinaAIResponseHandler(
6262
"OpenShift AI rerank",
63-
(request, response) -> CohereRankedResponseEntity.fromResponse(response),
64-
false
63+
(request, response) -> JinaAIRerankResponseEntity.fromResponse(response)
6564
);
6665

6766
private final Sender sender;
@@ -122,11 +121,11 @@ public ExecutableAction create(OpenShiftAiRerankModel model, Map<String, Object>
122121
inputs.getChunks(),
123122
inputs.getReturnDocuments(),
124123
inputs.getTopN(),
125-
model
124+
overriddenModel
126125
),
127126
QueryAndDocsInputs.class
128127
);
129-
var errorMessage = buildErrorMessage(TaskType.RERANK, model.getInferenceEntityId());
128+
var errorMessage = buildErrorMessage(TaskType.RERANK, overriddenModel.getInferenceEntityId());
130129
return new SenderExecutableAction(sender, manager, errorMessage);
131130
}
132131

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openshiftai/action/OpenShiftAiActionCreatorTests.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -735,8 +735,7 @@ public void testCreate_OpenShiftAiRerankModel_FailsFromInvalidResponseFormat() t
735735

736736
var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
737737
assertThat(thrownException.getMessage(), is("""
738-
Failed to send OpenShift AI rerank request from inference entity id [inferenceEntityId]. Cause: Failed to find required\
739-
field [results] in Cohere rerank response"""));
738+
Failed to send OpenShift AI rerank request from inference entity id [inferenceEntityId]. Cause: Required [results]"""));
740739
}
741740
assertRerankActionCreator(documents);
742741
}

0 commit comments

Comments
 (0)