|
7 | 7 |
|
8 | 8 | package org.elasticsearch.xpack.inference.queries; |
9 | 9 |
|
| 10 | +import org.elasticsearch.action.ResolvedIndices; |
| 11 | +import org.elasticsearch.cluster.metadata.IndexMetadata; |
| 12 | +import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; |
| 13 | +import org.elasticsearch.common.regex.Regex; |
| 14 | +import org.elasticsearch.common.settings.Settings; |
10 | 15 | import org.elasticsearch.features.NodeFeature; |
11 | 16 | import org.elasticsearch.index.query.DisMaxQueryBuilder; |
12 | 17 | import org.elasticsearch.index.query.MultiMatchQueryBuilder; |
13 | 18 | import org.elasticsearch.index.query.QueryBuilder; |
14 | 19 | import org.elasticsearch.index.query.QueryBuilders; |
| 20 | +import org.elasticsearch.index.search.QueryParserHelper; |
15 | 21 |
|
| 22 | +import java.util.Collection; |
16 | 23 | import java.util.HashMap; |
| 24 | +import java.util.HashSet; |
17 | 25 | import java.util.List; |
18 | 26 | import java.util.Map; |
19 | 27 | import java.util.Objects; |
20 | 28 | import java.util.Set; |
| 29 | +import java.util.stream.Collectors; |
| 30 | + |
| 31 | +import static org.elasticsearch.index.IndexSettings.DEFAULT_FIELD_SETTING; |
21 | 32 |
|
22 | 33 | public class SemanticMultiMatchQueryRewriteInterceptor extends SemanticQueryRewriteInterceptor { |
23 | 34 |
|
@@ -108,6 +119,118 @@ protected boolean shouldUseDefaultFields() { |
108 | 119 | return true; |
109 | 120 | } |
110 | 121 |
|
| 122 | + @Override |
| 123 | + protected InferenceIndexInformationForField resolveIndicesForFields( |
| 124 | + QueryBuilder queryBuilder, |
| 125 | + ResolvedIndices resolvedIndices, |
| 126 | + boolean resolveInferenceFieldWildcards |
| 127 | + ) { |
| 128 | + assert (queryBuilder instanceof MultiMatchQueryBuilder); |
| 129 | + MultiMatchQueryBuilder multiMatchQuery = (MultiMatchQueryBuilder) queryBuilder; |
| 130 | + |
| 131 | + // If wildcard resolution is disabled, use the simple parent class implementation |
| 132 | + if (resolveInferenceFieldWildcards == false || multiMatchQuery.resolveInferenceFieldWildcards() == false) { |
| 133 | + return super.resolveIndicesForFields(queryBuilder, resolvedIndices, resolveInferenceFieldWildcards); |
| 134 | + } |
| 135 | + |
| 136 | + return resolveIndicesForFieldsWithWildcards(queryBuilder, resolvedIndices); |
| 137 | + } |
| 138 | + |
| 139 | + private InferenceIndexInformationForField resolveIndicesForFieldsWithWildcards( |
| 140 | + QueryBuilder queryBuilder, |
| 141 | + ResolvedIndices resolvedIndices |
| 142 | + ) { |
| 143 | + Map<String, Float> fieldsWithWeights = getFieldsWithWeights(queryBuilder); |
| 144 | + Collection<IndexMetadata> indexMetadataCollection = resolvedIndices.getConcreteLocalIndicesMetadata().values(); |
| 145 | + |
| 146 | + // Global wildcard resolution for inference fields |
| 147 | + Map<String, Float> globalInferenceFieldBoosts = new HashMap<>(); |
| 148 | + |
| 149 | + // Get all unique inference fields across all indices |
| 150 | + Set<String> allInferenceFields = indexMetadataCollection.stream() |
| 151 | + .flatMap(idx -> idx.getInferenceFields().keySet().stream()) |
| 152 | + .collect(Collectors.toSet()); |
| 153 | + |
| 154 | + // Calculate boost for each inference field based on matching patterns |
| 155 | + for (String inferenceField : allInferenceFields) { |
| 156 | + for (Map.Entry<String, Float> entry : fieldsWithWeights.entrySet()) { |
| 157 | + String pattern = entry.getKey(); |
| 158 | + Float boost = entry.getValue(); |
| 159 | + |
| 160 | + if (Regex.isMatchAllPattern(pattern) |
| 161 | + || (Regex.isSimpleMatchPattern(pattern) && Regex.simpleMatch(pattern, inferenceField)) |
| 162 | + || pattern.equals(inferenceField)) { |
| 163 | + addToFieldBoostsMap(globalInferenceFieldBoosts, inferenceField, boost); |
| 164 | + } |
| 165 | + } |
| 166 | + } |
| 167 | + |
| 168 | + // Per-index processing using pre-calculated global boosts |
| 169 | + Map<String, Map<String, InferenceFieldMetadata>> inferenceFieldsPerIndex = new HashMap<>(); |
| 170 | + Map<String, Set<String>> nonInferenceFieldsPerIndex = new HashMap<>(); |
| 171 | + Map<String, Float> allFieldBoosts = new HashMap<>(globalInferenceFieldBoosts); |
| 172 | + |
| 173 | + for (IndexMetadata indexMetadata : indexMetadataCollection) { |
| 174 | + String indexName = indexMetadata.getIndex().getName(); |
| 175 | + Map<String, InferenceFieldMetadata> indexInferenceFields = new HashMap<>(); |
| 176 | + Map<String, InferenceFieldMetadata> indexInferenceMetadata = indexMetadata.getInferenceFields(); |
| 177 | + |
| 178 | + // Handle default fields per index when no fields are specified |
| 179 | + Map<String, Float> fieldsToProcess = fieldsWithWeights; |
| 180 | + if (fieldsToProcess.isEmpty() && shouldUseDefaultFields()) { |
| 181 | + Settings settings = indexMetadata.getSettings(); |
| 182 | + List<String> defaultFields = settings.getAsList(DEFAULT_FIELD_SETTING.getKey(), DEFAULT_FIELD_SETTING.getDefault(settings)); |
| 183 | + fieldsToProcess = QueryParserHelper.parseFieldsAndWeights(defaultFields); |
| 184 | + } |
| 185 | + |
| 186 | + // Collect resolved inference fields for this index |
| 187 | + Set<String> resolvedInferenceFields = new HashSet<>(); |
| 188 | + |
| 189 | + // Add wildcard-resolved inference fields that exist in this index |
| 190 | + for (String inferenceField : globalInferenceFieldBoosts.keySet()) { |
| 191 | + if (indexInferenceMetadata.containsKey(inferenceField)) { |
| 192 | + indexInferenceFields.put(inferenceField, indexInferenceMetadata.get(inferenceField)); |
| 193 | + resolvedInferenceFields.add(inferenceField); |
| 194 | + } |
| 195 | + } |
| 196 | + |
| 197 | + // Always handle explicit inference fields (both wildcard and non-wildcard cases) |
| 198 | + for (Map.Entry<String, Float> entry : fieldsToProcess.entrySet()) { |
| 199 | + String field = entry.getKey(); |
| 200 | + Float boost = entry.getValue(); |
| 201 | + |
| 202 | + if (indexInferenceMetadata.containsKey(field)) { |
| 203 | + indexInferenceFields.put(field, indexInferenceMetadata.get(field)); |
| 204 | + resolvedInferenceFields.add(field); |
| 205 | + addToFieldBoostsMap(allFieldBoosts, field, boost); |
| 206 | + } |
| 207 | + } |
| 208 | + |
| 209 | + // Non-inference fields: all patterns minus resolved inference fields |
| 210 | + Set<String> indexNonInferenceFields = new HashSet<>(fieldsToProcess.keySet()); |
| 211 | + indexNonInferenceFields.removeAll(resolvedInferenceFields); |
| 212 | + |
| 213 | + // Store boosts for non-inference field patterns |
| 214 | + for (String nonInferenceField : indexNonInferenceFields) { |
| 215 | + addToFieldBoostsMap(allFieldBoosts, nonInferenceField, fieldsToProcess.get(nonInferenceField)); |
| 216 | + } |
| 217 | + |
| 218 | + if (indexInferenceFields.isEmpty() == false) { |
| 219 | + inferenceFieldsPerIndex.put(indexName, indexInferenceFields); |
| 220 | + } |
| 221 | + |
| 222 | + if (indexNonInferenceFields.isEmpty() == false) { |
| 223 | + nonInferenceFieldsPerIndex.put(indexName, indexNonInferenceFields); |
| 224 | + } |
| 225 | + } |
| 226 | + |
| 227 | + return new InferenceIndexInformationForField(inferenceFieldsPerIndex, nonInferenceFieldsPerIndex, allFieldBoosts); |
| 228 | + } |
| 229 | + |
| 230 | + private static void addToFieldBoostsMap(Map<String, Float> fieldBoosts, String field, Float boost) { |
| 231 | + fieldBoosts.compute(field, (k, v) -> v == null ? boost : v * boost); |
| 232 | + } |
| 233 | + |
111 | 234 | private QueryBuilder buildMultiFieldSemanticQuery( |
112 | 235 | MultiMatchQueryBuilder originalQuery, |
113 | 236 | Set<String> inferenceFields, |
|
0 commit comments