|
11 | 11 | import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; |
12 | 12 |
|
13 | 13 | import org.apache.lucene.util.BytesRef; |
| 14 | +import org.elasticsearch.common.lucene.BytesRefs; |
14 | 15 | import org.elasticsearch.common.unit.ByteSizeValue; |
15 | 16 | import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction; |
16 | 17 | import org.elasticsearch.compute.data.Block; |
|
35 | 36 | import org.elasticsearch.xpack.esql.core.util.Holder; |
36 | 37 |
|
37 | 38 | import java.util.ArrayList; |
| 39 | +import java.util.HashSet; |
38 | 40 | import java.util.List; |
39 | 41 | import java.util.Locale; |
| 42 | +import java.util.Set; |
40 | 43 | import java.util.function.Consumer; |
41 | 44 | import java.util.stream.IntStream; |
42 | 45 | import java.util.stream.LongStream; |
@@ -1232,6 +1235,194 @@ public void testLongNull() { |
1232 | 1235 | }, blockFactory.newLongArrayVector(values, values.length).asBlock(), blockFactory.newConstantNullBlock(values.length)); |
1233 | 1236 | } |
1234 | 1237 |
|
| 1238 | + public void test2BytesRefsHighCardinalityKey() { |
| 1239 | + final Page page; |
| 1240 | + int positions1 = 10; |
| 1241 | + int positions2 = 100_000; |
| 1242 | + if (randomBoolean()) { |
| 1243 | + positions1 = 100_000; |
| 1244 | + positions2 = 10; |
| 1245 | + } |
| 1246 | + final int totalPositions = positions1 * positions2; |
| 1247 | + try ( |
| 1248 | + BytesRefBlock.Builder builder1 = blockFactory.newBytesRefBlockBuilder(totalPositions); |
| 1249 | + BytesRefBlock.Builder builder2 = blockFactory.newBytesRefBlockBuilder(totalPositions); |
| 1250 | + ) { |
| 1251 | + for (int i = 0; i < positions1; i++) { |
| 1252 | + for (int p = 0; p < positions2; p++) { |
| 1253 | + builder1.appendBytesRef(new BytesRef("abcdef" + i)); |
| 1254 | + builder2.appendBytesRef(new BytesRef("abcdef" + p)); |
| 1255 | + } |
| 1256 | + } |
| 1257 | + page = new Page(builder1.build(), builder2.build()); |
| 1258 | + } |
| 1259 | + record Output(int offset, IntBlock block, IntVector vector) implements Releasable { |
| 1260 | + @Override |
| 1261 | + public void close() { |
| 1262 | + Releasables.close(block, vector); |
| 1263 | + } |
| 1264 | + } |
| 1265 | + List<Output> output = new ArrayList<>(); |
| 1266 | + |
| 1267 | + try (BlockHash hash1 = new BytesRef2BlockHash(blockFactory, 0, 1, totalPositions);) { |
| 1268 | + hash1.add(page, new GroupingAggregatorFunction.AddInput() { |
| 1269 | + @Override |
| 1270 | + public void add(int positionOffset, IntArrayBlock groupIds) { |
| 1271 | + groupIds.incRef(); |
| 1272 | + output.add(new Output(positionOffset, groupIds, null)); |
| 1273 | + } |
| 1274 | + |
| 1275 | + @Override |
| 1276 | + public void add(int positionOffset, IntBigArrayBlock groupIds) { |
| 1277 | + groupIds.incRef(); |
| 1278 | + output.add(new Output(positionOffset, groupIds, null)); |
| 1279 | + } |
| 1280 | + |
| 1281 | + @Override |
| 1282 | + public void add(int positionOffset, IntVector groupIds) { |
| 1283 | + groupIds.incRef(); |
| 1284 | + output.add(new Output(positionOffset, null, groupIds)); |
| 1285 | + } |
| 1286 | + |
| 1287 | + @Override |
| 1288 | + public void close() { |
| 1289 | + fail("hashes should not close AddInput"); |
| 1290 | + } |
| 1291 | + }); |
| 1292 | + |
| 1293 | + Block[] keys = hash1.getKeys(); |
| 1294 | + try { |
| 1295 | + Set<String> distinctKeys = new HashSet<>(); |
| 1296 | + BytesRefBlock block0 = (BytesRefBlock) keys[0]; |
| 1297 | + BytesRefBlock block1 = (BytesRefBlock) keys[1]; |
| 1298 | + BytesRef scratch = new BytesRef(); |
| 1299 | + StringBuilder builder = new StringBuilder(); |
| 1300 | + for (int i = 0; i < totalPositions; i++) { |
| 1301 | + builder.setLength(0); |
| 1302 | + builder.append(BytesRefs.toString(block0.getBytesRef(i, scratch))); |
| 1303 | + builder.append("#"); |
| 1304 | + builder.append(BytesRefs.toString(block1.getBytesRef(i, scratch))); |
| 1305 | + distinctKeys.add(builder.toString()); |
| 1306 | + } |
| 1307 | + assertThat(distinctKeys.size(), equalTo(totalPositions)); |
| 1308 | + } finally { |
| 1309 | + Releasables.close(keys); |
| 1310 | + } |
| 1311 | + } finally { |
| 1312 | + Releasables.close(output); |
| 1313 | + page.releaseBlocks(); |
| 1314 | + } |
| 1315 | + } |
| 1316 | + |
| 1317 | + public void test2BytesRefs() { |
| 1318 | + final Page page; |
| 1319 | + final int positions = randomIntBetween(1, 1000); |
| 1320 | + final boolean generateVector = randomBoolean(); |
| 1321 | + try ( |
| 1322 | + BytesRefBlock.Builder builder1 = blockFactory.newBytesRefBlockBuilder(positions); |
| 1323 | + BytesRefBlock.Builder builder2 = blockFactory.newBytesRefBlockBuilder(positions); |
| 1324 | + ) { |
| 1325 | + List<BytesRefBlock.Builder> builders = List.of(builder1, builder2); |
| 1326 | + for (int p = 0; p < positions; p++) { |
| 1327 | + for (BytesRefBlock.Builder builder : builders) { |
| 1328 | + int valueCount = generateVector ? 1 : between(0, 3); |
| 1329 | + switch (valueCount) { |
| 1330 | + case 0 -> builder.appendNull(); |
| 1331 | + case 1 -> builder.appendBytesRef(new BytesRef(Integer.toString(between(1, 100)))); |
| 1332 | + default -> { |
| 1333 | + builder.beginPositionEntry(); |
| 1334 | + for (int v = 0; v < valueCount; v++) { |
| 1335 | + builder.appendBytesRef(new BytesRef(Integer.toString(between(1, 100)))); |
| 1336 | + } |
| 1337 | + builder.endPositionEntry(); |
| 1338 | + } |
| 1339 | + } |
| 1340 | + } |
| 1341 | + } |
| 1342 | + page = new Page(builder1.build(), builder2.build()); |
| 1343 | + } |
| 1344 | + final int emitBatchSize = between(positions, 10 * 1024); |
| 1345 | + var groupSpecs = List.of(new BlockHash.GroupSpec(0, ElementType.BYTES_REF), new BlockHash.GroupSpec(1, ElementType.BYTES_REF)); |
| 1346 | + record Output(int offset, IntBlock block, IntVector vector) implements Releasable { |
| 1347 | + @Override |
| 1348 | + public void close() { |
| 1349 | + Releasables.close(block, vector); |
| 1350 | + } |
| 1351 | + } |
| 1352 | + List<Output> output1 = new ArrayList<>(); |
| 1353 | + List<Output> output2 = new ArrayList<>(); |
| 1354 | + try ( |
| 1355 | + BlockHash hash1 = new BytesRef2BlockHash(blockFactory, 0, 1, emitBatchSize); |
| 1356 | + BlockHash hash2 = new PackedValuesBlockHash(groupSpecs, blockFactory, emitBatchSize) |
| 1357 | + ) { |
| 1358 | + hash1.add(page, new GroupingAggregatorFunction.AddInput() { |
| 1359 | + @Override |
| 1360 | + public void add(int positionOffset, IntArrayBlock groupIds) { |
| 1361 | + groupIds.incRef(); |
| 1362 | + output1.add(new Output(positionOffset, groupIds, null)); |
| 1363 | + } |
| 1364 | + |
| 1365 | + @Override |
| 1366 | + public void add(int positionOffset, IntBigArrayBlock groupIds) { |
| 1367 | + groupIds.incRef(); |
| 1368 | + output1.add(new Output(positionOffset, groupIds, null)); |
| 1369 | + } |
| 1370 | + |
| 1371 | + @Override |
| 1372 | + public void add(int positionOffset, IntVector groupIds) { |
| 1373 | + groupIds.incRef(); |
| 1374 | + output1.add(new Output(positionOffset, null, groupIds)); |
| 1375 | + } |
| 1376 | + |
| 1377 | + @Override |
| 1378 | + public void close() { |
| 1379 | + fail("hashes should not close AddInput"); |
| 1380 | + } |
| 1381 | + }); |
| 1382 | + hash2.add(page, new GroupingAggregatorFunction.AddInput() { |
| 1383 | + @Override |
| 1384 | + public void add(int positionOffset, IntArrayBlock groupIds) { |
| 1385 | + groupIds.incRef(); |
| 1386 | + output2.add(new Output(positionOffset, groupIds, null)); |
| 1387 | + } |
| 1388 | + |
| 1389 | + @Override |
| 1390 | + public void add(int positionOffset, IntBigArrayBlock groupIds) { |
| 1391 | + groupIds.incRef(); |
| 1392 | + output2.add(new Output(positionOffset, groupIds, null)); |
| 1393 | + } |
| 1394 | + |
| 1395 | + @Override |
| 1396 | + public void add(int positionOffset, IntVector groupIds) { |
| 1397 | + groupIds.incRef(); |
| 1398 | + output2.add(new Output(positionOffset, null, groupIds)); |
| 1399 | + } |
| 1400 | + |
| 1401 | + @Override |
| 1402 | + public void close() { |
| 1403 | + fail("hashes should not close AddInput"); |
| 1404 | + } |
| 1405 | + }); |
| 1406 | + assertThat(output1.size(), equalTo(output2.size())); |
| 1407 | + for (int i = 0; i < output1.size(); i++) { |
| 1408 | + Output o1 = output1.get(i); |
| 1409 | + Output o2 = output2.get(i); |
| 1410 | + assertThat(o1.offset, equalTo(o2.offset)); |
| 1411 | + if (o1.vector != null) { |
| 1412 | + assertNull(o1.block); |
| 1413 | + assertThat(o1.vector, equalTo(o2.vector != null ? o2.vector : o2.block.asVector())); |
| 1414 | + } else { |
| 1415 | + assertNull(o2.vector); |
| 1416 | + assertThat(o1.block, equalTo(o2.block)); |
| 1417 | + } |
| 1418 | + } |
| 1419 | + } finally { |
| 1420 | + Releasables.close(output1); |
| 1421 | + Releasables.close(output2); |
| 1422 | + page.releaseBlocks(); |
| 1423 | + } |
| 1424 | + } |
| 1425 | + |
1235 | 1426 | public void test3BytesRefs() { |
1236 | 1427 | final Page page; |
1237 | 1428 | final int positions = randomIntBetween(1, 1000); |
@@ -1326,7 +1517,7 @@ public void close() { |
1326 | 1517 | fail("hashes should not close AddInput"); |
1327 | 1518 | } |
1328 | 1519 | }); |
1329 | | - assertThat(output1.size(), equalTo(output1.size())); |
| 1520 | + assertThat(output1.size(), equalTo(output2.size())); |
1330 | 1521 | for (int i = 0; i < output1.size(); i++) { |
1331 | 1522 | Output o1 = output1.get(i); |
1332 | 1523 | Output o2 = output2.get(i); |
|
0 commit comments