Skip to content

Commit ffb8a4e

Browse files
authored
tune model config: change pooling mode to optional (#724)
Signed-off-by: Yaliang Wu <[email protected]>
1 parent 6baf6eb commit ffb8a4e

File tree

5 files changed

+23
-14
lines changed

5 files changed

+23
-14
lines changed

common/src/main/java/org/opensearch/ml/common/model/TextEmbeddingModelConfig.java

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,7 @@ public TextEmbeddingModelConfig(String modelType, Integer embeddingDimension, Fr
5555
}
5656
this.embeddingDimension = embeddingDimension;
5757
this.frameworkType = frameworkType;
58-
if (poolingMode != null) {
59-
this.poolingMode = poolingMode;
60-
} else {
61-
this.poolingMode = PoolingMode.MEAN;
62-
}
58+
this.poolingMode = poolingMode;
6359
this.normalizeResult = normalizeResult;
6460
this.modelMaxLength = modelMaxLength;
6561
}
@@ -69,7 +65,7 @@ public static TextEmbeddingModelConfig parse(XContentParser parser) throws IOExc
6965
Integer embeddingDimension = null;
7066
FrameworkType frameworkType = null;
7167
String allConfig = null;
72-
PoolingMode poolingMode = PoolingMode.MEAN;
68+
PoolingMode poolingMode = null;
7369
boolean normalizeResult = false;
7470
Integer modelMaxLength = null;
7571

@@ -117,7 +113,11 @@ public TextEmbeddingModelConfig(StreamInput in) throws IOException{
117113
super(in);
118114
embeddingDimension = in.readInt();
119115
frameworkType = in.readEnum(FrameworkType.class);
120-
poolingMode = in.readEnum(PoolingMode.class);
116+
if (in.readBoolean()) {
117+
poolingMode = in.readEnum(PoolingMode.class);
118+
} else {
119+
poolingMode = null;
120+
}
121121
normalizeResult = in.readBoolean();
122122
modelMaxLength = in.readOptionalInt();
123123
}
@@ -127,7 +127,12 @@ public void writeTo(StreamOutput out) throws IOException {
127127
super.writeTo(out);
128128
out.writeInt(embeddingDimension);
129129
out.writeEnum(frameworkType);
130-
out.writeEnum(poolingMode);
130+
if (poolingMode != null) {
131+
out.writeBoolean(true);
132+
out.writeEnum(poolingMode);
133+
} else {
134+
out.writeBoolean(false);
135+
}
131136
out.writeBoolean(normalizeResult);
132137
out.writeOptionalInt(modelMaxLength);
133138
}
@@ -150,8 +155,12 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
150155
if (modelMaxLength != null) {
151156
builder.field(MODEL_MAX_LENGTH_FIELD, modelMaxLength);
152157
}
153-
builder.field(POOLING_MODE_FIELD, poolingMode);
154-
builder.field(NORMALIZE_RESULT_FIELD, normalizeResult);
158+
if (poolingMode != null) {
159+
builder.field(POOLING_MODE_FIELD, poolingMode);
160+
}
161+
if (normalizeResult) {
162+
builder.field(NORMALIZE_RESULT_FIELD, normalizeResult);
163+
}
155164
builder.endObject();
156165
return builder;
157166
}

common/src/test/java/org/opensearch/ml/common/model/TextEmbeddingModelConfigTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ public void toXContent() throws IOException {
5252
config.toXContent(builder, EMPTY_PARAMS);
5353
String configContent = TestHelper.xContentBuilderToString(builder);
5454
System.out.println(configContent);
55-
assertEquals("{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\",\"pooling_mode\":\"MEAN\",\"normalize_result\":false}", configContent);
55+
assertEquals("{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"}", configContent);
5656
}
5757

5858
@Test

common/src/test/java/org/opensearch/ml/common/transport/upload/MLUploadInputTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ public class MLUploadInputTest {
3737
private final String expectedInputStr = "{\"function_name\":\"LINEAR_REGRESSION\",\"name\":\"modelName\",\"version\":\"version\",\"url\":\"url\",\"model_format\":\"ONNX\"," +
3838
"\"model_config\":{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\"," +
3939
"\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"" +
40-
",\"pooling_mode\":\"MEAN\",\"normalize_result\":false},\"load_model\":true,\"model_node_ids\":[\"modelNodeIds\"]}";
40+
"},\"load_model\":true,\"model_node_ids\":[\"modelNodeIds\"]}";
4141
private final FunctionName functionName = FunctionName.LINEAR_REGRESSION;
4242
private final String modelName = "modelName";
4343
private final String version = "version";

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_embedding/HuggingfaceTextEmbeddingTranslatorFactory.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ public class HuggingfaceTextEmbeddingTranslatorFactory implements TranslatorFact
3737
private final boolean neuron;
3838

3939
public HuggingfaceTextEmbeddingTranslatorFactory(TextEmbeddingModelConfig.PoolingMode poolingMode, boolean normalizeResult, String modelType, boolean neuron) {
40-
this.poolingMode = poolingMode;
40+
this.poolingMode = poolingMode == null ? TextEmbeddingModelConfig.PoolingMode.MEAN : poolingMode;
4141
this.normalizeResult = normalizeResult;
4242
this.modelType = modelType;
4343
this.neuron = neuron;

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_embedding/ONNXSentenceTransformerTextEmbeddingTranslator.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ public class ONNXSentenceTransformerTextEmbeddingTranslator implements ServingTr
3636
private String modelType;
3737

3838
public ONNXSentenceTransformerTextEmbeddingTranslator(TextEmbeddingModelConfig.PoolingMode poolingMode, boolean normalizeResult, String modelType) {
39-
this.poolingMode = poolingMode;
39+
this.poolingMode = poolingMode == null ? TextEmbeddingModelConfig.PoolingMode.MEAN : poolingMode;
4040
this.normalizeResult = normalizeResult;
4141
this.modelType = modelType;
4242
}

0 commit comments

Comments
 (0)