|
36 | 36 |
|
37 | 37 | import static org.elasticsearch.compute.data.BlockUtils.toJavaObject; |
38 | 38 | import static org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettings; |
| 39 | +import static org.elasticsearch.xpack.esql.core.type.DataType.UNSUPPORTED; |
39 | 40 | import static org.elasticsearch.xpack.esql.expression.function.scalar.string.Chunk.ALLOWED_CHUNKING_SETTING_OPTIONS; |
40 | 41 | import static org.elasticsearch.xpack.esql.expression.function.scalar.string.Chunk.DEFAULT_CHUNKING_SETTINGS; |
41 | 42 | import static org.elasticsearch.xpack.esql.expression.function.scalar.util.ChunkUtils.chunkText; |
@@ -64,42 +65,70 @@ private static String randomWordsBetween(int min, int max) { |
64 | 65 |
|
65 | 66 | @ParametersFactory |
66 | 67 | public static Iterable<Object[]> parameters() { |
67 | | - return parameterSuppliersFromTypedDataWithDefaultChecks( |
68 | | - true, |
69 | | - List.of(new TestCaseSupplier("Chunk with defaults", List.of(DataType.KEYWORD), () -> { |
70 | | - String text = randomWordsBetween(25, 50); |
71 | | - ChunkingSettings chunkingSettings = new SentenceBoundaryChunkingSettings(Chunk.DEFAULT_CHUNK_SIZE, 0); |
| 68 | + return parameterSuppliersFromTypedData(testCaseSuppliers()); |
| 69 | + } |
72 | 70 |
|
73 | | - List<String> chunks = chunkText(text, chunkingSettings); |
74 | | - Object expectedResult = chunks.size() == 1 |
75 | | - ? new BytesRef(chunks.get(0).trim()) |
76 | | - : chunks.stream().map(s -> new BytesRef(s.trim())).toList(); |
| 71 | + private static List<TestCaseSupplier> testCaseSuppliers() { |
| 72 | + List<TestCaseSupplier> suppliers = new ArrayList<>(); |
| 73 | + suppliers.add(createTestCaseSupplier("Chunk with defaults", DataType.KEYWORD)); |
| 74 | + suppliers.add(createTestCaseSupplier("Chunk with defaults text input", DataType.TEXT)); |
| 75 | + return addFunctionNamedParams(suppliers); |
| 76 | + } |
77 | 77 |
|
78 | | - return new TestCaseSupplier.TestCase( |
79 | | - List.of(new TestCaseSupplier.TypedData(new BytesRef(text), DataType.KEYWORD, "str")), |
80 | | - "ChunkBytesRefEvaluator[str=Attribute[channel=0], " |
81 | | - + "chunkingSettings={\"strategy\":\"sentence\",\"max_chunk_size\":300,\"sentence_overlap\":0}]", |
82 | | - DataType.KEYWORD, |
83 | | - equalTo(expectedResult) |
84 | | - ); |
85 | | - }), new TestCaseSupplier("Chunk with defaults text input", List.of(DataType.TEXT), () -> { |
| 78 | + private static TestCaseSupplier createTestCaseSupplier(String description, DataType fieldDataType) { |
| 79 | + return new TestCaseSupplier(description, List.of(fieldDataType), () -> { |
| 80 | + String text = randomWordsBetween(25, 50); |
| 81 | + ChunkingSettings chunkingSettings = new SentenceBoundaryChunkingSettings(Chunk.DEFAULT_CHUNK_SIZE, 0); |
| 82 | + |
| 83 | + List<String> chunks = chunkText(text, chunkingSettings); |
| 84 | + Object expectedResult = chunks.size() == 1 |
| 85 | + ? new BytesRef(chunks.get(0).trim()) |
| 86 | + : chunks.stream().map(s -> new BytesRef(s.trim())).toList(); |
| 87 | + |
| 88 | + return new TestCaseSupplier.TestCase( |
| 89 | + List.of(new TestCaseSupplier.TypedData(new BytesRef(text), fieldDataType, "str")), |
| 90 | + "ChunkBytesRefEvaluator[str=Attribute[channel=0], " |
| 91 | + + "chunkingSettings={\"strategy\":\"sentence\",\"max_chunk_size\":300,\"sentence_overlap\":0}]", |
| 92 | + DataType.KEYWORD, |
| 93 | + equalTo(expectedResult) |
| 94 | + ); |
| 95 | + }); |
| 96 | + } |
| 97 | + |
| 98 | + /** |
| 99 | + * Adds function named parameters to all the test case suppliers provided |
| 100 | + */ |
| 101 | + private static List<TestCaseSupplier> addFunctionNamedParams(List<TestCaseSupplier> suppliers) { |
| 102 | + List<TestCaseSupplier> result = new ArrayList<>(suppliers); |
| 103 | + for (TestCaseSupplier supplier : suppliers) { |
| 104 | + List<DataType> dataTypes = new ArrayList<>(supplier.types()); |
| 105 | + dataTypes.add(UNSUPPORTED); |
| 106 | + result.add(new TestCaseSupplier(supplier.name() + ", with chunking_settings", dataTypes, () -> { |
86 | 107 | String text = randomWordsBetween(25, 50); |
87 | | - ChunkingSettings chunkingSettings = new SentenceBoundaryChunkingSettings(Chunk.DEFAULT_CHUNK_SIZE, 0); |
| 108 | + int chunkSize = 25; |
| 109 | + ChunkingSettings chunkingSettings = new SentenceBoundaryChunkingSettings(chunkSize, 0); |
88 | 110 |
|
89 | 111 | List<String> chunks = chunkText(text, chunkingSettings); |
90 | 112 | Object expectedResult = chunks.size() == 1 |
91 | 113 | ? new BytesRef(chunks.get(0).trim()) |
92 | 114 | : chunks.stream().map(s -> new BytesRef(s.trim())).toList(); |
93 | 115 |
|
| 116 | + List<TestCaseSupplier.TypedData> values = List.of( |
| 117 | + new TestCaseSupplier.TypedData(new BytesRef(text), supplier.types().get(0), "str"), |
| 118 | + new TestCaseSupplier.TypedData(createChunkingSettings(chunkingSettings), UNSUPPORTED, "chunking_settings") |
| 119 | + .forceLiteral() |
| 120 | + ); |
| 121 | + |
94 | 122 | return new TestCaseSupplier.TestCase( |
95 | | - List.of(new TestCaseSupplier.TypedData(new BytesRef(text), DataType.TEXT, "str")), |
| 123 | + values, |
96 | 124 | "ChunkBytesRefEvaluator[str=Attribute[channel=0], " |
97 | | - + "chunkingSettings={\"strategy\":\"sentence\",\"max_chunk_size\":300,\"sentence_overlap\":0}]", |
| 125 | + + "chunkingSettings={\"strategy\":\"sentence\",\"max_chunk_size\":25,\"sentence_overlap\":0}]", |
98 | 126 | DataType.KEYWORD, |
99 | 127 | equalTo(expectedResult) |
100 | 128 | ); |
101 | | - })) |
102 | | - ); |
| 129 | + })); |
| 130 | + } |
| 131 | + return result; |
103 | 132 | } |
104 | 133 |
|
105 | 134 | private static MapExpression createChunkingSettings(ChunkingSettings chunkingSettings) { |
@@ -131,6 +160,16 @@ protected Expression build(Source source, List<Expression> args) { |
131 | 160 | return new Chunk(source, args.get(0), options); |
132 | 161 | } |
133 | 162 |
|
| 163 | + @Override |
| 164 | + public void testFold() { |
| 165 | + Expression expression = buildFieldExpression(testCase); |
| 166 | + // Skip testFold if the expression is not foldable (e.g., when chunking_settings contains MapExpression) |
| 167 | + if (expression.foldable() == false) { |
| 168 | + return; |
| 169 | + } |
| 170 | + super.testFold(); |
| 171 | + } |
| 172 | + |
134 | 173 | public void testDefaults() { |
135 | 174 | // Default of 300 is huge, only one chunk returned in this case |
136 | 175 | verifyChunks(null, 1); |
|
0 commit comments