diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/common/compress/FSSTCompressBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/common/compress/FSSTCompressBenchmark.java new file mode 100644 index 0000000000000..35d78a99c56f9 --- /dev/null +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/common/compress/FSSTCompressBenchmark.java @@ -0,0 +1,119 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.benchmark.common.compress; + +import org.apache.lucene.codecs.compressing.CompressionMode; +import org.apache.lucene.codecs.compressing.Compressor; +import org.apache.lucene.store.ByteArrayDataOutput; +import org.apache.lucene.store.ByteBuffersDataInput; +import org.elasticsearch.common.compress.fsst.FSST; +import org.openjdk.jmh.annotations.*; +import org.openjdk.jmh.infra.Blackhole; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.List; +import java.util.concurrent.TimeUnit; + +@Fork(1) +@Warmup(iterations = 2) +@Measurement(iterations = 3) +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(TimeUnit.MILLISECONDS) +@State(Scope.Benchmark) +public class FSSTCompressBenchmark { + + @Param("") + public String dataset; + + private byte[] input; + private int[] offsets; + private byte[] outBuf; + private int[] outOffsets; + + @AuxCounters(AuxCounters.Type.EVENTS) + @State(Scope.Thread) + public static class CompressionMetrics { + public double compressionRatio; + } + + private static final int MB_8 = 8 * 1024 * 1024; + + private byte[] concatenateTo8mb(byte[] contentBytes) { + byte[] bytes = new byte[MB_8 + 8]; + int i = 0; + while (i < MB_8) { + int remaining = MB_8 - i; + int len = Math.min(contentBytes.length, remaining); + System.arraycopy(contentBytes, 0, bytes, i, len); + i += len; + } + return bytes; + } + + @Setup(Level.Trial) + public void setup() throws IOException { + String content = Files.readString(Path.of(dataset), StandardCharsets.UTF_8); + + byte[] contentBytes = FSST.toBytes(content); + input = concatenateTo8mb(contentBytes); + offsets = new int[] { 0, MB_8 }; + outBuf = new byte[MB_8]; + outOffsets = new int[2]; + } + + @Benchmark + public void compressFSST(Blackhole bh, CompressionMetrics metrics) { + List sample = FSST.makeSample(input, offsets); + var symbolTable = FSST.SymbolTable.buildSymbolTable(sample); + symbolTable.compressBulk(1, input, offsets, outBuf, outOffsets); + bh.consume(outBuf); + bh.consume(outOffsets); + + int uncompressedSize = offsets[1]; + int compressedSize = outOffsets[1]; + metrics.compressionRatio = compressedSize / (double) uncompressedSize; + } + + @Benchmark + public void compressLZ4Fast(Blackhole bh, CompressionMetrics metrics) throws IOException { + int inputSize = offsets[1]; + + var dataInput = new ByteBuffersDataInput(List.of(ByteBuffer.wrap(input))); + var dataOutput = new ByteArrayDataOutput(outBuf); + + Compressor compressor = CompressionMode.FAST.newCompressor(); + compressor.compress(dataInput, dataOutput); + + long compressedSize = dataOutput.getPosition(); + bh.consume(dataOutput); + + metrics.compressionRatio = compressedSize / (double) inputSize; + } + + // @Benchmark + // public void compressLZ4High(Blackhole bh, CompressionMetrics metrics) throws IOException { + // int inputSize = offsets[1]; + // + // var dataInput = new ByteBuffersDataInput(List.of(ByteBuffer.wrap(input))); + // var dataOutput = new ByteArrayDataOutput(outBuf); + // + // Compressor compressor = CompressionMode.HIGH_COMPRESSION.newCompressor(); + // compressor.compress(dataInput, dataOutput); + // + // long compressedSize = dataOutput.getPosition(); + // bh.consume(dataOutput); + // + // metrics.compressionRatio = compressedSize / (double) inputSize; + // } +} diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/common/compress/FSSTDecompressBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/common/compress/FSSTDecompressBenchmark.java new file mode 100644 index 0000000000000..95ea2ff65301e --- /dev/null +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/common/compress/FSSTDecompressBenchmark.java @@ -0,0 +1,134 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.benchmark.common.compress; + +import org.apache.lucene.codecs.compressing.CompressionMode; +import org.apache.lucene.codecs.compressing.Compressor; +import org.apache.lucene.codecs.compressing.Decompressor; +import org.apache.lucene.store.ByteArrayDataInput; +import org.apache.lucene.store.ByteArrayDataOutput; +import org.apache.lucene.store.ByteBuffersDataInput; +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.compress.fsst.FSST; +import org.openjdk.jmh.annotations.*; +import org.openjdk.jmh.infra.Blackhole; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.List; +import java.util.concurrent.TimeUnit; + +@Fork(1) +@Warmup(iterations = 2) +@Measurement(iterations = 3) +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(TimeUnit.MILLISECONDS) +@State(Scope.Benchmark) +public class FSSTDecompressBenchmark { + + // @Param({ "fsst", "lz4_high", "lz4_fast" }) + @Param({ "fsst", "lz4_fast" }) + public String compressionType; + + @Param("") + public String dataset; + + // original file + private int originalSize; + private byte[] input; + private int[] offsets; + + // compressed + private byte[] outBuf; + private int[] outOffsets; + private int compressedSize; + + // decompressed + private byte[] decompressBuf; + + // fsst specific + private FSST.SymbolTable symbolTable; + + private static final int MB_8 = 8 * 1024 * 1024; + + private byte[] concatenateTo8mb(byte[] contentBytes) { + byte[] bytes = new byte[MB_8 + 8]; + int i = 0; + while (i < MB_8) { + int remaining = MB_8 - i; + int len = Math.min(contentBytes.length, remaining); + System.arraycopy(contentBytes, 0, bytes, i, len); + i += len; + } + return bytes; + } + + @Setup(Level.Trial) + public void setup() throws IOException { + String content = Files.readString(Path.of(dataset), StandardCharsets.UTF_8); + byte[] contentBytes = FSST.toBytes(content); + originalSize = MB_8; + input = concatenateTo8mb(contentBytes); + offsets = new int[] { 0, originalSize }; + + outBuf = new byte[input.length]; + outOffsets = new int[2]; + + decompressBuf = new byte[input.length]; + + if (compressionType.equals("fsst")) { + List sample = FSST.makeSample(input, offsets); + symbolTable = FSST.SymbolTable.buildSymbolTable(sample); + symbolTable.compressBulk(1, input, offsets, outBuf, outOffsets); + compressedSize = outOffsets[1]; + } else if (compressionType.equals("lz4_fast")) { + var dataInput = new ByteBuffersDataInput(List.of(ByteBuffer.wrap(input, 0, originalSize))); + var dataOutput = new ByteArrayDataOutput(outBuf); + Compressor compressor = CompressionMode.FAST.newCompressor(); + compressor.compress(dataInput, dataOutput); + compressedSize = dataOutput.getPosition(); + } else if (compressionType.equals("lz4_high")) { + var dataInput = new ByteBuffersDataInput(List.of(ByteBuffer.wrap(input, 0, originalSize))); + var dataOutput = new ByteArrayDataOutput(outBuf); + Compressor compressor = CompressionMode.HIGH_COMPRESSION.newCompressor(); + compressor.compress(dataInput, dataOutput); + compressedSize = dataOutput.getPosition(); + } + } + + @Benchmark + public void decompress(Blackhole bh) throws IOException { + if (compressionType.equals("fsst")) { + byte[] symbolTableBytes = symbolTable.exportToBytes(); + FSST.Decoder decoder = FSST.Decoder.readFrom(symbolTableBytes); + int decompressedLen = FSST.decompress(outBuf, 0, outOffsets[1], decoder, decompressBuf); + // assert Arrays.equals(input, 0, originalSize, decompressBuf, 0, originalSize); + bh.consume(decompressBuf); + bh.consume(decompressedLen); + } else if (compressionType.equals("lz4_fast")) { + Decompressor decompressor = CompressionMode.FAST.newDecompressor(); + var dataInput = new ByteArrayDataInput(outBuf, 0, compressedSize); + var outBytesRef = new BytesRef(decompressBuf); + decompressor.decompress(dataInput, originalSize, 0, originalSize, outBytesRef); + // assert Arrays.equals(input, 0, originalSize, outBytesRef.bytes, 0, originalSize); + bh.consume(outBytesRef); + } else if (compressionType.equals("lz4_high")) { + Decompressor decompressor = CompressionMode.HIGH_COMPRESSION.newDecompressor(); + var dataInput = new ByteArrayDataInput(outBuf, 0, compressedSize); + var outBytesRef = new BytesRef(decompressBuf); + decompressor.decompress(dataInput, originalSize, 0, originalSize, outBytesRef); + // assert Arrays.equals(input, 0, originalSize, outBytesRef.bytes, 0, originalSize); + bh.consume(outBytesRef); + } + } +} diff --git a/server/src/main/java/org/elasticsearch/common/compress/fsst/BulkCompressBufferer.java b/server/src/main/java/org/elasticsearch/common/compress/fsst/BulkCompressBufferer.java new file mode 100644 index 0000000000000..33026f80c42dd --- /dev/null +++ b/server/src/main/java/org/elasticsearch/common/compress/fsst/BulkCompressBufferer.java @@ -0,0 +1,125 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.common.compress.fsst; + +import org.apache.lucene.store.DataOutput; + +import java.io.Closeable; +import java.io.IOException; + +public class BulkCompressBufferer implements Closeable { + private static final int MAX_LINES = 512; + private static final int MAX_INPUT_DATA = 128 << 10; + private static final int MAX_OUTPUT_DATA = MAX_INPUT_DATA * 2; + + final byte[] inData = new byte[MAX_INPUT_DATA + 8]; + final int[] inOffsets = new int[MAX_LINES + 1]; // 1 additional space for offset where next item would have been + byte[] outBuf = new byte[MAX_OUTPUT_DATA + 8]; + int[] outOffsets = new int[MAX_LINES + 1]; // 1 additional space for offset where next item would have been + private final DataOutput finalOutput; + private final FSST.SymbolTable st; + private final FSST.OffsetWriter offsetWriter; + private int numLines = 0; + private int inOff = 0; + + public BulkCompressBufferer(DataOutput finalOutput, FSST.SymbolTable st, FSST.OffsetWriter offsetWriter) { + this.finalOutput = finalOutput; + this.st = st; + this.offsetWriter = offsetWriter; + } + + private void addToBuffer(byte[] bytes, int offset, int length) { + System.arraycopy(bytes, offset, inData, inOff, length); + int lineIdx = numLines; + inOffsets[lineIdx] = inOff; + inOff += length; + numLines++; + } + + public void addLine(byte[] bytes, int offset, int length) throws IOException { + if (inOff + length > MAX_INPUT_DATA || numLines == MAX_LINES) { + // can't fit another + compressAndWriteBuffer(); + + if (length > MAX_INPUT_DATA) { + // new item doesn't fit by itself, so deal with it by itself + compressAndWriteSingle(bytes, offset, length); + } else { + // does fit + addToBuffer(bytes, offset, length); + } + } else { + // does fit + addToBuffer(bytes, offset, length); + } + } + + private void compressAndWriteSingle(byte[] bytes, int offset, int length) throws IOException { + assert numLines == 0 && inOff == 0; + + int off = offset; + int lenToWrite = length; + int totalOutLen = 0; + + while (lenToWrite > 0) { + int len = Math.min(lenToWrite, MAX_INPUT_DATA); + + // copy data into buffer + numLines = 1; + inOffsets[0] = off; + inOffsets[1] = off + len; + + long outLine = st.compressBulk(numLines, bytes, inOffsets, outBuf, outOffsets); + assert outLine == numLines; + long outLen = outOffsets[(int) outLine]; + totalOutLen += (int) outLen; + finalOutput.writeBytes(outBuf, 0, (int) outLen); + + off += len; + lenToWrite -= len; + + } + offsetWriter.addLen(totalOutLen); + + clear(); + } + + private void compressAndWriteBuffer() throws IOException { + assert numLines < MAX_LINES + 1; + assert inOff <= MAX_INPUT_DATA; + + // add a pseudo-offset to provide last line's length + inOffsets[numLines] = inOff; + + long outLines = st.compressBulk(numLines, inData, inOffsets, outBuf, outOffsets); + assert outLines == numLines; + long fullOutLen = outOffsets[(int) outLines]; + + finalOutput.writeBytes(outBuf, 0, (int) fullOutLen); + for (int i = 0; i < numLines; ++i) { + int len = outOffsets[i + 1] - outOffsets[i]; + offsetWriter.addLen(len); + } + + clear(); + } + + void clear() { + numLines = inOff = 0; + } + + @Override + public void close() throws IOException { + if (numLines > 0) { + compressAndWriteBuffer(); + } + clear(); + } +} diff --git a/server/src/main/java/org/elasticsearch/common/compress/fsst/FSST.java b/server/src/main/java/org/elasticsearch/common/compress/fsst/FSST.java new file mode 100644 index 0000000000000..2394dcf259143 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/common/compress/fsst/FSST.java @@ -0,0 +1,1263 @@ +// this software is distributed under the MIT License (http://www.opensource.org/licenses/MIT): +// +// Copyright 2018-2019, CWI, TU Munich, FSU Jena +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files +// (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, +// merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// - The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +// OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR +// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +// +// You can contact the authors via the FSST source repository : https://github.com/cwida/fsst + +////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +/* + * This file is a Java port of the original library shown in the license above. + * Original C++ library: https://github.com/cwida/fsst + * + * This file contains code derived from https://github.com/cwida/fsst and + * also includes significant additions by parkertimmins. + */ + +package org.elasticsearch.common.compress.fsst; + +import java.io.IOException; +import java.lang.invoke.MethodHandles; +import java.lang.invoke.VarHandle; +import java.nio.ByteOrder; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.PriorityQueue; +import java.util.Random; + +public class FSST { + + public static final VarHandle VH_NATIVE_LONG = MethodHandles.byteArrayViewVarHandle(long[].class, ByteOrder.nativeOrder()); + public static final VarHandle VH_NATIVE_INT = MethodHandles.byteArrayViewVarHandle(int[].class, ByteOrder.nativeOrder()); + static final int FSST_SAMPLELINE = 512; + static final int maxStrLength = 8; + static final int FSST_SAMPLETARGET = 1 << 14; + static final int FSST_SAMPLEMAXSZ = 2 * FSST_SAMPLETARGET; + static final int FSST_CODE_BITS = 9; + static final int FSST_CODE_BASE = 256; // 0x100 + static final int FSST_CODE_MAX = 1 << FSST_CODE_BITS; // 512 + static final int FSST_CODE_MASK = FSST_CODE_MAX - 1; // 511 -> 0x1FF + static final int FSST_HASH_LOG2SIZE = 10; + + // we represent codes in u16 (not u8). 12 bits code (of which 10 are used), 4 bits length + // length is included because shortCodes can contain single-byte str during compressBulk + static final int FSST_LEN_BIT_OFFSET = 12; + static final int FSST_SHIFT = 15; + static final long FSST_ICL_FREE = ((15L << 28) | ((FSST_CODE_MASK) << 16)); + private static final int ESCAPE_BYTE = 255; + private static final long FSST_HASH_PRIME = 2971215073L; + + public static long hash(long w) { + return (((w) * FSST_HASH_PRIME) ^ (((w) * FSST_HASH_PRIME) >>> FSST_SHIFT)); + } + + static char first1(long str) { + return (char) (0xFF & str); + } + + static char first2(long str) { + return (char) str; + } + + static char first1(long[] symbols, int idx) { + return first1(getStr(symbols, idx)); + } + + static char first2(long[] symbols, int idx) { + return first2(getStr(symbols, idx)); + } + + static long getStr(long[] symbols, int idx) { + return symbols[(idx << 1) + 1]; + } + + static long getICL(long[] symbols, int idx) { + return symbols[idx << 1]; + } + + static int getLen(long[] symbols, int idx) { + return getLen(getICL(symbols, idx)); + } + + static int getLen(long icl) { + return (int) (icl >>> 28); + } + + static char getCode(long icl) { + return (char) (FSST_CODE_MASK & (icl >>> 16)); + } + + static int getIgnored(long icl) { + return (char) icl; + } + + static long removeIgnored(long fullStr, long icl) { + return fullStr & (0xFFFFFFFFFFFFFFFFL >>> getIgnored(icl)); + } + + // ignored bits: + // icl = ignoredBits:16,code:12,length:4,unused:32 -- but we avoid exposing this bit-field notation + static long toICL(int code, int len) { + return (((long) len << 28) | ((long) code << 16) | ((8 - len) * 8L)); + } + + static void set(long[] symbols, int idx, long str, int code, int len) { + symbols[(idx << 1)] = toICL(code, len); + symbols[(idx << 1) + 1] = str; + } + + static void set(long[] symbols, int idx, long str, long icl) { + symbols[(idx << 1)] = icl; + symbols[(idx << 1) + 1] = str; + } + + static void setFree(long[] symbols, int idx) { + symbols[(idx << 1)] = FSST_ICL_FREE; + symbols[(idx << 1) + 1] = 0; + } + + static void setICL(long[] symbols, int idx, int code, int len) { + symbols[(idx << 1)] = toICL(code, len); + } + + static char byteToCode(int b) { + return (char) ((1 << FSST_LEN_BIT_OFFSET) | b); + } + + public static long readLong(byte[] buf, int offset) { + return (long) VH_NATIVE_LONG.get(buf, offset); + } + + public static int readInt(byte[] buf, int offset) { + return (int) VH_NATIVE_INT.get(buf, offset); + } + + public static void writeLong(byte[] buf, int offset, long value) { + VH_NATIVE_LONG.set(buf, offset, value); + } + + @SuppressWarnings({ "fallthrough" }) + public static long readLong(byte[] str, int pos, int len) { + long res = 0; + + len = Math.min(len, 8); + switch (len) { + case 8: + res |= (str[pos + 7] & 0xFFL) << 56; + case 7: + res |= (str[pos + 6] & 0xFFL) << 48; + case 6: + res |= (str[pos + 5] & 0xFFL) << 40; + case 5: + res |= (str[pos + 4] & 0xFFL) << 32; + case 4: + res |= (str[pos + 3] & 0xFFL) << 24; + case 3: + res |= (str[pos + 2] & 0xFF) << 16; + case 2: + res |= (str[pos + 1] & 0xFF) << 8; + case 1: + res |= (str[pos] & 0xFF); + } + return res; + } + + static boolean isEscapeCode(char pos) { + return pos < FSST_CODE_BASE; + } + + public static class SymbolTable { + + static final int hashTableSize = 1 << FSST_HASH_LOG2SIZE; + static final int hashTableArraySize = hashTableSize * 2; + static final int hashTableMask = hashTableSize - 1; + + // high bits of icl (len=8,code=FSST_CODE_MASK) indicates free bucket + + char[] shortCodes = new char[65536]; + char[] byteCodes = new char[256]; + + // both ht and symbols contains symbols as two adjacent longs + // value at 0 is icl: + // value at 1 is symbol bytes + long[] symbols = new long[FSST_CODE_MAX * 2]; + long[] ht = new long[hashTableArraySize]; + + int nSymbols = 0; // amount of symbols in the map (max 255) + int suffixLim = FSST_CODE_MAX; // 512, codes higher than this do not have a longer suffix + char terminator; // code of 1-byte symbol, that can be used as a terminator during compression + char[] lenHisto = new char[8]; // lenHisto[x] is the amount of symbols of byte-length (x+1) in this SymbolTable + + void initialize() { + // fill in [0, 256) with single byte codes + for (int i = 0; i < 256; ++i) { + set(symbols, i, i, byteToCode(i), 1); // i is index, str, and code + } + + // fill [256, 512) with unused flag + for (int i = 256; i < FSST_CODE_MAX; ++i) { + set(symbols, i, 0, FSST_CODE_MASK, 1); + } + + // set hash table empty + for (int i = 0; i < hashTableSize; ++i) { + setFree(ht, i); + } + + // fill byteCodes[] with the pseudo code all bytes (escaped bytes) + for (int i = 0; i < 256; ++i) { + byteCodes[i] = byteToCode(i); + } + + // fill shortCodes[] with the pseudo code for the first byte of each two-byte pattern + for (int i = 0; i < 65536; ++i) { + shortCodes[i] = (char) ((1 << FSST_LEN_BIT_OFFSET) | (i & 0xFF)); // byteToCode(i & 0xFF) + } + } + + private SymbolTable() {} + + static SymbolTable build() { + var st = new SymbolTable(); + st.initialize(); + return st; + } + + static SymbolTable buildUnitialized() { + return new SymbolTable(); + } + + void clear() { + Arrays.fill(lenHisto, (char) 0); + + for (int i = FSST_CODE_BASE; i < FSST_CODE_BASE + nSymbols; i++) { + int len = getLen(symbols, i); + if (len == 1) { + char val = first1(symbols, i); + byteCodes[val] = (char) ((1 << FSST_LEN_BIT_OFFSET) | val); + } else if (len == 2) { + char val = first2(symbols, i); + shortCodes[val] = (char) ((1 << FSST_LEN_BIT_OFFSET) | (val & 0xFF)); + } else { + int idx = hashStr(getStr(symbols, i)); + setFree(ht, idx); + } + } + nSymbols = 0; // no need to clean symbols[] as no symbols are used + } + + public void copyInto(SymbolTable copy) { + copy.nSymbols = nSymbols; + copy.suffixLim = suffixLim; + copy.terminator = terminator; + System.arraycopy(lenHisto, 0, copy.lenHisto, 0, lenHisto.length); + System.arraycopy(byteCodes, 0, copy.byteCodes, 0, byteCodes.length); + System.arraycopy(shortCodes, 0, copy.shortCodes, 0, shortCodes.length); + System.arraycopy(symbols, 0, copy.symbols, 0, symbols.length); + System.arraycopy(ht, 0, copy.ht, 0, ht.length); + } + + char getShortCode(long str) { + return (char) (shortCodes[first2(str)] & FSST_CODE_MASK); + } + + // return index in hash table + public static int hashStr(long str) { + long first3 = str & 0xFFFFFF; + return (int) (hash((int) first3) & hashTableMask); + } + + // only called if string has at least 3 characters + boolean hashInsert(long str, long icl) { + // ignored prefix already removed + assert removeIgnored(str, icl) == str; + int idx = hashStr(str); + long currICL = getICL(ht, idx); + boolean taken = (currICL < FSST_ICL_FREE); + if (taken) return false; // collision in hash table + set(ht, idx, str, icl); + return true; + } + + /** + * Add an existing symbol to a symbol table. + * Used after candidates have been picked via priority queue and are re-added to a symbol table + */ + boolean add(long str, long i_l) { + assert (FSST_CODE_BASE + nSymbols < FSST_CODE_MAX); + + int len = getLen(i_l); + + int code = (char) (FSST_CODE_BASE + nSymbols); + long icl = toICL(code, len); + + if (len == 1) { + byteCodes[first1(str)] = (char) (code + (1 << FSST_LEN_BIT_OFFSET)); // len=1 + } else if (len == 2) { + shortCodes[first2(str)] = (char) (code + (2 << FSST_LEN_BIT_OFFSET)); // len=2 + } else if (hashInsert(str, icl) == false) { + return false; // already in hash table + } + + set(symbols, code, str, icl); + nSymbols++; + lenHisto[len - 1]++; + return true; + } + + // Removes length prefix from 1 and 2 byte codes + char findLongestSymbol(long inStr, int inLen) { + // use default max value, but this is not used + long inICL = toICL(FSST_CODE_MAX, inLen); + + // first check the hash table + int idx = hashStr(inStr); + long icl = getICL(ht, idx); + long str = getStr(ht, idx); + // check length in case there are 0x00 bytes in the data + if (icl <= inICL && str == removeIgnored(inStr, icl)) { + return getCode(icl); + } + + if (inLen >= 2) { + char code = getShortCode(inStr); + if (code >= FSST_CODE_BASE) return code; + } + return (char) (byteCodes[first1(inStr)] & FSST_CODE_MASK); + } + + char findLongestSymbol(byte[] line, int start, int len) { + long str; + if (len >= 8) { + len = 8; + str = readLong(line, start); + } else { + str = readLong(line, start, len); + } + return findLongestSymbol(str, len); + } + + void print() { + + System.out.println("numSymbols: " + nSymbols); + System.out.println("terminator: " + terminator); + System.out.println("suffixLim: " + suffixLim); + + System.out.println("Hash table: "); + for (int i = 0; i < hashTableSize; ++i) { + long icl = getICL(ht, i); + long str = getStr(ht, i); + if (icl < FSST_ICL_FREE) { + System.out.println("idx: " + i + ", " + FSST.toString(str, icl)); + } + } + + System.out.println("Symbol table: "); + for (int i = FSST_CODE_BASE; i < FSST_CODE_BASE + nSymbols; ++i) { + long icl = getICL(symbols, i); + long str = getStr(symbols, i); + System.out.println(FSST.toString(str, icl)); + } + + System.out.println("Symbol table final: "); + for (int i = 0; i < nSymbols; ++i) { + long icl = getICL(symbols, i); + long str = getStr(symbols, i); + System.out.println(FSST.toString(str, icl)); + } + + System.out.println("Short codes: "); + for (int i = 0; i < nSymbols; ++i) { + long icl = getICL(symbols, i); + long str = getStr(symbols, i); + + if (getLen(icl) == 2) { + var strRep = fromBytes(toByteArray(str, getLen(icl))); + System.out.println("code: " + (int) getShortCode(str) + ", str: '" + strRep + "'"); + } + } + + Map counts = new HashMap<>(); + for (int i = 0; i < 65536; ++i) { + int len = shortCodes[i] >> FSST_LEN_BIT_OFFSET; + if (counts.containsKey(len)) { + counts.put(len, counts.get(len) + 1); + } else { + counts.put(len, 1); + } + } + System.out.println(counts); + } + + // before finalize(): + // - The real symbols are symbols[256..256+nSymbols>. As we may have nSymbols > 255 + // - The first 256 codes are pseudo symbols (all escaped bytes) + // + // after finalize(): + // - table layout is symbols[0..nSymbols>, with nSymbols < 256. + // - Real codes are [0,nSymbols>. 8-th bit not set. + // - Escapes in shortCodes have the 8th bit set (value: 256+255=511). 255 because the code to be emitted is the escape byte 255 + // - symbols are grouped by length: 2,3,4,5,6,7,8, then 1 (single-byte codes last) + // the two-byte codes are split in two sections: + // - first section contains codes for symbols for which there is no longer symbol (no suffix). + // It allows an early-out during compression + // + // finally, shortCodes[] is modified to also encode all single-byte symbols + // (hence byteCodes[] is not required on a critical path anymore). + void finalizeSymbolTable() { + assert nSymbols <= 255; + + // maps original real code [0, 255] to new code + char[] newCode = new char[256]; + char[] rsum = new char[8]; + int numSymbolsLen1 = lenHisto[0]; + int byteLim = nSymbols - numSymbolsLen1; // since single-byte comes last, byteLim is index of first single-byte code + + // compute running sum of code lengths (starting offsets for each length) + rsum[0] = (char) byteLim; // byte 1-byte codes come last + rsum[1] = 0; // 2-byte start at 0 + // [0] = num 1 byte code, [2] = num 1,2 byte codes, etc + for (int i = 1; i < 7; i++) { + rsum[i + 1] = (char) (rsum[i] + lenHisto[i]); + } + + suffixLim = 0; + for (int i = 0, j = rsum[2]; i < nSymbols; i++) { + long s1_icl = getICL(symbols, FSST_CODE_BASE + i); + long s1_str = getStr(symbols, FSST_CODE_BASE + i); + int len = getLen(s1_icl); + + if (len == 2) { + boolean foundSuffix = false; + char s1_first2 = first2(s1_str); + for (int k = 0; k < nSymbols; k++) { + long s2_icl = getICL(symbols, FSST_CODE_BASE + k); + long s2_str = getStr(symbols, FSST_CODE_BASE + k); + + // test if symbol k is a suffix of symbol i + if (k != i && getLen(s2_icl) > 1 && s1_first2 == first2(s2_str)) { + foundSuffix = true; + break; + } + } + + // symbols without a larger suffix have a code < suffixLim + // opt == 0 means containing str found + newCode[i] = (char) (foundSuffix ? --j : suffixLim++); + } else { + // now using rsum as range start counters + newCode[i] = rsum[len - 1]++; + } + + // change code to newCode[i] and move to newCode[i] position + set(symbols, newCode[i], getStr(symbols, FSST_CODE_BASE + i), newCode[i], len); + } + + // renumber the codes in byteCodes[] + for (int i = 0; i < 256; i++) { + if ((byteCodes[i] & FSST_CODE_MASK) >= FSST_CODE_BASE) { + // & 0xFF here convert full code to real code, eg is equivalent to -256 + byteCodes[i] = (char) (newCode[byteCodes[i] & 0xFF] + (1 << FSST_LEN_BIT_OFFSET)); + } else { + byteCodes[i] = 511 + (1 << FSST_LEN_BIT_OFFSET); + } + } + + // renumber the codes in shortCodes[] + for (int i = 0; i < 65536; i++) { + if ((shortCodes[i] & FSST_CODE_MASK) >= FSST_CODE_BASE) { + // mask out the original length, but could probably use 0x3 as max length should be 2 + shortCodes[i] = (char) (newCode[shortCodes[i] & 0xFF] + (shortCodes[i] & (0xF << FSST_LEN_BIT_OFFSET))); + } else { + // if there is no code, use the single-byte code for the first byte + // if there is no single byte code, values will be 511 + shortCodes[i] = byteCodes[i & 0xFF]; + } + } + + // replace the symbols in the hash table + for (int i = 0; i < hashTableSize; i++) { + long icl = getICL(ht, i); + if (icl < FSST_ICL_FREE) { + char nc = newCode[getCode(icl) & 0xFF]; + long icl1 = getICL(symbols, nc); + long str1 = getStr(symbols, nc); + set(ht, i, str1, icl1); + } + } + } + + // aced: 1684366177 + + // find terminator: least frequent byte + public static char findTerminator(List lines) { + // find terminator: least frequent byte + char[] byteHisto = new char[256]; + for (int i = 0; i < lines.size(); i++) { + byte[] line = lines.get(i); + for (byte b : line) { + byteHisto[b & 0xFF]++; + } + } + + int minSize = FSST_SAMPLEMAXSZ; + int terminator = 256; + int i = terminator; + while (i-- > 0) { + if (byteHisto[i] <= minSize) { + terminator = (char) i; + minSize = byteHisto[i]; + } + } + + assert (terminator < 256); + return (char) terminator; + } + + private static int compressCount(int sampleFrac, List lines, SymbolTable st, Counters counters) { // returns gain + int gain = 0; + + var random = new Random(123); // use same see for reproducibility + for (var line : lines) { + int cur = 0, start = 0, end = line.length; + // TODO if there are few lines (or rather chunks), sampleFrac skipping may cause all data to be skipped + // Probably doesn't matter, since this means data is short + if (sampleFrac < 128) { + // in earlier rounds (sampleFrac < 128) we skip data in the sample (reduces overall work ~2x) + if (random.nextInt(0, 128 + 1) > sampleFrac) continue; + } + if (cur < end) { + char code2 = 255; + char code1 = st.findLongestSymbol(line, cur, end - cur); + var len = getLen(st.symbols, code1); + cur += len; + gain += len - (1 + (isEscapeCode(code1) ? 1 : 0)); + while (true) { + // count single symbol (i.e. an option is not extending it) + counters.count1Inc(code1); + + // as an alternative, consider just using the next byte.. + var len1 = getLen(st.symbols, code1); + if (len1 != 1) // .. but do not count single byte symbols doubly + counters.count1Inc(line[start] & 0xFF); + + if (cur == end) { + break; + } + + // now match a new symbol + start = cur; + if (cur < end - 7) { // add least 8 bytes left + long word = readLong(line, cur); + + // find existing string matching same 3 letters (or hash collision) + int idx = hashStr(word); + long icl = getICL(st.ht, idx); + long str = getStr(st.ht, idx); + + code2 = st.getShortCode(word); + word = removeIgnored(word, icl); + if ((icl < FSST_ICL_FREE) & (str == word)) { + code2 = getCode(icl); + cur += getLen(icl); + } else if (code2 >= FSST_CODE_BASE) { + cur += 2; + } else { + code2 = (char) (st.byteCodes[first1(word)] & FSST_CODE_MASK); + cur += 1; + } + } else { + code2 = st.findLongestSymbol(line, cur, end - cur); + cur += getLen(st.symbols, code2); + } + + // compute compressed output size + gain += (cur - start) - (1 + (isEscapeCode(code1) ? 1 : 0)); + + if (sampleFrac < 128) { // no need to count pairs in final round + // consider the symbol that is the concatenation of the two last symbols + counters.count2Inc(code1, code2); + + // as an alternative, consider just extending with the next byte.. + if ((cur - start) > 1) // ..but do not count single byte extensions doubly + counters.count2Inc(code1, line[start] & 0xFF); + } + code1 = code2; + } + } + } + return gain; + } + + static class QSymbol implements Comparable { + // symbol + final long icl; + final long str; + int gain; + + QSymbol(long icl, long str, int gain) { + this.icl = icl; + this.str = str; + this.gain = gain; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + QSymbol qSymbol = (QSymbol) o; + // str representation, which has been truncated to correct length + // Also include length, in case there are actual 0x00 values in one string + return str == qSymbol.str && getLen(icl) == getLen(qSymbol.icl); + } + + @Override + public int hashCode() { + return Objects.hash(str); + } + + @Override + public String toString() { + return "QSymbol{" + + "code=" + + getCode(icl) + + " (" + + (int) getCode(icl) + + ")" + + ", len=" + + getLen(icl) + + ", str=" + + fromBytes(toByteArray(str, getLen(icl))) + + '}'; + } + + @Override + public int compareTo(QSymbol o) { + // return element if higher gain, or equal gain and longer + boolean firstBetter = gain > o.gain || gain == o.gain && str > o.str; + // shouldn't compare equal elements due to deduplication + assert str != o.str; + return firstBetter ? -1 : 1; + } + } + + private static void addOrInc(int sampleFrac, Map cands, long icl, long str, int count) { + // only accepts strings which have been truncated to correct length + assert str == removeIgnored(str, icl); + + if (count < (5 * sampleFrac) / 128) return; // improves both compression speed (less candidates), but also quality!! + + int gain = count * getLen(icl); + QSymbol existing = cands.get(str); + if (existing == null) { + cands.put(str, new QSymbol(icl, str, gain)); + } else { + existing.gain += gain; + } + } + + record Symbol(long icl, long str) { + @Override + public String toString() { + return "QSymbol{" + + "code=" + + getCode(icl) + + " (" + + (int) getCode(icl) + + ")" + + ", len=" + + getLen(icl) + + ", str=" + + fromBytes(toByteArray(str, getLen(icl))) + + '}'; + } + } + + // TODO do this without making object + private static Symbol concat(long icl1, long str1, long icl2, long str2) { + int len1 = getLen(icl1); + int len2 = getLen(icl2); + int length = len1 + len2; + if (length > maxStrLength) length = maxStrLength; + + long icl = toICL(FSST_CODE_MASK, length); + long str = (str2 << (8 * len1)) | str1; + return new Symbol(icl, str); + } + + /** + * Use existing SymbolTable and counters to create priority queue of candidate symbols. + * The clear symbol table and add top 255 symbols from pq + */ + public static void makeTable(int sampleFrac, SymbolTable st, Counters counters) { + // hashmap of c (needed because we can generate duplicate candidates) + Map cands = new HashMap<>(); + + // artificially make terminator the most frequent symbol so it gets included + char terminator = st.nSymbols > 0 ? FSST_CODE_BASE : st.terminator; + counters.count1Set(terminator, (char) 65535); + + // add candidate symbols based on counted frequency + for (int pos1 = 0; pos1 < FSST_CODE_BASE + st.nSymbols; pos1++) { + int cnt1 = counters.count1Get(pos1); + if (cnt1 == 0) continue; + + // heuristic: promoting single-byte symbols (*8) helps reduce exception rates and increases [de]compression speed + long icl = getICL(st.symbols, pos1); + long str = getStr(st.symbols, pos1); + int len = getLen(icl); + addOrInc(sampleFrac, cands, icl, str, ((len == 1) ? 8 : 1) * cnt1); + + if (sampleFrac >= 128 || // last round we do not create new (combined) symbols + len == maxStrLength || // symbol cannot be extended + first1(str) == st.terminator) { // multi-byte symbols cannot contain the terminator byte + continue; + } + for (int pos2 = 0; pos2 < FSST_CODE_BASE + st.nSymbols; pos2++) { + int cnt2 = counters.count2Get(pos1, pos2); + if (cnt2 == 0) continue; + + // create a new symbol + long icl2 = getICL(st.symbols, pos2); + long str2 = getStr(st.symbols, pos2); + Symbol s3 = concat(icl, str, icl2, str2); + if (first1(str2) != st.terminator) // multi-byte symbols cannot contain the terminator byte + addOrInc(sampleFrac, cands, s3.icl, s3.str, cnt2); + } + } + + // insert candidates into priority queue (by gain) + var pq = new PriorityQueue(cands.size()); + pq.addAll(cands.values()); + + // Create new symbol map using best candidates + st.clear(); + while (st.nSymbols < 255 && pq.isEmpty() == false) { + QSymbol top = pq.poll(); + st.add(top.str, top.icl); + } + } + + public static SymbolTable buildSymbolTable(List lines) { + var counters = new Counters(); + var st = SymbolTable.build(); + var bestTable = SymbolTable.buildUnitialized(); + long bestGain = -FSST_SAMPLEMAXSZ; // worst case (everything exception) + int sampleFrac; + + st.terminator = findTerminator(lines); + + Counters bestCounters = new Counters(); + // we do 5 rounds (sampleFrac=8,38,68,98,128) + for (sampleFrac = 8; true; sampleFrac += 30) { + counters.clear(); + long gain = compressCount(sampleFrac, lines, st, counters); + if (gain >= bestGain) { // a new best solution! + counters.copyInto(bestCounters); + st.copyInto(bestTable); + bestGain = gain; + } + if (sampleFrac >= 128) break; // don't build st on last loop + makeTable(sampleFrac, st, counters); + } + makeTable(sampleFrac, bestTable, bestCounters); + bestTable.finalizeSymbolTable(); // renumber codes for more efficient compression + return bestTable; + } + + public long compressBulk( + int nlines, + byte[] data, /* input string data */ + int[] offsets, /* string offset, length is nlines+1 */ + byte[] outBuf, // output buffer, multiple lines will be within buffer + int[] outOffsets // compressed line start offsets within buffer, length known + ) { + boolean avoidBranch = false, noSuffixOpt = false; + + // if 2-byte symbols account for at least 65% percent of symbols + if (100 * lenHisto[1] > 65 * nSymbols + // and at least 95% of 2-byte symbols are have no longer symbol with matching prefix + && 100 * suffixLim > 95 * lenHisto[1]) { + // use noSuffixOpt - check shortCodes before checking hash table + noSuffixOpt = true; + + // otherwise decide if should use branch to separate between 1 and 2 byte symbols + } else if ((lenHisto[0] > 24 && lenHisto[0] < 92) + && (lenHisto[0] < 43 || lenHisto[6] + lenHisto[7] < 29) + && (lenHisto[0] < 72 || lenHisto[2] < 72)) { + avoidBranch = true; + } + + if (noSuffixOpt == false && avoidBranch) { + return compressBulk(nlines, data, offsets, outBuf, outOffsets, false, true); + } else if (noSuffixOpt && avoidBranch == false) { + return compressBulk(nlines, data, offsets, outBuf, outOffsets, true, false); + } else { + return compressBulk(nlines, data, offsets, outBuf, outOffsets, false, false); + } + } + + // optimized adaptive *scalar* compression method + public long compressBulk( + int numLines, + byte[] data, // input string data + int[] offsets, // offsets of each string values, length is one more than numLines + byte[] outBuf, // output buffer, multiple lines will be within buffer + int[] outOffsets, // compressed line start offsets within buffer, length known + boolean noSuffixOpt, + boolean avoidBranch + ) { + int outCur = 0; + int outLim = outBuf.length; + int curLine = 0; + int suffixLim = this.suffixLim; + int byteLim = this.nSymbols - this.lenHisto[0]; + + byte[] buf = new byte[512 + 8]; /* +8 sentinel is to avoid 8-byte unaligned-loads going beyond 511 out-of-bounds */ + + for (; curLine < numLines; curLine++) { + int lineLen = offsets[curLine + 1] - offsets[curLine]; + assert lineLen >= 0; + int chunkLen = 0; + int chunkStart = 0; + outOffsets[curLine] = outCur; + + // a single str/line can be in multiple chunks, but a chunk contains at most 1 str + do { + // we need to compress in chunks of 511 in order to be byte-compatible with simd-compressed FSST + chunkLen = Math.min(lineLen - chunkStart, 511); + + int remaining = outLim - outCur; + if ((2 * chunkLen + 7) > remaining) { + return curLine; // out of memory + } + + // copy the string to the 511-byte buffer + System.arraycopy(data, offsets[curLine] + chunkStart, buf, 0, chunkLen); + buf[chunkLen] = (byte) this.terminator; + + int chunkCur = 0; + // compress variant + while (chunkCur < chunkLen) { + long word = readLong(buf, chunkCur); + char code = shortCodes[first2(word)]; + if (noSuffixOpt && (code & 0xFF) < suffixLim) { + // 2 byte code without having to worry about longer matches + outBuf[outCur++] = (byte) code; + chunkCur += 2; + } else { + int idx = hashStr(word); + long icl = getICL(this.ht, idx); + long str = getStr(this.ht, idx); + outBuf[outCur + 1] = (byte) word; // speculatively write out escaped byte + word = removeIgnored(word, icl); + if ((icl < FSST_ICL_FREE) && str == word) { + outBuf[outCur++] = (byte) getCode(icl); + chunkCur += getLen(icl); + } else if (avoidBranch) { + // could be a 2-byte or 1-byte code, or miss + // handle everything with predication + outBuf[outCur] = (byte) code; + // if code has bit 9 set => move 2 spaces, because is escape code + outCur += 1 + ((code & FSST_CODE_BASE) >>> 8); + int symbolLen = code >>> FSST_LEN_BIT_OFFSET; + chunkCur += symbolLen; + } else if ((code & 0xFF) < byteLim) { + // 2 byte code after checking there is no longer pattern + outBuf[outCur++] = (byte) code; + chunkCur += 2; + } else { + // 1 byte code or miss. + outBuf[outCur] = (byte) code; + outCur += 1 + ((code & FSST_CODE_BASE) >>> 8); // predicated - tested with a branch, that was always worse + chunkCur++; + } + } + } + } while ((chunkStart += chunkLen) < lineLen); + } + + // set one more offset to provide last line length + outOffsets[numLines] = outCur; + return curLine; + } + + public byte[] exportToBytes() { + int totalStrLen = 0; + for (int len = 1; len <= 8; len++) { + char numWithLen = lenHisto[len - 1]; + totalStrLen += len * numWithLen; + } + + int outLen = 8 + totalStrLen; // 8 for len histo + byte[] out = new byte[outLen]; + int offset = 0; + for (char numWithLen : lenHisto) { + out[offset++] = (byte) numWithLen; + } + + int code = 0; + // current order of the str lengths in codes + for (int len : new int[] { 2, 3, 4, 5, 6, 7, 8, 1 }) { + char numWithLen = lenHisto[len - 1]; + for (int i = 0; i < numWithLen; ++i) { + long str = getStr(symbols, code); + for (int byteIdx = 0; byteIdx < len; byteIdx++) { + out[offset++] = (byte) (str >>> (8 * byteIdx)); + } + code++; + } + } + + return out; + } + } + + public static List makeSample(byte[] data, int[] offsets) { + return makeSample(data, offsets, FSST_SAMPLETARGET, FSST_SAMPLELINE); + } + + // quickly select a uniformly random set of lines such that we have between [FSST_SAMPLETARGET,FSST_SAMPLEMAXSZ) string bytes + // return list of indices within input offsets? + static List makeSample(byte[] data, int[] offsets, int sampleTargetLen, int sampleLineLen) { + List sample = new ArrayList<>(); + int totalSize = offsets[offsets.length - 1]; + if (totalSize < sampleTargetLen) { + for (int i = 0; i < offsets.length - 1; ++i) { + sample.add(Arrays.copyOfRange(data, offsets[i], offsets[i + 1])); + } + return sample; + } + + var random = new Random(456); + int numLines = offsets.length - 1; + int sampleSize = 0; + while (sampleSize < sampleTargetLen) { + int lineIdx = random.nextInt(numLines); + + // find next non-empty lines, wrapping around if necessary + int len = offsets[lineIdx + 1] - offsets[lineIdx]; + while (len == 0) { + if (++lineIdx == numLines) lineIdx = 0; + len = offsets[lineIdx + 1] - offsets[lineIdx]; + } + + if (len <= sampleLineLen) { + sample.add(Arrays.copyOfRange(data, offsets[lineIdx], offsets[lineIdx + 1])); + sampleSize += len; + } else { + int chunks = len / sampleLineLen + (len % sampleLineLen == 0 ? 0 : 1); + int chunk = random.nextInt(chunks); + int off = chunk * sampleLineLen; + int chunkLen = chunk == chunks - 1 ? len - off : sampleLineLen; + byte[] bytes = Arrays.copyOfRange(data, offsets[lineIdx] + off, offsets[lineIdx] + off + chunkLen); + sample.add(bytes); + sampleSize += chunkLen; + } + } + return sample; + } + + static class Counters { + char[] count1 = new char[FSST_CODE_MAX]; // array to count frequency of symbols as they occur in the sample + + char[] count2 = new char[FSST_CODE_MAX * FSST_CODE_MAX]; // array to count subsequent combinations of two symbols in the sample + + void count1Set(int pos1, char val) { + count1[pos1] = val; + } + + void count1Inc(int pos1) { + count1[pos1]++; + } + + void count2Inc(int pos1, int pos2) { + count2[(pos1 << FSST_CODE_BITS) + pos2]++; + } + + int count1Get(int pos1) { + return count1[pos1]; + } + + int count2Get(int pos1, int pos2) { + return count2[(pos1 << FSST_CODE_BITS) + pos2]; + } + + void clear() { + Arrays.fill(count1, (char) 0); + Arrays.fill(count2, (char) 0); + } + + void copyInto(Counters other) { + System.arraycopy(count1, 0, other.count1, 0, FSST_CODE_MAX); + System.arraycopy(count2, 0, other.count2, 0, FSST_CODE_MAX * FSST_CODE_MAX); + } + } + + public static class Decoder { + final byte[] lens; /* len[x] is the byte-length of the symbol x (1 < len[x] <= 8). */ + final long[] symbols; /* symbol[x] contains in LITTLE_ENDIAN the bytesequence that code x represents (0 <= x < 255). */ + + Decoder(byte[] lens, long[] symbols) { + this.lens = lens; + this.symbols = symbols; + } + + public static Decoder readFrom(byte[] exportedSymbolTable) throws IOException { + final int[] i = { 0 }; + return readFrom(() -> exportedSymbolTable[i[0]++]); + } + + public static Decoder readFrom(ByteReader in) throws IOException { + int[] lenHisto = new int[8]; + int numSymbols = 0; + for (int len = 1; len <= 8; len++) { + int numWithLen = lenHisto[len - 1] = in.readByte() & 0xFF; + numSymbols += numWithLen; + } + + byte[] lens = new byte[numSymbols]; + long[] symbols = new long[numSymbols]; + int code = 0; + for (int len : new int[] { 2, 3, 4, 5, 6, 7, 8, 1 }) { + int numWithLen = lenHisto[len - 1]; + + for (int i = 0; i < numWithLen; ++i) { + lens[code] = (byte) len; + + long symbol = 0; + for (int byteIdx = 0; byteIdx < len; ++byteIdx) { + symbol |= (in.readByte() & 0xFFL) << (byteIdx * 8); + } + symbols[code] = symbol; + code++; + } + } + + return new Decoder(lens, symbols); + } + } + + // Assumes you know length to decompress + // lenToConsume must not be longer than the compressed data length + // output must be large enough to fit the + // require that output buffer has 7 bytes more than required + // return output length + public static int decompress(byte[] in, int startOffset, int lenToConsume, Decoder decoder, byte[] output) throws IOException { + int code; + + int outIdx = 0; + int inIdx = startOffset; + int limit = startOffset + lenToConsume; + while (inIdx < limit) { + if ((code = in[inIdx++] & 0xFF) == ESCAPE_BYTE) { + output[outIdx++] = in[inIdx++]; + } else { + var symbol = decoder.symbols[code]; + var len = decoder.lens[code]; + writeLong(output, outIdx, symbol); + outIdx += len; + } + } + + return outIdx; + } + + @SuppressWarnings({ "fallthrough", "checkstyle:OneStatementPerLine" }) + public static int decompressUnrolled(byte[] in, int lenToConsume, Decoder decoder, byte[] output) throws IOException { + int posOut = 0; + long limit = lenToConsume; + int code; + int offset = 0; + while (offset + 4 <= limit) { + int nextBlock = readInt(in, offset); + int escapeMask = (nextBlock & 0x80808080) & ((((~nextBlock) & 0x7F7F7F7F) + 0x7F7F7F7F) ^ 0x80808080); + if (escapeMask == 0) { + code = nextBlock & 0xFF; + nextBlock >>>= 8; + writeLong(output, posOut, decoder.symbols[code]); + posOut += decoder.lens[code]; + code = nextBlock & 0xFF; + nextBlock >>>= 8; + writeLong(output, posOut, decoder.symbols[code]); + posOut += decoder.lens[code]; + code = nextBlock & 0xFF; + nextBlock >>>= 8; + writeLong(output, posOut, decoder.symbols[code]); + posOut += decoder.lens[code]; + code = nextBlock & 0xFF; + writeLong(output, posOut, decoder.symbols[code]); + posOut += decoder.lens[code]; + offset += 4; + } else { + int firstEscapePos = Long.numberOfTrailingZeros((long) escapeMask) >> 3; + switch (firstEscapePos) { /* Duff's device */ + case 3: + code = nextBlock & 0xFF; + nextBlock >>>= 8; + offset++; + writeLong(output, posOut, decoder.symbols[code]); + posOut += decoder.lens[code]; + // fall through + case 2: + code = nextBlock & 0xFF; + nextBlock >>>= 8; + offset++; + writeLong(output, posOut, decoder.symbols[code]); + posOut += decoder.lens[code]; + // fall through + case 1: + code = nextBlock & 0xFF; + offset++; + writeLong(output, posOut, decoder.symbols[code]); + posOut += decoder.lens[code]; + // fall through + case 0: /* decompress an escaped byte */ + offset += 2; + output[posOut++] = in[offset - 1]; + } + } + } + + if (offset + 2 <= limit) { + output[posOut] = in[offset + 1]; + if ((in[offset] & 0xFF) != ESCAPE_BYTE) { + code = in[offset++] & 0xFF; + writeLong(output, posOut, decoder.symbols[code]); + posOut += decoder.lens[code]; + if ((in[offset] & 0xFF) != ESCAPE_BYTE) { + code = in[offset++] & 0xFF; + writeLong(output, posOut, decoder.symbols[code]); + posOut += decoder.lens[code]; + } else { + offset += 2; + output[posOut++] = in[offset - 1]; + } + } else { + offset += 2; + posOut++; + } + } + if (offset < limit) { // last code cannot be an escape + code = in[offset++] & 0xFF; + writeLong(output, posOut, decoder.symbols[code]); + posOut += decoder.lens[code]; + } + + return posOut; + } + + public static byte[] toBytes(String text) { + return text.getBytes(StandardCharsets.UTF_8); + } + + public static String fromBytes(byte[] bytes) { + return new String(bytes, StandardCharsets.UTF_8); + } + + static byte[] toByteArray(long str, int len) { + byte[] arr = new byte[len]; + for (int i = 0; i < len; i++) { + arr[i] = (byte) ((int) (str >>> (8 * i))); + } + return arr; + } + + static String toString(long str, long icl) { + var strRep = fromBytes(toByteArray(str, getLen(icl))); + return "code: " + (int) getCode(icl) + ", len: " + getLen(icl) + ", str: '" + strRep + "'"; + } + + static String toString(long[] table, int idx) { + return toString(getStr(table, idx), getICL(table, idx)); + } + + static String printStr(long[] table, int idx) { + return fromBytes(toByteArray(getStr(table, idx), getLen(getICL(table, idx)))); + } + + public static void main(String[] args) throws IOException { + for (int i = 0; i < 100; i++) { + roundTrip(args[0]); + } + } + + public static void roundTrip(String fileName) throws IOException { + String content = Files.readString(Path.of(fileName), StandardCharsets.UTF_8); + + System.out.println("String length: " + content.length()); + + byte[] bytes = FSST.toBytes(content); + byte[] bytes2 = new byte[bytes.length + 8]; + System.arraycopy(bytes, 0, bytes2, 0, bytes.length); + int[] offsets = { 0, bytes.length }; + bytes = bytes2; + + byte[] outBuf = new byte[bytes.length]; + int[] outOffsets = new int[2]; + + List sample = FSST.makeSample(bytes, offsets); + var symbolTable = SymbolTable.buildSymbolTable(sample); + + long startComp = System.nanoTime(); + long linesCompressed = symbolTable.compressBulk(1, bytes, offsets, outBuf, outOffsets); + long endComp = System.nanoTime(); + + assert linesCompressed == 1; + long compressedLen = outOffsets[1]; + + byte[] symbolTableBytes = symbolTable.exportToBytes(); + Decoder decoder = Decoder.readFrom(symbolTableBytes); + + long startDec = System.nanoTime(); + byte[] decompressBuf = new byte[bytes.length + 8]; + var decoded = FSST.decompress(outBuf, 0, outOffsets[1], decoder, decompressBuf); + long endDec = System.nanoTime(); + + String uncompressedString = FSST.fromBytes(Arrays.copyOfRange(decompressBuf, 0, decoded)); + assert content.equals(uncompressedString); + + System.out.println("Comp Duration: " + (endComp - startComp) / 1e6 + "ms"); + System.out.println("Dec Duration: " + (endDec - startDec) / 1e6 + "ms"); + + long compressMs = endComp - startComp; + float compressMb = (float) bytes.length / (1 << 20); + double compressMbPerSec = compressMb * 1e9 / compressMs; + System.out.println("Comp rate: " + compressMbPerSec + " mb/s"); + + long decMs = endDec - startDec; + float decMb = (float) outOffsets[1] / (1 << 20); + double decMbPerSec = decMb * 1e9 / decMs; + System.out.println("Dec rate: " + decMbPerSec + " mb/s"); + + System.out.println("Original length: " + bytes.length); + System.out.println("Compressed length: " + compressedLen); + System.out.println("Compressed ratio: " + compressedLen / (float) bytes.length); + } + + public interface ByteReader { + byte readByte() throws IOException; + } + + public interface OffsetWriter { + void addLen(int len) throws IOException; + } +} diff --git a/server/src/main/java/org/elasticsearch/common/compress/fsst/ReservoirSampler.java b/server/src/main/java/org/elasticsearch/common/compress/fsst/ReservoirSampler.java new file mode 100644 index 0000000000000..2c03641e062d3 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/common/compress/fsst/ReservoirSampler.java @@ -0,0 +1,76 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.common.compress.fsst; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Random; + +import static org.elasticsearch.common.compress.fsst.FSST.FSST_SAMPLELINE; +import static org.elasticsearch.common.compress.fsst.FSST.FSST_SAMPLEMAXSZ; +import static org.elasticsearch.common.compress.fsst.FSST.FSST_SAMPLETARGET; + +public class ReservoirSampler { + private static final int SAMPLE_TARGET = FSST_SAMPLETARGET; + private static final int SAMPLE_MAX = FSST_SAMPLEMAXSZ; + private static final int SAMPLE_LINE = FSST_SAMPLELINE; + private int numBytesInSample = 0; + private int numChunksSeen = 0; + private final Random random = new Random(1234); + private List sample = new ArrayList<>(); + + public List getSample() { + return sample; + } + + // The byte array is only valid during this call, thus bytes need to be deep copied + public void processLine(byte[] bytes, int offset, int length) { + if (length == 0) { + return; + } + + // iterate over the chunks + int numChunks = length / SAMPLE_LINE + (length % SAMPLE_LINE == 0 ? 0 : 1); + for (int c = 0; c < numChunks; ++c) { + numChunksSeen++; + int chunkOffset = c * SAMPLE_LINE; + int chunkLen = c == numChunks - 1 ? length - chunkOffset : SAMPLE_LINE; + + if (numBytesInSample < SAMPLE_TARGET + SAMPLE_LINE) { + // If the reservoir isn't full, just add to it. + // This will occur on startup, but also if a recent swap caused us to go below the target. + // Add a buffer of an additional sample line, so that one swap doesn't cause us to fall below target. + byte[] chunkBytes = Arrays.copyOfRange(bytes, offset + chunkOffset, offset + chunkOffset + chunkLen); + sample.add(chunkBytes); + numBytesInSample += chunkBytes.length; + } else { + int p = random.nextInt(numChunksSeen); + if (p < sample.size()) { + // swap for an existing value + byte[] toAdd = Arrays.copyOfRange(bytes, offset + chunkOffset, offset + chunkOffset + chunkLen); + byte[] toRemove = sample.get(p); + numBytesInSample -= toRemove.length; + numBytesInSample += toAdd.length; + sample.set(p, toAdd); + + // Sample could now be too small if we swapped a small chunk for a big one. + // This will be rectified as the next chunk will just be added to the sample, in the if-block above + + // But if the sample is too large (from swapping big samples for small samples), + // we need to discard some + while (numBytesInSample > SAMPLE_MAX) { + numBytesInSample -= sample.removeLast().length; + } + } + } + } + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/PerFieldFormatSupplier.java b/server/src/main/java/org/elasticsearch/index/codec/PerFieldFormatSupplier.java index ecb0d6d5eb3ca..073f55a1f1f57 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/PerFieldFormatSupplier.java +++ b/server/src/main/java/org/elasticsearch/index/codec/PerFieldFormatSupplier.java @@ -20,6 +20,7 @@ import org.elasticsearch.index.IndexVersions; import org.elasticsearch.index.codec.bloomfilter.ES87BloomFilterPostingsFormat; import org.elasticsearch.index.codec.postings.ES812PostingsFormat; +import org.elasticsearch.index.codec.tsdb.BinaryDVCompressionMode; import org.elasticsearch.index.codec.tsdb.es819.ES819TSDBDocValuesFormat; import org.elasticsearch.index.mapper.CompletionFieldMapper; import org.elasticsearch.index.mapper.IdFieldMapper; @@ -36,6 +37,7 @@ public class PerFieldFormatSupplier { private static final DocValuesFormat docValuesFormat = new Lucene90DocValuesFormat(); private static final KnnVectorsFormat knnVectorsFormat = new Lucene99HnswVectorsFormat(); private static final ES819TSDBDocValuesFormat tsdbDocValuesFormat = new ES819TSDBDocValuesFormat(); + private static final DocValuesFormat stringDocValuesFormat = new ES819TSDBDocValuesFormat(BinaryDVCompressionMode.COMPRESSED_WITH_FSST); private static final ES812PostingsFormat es812PostingsFormat = new ES812PostingsFormat(); private static final PostingsFormat completionPostingsFormat = PostingsFormat.forName("Completion101"); @@ -105,6 +107,13 @@ public KnnVectorsFormat getKnnVectorsFormatForField(String field) { } public DocValuesFormat getDocValuesFormatForField(String field) { + if (mapperService != null) { + Mapper mapper = mapperService.mappingLookup().getMapper(field); + if (mapper != null && mapper.typeName().equals("wildcard")) { + return stringDocValuesFormat; + } + } + if (useTSDBDocValuesFormat(field)) { return tsdbDocValuesFormat; } diff --git a/server/src/main/java/org/elasticsearch/index/codec/tsdb/BinaryDVCompressionMode.java b/server/src/main/java/org/elasticsearch/index/codec/tsdb/BinaryDVCompressionMode.java new file mode 100644 index 0000000000000..577522c5e2ae8 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/tsdb/BinaryDVCompressionMode.java @@ -0,0 +1,31 @@ + +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.tsdb; + +public enum BinaryDVCompressionMode { + + NO_COMPRESS((byte) 0), + COMPRESSED_WITH_FSST((byte) 1); + + public final byte code; + + BinaryDVCompressionMode(byte code) { + this.code = code; + } + + public static BinaryDVCompressionMode fromMode(byte mode) { + return switch (mode) { + case 0 -> NO_COMPRESS; + case 1 -> COMPRESSED_WITH_FSST; + default -> throw new IllegalStateException("unknown compression mode [" + mode + "]"); + }; + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/tsdb/es819/ES819TSDBDocValuesConsumer.java b/server/src/main/java/org/elasticsearch/index/codec/tsdb/es819/ES819TSDBDocValuesConsumer.java index 3651be472051f..b27104c74cd70 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/tsdb/es819/ES819TSDBDocValuesConsumer.java +++ b/server/src/main/java/org/elasticsearch/index/codec/tsdb/es819/ES819TSDBDocValuesConsumer.java @@ -40,7 +40,11 @@ import org.apache.lucene.util.compress.LZ4; import org.apache.lucene.util.packed.DirectMonotonicWriter; import org.apache.lucene.util.packed.PackedInts; +import org.elasticsearch.common.compress.fsst.BulkCompressBufferer; +import org.elasticsearch.common.compress.fsst.FSST; +import org.elasticsearch.common.compress.fsst.ReservoirSampler; import org.elasticsearch.core.IOUtils; +import org.elasticsearch.index.codec.tsdb.BinaryDVCompressionMode; import org.elasticsearch.index.codec.tsdb.TSDBDocValuesEncoder; import java.io.IOException; @@ -63,6 +67,7 @@ final class ES819TSDBDocValuesConsumer extends XDocValuesConsumer { private byte[] termsDictBuffer; private final int skipIndexIntervalSize; final boolean enableOptimizedMerge; + private final BinaryDVCompressionMode binaryDVCompressionMode; ES819TSDBDocValuesConsumer( SegmentWriteState state, @@ -71,9 +76,11 @@ final class ES819TSDBDocValuesConsumer extends XDocValuesConsumer { String dataCodec, String dataExtension, String metaCodec, - String metaExtension + String metaExtension, + BinaryDVCompressionMode binaryDVCompressionMode ) throws IOException { this.termsDictBuffer = new byte[1 << 14]; + this.binaryDVCompressionMode = binaryDVCompressionMode; this.dir = state.directory; this.context = state.context; boolean success = false; @@ -150,13 +157,13 @@ private long[] writeField(FieldInfo field, TsdbDocValuesProducer valuesProducer, if (numValues > 0) { assert numDocsWithValue > 0; // Special case for maxOrd of 1, signal -1 that no blocks will be written - meta.writeInt(maxOrd != 1 ? ES819TSDBDocValuesFormat.DIRECT_MONOTONIC_BLOCK_SHIFT : -1); + meta.writeInt(maxOrd != 1 ? DIRECT_MONOTONIC_BLOCK_SHIFT : -1); final ByteBuffersDataOutput indexOut = new ByteBuffersDataOutput(); final DirectMonotonicWriter indexWriter = DirectMonotonicWriter.getInstance( meta, new ByteBuffersIndexOutput(indexOut, "temp-dv-index", "temp-dv-index"), 1L + ((numValues - 1) >>> ES819TSDBDocValuesFormat.NUMERIC_BLOCK_SHIFT), - ES819TSDBDocValuesFormat.DIRECT_MONOTONIC_BLOCK_SHIFT + DIRECT_MONOTONIC_BLOCK_SHIFT ); final long valuesDataOffset = data.getFilePointer(); @@ -272,7 +279,14 @@ public void mergeBinaryField(FieldInfo mergeFieldInfo, MergeState mergeState) th public void addBinaryField(FieldInfo field, DocValuesProducer valuesProducer) throws IOException { meta.writeInt(field.number); meta.writeByte(ES819TSDBDocValuesFormat.BINARY); + meta.writeByte(binaryDVCompressionMode.code); + switch (binaryDVCompressionMode) { + case NO_COMPRESS -> doAddUncompressedBinary(field, valuesProducer); + case COMPRESSED_WITH_FSST -> doAddCompressedBinaryFSST(field, valuesProducer); + } + } + public void doAddUncompressedBinary(FieldInfo field, DocValuesProducer valuesProducer) throws IOException { if (valuesProducer instanceof TsdbDocValuesProducer tsdbValuesProducer && tsdbValuesProducer.mergeStats.supported()) { final int numDocsWithField = tsdbValuesProducer.mergeStats.sumNumDocsWithField(); final int minLength = tsdbValuesProducer.mergeStats.minLength(); @@ -401,6 +415,199 @@ public void addBinaryField(FieldInfo field, DocValuesProducer valuesProducer) th } } + public void doAddCompressedBinaryFSST(FieldInfo field, DocValuesProducer valuesProducer) throws IOException { + if (valuesProducer instanceof TsdbDocValuesProducer tsdbValuesProducer && tsdbValuesProducer.mergeStats.supported()) { + final int numDocsWithField = tsdbValuesProducer.mergeStats.sumNumDocsWithField(); + final int minLength = tsdbValuesProducer.mergeStats.minLength(); + final int maxLength = tsdbValuesProducer.mergeStats.maxLength(); + assert maxLength >= minLength : "maxLength [" + maxLength + "] < minLength [" + minLength + "]"; + + BinaryDocValues values = valuesProducer.getBinary(field); + var sampler = new ReservoirSampler(); + + // Iteration 1: minLength, maxLength, numDocs, sample + for (int doc = values.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = values.nextDoc()) { + BytesRef v = values.binaryValue(); + sampler.processLine(v.bytes, v.offset, v.length); + } + + // Build encoder from sample + FSST.SymbolTable symbolTable = FSST.SymbolTable.buildSymbolTable(sampler.getSample()); + + assert numDocsWithField <= maxDoc; + + long start = data.getFilePointer(); + meta.writeLong(start); // dataOffset + + OffsetsAccumulator offsetsAccumulator = null; + DISIAccumulator disiAccumulator = null; + try { + if (numDocsWithField > 0 && numDocsWithField < maxDoc) { + disiAccumulator = new DISIAccumulator(dir, context, data, IndexedDISI.DEFAULT_DENSE_RANK_POWER); + } + offsetsAccumulator = new OffsetsAccumulator(dir, context, data, numDocsWithField); + CompressedOffsetWriter offsetWriter = new CompressedOffsetWriter(offsetsAccumulator); + + values = valuesProducer.getBinary(field); + try (var bulkCompressor = new BulkCompressBufferer(data, symbolTable, offsetWriter)) { + // Iteration 2: compress lines + for (int doc = values.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = values.nextDoc()) { + BytesRef v = values.binaryValue(); + bulkCompressor.addLine(v.bytes, v.offset, v.length); + if (disiAccumulator != null) { + disiAccumulator.addDocId(doc); + } + } + } + + // Write metadata + assert numDocsWithField <= maxDoc; + meta.writeLong(data.getFilePointer() - start); // dataLength + + if (numDocsWithField == 0) { + meta.writeLong(-2); // docsWithFieldOffset + meta.writeLong(0L); // docsWithFieldLength + meta.writeShort((short) -1); // jumpTableEntryCount + meta.writeByte((byte) -1); // denseRankPower + } else if (numDocsWithField == maxDoc) { + meta.writeLong(-1); // docsWithFieldOffset + meta.writeLong(0L); // docsWithFieldLength + meta.writeShort((short) -1); // jumpTableEntryCount + meta.writeByte((byte) -1); // denseRankPower + } else { + long offset = data.getFilePointer(); + meta.writeLong(offset); // docsWithFieldOffset + final short jumpTableEntryCount = disiAccumulator.build(data); + meta.writeLong(data.getFilePointer() - offset); // docsWithFieldLength + meta.writeShort(jumpTableEntryCount); + meta.writeByte(IndexedDISI.DEFAULT_DENSE_RANK_POWER); + } + + meta.writeInt(numDocsWithField); + meta.writeInt(minLength); + meta.writeInt(maxLength); + + int minCompressedLength = offsetWriter.minCompressedLength; + int maxCompressedLength = offsetWriter.maxCompressedLength; + + // add compression fields + meta.writeInt(minCompressedLength); + meta.writeInt(maxCompressedLength); + byte[] compressedSymbolTable = symbolTable.exportToBytes(); + meta.writeBytes(compressedSymbolTable, compressedSymbolTable.length); + + if (maxCompressedLength > minCompressedLength) { + offsetsAccumulator.build(meta, data); + } + } finally { + IOUtils.close(disiAccumulator, offsetsAccumulator); + } + } else { + BinaryDocValues values = valuesProducer.getBinary(field); + int numDocsWithField = 0; + int minLength = Integer.MAX_VALUE; + int maxLength = 0; + var sampler = new ReservoirSampler(); + + // Iteration 1: minLength, maxLength, numDocs, sample + for (int doc = values.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = values.nextDoc()) { + numDocsWithField++; + BytesRef v = values.binaryValue(); + int length = v.length; + minLength = Math.min(length, minLength); + maxLength = Math.max(length, maxLength); + sampler.processLine(v.bytes, v.offset, v.length); + } + + // Build encoder from sample + FSST.SymbolTable symbolTable = FSST.SymbolTable.buildSymbolTable(sampler.getSample()); + + DISIAccumulator disiAccumulator = null; + OffsetsAccumulator offsetsAccumulator = null; + try { + if (numDocsWithField > 0 && numDocsWithField < maxDoc) { + disiAccumulator = new DISIAccumulator(dir, context, data, IndexedDISI.DEFAULT_DENSE_RANK_POWER); + } + offsetsAccumulator = new OffsetsAccumulator(dir, context, data, numDocsWithField); + CompressedOffsetWriter offsetWriter = new CompressedOffsetWriter(offsetsAccumulator); + + // Compress Lines + values = valuesProducer.getBinary(field); + long start = data.getFilePointer(); + meta.writeLong(start); // dataOffset + try (var bulkCompressor = new BulkCompressBufferer(data, symbolTable, offsetWriter)) { + // Iteration 2: compress lines + for (int doc = values.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = values.nextDoc()) { + BytesRef v = values.binaryValue(); + bulkCompressor.addLine(v.bytes, v.offset, v.length); + if (disiAccumulator != null) { + disiAccumulator.addDocId(doc); + } + } + } + + // Write metadata + assert numDocsWithField <= maxDoc; + meta.writeLong(data.getFilePointer() - start); // dataLength + + if (numDocsWithField == 0) { + meta.writeLong(-2); // docsWithFieldOffset + meta.writeLong(0L); // docsWithFieldLength + meta.writeShort((short) -1); // jumpTableEntryCount + meta.writeByte((byte) -1); // denseRankPower + } else if (numDocsWithField == maxDoc) { + meta.writeLong(-1); // docsWithFieldOffset + meta.writeLong(0L); // docsWithFieldLength + meta.writeShort((short) -1); // jumpTableEntryCount + meta.writeByte((byte) -1); // denseRankPower + } else { + long offset = data.getFilePointer(); + meta.writeLong(offset); // docsWithFieldOffset + final short jumpTableEntryCount = disiAccumulator.build(data); + meta.writeLong(data.getFilePointer() - offset); // docsWithFieldLength + meta.writeShort(jumpTableEntryCount); + meta.writeByte(IndexedDISI.DEFAULT_DENSE_RANK_POWER); + } + + meta.writeInt(numDocsWithField); + meta.writeInt(minLength); + meta.writeInt(maxLength); + + int minCompressedLength = offsetWriter.minCompressedLength; + int maxCompressedLength = offsetWriter.maxCompressedLength; + + // add compression fields + meta.writeInt(minCompressedLength); + meta.writeInt(maxCompressedLength); + byte[] compressedSymbolTable = symbolTable.exportToBytes(); + meta.writeBytes(compressedSymbolTable, compressedSymbolTable.length); + + if (maxCompressedLength > minCompressedLength) { + offsetsAccumulator.build(meta, data); + } + } finally { + IOUtils.close(disiAccumulator, offsetsAccumulator); + } + } + } + + private static class CompressedOffsetWriter implements FSST.OffsetWriter { + int maxCompressedLength = 0; + int minCompressedLength = Integer.MAX_VALUE; + private final OffsetsAccumulator delegate; + + private CompressedOffsetWriter(OffsetsAccumulator delegate) { + this.delegate = delegate; + } + + public void addLen(int compressedLen) throws IOException { + assert compressedLen >= 0; + delegate.addDoc(compressedLen); + minCompressedLength = Math.min(compressedLen, minCompressedLength); + maxCompressedLength = Math.max(compressedLen, maxCompressedLength); + } + } + @Override public void addSortedField(FieldInfo field, DocValuesProducer valuesProducer) throws IOException { meta.writeInt(field.number); @@ -646,13 +853,13 @@ private void writeSortedNumericField(FieldInfo field, TsdbDocValuesProducer valu if (numValues > numDocsWithField) { long start = data.getFilePointer(); meta.writeLong(start); - meta.writeVInt(ES819TSDBDocValuesFormat.DIRECT_MONOTONIC_BLOCK_SHIFT); + meta.writeVInt(DIRECT_MONOTONIC_BLOCK_SHIFT); final DirectMonotonicWriter addressesWriter = DirectMonotonicWriter.getInstance( meta, data, numDocsWithField + 1L, - ES819TSDBDocValuesFormat.DIRECT_MONOTONIC_BLOCK_SHIFT + DIRECT_MONOTONIC_BLOCK_SHIFT ); long addr = 0; addressesWriter.add(addr); diff --git a/server/src/main/java/org/elasticsearch/index/codec/tsdb/es819/ES819TSDBDocValuesFormat.java b/server/src/main/java/org/elasticsearch/index/codec/tsdb/es819/ES819TSDBDocValuesFormat.java index 1a937e75ad5f9..1f89e1d82d4ff 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/tsdb/es819/ES819TSDBDocValuesFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/tsdb/es819/ES819TSDBDocValuesFormat.java @@ -14,6 +14,7 @@ import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; import org.elasticsearch.core.SuppressForbidden; +import org.elasticsearch.index.codec.tsdb.BinaryDVCompressionMode; import java.io.IOException; @@ -47,7 +48,8 @@ public class ES819TSDBDocValuesFormat extends org.apache.lucene.codecs.DocValues static final byte SORTED_NUMERIC = 4; static final int VERSION_START = 0; - static final int VERSION_CURRENT = VERSION_START; + static final int VERSION_BINARY_DV_COMPRESSION = 1; + static final int VERSION_CURRENT = VERSION_BINARY_DV_COMPRESSION; static final int TERMS_DICT_BLOCK_LZ4_SHIFT = 6; static final int TERMS_DICT_BLOCK_LZ4_SIZE = 1 << TERMS_DICT_BLOCK_LZ4_SHIFT; @@ -106,20 +108,30 @@ private static boolean getOptimizedMergeEnabledDefault() { final int skipIndexIntervalSize; private final boolean enableOptimizedMerge; + private final BinaryDVCompressionMode binaryDVCompressionMode; /** Default constructor. */ public ES819TSDBDocValuesFormat() { - this(DEFAULT_SKIP_INDEX_INTERVAL_SIZE, OPTIMIZED_MERGE_ENABLE_DEFAULT); + this(DEFAULT_SKIP_INDEX_INTERVAL_SIZE, OPTIMIZED_MERGE_ENABLE_DEFAULT, BinaryDVCompressionMode.NO_COMPRESS); + } + + public ES819TSDBDocValuesFormat(BinaryDVCompressionMode binaryDVCompressionMode) { + this(DEFAULT_SKIP_INDEX_INTERVAL_SIZE, OPTIMIZED_MERGE_ENABLE_DEFAULT, binaryDVCompressionMode); } /** Doc values fields format with specified skipIndexIntervalSize. */ - public ES819TSDBDocValuesFormat(int skipIndexIntervalSize, boolean enableOptimizedMerge) { + public ES819TSDBDocValuesFormat( + int skipIndexIntervalSize, + boolean enableOptimizedMerge, + BinaryDVCompressionMode binaryDVCompressionMode + ) { super(CODEC_NAME); if (skipIndexIntervalSize < 2) { throw new IllegalArgumentException("skipIndexIntervalSize must be > 1, got [" + skipIndexIntervalSize + "]"); } this.skipIndexIntervalSize = skipIndexIntervalSize; this.enableOptimizedMerge = enableOptimizedMerge; + this.binaryDVCompressionMode = binaryDVCompressionMode; } @Override @@ -131,7 +143,8 @@ public DocValuesConsumer fieldsConsumer(SegmentWriteState state) throws IOExcept DATA_CODEC, DATA_EXTENSION, META_CODEC, - META_EXTENSION + META_EXTENSION, + binaryDVCompressionMode ); } diff --git a/server/src/main/java/org/elasticsearch/index/codec/tsdb/es819/ES819TSDBDocValuesProducer.java b/server/src/main/java/org/elasticsearch/index/codec/tsdb/es819/ES819TSDBDocValuesProducer.java index 31d65bde1be0e..7ae0987ca1714 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/tsdb/es819/ES819TSDBDocValuesProducer.java +++ b/server/src/main/java/org/elasticsearch/index/codec/tsdb/es819/ES819TSDBDocValuesProducer.java @@ -41,7 +41,9 @@ import org.apache.lucene.util.compress.LZ4; import org.apache.lucene.util.packed.DirectMonotonicReader; import org.apache.lucene.util.packed.PackedInts; +import org.elasticsearch.common.compress.fsst.FSST; import org.elasticsearch.core.IOUtils; +import org.elasticsearch.index.codec.tsdb.BinaryDVCompressionMode; import org.elasticsearch.index.codec.tsdb.TSDBDocValuesEncoder; import java.io.IOException; @@ -89,7 +91,7 @@ final class ES819TSDBDocValuesProducer extends DocValuesProducer { state.segmentSuffix ); - readFields(in, state.fieldInfos); + readFields(in, state.fieldInfos, version); } catch (Throwable exception) { priorE = exception; @@ -182,6 +184,13 @@ public BinaryDocValues getBinary(FieldInfo field) throws IOException { return DocValues.emptyBinary(); } + return switch (entry.compression) { + case NO_COMPRESS -> getUncompressedBinary(entry); + case COMPRESSED_WITH_FSST -> getCompressedBinaryFSST(entry); + }; + } + + public BinaryDocValues getUncompressedBinary(BinaryEntry entry) throws IOException { final RandomAccessInput bytesSlice = data.randomAccessSlice(entry.dataOffset, entry.dataLength); if (entry.docsWithFieldOffset == -1) { @@ -256,6 +265,91 @@ public BytesRef binaryValue() throws IOException { } } + public BinaryDocValues getCompressedBinaryFSST(BinaryEntry entry) throws IOException { + + final RandomAccessInput bytesSlice = data.randomAccessSlice(entry.dataOffset, entry.dataLength); + + final int maxCompressedLength = entry.maxCompressedLength; + final int lengthToAllocate = entry.maxLength + 7; // uncompressed + if (entry.docsWithFieldOffset == -1) { + + // dense + if (entry.minCompressedLength == entry.maxCompressedLength) { + // fixed length + return new DenseBinaryDocValues(maxDoc) { + final BytesRef inBuf = new BytesRef(new byte[maxCompressedLength], 0, maxCompressedLength); + final BytesRef outBuf = new BytesRef(new byte[lengthToAllocate], 0, lengthToAllocate); + + @Override + public BytesRef binaryValue() throws IOException { + bytesSlice.readBytes((long) doc * maxCompressedLength, inBuf.bytes, 0, maxCompressedLength); + outBuf.length = FSST.decompress(inBuf.bytes, 0, inBuf.length, entry.decoder, outBuf.bytes); + return outBuf; + } + }; + } else { + // variable length + final RandomAccessInput addressesData = this.data.randomAccessSlice(entry.addressesOffset, entry.addressesLength); + final LongValues addresses = DirectMonotonicReader.getInstance(entry.addressesMeta, addressesData); + return new DenseBinaryDocValues(maxDoc) { + final BytesRef inBuf = new BytesRef(new byte[maxCompressedLength], 0, 0); + final BytesRef outBuf = new BytesRef(new byte[lengthToAllocate], 0, lengthToAllocate); + + @Override + public BytesRef binaryValue() throws IOException { + long startOffset = addresses.get(doc); + inBuf.length = (int) (addresses.get(doc + 1L) - startOffset); + bytesSlice.readBytes(startOffset, inBuf.bytes, 0, inBuf.length); + outBuf.length = FSST.decompress(inBuf.bytes, 0, inBuf.length, entry.decoder, outBuf.bytes); + return outBuf; + } + }; + } + } else { + // sparse + final IndexedDISI disi = new IndexedDISI( + data, + entry.docsWithFieldOffset, + entry.docsWithFieldLength, + entry.jumpTableEntryCount, + entry.denseRankPower, + entry.numDocsWithField + ); + if (entry.minCompressedLength == entry.maxCompressedLength) { + // fixed length + return new SparseBinaryDocValues(disi) { + final BytesRef inBuf = new BytesRef(new byte[maxCompressedLength], 0, entry.maxCompressedLength); + final BytesRef outBuf = new BytesRef(new byte[lengthToAllocate], 0, lengthToAllocate); + + @Override + public BytesRef binaryValue() throws IOException { + bytesSlice.readBytes((long) disi.index() * maxCompressedLength, inBuf.bytes, 0, maxCompressedLength); + outBuf.length = FSST.decompress(inBuf.bytes, 0, inBuf.length, entry.decoder, outBuf.bytes); + return outBuf; + } + }; + } else { + // variable length + final RandomAccessInput addressesData = this.data.randomAccessSlice(entry.addressesOffset, entry.addressesLength); + final LongValues addresses = DirectMonotonicReader.getInstance(entry.addressesMeta, addressesData); + return new SparseBinaryDocValues(disi) { + final BytesRef inBuf = new BytesRef(new byte[maxCompressedLength], 0, 0); + final BytesRef outBuf = new BytesRef(new byte[lengthToAllocate], 0, lengthToAllocate); + + @Override + public BytesRef binaryValue() throws IOException { + final int index = disi.index(); + long startOffset = addresses.get(index); + inBuf.length = (int) (addresses.get(index + 1L) - startOffset); + bytesSlice.readBytes(startOffset, inBuf.bytes, 0, inBuf.length); + outBuf.length = FSST.decompress(inBuf.bytes, 0, inBuf.length, entry.decoder, outBuf.bytes); + return outBuf; + } + }; + } + } + } + private abstract static class DenseBinaryDocValues extends BinaryDocValues { final int maxDoc; @@ -872,7 +966,7 @@ public void close() throws IOException { data.close(); } - private void readFields(IndexInput meta, FieldInfos infos) throws IOException { + private void readFields(IndexInput meta, FieldInfos infos, int version) throws IOException { for (int fieldNumber = meta.readInt(); fieldNumber != -1; fieldNumber = meta.readInt()) { FieldInfo info = infos.fieldInfo(fieldNumber); if (info == null) { @@ -885,7 +979,7 @@ private void readFields(IndexInput meta, FieldInfos infos) throws IOException { if (type == ES819TSDBDocValuesFormat.NUMERIC) { numerics.put(info.number, readNumeric(meta)); } else if (type == ES819TSDBDocValuesFormat.BINARY) { - binaries.put(info.number, readBinary(meta)); + binaries.put(info.number, readBinary(meta, version)); } else if (type == ES819TSDBDocValuesFormat.SORTED) { sorted.put(info.number, readSorted(meta)); } else if (type == ES819TSDBDocValuesFormat.SORTED_SET) { @@ -942,8 +1036,14 @@ private static void readNumeric(IndexInput meta, NumericEntry entry) throws IOEx entry.denseRankPower = meta.readByte(); } - private BinaryEntry readBinary(IndexInput meta) throws IOException { - final BinaryEntry entry = new BinaryEntry(); + private BinaryEntry readBinary(IndexInput meta, int version) throws IOException { + final BinaryDVCompressionMode compression; + if (version >= ES819TSDBDocValuesFormat.VERSION_BINARY_DV_COMPRESSION) { + compression = BinaryDVCompressionMode.fromMode(meta.readByte()); + } else { + compression = BinaryDVCompressionMode.NO_COMPRESS; + } + final BinaryEntry entry = new BinaryEntry(compression); entry.dataOffset = meta.readLong(); entry.dataLength = meta.readLong(); entry.docsWithFieldOffset = meta.readLong(); @@ -953,16 +1053,34 @@ private BinaryEntry readBinary(IndexInput meta) throws IOException { entry.numDocsWithField = meta.readInt(); entry.minLength = meta.readInt(); entry.maxLength = meta.readInt(); - if (entry.minLength < entry.maxLength) { - entry.addressesOffset = meta.readLong(); - // Old count of uncompressed addresses - long numAddresses = entry.numDocsWithField + 1L; + if (compression == BinaryDVCompressionMode.COMPRESSED_WITH_FSST) { + entry.minCompressedLength = meta.readInt(); + entry.maxCompressedLength = meta.readInt(); + entry.decoder = FSST.Decoder.readFrom(meta::readByte); + if (entry.minCompressedLength < entry.maxCompressedLength) { + entry.addressesOffset = meta.readLong(); - final int blockShift = meta.readVInt(); - entry.addressesMeta = DirectMonotonicReader.loadMeta(meta, numAddresses, blockShift); - entry.addressesLength = meta.readLong(); + // Old count of uncompressed addresses + long numAddresses = entry.numDocsWithField + 1L; + + final int blockShift = meta.readVInt(); + entry.addressesMeta = DirectMonotonicReader.loadMeta(meta, numAddresses, blockShift); + entry.addressesLength = meta.readLong(); + } + } else { // NO_COMPRESS + if (entry.minLength < entry.maxLength) { + entry.addressesOffset = meta.readLong(); + + // Old count of uncompressed addresses + long numAddresses = entry.numDocsWithField + 1L; + + final int blockShift = meta.readVInt(); + entry.addressesMeta = DirectMonotonicReader.loadMeta(meta, numAddresses, blockShift); + entry.addressesLength = meta.readLong(); + } } + return entry; } @@ -1458,6 +1576,18 @@ static class BinaryEntry { long addressesOffset; long addressesLength; DirectMonotonicReader.Meta addressesMeta; + + // Compression + final BinaryDVCompressionMode compression; + + // FSST + int minCompressedLength; + int maxCompressedLength; + FSST.Decoder decoder; + + private BinaryEntry(BinaryDVCompressionMode compression) { + this.compression = compression; + } } static class SortedNumericEntry extends NumericEntry { diff --git a/server/src/test/java/org/elasticsearch/index/codec/tsdb/DocValuesCodecDuelTests.java b/server/src/test/java/org/elasticsearch/index/codec/tsdb/DocValuesCodecDuelTests.java index f0ce28f11a51a..4ff3b4188cb73 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/tsdb/DocValuesCodecDuelTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/tsdb/DocValuesCodecDuelTests.java @@ -58,7 +58,7 @@ public void testDuel() throws IOException { Codec codec = new Elasticsearch900Lucene101Codec() { final DocValuesFormat docValuesFormat = randomBoolean() - ? new ES819TSDBDocValuesFormat() + ? new ES819TSDBDocValuesFormat(randomFrom(BinaryDVCompressionMode.values())) : new TestES87TSDBDocValuesFormat(); @Override diff --git a/server/src/test/java/org/elasticsearch/index/codec/tsdb/es819/ES819TSDBDocValuesFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/tsdb/es819/ES819TSDBDocValuesFormatTests.java index 368d6f23d0fa1..c1c43f9ad9d40 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/tsdb/es819/ES819TSDBDocValuesFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/tsdb/es819/ES819TSDBDocValuesFormatTests.java @@ -28,16 +28,18 @@ import org.apache.lucene.util.BytesRef; import org.elasticsearch.cluster.metadata.DataStream; import org.elasticsearch.index.codec.Elasticsearch900Lucene101Codec; +import org.elasticsearch.index.codec.tsdb.BinaryDVCompressionMode; import org.elasticsearch.index.codec.tsdb.ES87TSDBDocValuesFormatTests; import java.util.Arrays; import java.util.Locale; +import java.util.Random; public class ES819TSDBDocValuesFormatTests extends ES87TSDBDocValuesFormatTests { private final Codec codec = new Elasticsearch900Lucene101Codec() { - final ES819TSDBDocValuesFormat docValuesFormat = new ES819TSDBDocValuesFormat(); + final ES819TSDBDocValuesFormat docValuesFormat = new ES819TSDBDocValuesFormat(randomBinaryDVCompressionMode(random())); @Override public DocValuesFormat getDocValuesFormatForField(String field) { @@ -528,4 +530,8 @@ private IndexWriterConfig getTimeSeriesIndexWriterConfig(String hostnameField, S return config; } + public static BinaryDVCompressionMode randomBinaryDVCompressionMode(Random random) { + BinaryDVCompressionMode[] values = BinaryDVCompressionMode.values(); + return values[random.nextInt(0, values.length)]; + } } diff --git a/server/src/test/java/org/elasticsearch/index/codec/tsdb/es819/ES819TSDBDocValuesFormatVariableSkipIntervalTests.java b/server/src/test/java/org/elasticsearch/index/codec/tsdb/es819/ES819TSDBDocValuesFormatVariableSkipIntervalTests.java index d158236ecc7ac..74bca821ba7e1 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/tsdb/es819/ES819TSDBDocValuesFormatVariableSkipIntervalTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/tsdb/es819/ES819TSDBDocValuesFormatVariableSkipIntervalTests.java @@ -17,14 +17,18 @@ public class ES819TSDBDocValuesFormatVariableSkipIntervalTests extends ES87TSDBD @Override protected Codec getCodec() { + var compressionMode = ES819TSDBDocValuesFormatTests.randomBinaryDVCompressionMode(random()); // small interval size to test with many intervals - return TestUtil.alwaysDocValuesFormat(new ES819TSDBDocValuesFormat(random().nextInt(4, 16), random().nextBoolean())); + return TestUtil.alwaysDocValuesFormat( + new ES819TSDBDocValuesFormat(random().nextInt(4, 16), random().nextBoolean(), compressionMode) + ); } public void testSkipIndexIntervalSize() { + var compressionMode = ES819TSDBDocValuesFormatTests.randomBinaryDVCompressionMode(random()); IllegalArgumentException ex = expectThrows( IllegalArgumentException.class, - () -> new ES819TSDBDocValuesFormat(random().nextInt(Integer.MIN_VALUE, 2), random().nextBoolean()) + () -> new ES819TSDBDocValuesFormat(random().nextInt(Integer.MIN_VALUE, 2), random().nextBoolean(), compressionMode) ); assertTrue(ex.getMessage().contains("skipIndexIntervalSize must be > 1")); }