Skip to content

Commit d5f47b4

Browse files
tkykenmtylwu-amzn
andauthored
fix post_process_function on rerank_pipeline_with_bge-rerank-m3-v2_model_deployed_on_Sagemaker.md (opensearch-project#3296)
* fix post_process_function bug on sort results for rerank_pipeline_with_bge-rerank-m3-v2_model_deployed_on_Sagemaker.md (opensearch-project#3247) Signed-off-by: tkykenmt <[email protected]> * fix typo Signed-off-by: Yaliang Wu <[email protected]> --------- Signed-off-by: tkykenmt <[email protected]> Signed-off-by: Yaliang Wu <[email protected]> Co-authored-by: Yaliang Wu <[email protected]>
1 parent bf48f99 commit d5f47b4

File tree

1 file changed

+127
-14
lines changed

1 file changed

+127
-14
lines changed

docs/tutorials/rerank/rerank_pipeline_with_bge-rerank-m3-v2_model_deployed_on_Sagemaker.md

Lines changed: 127 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,39 @@ result = predictor.predict(data={
5959
]
6060
})
6161

62-
print(json.dumps(sorted(result, key=lambda x: x['index']), indent=2))
62+
print(json.dumps(result, indent=2))
6363
```
6464

65-
The reranking results are as follows:
65+
The reranking result is ordering by the highest score first:
66+
```
67+
[
68+
{
69+
"index": 2,
70+
"score": 0.92879725
71+
},
72+
{
73+
"index": 0,
74+
"score": 0.013636836
75+
},
76+
{
77+
"index": 1,
78+
"score": 0.000593021
79+
},
80+
{
81+
"index": 3,
82+
"score": 0.00012148176
83+
}
84+
]
85+
```
86+
87+
You can sort the result by index number.
88+
89+
```python
90+
print(json.dumps(sorted(result, key=lambda x: x['index']),indent=2))
91+
92+
```
93+
94+
The results are as follows:
6695

6796
```
6897
[
@@ -121,9 +150,51 @@ POST /_plugins/_ml/connectors/_create
121150
"headers": {
122151
"content-type": "application/json"
123152
},
124-
"request_body": "{ \"query\": \"${parameters.query}\", \"texts\": ${parameters.texts} }",
125-
"pre_process_function": "\n def query_text = params.query_text;\n def text_docs = params.text_docs;\n def textDocsBuilder = new StringBuilder('[');\n for (int i=0; i<text_docs.length; i++) {\n textDocsBuilder.append('\"');\n textDocsBuilder.append(text_docs[i]);\n textDocsBuilder.append('\"');\n if (i<text_docs.length - 1) {\n textDocsBuilder.append(',');\n }\n }\n textDocsBuilder.append(']');\n def parameters = '{ \"query\": \"' + query_text + '\", \"texts\": ' + textDocsBuilder.toString() + ' }';\n return '{\"parameters\": ' + parameters + '}';\n",
126-
"post_process_function": "\n \n def dataType = \"FLOAT32\";\n \n \n if (params.result == null)\n {\n return 'no result generated';\n //return params.response;\n }\n def outputs = params.result;\n \n def sorted_outputs = outputs;\n for (int i=0; i<outputs.length; i++) {\n def idx = new BigDecimal(outputs[i].index.toString()).intValue();\n sorted_outputs[idx] = outputs[i];\n }\n def resultBuilder = new StringBuilder('[');\n for (int i=0; i<outputs.length; i++) {\n resultBuilder.append(' {\"name\": \"similarity\", \"data_type\": \"FLOAT32\", \"shape\": [1],');\n //resultBuilder.append('{\"name\": \"similarity\"}');\n \n resultBuilder.append('\"data\": [');\n resultBuilder.append(outputs[i].score);\n resultBuilder.append(']}');\n if (i<outputs.length - 1) {\n resultBuilder.append(',');\n }\n }\n resultBuilder.append(']');\n \n return resultBuilder.toString();\n "
153+
"pre_process_function": """
154+
def query_text = params.query_text;
155+
def text_docs = params.text_docs;
156+
def textDocsBuilder = new StringBuilder('[');
157+
for (int i=0; i<text_docs.length; i++) {
158+
textDocsBuilder.append('"');
159+
textDocsBuilder.append(text_docs[i]);
160+
textDocsBuilder.append('"');
161+
if (i<text_docs.length - 1) {
162+
textDocsBuilder.append(',');
163+
}
164+
}
165+
textDocsBuilder.append(']');
166+
def parameters = '{ "query": "' + query_text + '", "texts": ' + textDocsBuilder.toString() + ' }';
167+
return '{"parameters": ' + parameters + '}';
168+
""",
169+
"request_body": """
170+
{
171+
"query": "${parameters.query}",
172+
"texts": ${parameters.texts}
173+
}
174+
""",
175+
"post_process_function": """
176+
if (params.result == null || params.result.length == 0) {
177+
throw new IllegalArgumentException("Post process function input is empty.");
178+
}
179+
def outputs = params.result;
180+
def scores = new Double[outputs.length];
181+
for (int i=0; i<outputs.length; i++) {
182+
def index = new BigDecimal(outputs[i].index.toString()).intValue();
183+
scores[index] = outputs[i].score;
184+
}
185+
def resultBuilder = new StringBuilder('[');
186+
for (int i=0; i<scores.length; i++) {
187+
resultBuilder.append(' {"name": "similarity", "data_type": "FLOAT32", "shape": [1],');
188+
resultBuilder.append('"data": [');
189+
resultBuilder.append(scores[i]);
190+
resultBuilder.append(']}');
191+
if (i<outputs.length - 1) {
192+
resultBuilder.append(',');
193+
}
194+
}
195+
resultBuilder.append(']');
196+
return resultBuilder.toString();
197+
"""
127198
}
128199
]
129200
}
@@ -152,9 +223,51 @@ POST /_plugins/_ml/connectors/_create
152223
"headers": {
153224
"content-type": "application/json"
154225
},
155-
"request_body": "{ \"query\": \"${parameters.query}\", \"texts\": ${parameters.texts} }",
156-
"pre_process_function": "\n def query_text = params.query_text;\n def text_docs = params.text_docs;\n def textDocsBuilder = new StringBuilder('[');\n for (int i=0; i<text_docs.length; i++) {\n textDocsBuilder.append('\"');\n textDocsBuilder.append(text_docs[i]);\n textDocsBuilder.append('\"');\n if (i<text_docs.length - 1) {\n textDocsBuilder.append(',');\n }\n }\n textDocsBuilder.append(']');\n def parameters = '{ \"query\": \"' + query_text + '\", \"texts\": ' + textDocsBuilder.toString() + ' }';\n return '{\"parameters\": ' + parameters + '}';\n",
157-
"post_process_function": "\n \n def dataType = \"FLOAT32\";\n \n \n if (params.result == null)\n {\n return 'no result generated';\n //return params.response;\n }\n def outputs = params.result;\n \n def sorted_outputs = outputs;\n for (int i=0; i<outputs.length; i++) {\n def idx = new BigDecimal(outputs[i].index.toString()).intValue();\n sorted_outputs[idx] = outputs[i];\n }\n def resultBuilder = new StringBuilder('[');\n for (int i=0; i<outputs.length; i++) {\n resultBuilder.append(' {\"name\": \"similarity\", \"data_type\": \"FLOAT32\", \"shape\": [1],');\n //resultBuilder.append('{\"name\": \"similarity\"}');\n \n resultBuilder.append('\"data\": [');\n resultBuilder.append(outputs[i].score);\n resultBuilder.append(']}');\n if (i<outputs.length - 1) {\n resultBuilder.append(',');\n }\n }\n resultBuilder.append(']');\n \n return resultBuilder.toString();\n "
226+
"pre_process_function": """
227+
def query_text = params.query_text;
228+
def text_docs = params.text_docs;
229+
def textDocsBuilder = new StringBuilder('[');
230+
for (int i=0; i<text_docs.length; i++) {
231+
textDocsBuilder.append('"');
232+
textDocsBuilder.append(text_docs[i]);
233+
textDocsBuilder.append('"');
234+
if (i<text_docs.length - 1) {
235+
textDocsBuilder.append(',');
236+
}
237+
}
238+
textDocsBuilder.append(']');
239+
def parameters = '{ "query": "' + query_text + '", "texts": ' + textDocsBuilder.toString() + ' }';
240+
return '{"parameters": ' + parameters + '}';
241+
""",
242+
"request_body": """
243+
{
244+
"query": "${parameters.query}",
245+
"texts": ${parameters.texts}
246+
}
247+
""",
248+
"post_process_function": """
249+
if (params.result == null || params.result.length == 0) {
250+
throw new IllegalArgumentException("Post process function input is empty.");
251+
}
252+
def outputs = params.result;
253+
def scores = new Double[outputs.length];
254+
for (int i=0; i<outputs.length; i++) {
255+
def index = new BigDecimal(outputs[i].index.toString()).intValue();
256+
scores[index] = outputs[i].score;
257+
}
258+
def resultBuilder = new StringBuilder('[');
259+
for (int i=0; i<scores.length; i++) {
260+
resultBuilder.append(' {"name": "similarity", "data_type": "FLOAT32", "shape": [1],');
261+
resultBuilder.append('"data": [');
262+
resultBuilder.append(scores[i]);
263+
resultBuilder.append(']}');
264+
if (i<outputs.length - 1) {
265+
resultBuilder.append(',');
266+
}
267+
}
268+
resultBuilder.append(']');
269+
return resultBuilder.toString();
270+
"""
158271
}
159272
]
160273
}
@@ -188,7 +301,7 @@ POST _plugins/_ml/models/your_model_id/_predict
188301
}
189302
```
190303

191-
Each item in the `inputs` array comprises a `query_text` and a `text_docs` string, separated by a ` . `
304+
Each item in the array comprises a `query_text` and a `text_docs` string, separated by a ` . `
192305

193306
Alternatively, you can test the model as follows:
194307
```json
@@ -209,6 +322,10 @@ The connector `pre_process_function` transforms the input into the format requir
209322
By default, the SageMaker model output has the following format:
210323
```json
211324
[
325+
{
326+
"index": 2,
327+
"score": 0.92879725
328+
},
212329
{
213330
"index": 0,
214331
"score": 0.013636836
@@ -217,18 +334,14 @@ By default, the SageMaker model output has the following format:
217334
"index": 1,
218335
"score": 0.000593021
219336
},
220-
{
221-
"index": 2,
222-
"score": 0.92879725
223-
},
224337
{
225338
"index": 3,
226339
"score": 0.00012148176
227340
}
228341
]
229342
```
230343

231-
The connector `post_process_function` transforms the model's output into a format that the [Reranker processor](https://opensearch.org/docs/latest/search-plugins/search-pipelines/rerank-processor/) can interpret. This adapted format is as follows:
344+
The connector `post_process_function` transforms the model's output into a format that the [Reranker processor](https://opensearch.org/docs/latest/search-plugins/search-pipelines/rerank-processor/) can interpret, and order result by index. This adapted format is as follows:
232345
```json
233346
{
234347
"inference_results": [

0 commit comments

Comments
 (0)