Skip to content

Commit cf7b221

Browse files
committed
feat(model/rerank): support text-rereank model
1 parent 377f231 commit cf7b221

File tree

7 files changed

+302
-0
lines changed

7 files changed

+302
-0
lines changed

samples/TextReRankTest.java

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
// Copyright (c) Alibaba, Inc. and its affiliates.
2+
3+
import com.alibaba.dashscope.exception.InputRequiredException;
4+
import com.alibaba.dashscope.rerank.TextReRank;
5+
import com.alibaba.dashscope.rerank.TextReRankParam;
6+
import com.alibaba.dashscope.rerank.TextReRankResult;
7+
import com.alibaba.dashscope.exception.ApiException;
8+
import com.alibaba.dashscope.exception.NoApiKeyException;
9+
import com.alibaba.dashscope.utils.JsonUtils;
10+
11+
import java.util.Arrays;
12+
13+
public class TextReRankTest {
14+
15+
public static void main(String[] args) {
16+
try {
17+
// Create TextReRank instance
18+
TextReRank textReRank = new TextReRank();
19+
20+
// Create parameters
21+
TextReRankParam param = TextReRankParam.builder()
22+
.model(TextReRank.Models.GTE_RERANK_V2)
23+
.query("什么是文本排序模型")
24+
.documents(Arrays.asList(
25+
"文本排序模型广泛用于搜索引擎和推荐系统中,它们根据文本相关性对候选文本进行排序",
26+
"量子计算是计算科学的一个前沿领域",
27+
"预训练语言模型的发展给文本排序模型带来了新的进展"
28+
))
29+
.topN(10)
30+
.returnDocuments(true)
31+
.build();
32+
33+
// Call the API
34+
TextReRankResult result = textReRank.call(param);
35+
36+
System.out.println("Rerank Result:");
37+
System.out.println(JsonUtils.toJson(result));
38+
} catch (NoApiKeyException e) {
39+
System.err.println("API key not found: " + e.getMessage());
40+
} catch (ApiException e) {
41+
System.err.println("API call failed: " + e.getMessage());
42+
} catch (InputRequiredException e) {
43+
throw new RuntimeException(e);
44+
}
45+
}
46+
}

