@@ -33,19 +33,19 @@ public class TextEmbeddingModelConfig extends MLModelConfig {
3333
3434 public static final String EMBEDDING_DIMENSION_FIELD = "embedding_dimension" ;
3535 public static final String FRAMEWORK_TYPE_FIELD = "framework_type" ;
36- public static final String POOLING_METHOD_FIELD = "pooling_method " ;
36+ public static final String POOLING_MODE_FIELD = "pooling_mode " ;
3737 public static final String NORMALIZE_RESULT_FIELD = "normalize_result" ;
3838 public static final String MODEL_MAX_LENGTH_FIELD = "model_max_length" ;
3939
4040 private final Integer embeddingDimension ;
4141 private final FrameworkType frameworkType ;
42- private final PoolingMethod poolingMethod ;
42+ private final PoolingMode poolingMode ;
4343 private final boolean normalizeResult ;
4444 private final Integer modelMaxLength ;
4545
4646 @ Builder (toBuilder = true )
4747 public TextEmbeddingModelConfig (String modelType , Integer embeddingDimension , FrameworkType frameworkType , String allConfig ,
48- PoolingMethod poolingMethod , boolean normalizeResult , Integer modelMaxLength ) {
48+ PoolingMode poolingMode , boolean normalizeResult , Integer modelMaxLength ) {
4949 super (modelType , allConfig );
5050 if (embeddingDimension == null ) {
5151 throw new IllegalArgumentException ("embedding dimension is null" );
@@ -55,10 +55,10 @@ public TextEmbeddingModelConfig(String modelType, Integer embeddingDimension, Fr
5555 }
5656 this .embeddingDimension = embeddingDimension ;
5757 this .frameworkType = frameworkType ;
58- if (poolingMethod != null ) {
59- this .poolingMethod = poolingMethod ;
58+ if (poolingMode != null ) {
59+ this .poolingMode = poolingMode ;
6060 } else {
61- this .poolingMethod = PoolingMethod .MEAN ;
61+ this .poolingMode = PoolingMode .MEAN ;
6262 }
6363 this .normalizeResult = normalizeResult ;
6464 this .modelMaxLength = modelMaxLength ;
@@ -69,7 +69,7 @@ public static TextEmbeddingModelConfig parse(XContentParser parser) throws IOExc
6969 Integer embeddingDimension = null ;
7070 FrameworkType frameworkType = null ;
7171 String allConfig = null ;
72- PoolingMethod poolingMethod = PoolingMethod .MEAN ;
72+ PoolingMode poolingMode = PoolingMode .MEAN ;
7373 boolean normalizeResult = false ;
7474 Integer modelMaxLength = null ;
7575
@@ -91,8 +91,8 @@ public static TextEmbeddingModelConfig parse(XContentParser parser) throws IOExc
9191 case ALL_CONFIG_FIELD :
9292 allConfig = parser .text ();
9393 break ;
94- case POOLING_METHOD_FIELD :
95- poolingMethod = PoolingMethod .from (parser .text ().toUpperCase (Locale .ROOT ));
94+ case POOLING_MODE_FIELD :
95+ poolingMode = PoolingMode .from (parser .text ().toUpperCase (Locale .ROOT ));
9696 break ;
9797 case NORMALIZE_RESULT_FIELD :
9898 normalizeResult = parser .booleanValue ();
@@ -105,7 +105,7 @@ public static TextEmbeddingModelConfig parse(XContentParser parser) throws IOExc
105105 break ;
106106 }
107107 }
108- return new TextEmbeddingModelConfig (modelType , embeddingDimension , frameworkType , allConfig , poolingMethod , normalizeResult , modelMaxLength );
108+ return new TextEmbeddingModelConfig (modelType , embeddingDimension , frameworkType , allConfig , poolingMode , normalizeResult , modelMaxLength );
109109 }
110110
111111 @ Override
@@ -117,7 +117,7 @@ public TextEmbeddingModelConfig(StreamInput in) throws IOException{
117117 super (in );
118118 embeddingDimension = in .readInt ();
119119 frameworkType = in .readEnum (FrameworkType .class );
120- poolingMethod = in .readEnum (PoolingMethod .class );
120+ poolingMode = in .readEnum (PoolingMode .class );
121121 normalizeResult = in .readBoolean ();
122122 modelMaxLength = in .readOptionalInt ();
123123 }
@@ -127,7 +127,7 @@ public void writeTo(StreamOutput out) throws IOException {
127127 super .writeTo (out );
128128 out .writeInt (embeddingDimension );
129129 out .writeEnum (frameworkType );
130- out .writeEnum (poolingMethod );
130+ out .writeEnum (poolingMode );
131131 out .writeBoolean (normalizeResult );
132132 out .writeOptionalInt (modelMaxLength );
133133 }
@@ -150,19 +150,32 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
150150 if (modelMaxLength != null ) {
151151 builder .field (MODEL_MAX_LENGTH_FIELD , modelMaxLength );
152152 }
153- builder .field (POOLING_METHOD_FIELD , poolingMethod );
153+ builder .field (POOLING_MODE_FIELD , poolingMode );
154154 builder .field (NORMALIZE_RESULT_FIELD , normalizeResult );
155155 builder .endObject ();
156156 return builder ;
157157 }
158158
159- public enum PoolingMethod {
160- MEAN ,
161- CLS ;
159+ public enum PoolingMode {
160+ MEAN ("mean" ),
161+ MEAN_SQRT_LEN ("mean_sqrt_len" ),
162+ MAX ("max" ),
163+ WEIGHTED_MEAN ("weightedmean" ),
164+ CLS ("cls" ),
165+ LAST_TOKEN ("lasttoken" );
162166
163- public static PoolingMethod from (String value ) {
167+ private String name ;
168+
169+ public String getName () {
170+ return name ;
171+ }
172+ PoolingMode (String name ) {
173+ this .name = name ;
174+ }
175+
176+ public static PoolingMode from (String value ) {
164177 try {
165- return PoolingMethod .valueOf (value );
178+ return PoolingMode .valueOf (value );
166179 } catch (Exception e ) {
167180 throw new IllegalArgumentException ("Wrong pooling method" );
168181 }
0 commit comments