Skip to content

Commit 605b373

Browse files
committed
Add UTF-8 code-point byte-length helper for Trino-style pad loops
Adds codePointByteLengths so callers can decode UTF-8 once and directly materialize per-code-point byte widths (1..4) for padding/loop planning. Benchmark (SliceUtf8Benchmark, length=128 code points): - ascii=true: helper(byte[]) 0.696 ns/codepoint vs Trino byte[] baseline 1.020 ns/codepoint - ascii=false: helper(byte[]) 2.129 ns/codepoint vs Trino byte[] baseline 3.596 ns/codepoint
1 parent f706f30 commit 605b373

File tree

3 files changed

+336
-0
lines changed

3 files changed

+336
-0
lines changed

src/main/java/io/airlift/slice/SliceUtf8.java

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1666,6 +1666,64 @@ else if (codePoint < 0x1_0000) {
16661666
return Arrays.copyOf(codePoints, codePointCount);
16671667
}
16681668

1669+
/**
1670+
* Decodes UTF-8 and returns UTF-8 byte lengths ({@code 1..4}) for each code point.
1671+
* <p>
1672+
* Note: This method does not explicitly check for valid UTF-8, and may
1673+
* return incorrect results or throw an exception for invalid UTF-8.
1674+
*/
1675+
public static byte[] codePointByteLengths(Slice utf8)
1676+
{
1677+
return codePointByteLengths(utf8.byteArray(), utf8.byteArrayOffset(), utf8.length());
1678+
}
1679+
1680+
/**
1681+
* Decodes UTF-8 byte array range and returns UTF-8 byte lengths ({@code 1..4}) for each code point.
1682+
* <p>
1683+
* Note: This method does not explicitly check for valid UTF-8, and may
1684+
* return incorrect results or throw an exception for invalid UTF-8.
1685+
*/
1686+
public static byte[] codePointByteLengths(byte[] utf8, int offset, int length)
1687+
{
1688+
checkFromIndexSize(offset, length, utf8.length);
1689+
return codePointByteLengthsRaw(utf8, offset, length);
1690+
}
1691+
1692+
private static byte[] codePointByteLengthsRaw(byte[] utf8, int utf8Offset, int utf8Length)
1693+
{
1694+
if (utf8Length == 0) {
1695+
return new byte[0];
1696+
}
1697+
1698+
if (isAsciiRaw(utf8, utf8Offset, utf8Length)) {
1699+
byte[] lengths = new byte[utf8Length];
1700+
Arrays.fill(lengths, (byte) 1);
1701+
return lengths;
1702+
}
1703+
1704+
byte[] lengths = new byte[Math.max(8, utf8Length >>> 1)];
1705+
int codePointCount = 0;
1706+
int position = 0;
1707+
while (position < utf8Length) {
1708+
int codePointLength = lengthOfCodePointFromStartByteSafe(utf8[utf8Offset + position]);
1709+
if (codePointLength < 0 || position + codePointLength > utf8Length) {
1710+
throw new InvalidUtf8Exception("Invalid UTF-8 sequence at position " + position);
1711+
}
1712+
1713+
if (codePointCount == lengths.length) {
1714+
lengths = Arrays.copyOf(lengths, lengths.length * 2);
1715+
}
1716+
lengths[codePointCount] = (byte) codePointLength;
1717+
codePointCount++;
1718+
position += codePointLength;
1719+
}
1720+
1721+
if (codePointCount == lengths.length) {
1722+
return lengths;
1723+
}
1724+
return Arrays.copyOf(lengths, codePointCount);
1725+
}
1726+
16691727
/**
16701728
* Encodes Unicode code points into UTF-8.
16711729
*

src/test/java/io/airlift/slice/SliceUtf8Benchmark.java

Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import java.util.concurrent.ThreadLocalRandom;
3434
import java.util.stream.IntStream;
3535

36+
import static io.airlift.slice.SliceUtf8.codePointByteLengths;
3637
import static io.airlift.slice.SliceUtf8.codePointToUtf8;
3738
import static io.airlift.slice.SliceUtf8.compareUtf16BE;
3839
import static io.airlift.slice.SliceUtf8.countCodePoints;
@@ -339,6 +340,125 @@ else if ((currentChar == '%') || (currentChar == '_')) {
339340
return position;
340341
}
341342

343+
@Benchmark
344+
public int benchmarkTrinoPadStringCodePointLengths(TrinoPadData data)
345+
{
346+
Slice padString = data.getPadString();
347+
int padStringLength = countCodePoints(padString);
348+
int[] padStringCounts = new int[padStringLength];
349+
for (int index = 0; index < padStringLength; index++) {
350+
padStringCounts[index] = lengthOfCodePointSafe(padString, offsetOfCodePoint(padString, index));
351+
}
352+
return checksum(padStringCounts);
353+
}
354+
355+
@Benchmark
356+
public int benchmarkTrinoPadStringCodePointLengthsSinglePass(TrinoPadData data)
357+
{
358+
Slice padString = data.getPadString();
359+
int[] padStringCounts = new int[countCodePoints(padString)];
360+
int position = 0;
361+
int index = 0;
362+
while (position < padString.length()) {
363+
int codePoint = getCodePointAt(padString, position);
364+
int codePointLength = lengthOfCodePoint(codePoint);
365+
padStringCounts[index] = codePointLength;
366+
index++;
367+
position += codePointLength;
368+
}
369+
if (index != padStringCounts.length) {
370+
throw new AssertionError();
371+
}
372+
return checksum(padStringCounts);
373+
}
374+
375+
@Benchmark
376+
public int benchmarkTrinoPadStringCodePointLengthsByteArray(TrinoPadData data)
377+
{
378+
byte[] utf8 = data.getUtf8();
379+
int baseOffset = data.getOffset();
380+
int byteLength = data.getByteLength();
381+
int[] padStringCounts = new int[countCodePoints(utf8, baseOffset, byteLength)];
382+
int position = 0;
383+
int index = 0;
384+
while (position < byteLength) {
385+
int codePoint = getCodePointAt(utf8, baseOffset, byteLength, position);
386+
int codePointLength = lengthOfCodePoint(codePoint);
387+
padStringCounts[index] = codePointLength;
388+
index++;
389+
position += codePointLength;
390+
}
391+
if (index != padStringCounts.length) {
392+
throw new AssertionError();
393+
}
394+
return checksum(padStringCounts);
395+
}
396+
397+
@Benchmark
398+
public int benchmarkTrinoPadStringCodePointLengthsSliceUtf8Helper(TrinoPadData data)
399+
{
400+
return checksum(codePointByteLengths(data.getPadString()));
401+
}
402+
403+
@Benchmark
404+
public int benchmarkTrinoPadStringCodePointLengthsSliceUtf8HelperByteArray(TrinoPadData data)
405+
{
406+
return checksum(codePointByteLengths(data.getUtf8(), data.getOffset(), data.getByteLength()));
407+
}
408+
409+
@Benchmark
410+
public Slice benchmarkTrinoDomainTranslatorPrefixRange(TrinoPrefixRangeData data)
411+
{
412+
Slice constantPrefix = data.getConstantPrefix();
413+
414+
int lastIncrementable = -1;
415+
for (int position = 0; position < constantPrefix.length(); position += lengthOfCodePoint(constantPrefix, position)) {
416+
if (getCodePointAt(constantPrefix, position) < 127) {
417+
lastIncrementable = position;
418+
}
419+
}
420+
421+
if (lastIncrementable == -1) {
422+
return Slices.EMPTY_SLICE;
423+
}
424+
425+
Slice upperBound = constantPrefix.slice(0, lastIncrementable + lengthOfCodePoint(constantPrefix, lastIncrementable)).copy();
426+
setCodePointAt(getCodePointAt(constantPrefix, lastIncrementable) + 1, upperBound, lastIncrementable);
427+
return upperBound;
428+
}
429+
430+
@Benchmark
431+
public Slice benchmarkTrinoDomainTranslatorPrefixRangeSingleDecode(TrinoPrefixRangeData data)
432+
{
433+
byte[] utf8 = data.getUtf8();
434+
int baseOffset = data.getOffset();
435+
int byteLength = data.getByteLength();
436+
Slice constantPrefix = data.getConstantPrefix();
437+
438+
int lastIncrementableOffset = -1;
439+
int lastIncrementableCodePoint = -1;
440+
int lastIncrementableLength = 0;
441+
int position = 0;
442+
while (position < byteLength) {
443+
int codePoint = getCodePointAt(utf8, baseOffset, byteLength, position);
444+
int codePointLength = lengthOfCodePoint(codePoint);
445+
if (codePoint < 127) {
446+
lastIncrementableOffset = position;
447+
lastIncrementableCodePoint = codePoint;
448+
lastIncrementableLength = codePointLength;
449+
}
450+
position += codePointLength;
451+
}
452+
453+
if (lastIncrementableOffset == -1) {
454+
return Slices.EMPTY_SLICE;
455+
}
456+
457+
Slice upperBound = constantPrefix.slice(0, lastIncrementableOffset + lastIncrementableLength).copy();
458+
setCodePointAt(lastIncrementableCodePoint + 1, upperBound, lastIncrementableOffset);
459+
return upperBound;
460+
}
461+
342462
@Benchmark
343463
public int benchmarkCompareUtf16BE(CompareData data)
344464
{
@@ -452,6 +572,24 @@ public int benchmarkCodePointToUtf8(CodePointWriteData data)
452572
return totalBytes;
453573
}
454574

575+
private static int checksum(int[] values)
576+
{
577+
int checksum = 1;
578+
for (int value : values) {
579+
checksum = (31 * checksum) ^ value;
580+
}
581+
return checksum;
582+
}
583+
584+
private static int checksum(byte[] values)
585+
{
586+
int checksum = 1;
587+
for (byte value : values) {
588+
checksum = (31 * checksum) ^ value;
589+
}
590+
return checksum;
591+
}
592+
455593
@State(Thread)
456594
public static class BenchmarkData
457595
{
@@ -814,6 +952,120 @@ public int getEscapeChar()
814952
}
815953
}
816954

955+
@State(Thread)
956+
public static class TrinoPadData
957+
{
958+
@Param("128")
959+
private int length;
960+
961+
@Param({"true", "false"})
962+
private boolean ascii;
963+
964+
private byte[] utf8;
965+
private int offset;
966+
private int byteLength;
967+
private Slice padString;
968+
969+
@Setup
970+
public void setup()
971+
{
972+
int[] codePointSet = ascii ? BenchmarkData.ASCII_CODE_POINTS : BenchmarkData.ALL_CODE_POINTS;
973+
ThreadLocalRandom random = ThreadLocalRandom.current();
974+
DynamicSliceOutput out = new DynamicSliceOutput(length * 4);
975+
for (int index = 0; index < length; index++) {
976+
int codePoint = codePointSet[random.nextInt(codePointSet.length)];
977+
out.appendBytes(new String(Character.toChars(codePoint)).getBytes(StandardCharsets.UTF_8));
978+
}
979+
980+
byte[] encoded = out.slice().getBytes();
981+
offset = 9;
982+
utf8 = new byte[offset + encoded.length + 3];
983+
System.arraycopy(encoded, 0, utf8, offset, encoded.length);
984+
byteLength = encoded.length;
985+
padString = Slices.wrappedBuffer(utf8, offset, byteLength);
986+
}
987+
988+
public byte[] getUtf8()
989+
{
990+
return utf8;
991+
}
992+
993+
public int getOffset()
994+
{
995+
return offset;
996+
}
997+
998+
public int getByteLength()
999+
{
1000+
return byteLength;
1001+
}
1002+
1003+
public Slice getPadString()
1004+
{
1005+
return padString;
1006+
}
1007+
}
1008+
1009+
@State(Thread)
1010+
public static class TrinoPrefixRangeData
1011+
{
1012+
@Param("256")
1013+
private int length;
1014+
1015+
@Param({"true", "false"})
1016+
private boolean ascii;
1017+
1018+
private byte[] utf8;
1019+
private int offset;
1020+
private int byteLength;
1021+
private Slice constantPrefix;
1022+
1023+
@Setup
1024+
public void setup()
1025+
{
1026+
int[] codePointSet = ascii ? BenchmarkData.ASCII_CODE_POINTS : BenchmarkData.ALL_CODE_POINTS;
1027+
ThreadLocalRandom random = ThreadLocalRandom.current();
1028+
1029+
int[] codePoints = new int[length];
1030+
codePoints[0] = 'a';
1031+
for (int index = 1; index < codePoints.length; index++) {
1032+
codePoints[index] = codePointSet[random.nextInt(codePointSet.length)];
1033+
}
1034+
1035+
DynamicSliceOutput out = new DynamicSliceOutput(length * 4);
1036+
for (int codePoint : codePoints) {
1037+
out.appendBytes(new String(Character.toChars(codePoint)).getBytes(StandardCharsets.UTF_8));
1038+
}
1039+
1040+
byte[] encoded = out.slice().getBytes();
1041+
offset = 13;
1042+
utf8 = new byte[offset + encoded.length + 5];
1043+
System.arraycopy(encoded, 0, utf8, offset, encoded.length);
1044+
byteLength = encoded.length;
1045+
constantPrefix = Slices.wrappedBuffer(utf8, offset, byteLength);
1046+
}
1047+
1048+
public byte[] getUtf8()
1049+
{
1050+
return utf8;
1051+
}
1052+
1053+
public int getOffset()
1054+
{
1055+
return offset;
1056+
}
1057+
1058+
public int getByteLength()
1059+
{
1060+
return byteLength;
1061+
}
1062+
1063+
public Slice getConstantPrefix()
1064+
{
1065+
return constantPrefix;
1066+
}
1067+
}
1068+
8171069
@State(Thread)
8181070
public static class CodePointWriteData
8191071
{

src/test/java/io/airlift/slice/TestSliceUtf8.java

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import java.util.stream.IntStream;
2626

2727
import static com.google.common.primitives.Bytes.concat;
28+
import static io.airlift.slice.SliceUtf8.codePointByteLengths;
2829
import static io.airlift.slice.SliceUtf8.codePointToUtf8;
2930
import static io.airlift.slice.SliceUtf8.compareUtf16BE;
3031
import static io.airlift.slice.SliceUtf8.countCodePoints;
@@ -266,6 +267,7 @@ public void testByteArrayOverloadsMatchSlice()
266267
assertThat(wrappedBuffer(byteArrayTarget, 0, arrayWritten)).isEqualTo(sliceTarget.slice(0, sliceWritten));
267268

268269
assertThat(toCodePoints(padded, offset, length)).isEqualTo(toCodePoints(view));
270+
assertThat(codePointByteLengths(padded, offset, length)).isEqualTo(codePointByteLengths(view));
269271
assertThat(fromCodePoints(toCodePoints(view))).isEqualTo(view);
270272
}
271273

@@ -337,6 +339,30 @@ public void testToCodePointsInvalidUtf8()
337339
.hasMessageContaining("Invalid UTF-8 sequence at position");
338340
}
339341

342+
@Test
343+
public void testCodePointByteLengths()
344+
{
345+
assertCodePointByteLengths(STRING_EMPTY);
346+
assertCodePointByteLengths(STRING_HELLO);
347+
assertCodePointByteLengths(STRING_OESTERREICH);
348+
assertCodePointByteLengths(STRING_DULIOE_DULIOE);
349+
assertCodePointByteLengths(STRING_FAITH_HOPE_LOVE);
350+
assertCodePointByteLengths(STRING_OO);
351+
assertCodePointByteLengths(STRING_ASCII_CODE_POINTS);
352+
assertCodePointByteLengths(STRING_ALL_CODE_POINTS_RANDOM);
353+
}
354+
355+
private static void assertCodePointByteLengths(String value)
356+
{
357+
Slice utf8 = utf8Slice(value);
358+
int[] codePoints = value.codePoints().toArray();
359+
byte[] expectedLengths = new byte[codePoints.length];
360+
for (int index = 0; index < codePoints.length; index++) {
361+
expectedLengths[index] = (byte) lengthOfCodePoint(codePoints[index]);
362+
}
363+
assertThat(codePointByteLengths(utf8)).isEqualTo(expectedLengths);
364+
}
365+
340366
@Test
341367
public void testFromCodePointsInvalid()
342368
{

0 commit comments

Comments
 (0)