src/main/java/com/alibaba/dashscope/common/TaskGroup.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ public enum TaskGroup {
66
EMBEDDINGS("embeddings"),
77
AUDIO("audio"),
88
NLP("nlp"),
9+
RERANK("rerank"),
910
;
1011

1112
private final String value;
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
// Copyright (c) Alibaba, Inc. and its affiliates.
2+
package com.alibaba.dashscope.rerank;
3+
4+
import com.alibaba.dashscope.api.SynchronizeHalfDuplexApi;
5+
import com.alibaba.dashscope.common.*;
6+
import com.alibaba.dashscope.exception.ApiException;
7+
import com.alibaba.dashscope.exception.InputRequiredException;
8+
import com.alibaba.dashscope.exception.NoApiKeyException;
9+
import com.alibaba.dashscope.protocol.*;
10+
import lombok.extern.slf4j.Slf4j;
11+
12+
13+
@Slf4j
14+
public final class TextReRank {
15+
16+
private final SynchronizeHalfDuplexApi<TextReRankParam> syncApi;
17+
private final ApiServiceOption serviceOption;
18+
19+
public static class Models {
20+
public static final String GTE_RERANK_V2 = "gte-rerank-v2";
21+
}
22+
23+
private ApiServiceOption defaultApiServiceOption() {
24+
return ApiServiceOption.builder()
25+
.protocol(Protocol.HTTP)
26+
.httpMethod(HttpMethod.POST)
27+
.streamingMode(StreamingMode.NONE)
28+
.outputMode(OutputMode.ACCUMULATE)
29+
.taskGroup(TaskGroup.RERANK.getValue())
30+
.task("text-rerank")
31+
.function("text-rerank")
32+
.build();
33+
}
34+
35+
public TextReRank() {
36+
serviceOption = defaultApiServiceOption();
37+
syncApi = new SynchronizeHalfDuplexApi<>(serviceOption);
38+
}
39+
40+
public TextReRank(String protocol) {
41+
serviceOption = defaultApiServiceOption();
42+
serviceOption.setProtocol(Protocol.of(protocol));
43+
syncApi = new SynchronizeHalfDuplexApi<>(serviceOption);
44+
}
45+
46+
public TextReRank(String protocol, String baseUrl) {
47+
serviceOption = defaultApiServiceOption();
48+
serviceOption.setProtocol(Protocol.of(protocol));
49+
if (Protocol.HTTP.getValue().equals(protocol)) {
50+
serviceOption.setBaseHttpUrl(baseUrl);
51+
} else {
52+
serviceOption.setBaseWebSocketUrl(baseUrl);
53+
}
54+
syncApi = new SynchronizeHalfDuplexApi<>(serviceOption);
55+
}
56+
57+
public TextReRank(
58+
String protocol, String baseUrl, ConnectionOptions connectionOptions) {
59+
serviceOption = defaultApiServiceOption();
60+
serviceOption.setProtocol(Protocol.of(protocol));
61+
if (Protocol.HTTP.getValue().equals(protocol)) {
62+
serviceOption.setBaseHttpUrl(baseUrl);
63+
} else {
64+
serviceOption.setBaseWebSocketUrl(baseUrl);
65+
}
66+
syncApi = new SynchronizeHalfDuplexApi<>(connectionOptions, serviceOption);
67+
}
68+
69+
/**
70+
* Call the server to get the whole result.
71+
*
72+
* @param param The input param of class `TextReRankParam`.
73+
* @return The output structure of `TextReRankResult`.
74+
* @throws NoApiKeyException Can not find api key
75+
* @throws ApiException The request failed, possibly due to a network or data error.
76+
*/
77+
public TextReRankResult call(TextReRankParam param)
78+
throws ApiException, NoApiKeyException, InputRequiredException {
79+
param.validate();
80+
serviceOption.setIsSSE(false);
81+
serviceOption.setStreamingMode(StreamingMode.NONE);
82+
return TextReRankResult.fromDashScopeResult(syncApi.call(param));
83+
}
84+
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// Copyright (c) Alibaba, Inc. and its affiliates.
2+
package com.alibaba.dashscope.rerank;
3+
4+
import com.google.gson.annotations.SerializedName;
5+
import java.util.List;
6+
import lombok.Data;
7+
8+
@Data
9+
public class TextReRankOutput {
10+
11+
@Data
12+
public static class Result {
13+
private Integer index;
14+
15+
@SerializedName("relevance_score")
16+
private Double relevanceScore;
17+
18+
private Document document;
19+
}
20+
21+
@Data
22+
public static class Document {
23+
private String text;
24+
}
25+
26+
private List<Result> results;
27+
}
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
// Copyright (c) Alibaba, Inc. and its affiliates.
2+
package com.alibaba.dashscope.rerank;
3+
4+
import com.alibaba.dashscope.base.HalfDuplexServiceParam;
5+
import com.alibaba.dashscope.exception.InputRequiredException;
6+
import com.alibaba.dashscope.utils.ApiKeywords;
7+
import com.alibaba.dashscope.utils.JsonUtils;
8+
import com.google.gson.JsonObject;
9+
import java.nio.ByteBuffer;
10+
import java.util.HashMap;
11+
import java.util.List;
12+
import java.util.Map;
13+
import lombok.Data;
14+
import lombok.EqualsAndHashCode;
15+
import lombok.Singular;
16+
import lombok.experimental.SuperBuilder;
17+
18+
@EqualsAndHashCode(callSuper = true)
19+
@Data
20+
@SuperBuilder
21+
public class TextReRankParam extends HalfDuplexServiceParam {
22+
23+
/** The query text for reranking. Maximum length is 4,000 tokens. */
24+
private String query;
25+
26+
/** The list of candidate documents to be reranked. Maximum 500 documents. */
27+
@Singular private List<String> documents;
28+
29+
/**
30+
* The number of top documents to return.
31+
* If not specified, returns all candidate documents.
32+
* If top_n is greater than the number of input documents, returns all documents.
33+
*/
34+
private Integer topN;
35+
36+
/**
37+
* Whether to return the original document text in the results.
38+
* Default is false.
39+
*/
40+
private Boolean returnDocuments;
41+
42+
@Override
43+
public JsonObject getHttpBody() {
44+
JsonObject requestObject = new JsonObject();
45+
requestObject.addProperty(ApiKeywords.MODEL, getModel());
46+
requestObject.add(ApiKeywords.INPUT, getInput());
47+
Map<String, Object> params = getParameters();
48+
if (params != null && !params.isEmpty()) {
49+
requestObject.add(ApiKeywords.PARAMETERS, JsonUtils.parametersToJsonObject(params));
50+
}
51+
return requestObject;
52+
}
53+
54+
@Override
55+
public JsonObject getInput() {
56+
JsonObject jsonObject = new JsonObject();
57+
jsonObject.addProperty("query", query);
58+
jsonObject.add("documents", JsonUtils.toJsonArray(documents));
59+
return jsonObject;
60+
}
61+
62+
@Override
63+
public Map<String, Object> getParameters() {
64+
Map<String, Object> params = new HashMap<>();
65+
66+
if (topN != null) {
67+
params.put("top_n", topN);
68+
}
69+
70+
if (returnDocuments != null) {
71+
params.put("return_documents", returnDocuments);
72+
}
73+
74+
params.putAll(parameters);
75+
return params;
76+
}
77+
78+
@Override
79+
public ByteBuffer getBinaryData() {
80+
return null;
81+
}
82+
83+
@Override
84+
public void validate() throws InputRequiredException {
85+
if (query == null || query.trim().isEmpty()) {
86+
throw new InputRequiredException("Query must not be null or empty!");
87+
}
88+
89+
if (documents == null || documents.isEmpty()) {
90+
throw new InputRequiredException("Documents must not be null or empty!");
91+
}
92+
}
93+
}
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// Copyright (c) Alibaba, Inc. and its affiliates.
2+
package com.alibaba.dashscope.rerank;
3+
4+
import com.alibaba.dashscope.common.DashScopeResult;
5+
import com.alibaba.dashscope.utils.JsonUtils;
6+
import com.google.gson.JsonObject;
7+
import com.google.gson.annotations.SerializedName;
8+
import lombok.Data;
9+
import lombok.extern.slf4j.Slf4j;
10+
11+
@Slf4j
12+
@Data
13+
public class TextReRankResult {
14+
@SerializedName("request_id")
15+
private String requestId;
16+
17+
private TextReRankUsage usage;
18+
19+
private TextReRankOutput output;
20+
21+
private TextReRankResult() {}
22+
23+
public static TextReRankResult fromDashScopeResult(DashScopeResult dashScopeResult) {
24+
TextReRankResult result = new TextReRankResult();
25+
result.setRequestId(dashScopeResult.getRequestId());
26+
if (dashScopeResult.getUsage() != null) {
27+
result.setUsage(
28+
JsonUtils.fromJsonObject(
29+
dashScopeResult.getUsage().getAsJsonObject(), TextReRankUsage.class));
30+
}
31+
if (dashScopeResult.getOutput() != null) {
32+
result.setOutput(
33+
JsonUtils.fromJsonObject(
34+
(JsonObject) dashScopeResult.getOutput(), TextReRankOutput.class));
35+
} else {
36+
log.error("Result no output: {}", dashScopeResult);
37+
}
38+
return result;
39+
}
40+
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
// Copyright (c) Alibaba, Inc. and its affiliates.
2+
package com.alibaba.dashscope.rerank;
3+
4+
import com.google.gson.annotations.SerializedName;
5+
import lombok.Data;
6+
7+
@Data
8+
public class TextReRankUsage {
9+
@SerializedName("total_tokens")
10+
private Integer totalTokens;
11+
}

0 commit comments

Comments
 (0)