|
38 | 38 | import java.util.List; |
39 | 39 | import java.util.Locale; |
40 | 40 | import java.util.Map; |
| 41 | +import java.util.TreeMap; |
41 | 42 |
|
42 | 43 | import static java.util.Map.entry; |
43 | 44 | import static org.elasticsearch.common.logging.LoggerMessageFormat.format; |
@@ -66,10 +67,12 @@ public class Categorize extends GroupingFunction.NonEvaluatableGroupingFunction |
66 | 67 | Categorize::new |
67 | 68 | ); |
68 | 69 |
|
69 | | - public static final Map<String, DataType> ALLOWED_OPTIONS = Map.ofEntries( |
70 | | - entry("analyzer", KEYWORD), |
71 | | - entry("output_format", KEYWORD), |
72 | | - entry("similarity_threshold", INTEGER) |
| 70 | + private static final String ANALYZER = "analyzer"; |
| 71 | + private static final String OUTPUT_FORMAT = "output_format"; |
| 72 | + private static final String SIMILARITY_THRESHOLD = "similarity_threshold"; |
| 73 | + |
| 74 | + private static final Map<String, DataType> ALLOWED_OPTIONS = new TreeMap<>( |
| 75 | + Map.ofEntries(entry(ANALYZER, KEYWORD), entry(OUTPUT_FORMAT, KEYWORD), entry(SIMILARITY_THRESHOLD, INTEGER)) |
73 | 76 | ); |
74 | 77 |
|
75 | 78 | private final Expression field; |
@@ -100,19 +103,19 @@ public Categorize( |
100 | 103 | description = "(Optional) Categorize additional options as <<esql-function-named-params,function named parameters>>.", |
101 | 104 | params = { |
102 | 105 | @MapParam.MapParamEntry( |
103 | | - name = "analyzer", |
| 106 | + name = ANALYZER, |
104 | 107 | type = "keyword", |
105 | 108 | valueHint = { "standard" }, |
106 | 109 | description = "Analyzer used to convert the field into tokens for text categorization." |
107 | 110 | ), |
108 | 111 | @MapParam.MapParamEntry( |
109 | | - name = "output_format", |
| 112 | + name = OUTPUT_FORMAT, |
110 | 113 | type = "keyword", |
111 | 114 | valueHint = { "regex", "tokens" }, |
112 | 115 | description = "The output format of the categories. Defaults to regex." |
113 | 116 | ), |
114 | 117 | @MapParam.MapParamEntry( |
115 | | - name = "similarity_threshold", |
| 118 | + name = SIMILARITY_THRESHOLD, |
116 | 119 | type = "integer", |
117 | 120 | valueHint = { "70" }, |
118 | 121 | description = "The minimum percentage of token weight that must match for text to be added to the category bucket. " |
@@ -166,40 +169,43 @@ public Nullability nullable() { |
166 | 169 |
|
167 | 170 | @Override |
168 | 171 | protected TypeResolution resolveType() { |
169 | | - return isString(field(), sourceText(), DEFAULT).and(Options.resolve(options, source(), SECOND, ALLOWED_OPTIONS)).and(() -> { |
170 | | - try { |
171 | | - categorizeDef(); |
172 | | - } catch (InvalidArgumentException e) { |
173 | | - return new TypeResolution(e.getMessage()); |
174 | | - } |
175 | | - return TypeResolution.TYPE_RESOLVED; |
176 | | - }); |
| 172 | + return isString(field(), sourceText(), DEFAULT).and( |
| 173 | + Options.resolve(options, source(), SECOND, ALLOWED_OPTIONS, this::verifyOptions) |
| 174 | + ); |
177 | 175 | } |
178 | 176 |
|
179 | | - public CategorizeDef categorizeDef() { |
180 | | - Map<String, Object> optionsMap = new HashMap<>(); |
181 | | - if (options != null) { |
182 | | - Options.populateMap((MapExpression) options, optionsMap, source(), SECOND, ALLOWED_OPTIONS); |
| 177 | + private void verifyOptions(Map<String, Object> optionsMap) { |
| 178 | + if (options == null) { |
| 179 | + return; |
183 | 180 | } |
184 | | - Integer similarityThreshold = (Integer) optionsMap.get("similarity_threshold"); |
| 181 | + Integer similarityThreshold = (Integer) optionsMap.get(SIMILARITY_THRESHOLD); |
185 | 182 | if (similarityThreshold != null) { |
186 | 183 | if (similarityThreshold <= 0 || similarityThreshold > 100) { |
187 | 184 | throw new InvalidArgumentException( |
188 | 185 | format("invalid similarity threshold [{}], expecting a number between 1 and 100, inclusive", similarityThreshold) |
189 | 186 | ); |
190 | 187 | } |
191 | 188 | } |
192 | | - OutputFormat outputFormat = null; |
193 | | - String outputFormatString = (String) optionsMap.get("output_format"); |
194 | | - if (outputFormatString != null) { |
| 189 | + String outputFormat = (String) optionsMap.get(OUTPUT_FORMAT); |
| 190 | + if (outputFormat != null) { |
195 | 191 | try { |
196 | | - outputFormat = OutputFormat.valueOf(outputFormatString.toUpperCase(Locale.ROOT)); |
| 192 | + OutputFormat.valueOf(outputFormat.toUpperCase(Locale.ROOT)); |
197 | 193 | } catch (IllegalArgumentException e) { |
198 | 194 | throw new InvalidArgumentException( |
199 | | - format(null, "invalid output format [{}], expecting one of [REGEX, TOKENS]", outputFormatString) |
| 195 | + format(null, "invalid output format [{}], expecting one of [REGEX, TOKENS]", outputFormat) |
200 | 196 | ); |
201 | 197 | } |
202 | 198 | } |
| 199 | + } |
| 200 | + |
| 201 | + public CategorizeDef categorizeDef() { |
| 202 | + Map<String, Object> optionsMap = new HashMap<>(); |
| 203 | + if (options != null) { |
| 204 | + Options.populateMap((MapExpression) options, optionsMap, source(), SECOND, ALLOWED_OPTIONS); |
| 205 | + } |
| 206 | + Integer similarityThreshold = (Integer) optionsMap.get(SIMILARITY_THRESHOLD); |
| 207 | + String outputFormatString = (String) optionsMap.get(OUTPUT_FORMAT); |
| 208 | + OutputFormat outputFormat = outputFormatString == null ? null : OutputFormat.valueOf(outputFormatString.toUpperCase(Locale.ROOT)); |
203 | 209 | return new CategorizeDef( |
204 | 210 | (String) optionsMap.get("analyzer"), |
205 | 211 | outputFormat == null ? REGEX : outputFormat, |
|
0 commit comments