Skip to content

Commit 803a966

Browse files
authored
[fel] add apiKey logical of rerank model. (#226)
* [fel] add apiKey logical of rerank model. * [fel] Code tidying up. * [fel] Review comments and modifications. * [fel] Review comments and modifications.
1 parent 42ef89e commit 803a966

File tree

13 files changed

+270
-168
lines changed

13 files changed

+270
-168
lines changed

framework/fel/java/fel-community/model-openai/pom.xml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,12 @@
4242
<artifactId>fel-core</artifactId>
4343
</dependency>
4444

45+
<!-- Lombok -->
46+
<dependency>
47+
<groupId>org.projectlombok</groupId>
48+
<artifactId>lombok</artifactId>
49+
</dependency>
50+
4551
<!-- Test Plugins -->
4652
<dependency>
4753
<groupId>org.fitframework.plugin</groupId>

framework/fel/java/fel-community/model-openai/src/main/java/modelengine/fel/community/model/openai/OpenAiModel.java

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,25 +19,32 @@
1919
import modelengine.fel.community.model.openai.entity.embed.OpenAiEmbeddingResponse;
2020
import modelengine.fel.community.model.openai.entity.image.OpenAiImageRequest;
2121
import modelengine.fel.community.model.openai.entity.image.OpenAiImageResponse;
22+
import modelengine.fel.community.model.openai.entity.rerank.OpenAiRerankRequest;
23+
import modelengine.fel.community.model.openai.entity.rerank.OpenAiRerankResponse;
2224
import modelengine.fel.community.model.openai.enums.ModelProcessingState;
2325
import modelengine.fel.community.model.openai.util.HttpUtils;
2426
import modelengine.fel.core.chat.ChatMessage;
2527
import modelengine.fel.core.chat.ChatModel;
2628
import modelengine.fel.core.chat.ChatOption;
2729
import modelengine.fel.core.chat.Prompt;
2830
import modelengine.fel.core.chat.support.AiMessage;
31+
import modelengine.fel.core.document.MeasurableDocument;
2932
import modelengine.fel.core.embed.EmbedModel;
3033
import modelengine.fel.core.embed.EmbedOption;
3134
import modelengine.fel.core.embed.Embedding;
3235
import modelengine.fel.core.image.ImageModel;
3336
import modelengine.fel.core.image.ImageOption;
3437
import modelengine.fel.core.model.http.SecureConfig;
38+
import modelengine.fel.core.rerank.RerankModel;
39+
import modelengine.fel.core.rerank.RerankOption;
3540
import modelengine.fit.http.client.HttpClassicClient;
3641
import modelengine.fit.http.client.HttpClassicClientFactory;
3742
import modelengine.fit.http.client.HttpClassicClientRequest;
3843
import modelengine.fit.http.client.HttpClassicClientResponse;
44+
import modelengine.fit.http.entity.Entity;
3945
import modelengine.fit.http.entity.ObjectEntity;
4046
import modelengine.fit.http.protocol.HttpRequestMethod;
47+
import modelengine.fit.http.protocol.HttpResponseStatus;
4148
import modelengine.fit.security.Decryptor;
4249
import modelengine.fitframework.annotation.Component;
4350
import modelengine.fitframework.annotation.Fit;
@@ -69,7 +76,7 @@
6976
* @since 2024-08-07
7077
*/
7178
@Component
72-
public class OpenAiModel implements EmbedModel, ChatModel, ImageModel {
79+
public class OpenAiModel implements EmbedModel, ChatModel, ImageModel, RerankModel {
7380
private static final Logger log = Logger.get(OpenAiModel.class);
7481
private static final Map<String, Boolean> HTTPS_CONFIG_KEY_MAPS = MapBuilder.<String, Boolean>get()
7582
.put("client.http.secure.ignore-trust", Boolean.FALSE)
@@ -168,6 +175,42 @@ public List<Media> generate(String prompt, ImageOption option) {
168175
}
169176
}
170177

178+
@Override
179+
public List<MeasurableDocument> generate(List<MeasurableDocument> documents, RerankOption rerankOption) {
180+
notEmpty(documents, "The documents cannot be empty.");
181+
notNull(rerankOption, "The rerank option cannot be null.");
182+
String modelSource = StringUtils.blankIf(rerankOption.baseUri(), this.baseUrl);
183+
HttpClassicClientRequest request = this.getHttpClient(rerankOption.secureConfig())
184+
.createRequest(HttpRequestMethod.POST, UrlUtils.combine(modelSource, OpenAiApi.RERANK_ENDPOINT));
185+
HttpUtils.setBearerAuth(request, StringUtils.blankIf(rerankOption.apiKey(), this.defaultApiKey));
186+
List<String> docs = documents.stream().map(MeasurableDocument::text).collect(Collectors.toList());
187+
OpenAiRerankRequest fields = new OpenAiRerankRequest(rerankOption, docs);
188+
request.entity(Entity.createObject(request, fields));
189+
OpenAiRerankResponse rerankResponse = this.rerankExchange(request);
190+
191+
return rerankResponse.results()
192+
.stream()
193+
.map(result -> new MeasurableDocument(documents.get(result.index()), result.relevanceScore()))
194+
.sorted((document1, document2) -> (int) (document2.score() - document1.score()))
195+
.collect(Collectors.toList());
196+
}
197+
198+
private OpenAiRerankResponse rerankExchange(HttpClassicClientRequest request) {
199+
try (HttpClassicClientResponse<Object> response = request.exchange(OpenAiRerankResponse.class)) {
200+
if (response.statusCode() != HttpResponseStatus.OK.statusCode()) {
201+
log.error("Failed to get rerank model response. [code={}, reason={}]",
202+
response.statusCode(),
203+
response.reasonPhrase());
204+
throw new FitException("Failed to get rerank model response.");
205+
}
206+
return ObjectUtils.cast(response.objectEntity()
207+
.map(ObjectEntity::object)
208+
.orElseThrow(() -> new FitException("The response body is abnormal.")));
209+
} catch (IOException e) {
210+
throw new IllegalStateException("Failed to request rerank model.", e);
211+
}
212+
}
213+
171214
private Choir<ChatMessage> createChatStream(HttpClassicClientRequest request) {
172215
AtomicReference<ModelProcessingState> modelProcessingState =
173216
new AtomicReference<>(ModelProcessingState.INITIAL);

framework/fel/java/fel-community/model-openai/src/main/java/modelengine/fel/community/model/openai/api/OpenAiApi.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,11 @@ public interface OpenAiApi {
2727
*/
2828
String IMAGE_ENDPOINT = "/images/generations";
2929

30+
/**
31+
* 重排请求的端点。
32+
*/
33+
String RERANK_ENDPOINT = "/rerank";
34+
3035
/**
3136
* 请求头模型密钥字段。
3237
*/

framework/fel/java/fel-core/src/main/java/modelengine/fel/core/document/support/RerankRequest.java renamed to framework/fel/java/fel-community/model-openai/src/main/java/modelengine/fel/community/model/openai/entity/rerank/OpenAiRerankRequest.java

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,34 +4,39 @@
44
* Licensed under the MIT License. See License.txt in the project root for license information.
55
*--------------------------------------------------------------------------------------------*/
66

7-
package modelengine.fel.core.document.support;
7+
package modelengine.fel.community.model.openai.entity.rerank;
88

9+
import lombok.Data;
10+
import lombok.NoArgsConstructor;
11+
import modelengine.fel.core.rerank.RerankOption;
912
import modelengine.fitframework.annotation.Property;
1013
import modelengine.fitframework.inspection.Validation;
1114
import modelengine.fitframework.serialization.annotation.SerializeStrategy;
1215

1316
import java.util.List;
1417

1518
/**
16-
* 表示 Rerank API 格式的请求
19+
* 表示 OpenAI API 格式的重排请求
1720
*
1821
* @since 2024-09-27
1922
*/
23+
@Data
2024
@SerializeStrategy(include = SerializeStrategy.Include.NON_NULL)
21-
public class RerankRequest {
22-
private final String model;
23-
private final String query;
24-
private final List<String> documents;
25+
@NoArgsConstructor
26+
public class OpenAiRerankRequest {
27+
private String model;
28+
private String query;
29+
private List<String> documents;
2530
@Property(name = "top_n")
26-
private final Integer topN;
31+
private Integer topN;
2732

2833
/**
29-
* 创建 {@link RerankRequest} 的实体。
34+
* 创建 {@link OpenAiRerankRequest} 的实体。
3035
*
36+
* @param rerankOption 表示重排模型参数。
3137
* @param documents 表示要重新排序的文档对象。
32-
* @param rerankOption 表示 rerank 模型参数。
3338
*/
34-
public RerankRequest(RerankOption rerankOption, List<String> documents) {
39+
public OpenAiRerankRequest(RerankOption rerankOption, List<String> documents) {
3540
Validation.notNull(rerankOption, "The rerankOption cannot be null.");
3641
this.model = rerankOption.model();
3742
this.query = rerankOption.query();

framework/fel/java/fel-core/src/main/java/modelengine/fel/core/document/support/RerankResponse.java renamed to framework/fel/java/fel-community/model-openai/src/main/java/modelengine/fel/community/model/openai/entity/rerank/OpenAiRerankResponse.java

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,34 +4,47 @@
44
* Licensed under the MIT License. See License.txt in the project root for license information.
55
*--------------------------------------------------------------------------------------------*/
66

7-
package modelengine.fel.core.document.support;
7+
package modelengine.fel.community.model.openai.entity.rerank;
88

9+
import lombok.AllArgsConstructor;
10+
import lombok.Data;
11+
import lombok.NoArgsConstructor;
912
import modelengine.fitframework.annotation.Property;
1013
import modelengine.fitframework.util.CollectionUtils;
1114

1215
import java.util.Collections;
1316
import java.util.List;
1417

1518
/**
16-
* 表示 Rerank API 格式的请求
19+
* 表示 OpenAI API 格式的重排响应
1720
*
1821
* @since 2024-09-27
1922
*/
20-
public class RerankResponse {
21-
private List<RerankOrder> results;
23+
@Data
24+
@NoArgsConstructor
25+
@AllArgsConstructor
26+
public class OpenAiRerankResponse {
27+
private List<OpenAiRerankResponse.RerankOrder> results;
2228

2329
/**
2430
* 获取重新排序后的文档列表。
2531
*
26-
* @return 表示重新排序后的文档列表的 {@link List}{@code <}{@link RerankOrder}{@code >}。
32+
* @return 表示重新排序后的文档列表的 {@link List}{@code <}{@link OpenAiRerankResponse.RerankOrder}{@code >}。
2733
*/
28-
public List<RerankOrder> results() {
34+
public List<OpenAiRerankResponse.RerankOrder> results() {
2935
return CollectionUtils.isEmpty(this.results)
3036
? Collections.emptyList()
3137
: Collections.unmodifiableList(this.results);
3238
}
3339

34-
static class RerankOrder {
40+
/**
41+
* 表示重排序后的文档项,包含文档在原始列表中的索引和重新计算的相关性评分。
42+
* 用于存储和访问重新排序后的文档信息。
43+
*/
44+
@Data
45+
@NoArgsConstructor
46+
@AllArgsConstructor
47+
public static class RerankOrder {
3548
private int index;
3649
@Property(name = "relevance_score")
3750
private double relevanceScore;

framework/fel/java/fel-community/model-openai/src/test/java/modelengine/fel/community/model/openai/OpenAiModelTest.java

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,19 @@
77
package modelengine.fel.community.model.openai;
88

99
import static org.assertj.core.api.Assertions.assertThat;
10+
import static org.junit.jupiter.api.Assertions.assertAll;
1011

1112
import modelengine.fel.community.model.openai.config.OpenAiConfig;
1213
import modelengine.fel.core.chat.ChatMessage;
1314
import modelengine.fel.core.chat.ChatOption;
1415
import modelengine.fel.core.chat.support.ChatMessages;
1516
import modelengine.fel.core.chat.support.HumanMessage;
17+
import modelengine.fel.core.document.Document;
18+
import modelengine.fel.core.document.MeasurableDocument;
1619
import modelengine.fel.core.embed.EmbedOption;
1720
import modelengine.fel.core.embed.Embedding;
1821
import modelengine.fel.core.image.ImageOption;
22+
import modelengine.fel.core.rerank.RerankOption;
1923
import modelengine.fit.http.client.HttpClassicClientFactory;
2024
import modelengine.fitframework.annotation.Fit;
2125
import modelengine.fitframework.conf.Config;
@@ -31,6 +35,8 @@
3135
import org.junit.jupiter.api.Test;
3236

3337
import java.util.Arrays;
38+
import java.util.Collections;
39+
import java.util.HashMap;
3440
import java.util.List;
3541
import java.util.stream.Collectors;
3642

@@ -41,6 +47,9 @@
4147
*/
4248
@MvcTest(classes = TestModelController.class)
4349
public class OpenAiModelTest {
50+
private static final int EXPECTED_TOP_K = 3;
51+
private static final String HIGHEST_RANKED_TEXT = "C++ offers high performance.";
52+
private static final double EXPECTED_HIGHEST_SCORE = 0.999071;
4453
private OpenAiModel openAiModel;
4554

4655
@Fit
@@ -91,4 +100,45 @@ void testOpenAiImageModel() {
91100
"456",
92101
"789");
93102
}
103+
104+
@Test
105+
@DisplayName("测试重排模型返回:应返回按相关性排序的前 K 个文档")
106+
void testOpenAiRerankModel() {
107+
// Given: 准备输入文档
108+
List<MeasurableDocument> inputDocs = Arrays.asList(doc("0", "Java is a programming language."),
109+
doc("1", "Python is great for data science."),
110+
doc("2", HIGHEST_RANKED_TEXT),
111+
doc("3", "Rust offers high performance."),
112+
doc("4", "C offers high performance."));
113+
114+
RerankOption rerankOption = RerankOption.custom().model("rerank-model").topN(EXPECTED_TOP_K).build();
115+
116+
// When: 调用重排接口
117+
List<MeasurableDocument> result = this.openAiModel.generate(inputDocs, rerankOption);
118+
119+
// Then: 验证结果
120+
assertAll(() -> assertThat(result).as("应返回 top-%d 结果", EXPECTED_TOP_K).hasSize(EXPECTED_TOP_K),
121+
122+
() -> {
123+
List<Double> scores = result.stream().map(MeasurableDocument::score).collect(Collectors.toList());
124+
assertThat(scores).as("结果应按相关性分数降序排列").isSortedAccordingTo(Collections.reverseOrder());
125+
},
126+
127+
() -> {
128+
List<String> resultTexts =
129+
result.stream().map(MeasurableDocument::text).collect(Collectors.toList());
130+
List<String> inputTexts =
131+
inputDocs.stream().map(MeasurableDocument::text).collect(Collectors.toList());
132+
assertThat(inputTexts).as("所有返回文档必须来自输入集").containsAll(resultTexts);
133+
},
134+
135+
() -> assertThat(result.get(0).text()).as("得分最高的文档应为 C++").isEqualTo(HIGHEST_RANKED_TEXT),
136+
137+
() -> assertThat(result.get(0).score()).as("最高分应与模拟响应一致").isEqualTo(EXPECTED_HIGHEST_SCORE));
138+
}
139+
140+
private MeasurableDocument doc(String id, String text) {
141+
Document document = Document.custom().id(id).text(text).metadata(new HashMap<>()).build();
142+
return new MeasurableDocument(document, 0.0);
143+
}
94144
}

framework/fel/java/fel-community/model-openai/src/test/java/modelengine/fel/community/model/openai/TestModelController.java

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,22 @@
99
import static modelengine.fel.community.model.openai.api.OpenAiApi.CHAT_ENDPOINT;
1010
import static modelengine.fel.community.model.openai.api.OpenAiApi.EMBEDDING_ENDPOINT;
1111
import static modelengine.fel.community.model.openai.api.OpenAiApi.IMAGE_ENDPOINT;
12+
import static modelengine.fel.community.model.openai.api.OpenAiApi.RERANK_ENDPOINT;
1213

1314
import modelengine.fel.community.model.openai.entity.embed.OpenAiEmbeddingResponse;
1415
import modelengine.fel.community.model.openai.entity.image.OpenAiImageResponse;
16+
import modelengine.fel.community.model.openai.entity.rerank.OpenAiRerankRequest;
17+
import modelengine.fel.community.model.openai.entity.rerank.OpenAiRerankResponse;
1518
import modelengine.fit.http.annotation.PostMapping;
19+
import modelengine.fit.http.annotation.RequestBody;
1620
import modelengine.fitframework.annotation.Component;
1721
import modelengine.fitframework.flowable.Choir;
1822
import modelengine.fitframework.serialization.ObjectSerializer;
1923

24+
import java.util.ArrayList;
25+
import java.util.List;
26+
import java.util.stream.Collectors;
27+
2028
/**
2129
* 表示测试使用的聊天接口。
2230
*
@@ -81,4 +89,27 @@ public OpenAiImageResponse image() {
8189
+ "\"data\":[{\"b64_json\":\"123\"}, {\"b64_json\":\"456\"}, {\"b64_json\":\"789\"}]}";
8290
return this.serializer.deserialize(json, OpenAiImageResponse.class);
8391
}
92+
93+
/**
94+
* 测试用重排接口。
95+
*
96+
* @return 表示重排响应的 {@link OpenAiRerankResponse}。
97+
*/
98+
@PostMapping(RERANK_ENDPOINT)
99+
public OpenAiRerankResponse rerank(@RequestBody OpenAiRerankRequest request) {
100+
int topN = request.getTopN();
101+
List<String> docs = request.getDocuments();
102+
// 模拟生成结果:按 index 顺序生成 relevance_score,最多返回 topN 个
103+
List<OpenAiRerankResponse.RerankOrder> results = new ArrayList<>();
104+
double[] mockScores = {0.32713068, 0.4, 0.999071, 0.7867867, 0.6}; // 对应 index 0~4
105+
List<OpenAiRerankResponse.RerankOrder> allResults = new ArrayList<>();
106+
for (int i = 0; i < mockScores.length && i < docs.size(); i++) {
107+
allResults.add(new OpenAiRerankResponse.RerankOrder(i, mockScores[i]));
108+
}
109+
allResults.sort((a, b) -> Double.compare(b.relevanceScore(), a.relevanceScore()));
110+
List<OpenAiRerankResponse.RerankOrder> limitedResults = allResults.stream()
111+
.limit(topN)
112+
.collect(Collectors.toList());
113+
return new OpenAiRerankResponse(limitedResults);
114+
}
84115
}

0 commit comments

Comments
 (0)