|
14 | 14 | import org.apache.calcite.rel.type.RelDataType; |
15 | 15 | import org.apache.calcite.rel.type.RelDataTypeField; |
16 | 16 | import org.apache.calcite.rex.RexNode; |
| 17 | +import org.apache.calcite.sql.type.SqlTypeName; |
17 | 18 |
|
18 | 19 | /** |
19 | | - * Utility class for unifying schemas across multiple RelNodes. Throws an exception when type |
20 | | - * conflicts are detected. |
| 20 | + * Utility class for unifying schemas across multiple RelNodes. Supports two strategies: |
| 21 | + * |
| 22 | + * <ul> |
| 23 | + * <li>Conflict resolution (multisearch): throws on type mismatch, fills missing fields with NULL |
| 24 | + * <li>Type coercion (union): widens compatible types (e.g. INTEGER→BIGINT), falls back to VARCHAR |
| 25 | + * for incompatible types, fills missing fields with NULL |
| 26 | + * </ul> |
21 | 27 | */ |
22 | 28 | public class SchemaUnifier { |
23 | 29 |
|
@@ -147,4 +153,236 @@ RelDataType getType() { |
147 | 153 | return type; |
148 | 154 | } |
149 | 155 | } |
| 156 | + |
| 157 | + /** |
| 158 | + * Builds unified schema with type coercion for UNION command. Coerces compatible types to a |
| 159 | + * common supertype (e.g. int+float→float), falls back to VARCHAR for incompatible types, and |
| 160 | + * fills missing fields with NULL. |
| 161 | + */ |
| 162 | + public static List<RelNode> buildUnifiedSchemaWithTypeCoercion( |
| 163 | + List<RelNode> inputs, CalcitePlanContext context) { |
| 164 | + if (inputs.isEmpty() || inputs.size() == 1) { |
| 165 | + return inputs; |
| 166 | + } |
| 167 | + |
| 168 | + List<RelNode> coercedInputs = coerceUnionTypes(inputs, context); |
| 169 | + return unifySchemasForUnion(coercedInputs, context); |
| 170 | + } |
| 171 | + |
| 172 | + /** |
| 173 | + * Aligns schemas by projecting NULL for missing fields and CAST for type mismatches. Uses |
| 174 | + * force=true to clear collation traits and prevent EnumerableMergeUnion cast exception. |
| 175 | + */ |
| 176 | + private static List<RelNode> unifySchemasForUnion( |
| 177 | + List<RelNode> inputs, CalcitePlanContext context) { |
| 178 | + List<SchemaField> unifiedSchema = buildUnifiedSchemaForUnion(inputs); |
| 179 | + List<String> fieldNames = |
| 180 | + unifiedSchema.stream().map(SchemaField::getName).collect(Collectors.toList()); |
| 181 | + |
| 182 | + List<RelNode> projectedNodes = new ArrayList<>(); |
| 183 | + for (RelNode node : inputs) { |
| 184 | + List<RexNode> projection = buildProjectionForUnion(node, unifiedSchema, context); |
| 185 | + RelNode projectedNode = |
| 186 | + context.relBuilder.push(node).project(projection, fieldNames, true).build(); |
| 187 | + projectedNodes.add(projectedNode); |
| 188 | + } |
| 189 | + return projectedNodes; |
| 190 | + } |
| 191 | + |
| 192 | + private static List<SchemaField> buildUnifiedSchemaForUnion(List<RelNode> nodes) { |
| 193 | + List<SchemaField> schema = new ArrayList<>(); |
| 194 | + Map<String, RelDataType> seenFields = new HashMap<>(); |
| 195 | + |
| 196 | + for (RelNode node : nodes) { |
| 197 | + for (RelDataTypeField field : node.getRowType().getFieldList()) { |
| 198 | + if (!seenFields.containsKey(field.getName())) { |
| 199 | + schema.add(new SchemaField(field.getName(), field.getType())); |
| 200 | + seenFields.put(field.getName(), field.getType()); |
| 201 | + } |
| 202 | + } |
| 203 | + } |
| 204 | + return schema; |
| 205 | + } |
| 206 | + |
| 207 | + private static List<RexNode> buildProjectionForUnion( |
| 208 | + RelNode node, List<SchemaField> unifiedSchema, CalcitePlanContext context) { |
| 209 | + Map<String, RelDataTypeField> nodeFieldMap = |
| 210 | + node.getRowType().getFieldList().stream() |
| 211 | + .collect(Collectors.toMap(RelDataTypeField::getName, field -> field)); |
| 212 | + |
| 213 | + List<RexNode> projection = new ArrayList<>(); |
| 214 | + for (SchemaField schemaField : unifiedSchema) { |
| 215 | + RelDataTypeField nodeField = nodeFieldMap.get(schemaField.getName()); |
| 216 | + |
| 217 | + if (nodeField != null) { |
| 218 | + RexNode fieldRef = context.rexBuilder.makeInputRef(node, nodeField.getIndex()); |
| 219 | + if (!nodeField.getType().equals(schemaField.getType())) { |
| 220 | + projection.add(context.rexBuilder.makeCast(schemaField.getType(), fieldRef)); |
| 221 | + } else { |
| 222 | + projection.add(fieldRef); |
| 223 | + } |
| 224 | + } else { |
| 225 | + projection.add(context.rexBuilder.makeNullLiteral(schemaField.getType())); |
| 226 | + } |
| 227 | + } |
| 228 | + return projection; |
| 229 | + } |
| 230 | + |
| 231 | + /** Casts fields to their common supertypes across all inputs when types differ. */ |
| 232 | + private static List<RelNode> coerceUnionTypes(List<RelNode> inputs, CalcitePlanContext context) { |
| 233 | + Map<String, List<SqlTypeName>> fieldTypeMap = new HashMap<>(); |
| 234 | + for (RelNode input : inputs) { |
| 235 | + for (RelDataTypeField field : input.getRowType().getFieldList()) { |
| 236 | + String fieldName = field.getName(); |
| 237 | + SqlTypeName typeName = field.getType().getSqlTypeName(); |
| 238 | + if (typeName != null) { |
| 239 | + fieldTypeMap.computeIfAbsent(fieldName, k -> new ArrayList<>()).add(typeName); |
| 240 | + } |
| 241 | + } |
| 242 | + } |
| 243 | + |
| 244 | + Map<String, SqlTypeName> targetTypeMap = new HashMap<>(); |
| 245 | + for (Map.Entry<String, List<SqlTypeName>> entry : fieldTypeMap.entrySet()) { |
| 246 | + String fieldName = entry.getKey(); |
| 247 | + List<SqlTypeName> types = entry.getValue(); |
| 248 | + |
| 249 | + SqlTypeName commonType = types.getFirst(); |
| 250 | + for (int i = 1; i < types.size(); i++) { |
| 251 | + commonType = findCommonTypeForUnion(commonType, types.get(i)); |
| 252 | + } |
| 253 | + targetTypeMap.put(fieldName, commonType); |
| 254 | + } |
| 255 | + |
| 256 | + boolean needsCoercion = false; |
| 257 | + for (RelNode input : inputs) { |
| 258 | + for (RelDataTypeField field : input.getRowType().getFieldList()) { |
| 259 | + SqlTypeName targetType = targetTypeMap.get(field.getName()); |
| 260 | + if (targetType != null && field.getType().getSqlTypeName() != targetType) { |
| 261 | + needsCoercion = true; |
| 262 | + break; |
| 263 | + } |
| 264 | + } |
| 265 | + if (needsCoercion) break; |
| 266 | + } |
| 267 | + |
| 268 | + if (!needsCoercion) { |
| 269 | + return inputs; |
| 270 | + } |
| 271 | + |
| 272 | + List<RelNode> coercedInputs = new ArrayList<>(); |
| 273 | + for (RelNode input : inputs) { |
| 274 | + List<RexNode> projections = new ArrayList<>(); |
| 275 | + List<String> projectionNames = new ArrayList<>(); |
| 276 | + boolean needsProjection = false; |
| 277 | + |
| 278 | + for (RelDataTypeField field : input.getRowType().getFieldList()) { |
| 279 | + String fieldName = field.getName(); |
| 280 | + SqlTypeName currentType = field.getType().getSqlTypeName(); |
| 281 | + SqlTypeName targetType = targetTypeMap.get(fieldName); |
| 282 | + |
| 283 | + RexNode fieldRef = context.rexBuilder.makeInputRef(input, field.getIndex()); |
| 284 | + |
| 285 | + if (currentType != targetType && targetType != null) { |
| 286 | + projections.add(context.relBuilder.cast(fieldRef, targetType)); |
| 287 | + needsProjection = true; |
| 288 | + } else { |
| 289 | + projections.add(fieldRef); |
| 290 | + } |
| 291 | + projectionNames.add(fieldName); |
| 292 | + } |
| 293 | + |
| 294 | + if (needsProjection) { |
| 295 | + context.relBuilder.push(input); |
| 296 | + context.relBuilder.project(projections, projectionNames, true); |
| 297 | + coercedInputs.add(context.relBuilder.build()); |
| 298 | + } else { |
| 299 | + coercedInputs.add(input); |
| 300 | + } |
| 301 | + } |
| 302 | + |
| 303 | + return coercedInputs; |
| 304 | + } |
| 305 | + |
| 306 | + /** |
| 307 | + * Returns the wider type for two SqlTypeNames. Within the same family, returns the wider type |
| 308 | + * (e.g. INTEGER+BIGINT-->BIGINT). Across families, falls back to VARCHAR. |
| 309 | + */ |
| 310 | + private static SqlTypeName findCommonTypeForUnion(SqlTypeName type1, SqlTypeName type2) { |
| 311 | + if (type1 == type2) { |
| 312 | + return type1; |
| 313 | + } |
| 314 | + |
| 315 | + if (type1 == SqlTypeName.NULL) { |
| 316 | + return type2; |
| 317 | + } |
| 318 | + if (type2 == SqlTypeName.NULL) { |
| 319 | + return type1; |
| 320 | + } |
| 321 | + |
| 322 | + if (isNumericTypeForUnion(type1) && isNumericTypeForUnion(type2)) { |
| 323 | + return getWiderNumericTypeForUnion(type1, type2); |
| 324 | + } |
| 325 | + |
| 326 | + if (isStringTypeForUnion(type1) && isStringTypeForUnion(type2)) { |
| 327 | + return SqlTypeName.VARCHAR; |
| 328 | + } |
| 329 | + |
| 330 | + if (isTemporalTypeForUnion(type1) && isTemporalTypeForUnion(type2)) { |
| 331 | + return getWiderTemporalTypeForUnion(type1, type2); |
| 332 | + } |
| 333 | + |
| 334 | + return SqlTypeName.VARCHAR; |
| 335 | + } |
| 336 | + |
| 337 | + private static boolean isNumericTypeForUnion(SqlTypeName typeName) { |
| 338 | + return typeName == SqlTypeName.TINYINT |
| 339 | + || typeName == SqlTypeName.SMALLINT |
| 340 | + || typeName == SqlTypeName.INTEGER |
| 341 | + || typeName == SqlTypeName.BIGINT |
| 342 | + || typeName == SqlTypeName.FLOAT |
| 343 | + || typeName == SqlTypeName.REAL |
| 344 | + || typeName == SqlTypeName.DOUBLE |
| 345 | + || typeName == SqlTypeName.DECIMAL; |
| 346 | + } |
| 347 | + |
| 348 | + private static boolean isStringTypeForUnion(SqlTypeName typeName) { |
| 349 | + return typeName == SqlTypeName.CHAR || typeName == SqlTypeName.VARCHAR; |
| 350 | + } |
| 351 | + |
| 352 | + private static boolean isTemporalTypeForUnion(SqlTypeName typeName) { |
| 353 | + return typeName == SqlTypeName.DATE |
| 354 | + || typeName == SqlTypeName.TIMESTAMP |
| 355 | + || typeName == SqlTypeName.TIMESTAMP_WITH_LOCAL_TIME_ZONE; |
| 356 | + } |
| 357 | + |
| 358 | + private static SqlTypeName getWiderNumericTypeForUnion(SqlTypeName type1, SqlTypeName type2) { |
| 359 | + int rank1 = getNumericTypeRankForUnion(type1); |
| 360 | + int rank2 = getNumericTypeRankForUnion(type2); |
| 361 | + return rank1 >= rank2 ? type1 : type2; |
| 362 | + } |
| 363 | + |
| 364 | + private static int getNumericTypeRankForUnion(SqlTypeName typeName) { |
| 365 | + return switch (typeName) { |
| 366 | + case TINYINT -> 1; |
| 367 | + case SMALLINT -> 2; |
| 368 | + case INTEGER -> 3; |
| 369 | + case BIGINT -> 4; |
| 370 | + case FLOAT -> 5; |
| 371 | + case REAL -> 6; |
| 372 | + case DOUBLE -> 7; |
| 373 | + case DECIMAL -> 8; |
| 374 | + default -> 0; |
| 375 | + }; |
| 376 | + } |
| 377 | + |
| 378 | + private static SqlTypeName getWiderTemporalTypeForUnion(SqlTypeName type1, SqlTypeName type2) { |
| 379 | + if (type1 == SqlTypeName.TIMESTAMP || type2 == SqlTypeName.TIMESTAMP) { |
| 380 | + return SqlTypeName.TIMESTAMP; |
| 381 | + } |
| 382 | + if (type1 == SqlTypeName.TIMESTAMP_WITH_LOCAL_TIME_ZONE |
| 383 | + || type2 == SqlTypeName.TIMESTAMP_WITH_LOCAL_TIME_ZONE) { |
| 384 | + return SqlTypeName.TIMESTAMP_WITH_LOCAL_TIME_ZONE; |
| 385 | + } |
| 386 | + return SqlTypeName.DATE; |
| 387 | + } |
150 | 388 | } |
0 commit comments