diff --git a/CHANGES.txt b/CHANGES.txt index 9bca6a4bdee0..c0f0f8d4682d 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -1,3 +1,6 @@ +trunk?? (The current trunk is on 5.1) +* Support ZSTD dictionary compression (CASSANDRA-17021) + 5.1 * Add cqlsh autocompletion for the identity mapping feature (CASSANDRA-20021) * Add DDL Guardrail enabling administrators to disallow creation/modification of keyspaces with durable_writes = false (CASSANDRA-20913) diff --git a/conf/cassandra.yaml b/conf/cassandra.yaml index 0147adc74b6a..a50bd9db0fa5 100644 --- a/conf/cassandra.yaml +++ b/conf/cassandra.yaml @@ -2849,3 +2849,49 @@ storage_compatibility_mode: NONE # # especially in keyspaces with many tables. The splitter avoids batching tables together if they # # exceed other configuration parameters like bytes_per_assignment or partitions_per_assignment. # max_tables_per_assignment: 64 + +# Dictionary compression settings for ZSTD dictionary-based compression +# These settings control the automatic training and caching of compression dictionaries +# for tables that use ZSTD dictionary compression. + +# How often to refresh compression dictionaries across the cluster. +# During refresh, nodes will check for newer dictionary versions and update their caches. +# Min unit: s +compression_dictionary_refresh_interval: 3600s + +# Initial delay before starting the first dictionary refresh cycle after node startup. +# This prevents all nodes from refreshing simultaneously when the cluster starts. +# Min unit: s +compression_dictionary_refresh_initial_delay: 10s + +# Maximum number of compression dictionaries to cache per table. +# Each table using dictionary compression can have multiple dictionaries cached +# (current version plus recently used versions for reading older SSTables). +compression_dictionary_cache_size: 10 + +# How long to keep compression dictionaries in the cache before they expire. +# Expired dictionaries will be removed from memory but can be reloaded if needed. +# Min unit: s +compression_dictionary_cache_expire: 24h + +# Dictionary training configuration (advanced settings) +# These settings control how compression dictionaries are trained from sample data. + +# Maximum size of a trained compression dictionary. +# Larger dictionaries may provide better compression but use more memory. +compression_dictionary_training_max_dictionary_size: 64KiB + +# Maximum total size of sample data to collect for dictionary training. +# More sample data generally produces better dictionaries but takes longer to train. +# The recommended sample size is 100x the dictionary size. +compression_dictionary_training_max_total_sample_size: 10MiB + +# Enable automatic dictionary training based on sampling of write operations. +# When enabled, the system will automatically collect samples and train new dictionaries. +# Manual training via nodetool is always available regardless of this setting. +compression_dictionary_training_auto_train_enabled: false + +# Sampling rate for automatic dictionary training (1-10000). +# Value of 100 means 1% of writes are sampled. Lower values reduce overhead but may +# result in less representative sample data for dictionary training. +compression_dictionary_training_sampling_rate: 0.01 diff --git a/conf/cassandra_latest.yaml b/conf/cassandra_latest.yaml index f42604843146..1235d2f57a9c 100644 --- a/conf/cassandra_latest.yaml +++ b/conf/cassandra_latest.yaml @@ -2529,3 +2529,49 @@ storage_compatibility_mode: NONE # # especially in keyspaces with many tables. The splitter avoids batching tables together if they # # exceed other configuration parameters like bytes_per_assignment or partitions_per_assignment. # max_tables_per_assignment: 64 + +# Dictionary compression settings for ZSTD dictionary-based compression +# These settings control the automatic training and caching of compression dictionaries +# for tables that use ZSTD dictionary compression. + +# How often to refresh compression dictionaries across the cluster. +# During refresh, nodes will check for newer dictionary versions and update their caches. +# Min unit: s +compression_dictionary_refresh_interval: 3600s + +# Initial delay before starting the first dictionary refresh cycle after node startup. +# This prevents all nodes from refreshing simultaneously when the cluster starts. +# Min unit: s +compression_dictionary_refresh_initial_delay: 10s + +# Maximum number of compression dictionaries to cache per table. +# Each table using dictionary compression can have multiple dictionaries cached +# (current version plus recently used versions for reading older SSTables). +compression_dictionary_cache_size: 10 + +# How long to keep compression dictionaries in the cache before they expire. +# Expired dictionaries will be removed from memory but can be reloaded if needed. +# Min unit: s +compression_dictionary_cache_expire: 24h + +# Dictionary training configuration (advanced settings) +# These settings control how compression dictionaries are trained from sample data. + +# Maximum size of a trained compression dictionary. +# Larger dictionaries may provide better compression but use more memory. +compression_dictionary_training_max_dictionary_size: 64KiB + +# Maximum total size of sample data to collect for dictionary training. +# More sample data generally produces better dictionaries but takes longer to train. +# The recommended sample size is 100x the dictionary size. +compression_dictionary_training_max_total_sample_size: 10MiB + +# Enable automatic dictionary training based on sampling of write operations. +# When enabled, the system will automatically collect samples and train new dictionaries. +# Manual training via nodetool is always available regardless of this setting. +compression_dictionary_training_auto_train_enabled: false + +# Sampling rate for automatic dictionary training (1-10000). +# Value of 100 means 1% of writes are sampled. Lower values reduce overhead but may +# result in less representative sample data for dictionary training. +compression_dictionary_training_sampling_rate: 0.01 diff --git a/doc/modules/cassandra/pages/managing/operating/compression.adoc b/doc/modules/cassandra/pages/managing/operating/compression.adoc index f5c6f2f7aa13..967cd411d607 100644 --- a/doc/modules/cassandra/pages/managing/operating/compression.adoc +++ b/doc/modules/cassandra/pages/managing/operating/compression.adoc @@ -49,6 +49,8 @@ these areas (A is relatively good, F is relatively bad): |https://facebook.github.io/zstd/[Zstd] |`ZstdCompressor` | A- | A- | A+ | `>= 4.0` +|https://facebook.github.io/zstd/[Zstd with Dictionary] |`ZstdDictionaryCompressor` | A- | A- | A++ | `>= 6.0` + |http://google.github.io/snappy/[Snappy] |`SnappyCompressor` | A- | A | C | `>= 1.0` |https://zlib.net[Deflate (zlib)] |`DeflateCompressor` | C | C | A | `>= 1.0` @@ -60,6 +62,9 @@ cycle spent. This is why it is the default choice in Cassandra. For storage critical applications (disk footprint), however, `Zstd` may be a better choice as it can get significant additional ratio to `LZ4`. +For workloads with highly repetitive or similar data patterns, +`ZstdDictionaryCompressor` can achieve even better compression ratios by +training a compression dictionary on representative data samples. `Snappy` is kept for backwards compatibility and `LZ4` will typically be preferable. @@ -67,6 +72,91 @@ preferable. `Deflate` is kept for backwards compatibility and `Zstd` will typically be preferable. +== ZSTD Dictionary Compression + +The `ZstdDictionaryCompressor` extends standard ZSTD compression by using +trained compression dictionaries to achieve superior compression ratios, +particularly for workloads with repetitive or similar data patterns. + +=== How Dictionary Compression Works + +Dictionary compression improves upon standard compression by training a +compression dictionary on representative samples of your data. This +dictionary captures common patterns, repeated strings, and data structures, +allowing the compressor to reference these patterns more efficiently than +discovering them independently in each compression chunk. + +=== When to Use Dictionary Compression + +Dictionary compression is most effective for: + +* *Tables with similar row structures*: JSON documents, XML data, or +repeated data schemas benefit significantly from dictionary compression. +* *Storage-critical workloads*: When disk space savings justify the +additional operational overhead of dictionary training and management. +* *Large datasets with repetitive patterns*: The more similar your data, +the better the compression ratio improvement. + +Dictionary compression may not be ideal for: + +* *Highly random or unique data*: Already-compressed data or cryptographic +data will see minimal benefit. +* *Small tables*: The overhead of dictionary management may outweigh the +storage savings. +* *Frequently changing schemas*: Schema changes may require retraining +dictionaries to maintain optimal compression ratios. + +=== Dictionary Training + +Before dictionary compression can provide optimal results, a compression +dictionary must be trained on representative data samples. Cassandra +supports both manual and automatic training approaches. + +==== Manual Dictionary Training + +Use the `nodetool traincompressiondictionary` command to manually train +a compression dictionary: + +[source,bash] +---- +nodetool traincompressiondictionary +---- + +The command trains a dictionary by sampling from existing SSTables. If no +SSTables are available on disk (e.g., all data is in memtables), the command +will automatically flush the memtable before sampling. + +The training process completes synchronously and displays progress information +including sample count, sample size, and elapsed time. Training typically +completes within minutes for most workloads. + +==== Automatic Dictionary Training + +Enable automatic training in `cassandra.yaml`: + +[source,yaml] +---- +compression_dictionary_training_auto_train_enabled: true +compression_dictionary_training_sampling_rate: 100 # 1% of writes +---- + +When enabled, Cassandra automatically samples write operations and +trains dictionaries in the background based on the configured sampling +rate (range: 1-10000, where 100 = 1% of writes). + +=== Dictionary Storage and Distribution + +Compression dictionaries are stored cluster-wide in the +`system_distributed.compression_dictionaries` table. Each table can +maintain multiple dictionary versions: the current dictionary for +compressing new SSTables, plus historical dictionaries needed for +reading older SSTables. + +Dictionaries are identified by `dict_id`, with higher IDs representing +newer dictionaries. Cassandra automatically refreshes dictionaries +across the cluster based on configured intervals, and caches them +locally to minimize lookup overhead. + == Configuring Compression Compression is configured on a per-table basis as an optional argument @@ -105,6 +195,17 @@ should be used with caution, as they require more memory. The default of `3` is a good choice for competing with `Deflate` ratios and `1` is a good choice for competing with `LZ4`. +The `ZstdDictionaryCompressor` supports the same options as +`ZstdCompressor`: + +* `compression_level` (default `3`): Same range and behavior as +`ZstdCompressor`. Dictionary compression provides improved ratios at +any compression level compared to standard ZSTD. + +NOTE: `ZstdDictionaryCompressor` requires a trained compression +dictionary to achieve optimal results. See the ZSTD Dictionary +Compression section above for training instructions. + Users can set compression using the following syntax: [source,cql] @@ -121,6 +222,25 @@ ALTER TABLE keyspace.table WITH compression = {'class': 'LZ4Compressor', 'chunk_length_in_kb': 64}; ---- +For dictionary compression: + +[source,cql] +---- +CREATE TABLE keyspace.table (id int PRIMARY KEY) + WITH compression = {'class': 'ZstdDictionaryCompressor'}; +---- + +Or with a specific compression level: + +[source,cql] +---- +ALTER TABLE keyspace.table + WITH compression = { + 'class': 'ZstdDictionaryCompressor', + 'compression_level': '3' + }; +---- + Once enabled, compression can be disabled with `ALTER TABLE` setting `enabled` to `false`: @@ -140,6 +260,63 @@ immediately, the operator can trigger an SSTable rewrite using `nodetool scrub` or `nodetool upgradesstables -a`, both of which will rebuild the SSTables on disk, re-compressing the data in the process. +== Dictionary Compression Configuration + +When using `ZstdDictionaryCompressor`, several additional configuration +options are available in `cassandra.yaml` to control dictionary +management, caching, and training behavior. + +=== Dictionary Refresh Settings + +* `compression_dictionary_refresh_interval` (default: `3600`): How often +(in seconds) to check for and refresh compression dictionaries +cluster-wide. Newly trained dictionaries will be picked up by all nodes +within this interval. +* `compression_dictionary_refresh_initial_delay` (default: `10`): Initial +delay (in seconds) before the first dictionary refresh check after node +startup. + +=== Dictionary Caching + +* `compression_dictionary_cache_size` (default: `10`): Maximum number of +compression dictionaries to cache per table. Higher values reduce lookup +overhead but increase memory usage. +* `compression_dictionary_cache_expire` (default: `3600`): Dictionary +cache entry TTL in seconds. Expired entries are evicted and reloaded on +next access. + +=== Training Configuration + +* `compression_dictionary_training_max_dictionary_size` (default: `65536`): +Maximum size of trained dictionaries in bytes. Larger dictionaries can +capture more patterns but increase memory overhead. +* `compression_dictionary_training_max_total_sample_size` (default: +`10485760`): Maximum total size of sample data to collect for training, +approximately 10MB. +* `compression_dictionary_training_auto_train_enabled` (default: `false`): +Enable automatic background dictionary training. When enabled, Cassandra +samples writes and trains dictionaries automatically. +* `compression_dictionary_training_sampling_rate` (default: `100`): +Sampling rate for automatic training, range 1-10000 where 100 = 1% of +writes. Lower values reduce training overhead but may miss data patterns. + +Example configuration: + +[source,yaml] +---- +# Dictionary refresh and caching +compression_dictionary_refresh_interval: 3600 +compression_dictionary_refresh_initial_delay: 10 +compression_dictionary_cache_size: 10 +compression_dictionary_cache_expire: 3600 + +# Automatic training +compression_dictionary_training_auto_train_enabled: false +compression_dictionary_training_sampling_rate: 100 +compression_dictionary_training_max_dictionary_size: 65536 +compression_dictionary_training_max_total_sample_size: 10485760 +---- + == Other options * `crc_check_chance` (default: `1.0`): determines how likely Cassandra @@ -186,6 +363,39 @@ correctness of data on disk, compressed tables allow the user to set probabilistically validate chunks on read to verify bits on disk are not corrupt. +=== Dictionary Compression Operational Considerations + +When using `ZstdDictionaryCompressor`, additional operational factors +apply: + +* *Dictionary Storage*: Compression dictionaries are stored in the +`system_distributed.compression_dictionaries` table and replicated +cluster-wide. Each table maintains current and historical dictionary +versions. +* *Dictionary Cache Memory*: Dictionaries are cached locally on each node +according to `compression_dictionary_cache_size`. Memory overhead is +typically minimal (default 64KB per dictionary × cache size). +* *Dictionary Training Overhead*: Manual training via +`nodetool traincompressiondictionary` samples SSTable chunk data and +performs CPU-intensive dictionary training. Consider running training +during off-peak hours. +* *Automatic Training Impact*: When +`compression_dictionary_training_auto_train_enabled` is true, write +operations are sampled based on `compression_dictionary_training_sampling_rate`. +This adds minimal overhead but should be monitored in write-intensive +workloads. +* *Dictionary Refresh*: The dictionary refresh process +(`compression_dictionary_refresh_interval`) checks for new dictionaries +cluster-wide. The default 1-hour interval balances freshness with +overhead. +* *SSTable Compatibility*: Each SSTable is compressed with a specific +dictionary version. Historical dictionaries must be retained to read +older SSTables until they are compacted with new dictionaries. +* *Schema Changes*: Significant schema changes or data pattern shifts may +require retraining dictionaries to maintain optimal compression ratios. +Monitor the `SSTable Compression Ratio` via `nodetool tablestats` to +detect degradation. + == Advanced Use Advanced users can provide their own compression class by implementing diff --git a/pylib/cqlshlib/cqlhandling.py b/pylib/cqlshlib/cqlhandling.py index cd19e39fda6b..90e552fc275d 100644 --- a/pylib/cqlshlib/cqlhandling.py +++ b/pylib/cqlshlib/cqlhandling.py @@ -44,6 +44,7 @@ class CqlParsingRuleSet(pylexotron.ParsingRuleSet): 'SnappyCompressor', 'LZ4Compressor', 'ZstdCompressor', + 'ZstdDictionaryCompressor' ) available_compaction_classes = ( diff --git a/src/java/org/apache/cassandra/config/Config.java b/src/java/org/apache/cassandra/config/Config.java index 9783e086ad13..f5f149e49135 100644 --- a/src/java/org/apache/cassandra/config/Config.java +++ b/src/java/org/apache/cassandra/config/Config.java @@ -514,6 +514,17 @@ public static class SSTableConfig public volatile DurationSpec.IntSecondsBound counter_cache_save_period = new DurationSpec.IntSecondsBound("7200s"); public volatile int counter_cache_keys_to_save = Integer.MAX_VALUE; + public volatile DurationSpec.IntSecondsBound compression_dictionary_refresh_interval = new DurationSpec.IntSecondsBound("3600s"); // 1 hour - TODO: re-assess whether daily (86400s) is more appropriate + public volatile DurationSpec.IntSecondsBound compression_dictionary_refresh_initial_delay = new DurationSpec.IntSecondsBound("10s"); // 10 seconds default + public volatile int compression_dictionary_cache_size = 10; // max dictionaries per table + public volatile DurationSpec.IntSecondsBound compression_dictionary_cache_expire = new DurationSpec.IntSecondsBound("24h"); + + // Dictionary training settings + public volatile DataStorageSpec.IntKibibytesBound compression_dictionary_training_max_dictionary_size = new DataStorageSpec.IntKibibytesBound("64KiB"); + public volatile DataStorageSpec.IntKibibytesBound compression_dictionary_training_max_total_sample_size = new DataStorageSpec.IntKibibytesBound("10MiB"); + public volatile boolean compression_dictionary_training_auto_train_enabled = false; + public volatile float compression_dictionary_training_sampling_rate = 0.01f; // samples 1% + public DataStorageSpec.LongMebibytesBound paxos_cache_size = null; public DataStorageSpec.LongMebibytesBound consensus_migration_cache_size = null; diff --git a/src/java/org/apache/cassandra/config/DatabaseDescriptor.java b/src/java/org/apache/cassandra/config/DatabaseDescriptor.java index ec76193e1046..fcef1af0eb7d 100644 --- a/src/java/org/apache/cassandra/config/DatabaseDescriptor.java +++ b/src/java/org/apache/cassandra/config/DatabaseDescriptor.java @@ -4361,6 +4361,47 @@ public static void setCounterCacheKeysToSave(int counterCacheKeysToSave) conf.counter_cache_keys_to_save = counterCacheKeysToSave; } + public static int getCompressionDictionaryRefreshIntervalSeconds() + { + return conf.compression_dictionary_refresh_interval.toSeconds(); + } + + public static int getCompressionDictionaryRefreshInitialDelaySeconds() + { + return conf.compression_dictionary_refresh_initial_delay.toSeconds(); + } + + public static int getCompressionDictionaryCacheSize() + { + return conf.compression_dictionary_cache_size; + } + + public static int getCompressionDictionaryCacheExpireSeconds() + { + return conf.compression_dictionary_cache_expire.toSeconds(); + } + + public static int getCompressionDictionaryTrainingMaxDictionarySize() + { + return conf.compression_dictionary_training_max_dictionary_size.toBytes(); + } + + public static int getCompressionDictionaryTrainingMaxTotalSampleSize() + { + return conf.compression_dictionary_training_max_total_sample_size.toBytes(); + } + + public static boolean getCompressionDictionaryTrainingAutoTrainEnabled() + { + return conf.compression_dictionary_training_auto_train_enabled; + } + + + public static float getCompressionDictionaryTrainingSamplingRate() + { + return conf.compression_dictionary_training_sampling_rate; + } + public static int getStreamingKeepAlivePeriod() { return conf.streaming_keep_alive_period.toSeconds(); diff --git a/src/java/org/apache/cassandra/db/ColumnFamilyStore.java b/src/java/org/apache/cassandra/db/ColumnFamilyStore.java index 0d6cd2eb9be2..839a21d55d67 100644 --- a/src/java/org/apache/cassandra/db/ColumnFamilyStore.java +++ b/src/java/org/apache/cassandra/db/ColumnFamilyStore.java @@ -83,6 +83,7 @@ import org.apache.cassandra.db.compaction.CompactionManager; import org.apache.cassandra.db.compaction.CompactionStrategyManager; import org.apache.cassandra.db.compaction.OperationType; +import org.apache.cassandra.db.compression.CompressionDictionaryManager; import org.apache.cassandra.db.filter.ClusteringIndexFilter; import org.apache.cassandra.db.filter.DataLimits; import org.apache.cassandra.db.lifecycle.ILifecycleTransaction; @@ -320,6 +321,7 @@ public enum FlushReason public final TopPartitionTracker topPartitions; private final SSTableImporter sstableImporter; + private final CompressionDictionaryManager compressionDictionaryManager; private volatile boolean compactionSpaceCheck = true; @@ -390,6 +392,7 @@ public void reload(TableMetadata tableMetadata) cfs.crcCheckChance = new DefaultValue<>(tableMetadata.params.crcCheckChance); compactionStrategyManager.maybeReloadParamsFromSchema(tableMetadata.params.compaction); + compressionDictionaryManager.maybeReloadFromSchema(tableMetadata.params.compression); indexManager.reload(tableMetadata); @@ -576,6 +579,7 @@ public ColumnFamilyStore(Keyspace keyspace, streamManager = new CassandraStreamManager(this); repairManager = new CassandraTableRepairManager(this); sstableImporter = new SSTableImporter(this); + compressionDictionaryManager = new CompressionDictionaryManager(this, registerBookeeping); if (DatabaseDescriptor.isClientOrToolInitialized() || SchemaConstants.isSystemKeyspace(getKeyspaceName())) topPartitions = null; @@ -733,6 +737,8 @@ public void invalidate(boolean expectMBean, boolean dropData) invalidateCaches(); if (topPartitions != null) topPartitions.close(); + + compressionDictionaryManager.close(); } /** @@ -3420,6 +3426,12 @@ public TableMetrics getMetrics() return metric; } + @Override + public CompressionDictionaryManager compressionDictionaryManager() + { + return compressionDictionaryManager; + } + public TableId getTableId() { return metadata().id; diff --git a/src/java/org/apache/cassandra/db/compaction/CompactionManager.java b/src/java/org/apache/cassandra/db/compaction/CompactionManager.java index 3ab00acfb954..4313eb82a1d4 100644 --- a/src/java/org/apache/cassandra/db/compaction/CompactionManager.java +++ b/src/java/org/apache/cassandra/db/compaction/CompactionManager.java @@ -1796,6 +1796,7 @@ public static SSTableWriter createWriter(ColumnFamilyStore cfs, .setSerializationHeader(sstable.header) .addDefaultComponents(cfs.indexManager.listIndexGroups()) .setSecondaryIndexGroups(cfs.indexManager.listIndexGroups()) + .setCompressionDictionaryManager(cfs.compressionDictionaryManager()) .build(txn, cfs); } @@ -1836,6 +1837,7 @@ public static SSTableWriter createWriterForAntiCompaction(ColumnFamilyStore cfs, .setSerializationHeader(SerializationHeader.make(cfs.metadata(), sstables)) .addDefaultComponents(cfs.indexManager.listIndexGroups()) .setSecondaryIndexGroups(cfs.indexManager.listIndexGroups()) + .setCompressionDictionaryManager(cfs.compressionDictionaryManager()) .build(txn, cfs); } diff --git a/src/java/org/apache/cassandra/db/compaction/Upgrader.java b/src/java/org/apache/cassandra/db/compaction/Upgrader.java index 9e4c4dd7502b..22913a84f612 100644 --- a/src/java/org/apache/cassandra/db/compaction/Upgrader.java +++ b/src/java/org/apache/cassandra/db/compaction/Upgrader.java @@ -85,6 +85,7 @@ private SSTableWriter createCompactionWriter(StatsMetadata metadata) .setSerializationHeader(SerializationHeader.make(cfs.metadata(), Sets.newHashSet(sstable))) .addDefaultComponents(cfs.indexManager.listIndexGroups()) .setSecondaryIndexGroups(cfs.indexManager.listIndexGroups()) + .setCompressionDictionaryManager(cfs.compressionDictionaryManager()) .build(transaction, cfs); } diff --git a/src/java/org/apache/cassandra/db/compaction/unified/ShardedMultiWriter.java b/src/java/org/apache/cassandra/db/compaction/unified/ShardedMultiWriter.java index ab5465df253c..79efadbfa37f 100644 --- a/src/java/org/apache/cassandra/db/compaction/unified/ShardedMultiWriter.java +++ b/src/java/org/apache/cassandra/db/compaction/unified/ShardedMultiWriter.java @@ -118,6 +118,7 @@ private SSTableWriter createWriter(Descriptor descriptor) .setSerializationHeader(header) .addDefaultComponents(indexGroups) .setSecondaryIndexGroups(indexGroups) + .setCompressionDictionaryManager(cfs.compressionDictionaryManager()) .build(txn, cfs); } diff --git a/src/java/org/apache/cassandra/db/compaction/writers/CompactionAwareWriter.java b/src/java/org/apache/cassandra/db/compaction/writers/CompactionAwareWriter.java index ea21f7be57e0..fd1966ad35d3 100644 --- a/src/java/org/apache/cassandra/db/compaction/writers/CompactionAwareWriter.java +++ b/src/java/org/apache/cassandra/db/compaction/writers/CompactionAwareWriter.java @@ -329,6 +329,7 @@ protected long getExpectedWriteSize() .setRepairedAt(minRepairedAt) .setPendingRepair(pendingRepair) .setSecondaryIndexGroups(cfs.indexManager.listIndexGroups()) - .addDefaultComponents(cfs.indexManager.listIndexGroups()); + .addDefaultComponents(cfs.indexManager.listIndexGroups()) + .setCompressionDictionaryManager(cfs.compressionDictionaryManager()); } } diff --git a/src/java/org/apache/cassandra/db/compression/CompressionDictionary.java b/src/java/org/apache/cassandra/db/compression/CompressionDictionary.java new file mode 100644 index 000000000000..780c40ed332b --- /dev/null +++ b/src/java/org/apache/cassandra/db/compression/CompressionDictionary.java @@ -0,0 +1,265 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.db.compression; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.EOFException; +import java.io.IOException; +import java.util.Objects; +import javax.annotation.Nullable; + +import com.google.common.base.Preconditions; +import com.google.common.hash.Hasher; +import com.google.common.hash.Hashing; + +import org.apache.cassandra.cql3.UntypedResultSet; +import org.apache.cassandra.io.compress.ICompressor; +import org.apache.cassandra.io.compress.ZstdDictionaryCompressor; + +public interface CompressionDictionary extends AutoCloseable +{ + /** + * Get the dictionary id + * + * @return dictionary id + */ + DictId dictId(); + + /** + * Get the raw bytes of the compression dictionary + * + * @return raw compression dictionary + */ + byte[] rawDictionary(); + + /** + * Get the kind of the compression algorithm + * + * @return compression algorithm kind + */ + default Kind kind() + { + return dictId().kind; + } + + /** + * Write compression dictionary to file + * + * @param out file output stream + * @throws IOException on any I/O exception when writing to the file + */ + default void serialize(DataOutput out) throws IOException + { + DictId dictId = dictId(); + int ordinal = dictId.kind.ordinal(); + out.writeByte(ordinal); + out.writeLong(dictId.id); + byte[] dict = rawDictionary(); + out.writeInt(dict.length); + out.write(dict); + int checksum = calculateChecksum((byte) ordinal, dictId.id, dict); + out.writeInt(checksum); + } + + /** + * A factory method to create concrete CompressionDictionary from the file content + * + * @param input file input stream + * @param manager compression dictionary manager that caches the dictionaries + * @return compression dictionary; otherwise, null if there is no dictionary + * @throws IOException on any I/O exception when reading from the file + */ + @Nullable + static CompressionDictionary deserialize(DataInput input, @Nullable CompressionDictionaryManager manager) throws IOException + { + int kindOrdinal; + try + { + kindOrdinal = input.readByte(); + } + catch (EOFException eof) + { + // no dictionary + return null; + } + + if (kindOrdinal < 0 || kindOrdinal >= Kind.values().length) + { + throw new IOException("Invalid compression dictionary kind: " + kindOrdinal); + } + Kind kind = Kind.values()[kindOrdinal]; + long id = input.readLong(); + DictId dictId = new DictId(kind, id); + + if (manager != null) + { + CompressionDictionary dictionary = manager.get(dictId); + if (dictionary != null) + { + return dictionary; + } + } + + int length = input.readInt(); + byte[] dict = new byte[length]; + input.readFully(dict); + int checksum = input.readInt(); + int calculatedChecksum = calculateChecksum((byte) kindOrdinal, id, dict); + if (checksum != calculatedChecksum) + throw new IOException("Compression dictionary checksum does not match. " + + "Expected: " + checksum + "; actual: " + calculatedChecksum); + + CompressionDictionary dictionary = kind.createDictionary(dictId, dict); + + // update the dictionary manager if it exists + if (manager != null) + { + manager.add(dictionary); + } + + return dictionary; + } + + static CompressionDictionary createFromRow(UntypedResultSet.Row row) + { + String kindStr = row.getString("kind"); + long dictId = row.getLong("dict_id"); + + try + { + Kind kind = CompressionDictionary.Kind.valueOf(kindStr); + return kind.createDictionary(new DictId(kind, dictId), row.getByteArray("dict")); + } + catch (IllegalArgumentException ex) + { + throw new IllegalStateException(kindStr + " compression dictionary is not created for dict id " + dictId); + } + } + + @SuppressWarnings("UnstableApiUsage") + static int calculateChecksum(byte kindOrdinal, long dictId, byte[] dict) + { + Hasher hasher = Hashing.crc32c().newHasher(); + hasher.putByte(kindOrdinal); + hasher.putLong(dictId); + hasher.putBytes(dict); + return hasher.hash().asInt(); + } + + // Defines the compression dictionary kind, as well as serves as a factory for creating components of dictionary compression + enum Kind + { + // Order matters: the enum ordinal is serialized + ZSTD + { + public CompressionDictionary createDictionary(DictId dictId, byte[] dict) + { + return new ZstdCompressionDictionary(dictId, dict); + } + + @Override + public ICompressor createCompressor(CompressionDictionary dictionary) + { + Preconditions.checkArgument(dictionary instanceof ZstdCompressionDictionary, + "Expected dictionary to be ZstdCompressionDictionary; actual: %s", + dictionary.getClass().getSimpleName()); + return ZstdDictionaryCompressor.create((ZstdCompressionDictionary) dictionary); + } + + @Override + public ICompressionDictionaryTrainer createTrainer(String keyspaceName, + String tableName, + CompressionDictionaryTrainingConfig config, + ICompressor compressor) + { + Preconditions.checkArgument(compressor instanceof ZstdDictionaryCompressor, + "Expected compressor to be ZstdDictionaryCompressor; actual: %s", + compressor.getClass().getSimpleName()); + return new ZstdDictionaryTrainer(keyspaceName, tableName, config, ((ZstdDictionaryCompressor) compressor).compressionLevel()); + } + }; + + /** + * Creates a compression dictionary instance for this kind + * + * @param dictId the dictionary identifier + * @param dict the raw dictionary bytes + * @return a compression dictionary instance + */ + public abstract CompressionDictionary createDictionary(CompressionDictionary.DictId dictId, byte[] dict); + + /** + * Creates a dictionary compressor for this kind + * + * @param dictionary the compression dictionary to use for compression + * @return a dictionary compressor instance + */ + public abstract ICompressor createCompressor(CompressionDictionary dictionary); + + /** + * Creates a dictionary trainer for this kind + * + * @param keyspaceName the keyspace name + * @param tableName the table name + * @param config the training configuration + * @param compressor the compressor to use for training + * @return a dictionary trainer instance + */ + public abstract ICompressionDictionaryTrainer createTrainer(String keyspaceName, + String tableName, + CompressionDictionaryTrainingConfig config, + ICompressor compressor); + } + + final class DictId + { + public final Kind kind; + public final long id; // A value of negative or 0 means no dictionary + + public DictId(Kind kind, long id) + { + this.kind = kind; + this.id = id; + } + + @Override + public boolean equals(Object o) + { + if (!(o instanceof DictId)) return false; + DictId dictId = (DictId) o; + return id == dictId.id && kind == dictId.kind; + } + + @Override + public int hashCode() + { + return Objects.hash(kind, id); + } + + @Override + public String toString() + { + return "DictId{" + + "kind=" + kind + + ", id=" + id + + '}'; + } + } +} diff --git a/src/java/org/apache/cassandra/db/compression/CompressionDictionaryCache.java b/src/java/org/apache/cassandra/db/compression/CompressionDictionaryCache.java new file mode 100644 index 000000000000..7da84b437940 --- /dev/null +++ b/src/java/org/apache/cassandra/db/compression/CompressionDictionaryCache.java @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.db.compression; + +import java.time.Duration; +import java.util.concurrent.atomic.AtomicReference; +import javax.annotation.Nullable; + +import com.google.common.annotations.VisibleForTesting; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.github.benmanes.caffeine.cache.Cache; +import com.github.benmanes.caffeine.cache.Caffeine; +import com.github.benmanes.caffeine.cache.RemovalCause; +import org.apache.cassandra.config.DatabaseDescriptor; +import org.apache.cassandra.db.compression.CompressionDictionary.DictId; + +/** + * Manages caching and current dictionary state for compression dictionaries. + *

+ * This class handles: + * - Local caching of compression dictionaries with automatic cleanup + * - Managing the current active dictionary for write operations + * - Thread-safe access to cached dictionaries + */ +public class CompressionDictionaryCache implements ICompressionDictionaryCache +{ + private static final Logger logger = LoggerFactory.getLogger(CompressionDictionaryCache.class); + + private final Cache cache; + private final AtomicReference currentDictId = new AtomicReference<>(); + + public CompressionDictionaryCache() + { + this(DatabaseDescriptor.getCompressionDictionaryCacheSize(), DatabaseDescriptor.getCompressionDictionaryCacheExpireSeconds()); + } + + @VisibleForTesting + CompressionDictionaryCache(int maximumSize, int expireAfterSeconds) + { + this.cache = Caffeine.newBuilder() + .maximumSize(maximumSize) + .expireAfterAccess(Duration.ofSeconds(expireAfterSeconds)) + .removalListener((DictId dictId, + CompressionDictionary dictionary, + RemovalCause cause) -> { + // Close dictionary when evicted from cache to free native resources + // SelfRefCounted ensures dictionary won't be actually closed if still referenced by compressors + if (dictionary != null) + { + try + { + dictionary.close(); + } + catch (Exception e) + { + logger.warn("Failed to close compression dictionary {}", dictId, e); + } + } + }) + .build(); + } + + @Nullable + @Override + public CompressionDictionary getCurrent() + { + DictId dictId = currentDictId.get(); + return dictId == null ? null : get(dictId); + } + + @Nullable + @Override + public CompressionDictionary get(DictId dictId) + { + return cache.getIfPresent(dictId); + } + + @Override + public void add(@Nullable CompressionDictionary compressionDictionary) + { + if (compressionDictionary == null) + return; + + // Only update cache if not already in the cache + DictId newDictId = compressionDictionary.dictId(); + cache.get(newDictId, id -> compressionDictionary); + + // Update current dictionary if we don't have one or the new one has a higher ID (newer) + DictId currentId = currentDictId.get(); + while ((currentId == null || newDictId.id > currentId.id) + && !currentDictId.compareAndSet(currentId, newDictId)) + { + currentId = currentDictId.get(); + } + } + + @Override + public synchronized void close() + { + currentDictId.set(null); + // Invalidate cache will trigger removalListener to close all cached dictionaries, including the currentDictionary + cache.invalidateAll(); + // Force synchronous cleanup to ensure removal listener executes immediately + cache.cleanUp(); + } +} diff --git a/src/java/org/apache/cassandra/db/compression/CompressionDictionaryEventHandler.java b/src/java/org/apache/cassandra/db/compression/CompressionDictionaryEventHandler.java new file mode 100644 index 000000000000..f0193de063ea --- /dev/null +++ b/src/java/org/apache/cassandra/db/compression/CompressionDictionaryEventHandler.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.db.compression; + +import org.apache.cassandra.concurrent.ScheduledExecutors; +import org.apache.cassandra.db.ColumnFamilyStore; +import org.apache.cassandra.locator.InetAddressAndPort; +import org.apache.cassandra.net.Message; +import org.apache.cassandra.net.MessagingService; +import org.apache.cassandra.net.Verb; +import org.apache.cassandra.schema.SystemDistributedKeyspace; +import org.apache.cassandra.tcm.ClusterMetadata; +import org.apache.cassandra.utils.FBUtilities; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Collection; + +/** + * Handles compression dictionary events including training completion and cluster notifications. + *

+ * This class handles: + * - Broadcasting dictionary updates to cluster nodes + * - Retrieving new dictionaries when notified by other nodes + * - Managing dictionary cache updates + */ +public class CompressionDictionaryEventHandler implements ICompressionDictionaryEventHandler +{ + private static final Logger logger = LoggerFactory.getLogger(CompressionDictionaryEventHandler.class); + + private final ColumnFamilyStore cfs; + private final String keyspaceName; + private final String tableName; + private final ICompressionDictionaryCache cache; + + public CompressionDictionaryEventHandler(ColumnFamilyStore cfs, ICompressionDictionaryCache cache) + { + this.cfs = cfs; + this.keyspaceName = cfs.keyspace.getName(); + this.tableName = cfs.getTableName(); + this.cache = cache; + } + + @Override + public void onNewDictionaryTrained(CompressionDictionary.DictId dictionaryId) + { + logger.info("Notifying cluster about dictionary update for {}.{} with {}", + keyspaceName, tableName, dictionaryId); + + CompressionDictionaryUpdateMessage message = new CompressionDictionaryUpdateMessage(cfs.metadata().id, dictionaryId); + Collection allNodes = ClusterMetadata.current().directory.allJoinedEndpoints(); + // Broadcast notification using the fire-and-forget fashion + for (InetAddressAndPort node : allNodes) + { + if (node.equals(FBUtilities.getBroadcastAddressAndPort())) // skip ourself + continue; + sendNotification(node, message); + } + } + + @Override + public void onNewDictionaryAvailable(CompressionDictionary.DictId dictionaryId) + { + // Best effort to retrieve the dictionary; otherwise, the periodic task should retrieve the dictionary later + ScheduledExecutors.nonPeriodicTasks.submit(() -> { + try + { + if (!cfs.metadata().params.compression.isDictionaryCompressionEnabled()) + { + return; + } + + CompressionDictionary dictionary = SystemDistributedKeyspace.retrieveCompressionDictionary(keyspaceName, tableName, dictionaryId); + cache.add(dictionary); + } + catch (Exception e) + { + logger.warn("Failed to retrieve compression dictionary for {}.{}. {}", + keyspaceName, tableName, dictionaryId, e); + } + }); + } + + // Best effort to notify the peer regarding the new dictionary being available to pull. + // If the request fails, each peer has periodic task scheduled to pull. + private void sendNotification(InetAddressAndPort target, CompressionDictionaryUpdateMessage message) + { + logger.debug("Sending dictionary update notification for {} to {}", message.dictionaryId, target); + + Message msg = Message.out(Verb.DICTIONARY_UPDATE_REQ, message); + MessagingService.instance() + .sendWithResponse(target, msg) + .addListener(future -> { + if (future.isSuccess()) + { + logger.debug("Successfully sent dictionary update notification to {}", target); + } + else + { + logger.warn("Failed to send dictionary update notification to {}", + target, future.cause()); + } + }); + } +} diff --git a/src/java/org/apache/cassandra/db/compression/CompressionDictionaryManager.java b/src/java/org/apache/cassandra/db/compression/CompressionDictionaryManager.java new file mode 100644 index 000000000000..a84b4b01c159 --- /dev/null +++ b/src/java/org/apache/cassandra/db/compression/CompressionDictionaryManager.java @@ -0,0 +1,355 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.db.compression; + +import java.nio.ByteBuffer; +import java.util.Set; +import javax.annotation.Nullable; +import javax.management.openmbean.CompositeData; + +import com.google.common.annotations.VisibleForTesting; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.cassandra.config.DatabaseDescriptor; +import org.apache.cassandra.db.ColumnFamilyStore; +import org.apache.cassandra.io.sstable.format.SSTableReader; +import org.apache.cassandra.schema.CompressionParams; +import org.apache.cassandra.schema.SystemDistributedKeyspace; +import org.apache.cassandra.utils.MBeanWrapper; +import org.apache.cassandra.utils.MBeanWrapper.OnException; + +public class CompressionDictionaryManager implements CompressionDictionaryManagerMBean, + ICompressionDictionaryCache, + ICompressionDictionaryEventHandler, + AutoCloseable +{ + private static final Logger logger = LoggerFactory.getLogger(CompressionDictionaryManager.class); + + private final String keyspaceName; + private final String tableName; + private final ColumnFamilyStore columnFamilyStore; + private volatile boolean mbeanRegistered; + private volatile boolean isEnabled; + + // Components + private final ICompressionDictionaryEventHandler eventHandler; + private final ICompressionDictionaryCache cache; + private final ICompressionDictionaryScheduler scheduler; + private ICompressionDictionaryTrainer trainer = null; + + public CompressionDictionaryManager(ColumnFamilyStore columnFamilyStore, boolean registerBookkeeping) + { + this.keyspaceName = columnFamilyStore.keyspace.getName(); + this.tableName = columnFamilyStore.getTableName(); + this.columnFamilyStore = columnFamilyStore; + + this.isEnabled = columnFamilyStore.metadata().params.compression.isDictionaryCompressionEnabled(); + this.cache = new CompressionDictionaryCache(); + this.eventHandler = new CompressionDictionaryEventHandler(columnFamilyStore, cache); + this.scheduler = new CompressionDictionaryScheduler(keyspaceName, tableName, cache, isEnabled); + if (isEnabled) + { + // Initialize components + this.trainer = ICompressionDictionaryTrainer.create(keyspaceName, tableName, + columnFamilyStore.metadata().params.compression, + createTrainingConfig()); + trainer.setDictionaryTrainedListener(this::handleNewDictionary); + + scheduler.scheduleRefreshTask(); + + trainer.start(false); + } + + if (registerBookkeeping && isEnabled) + { + registerMbean(); + } + } + + static String mbeanName(String keyspaceName, String tableName) + { + return MBEAN_NAME + ",keyspace=" + keyspaceName + ",table=" + tableName; + } + + public boolean isEnabled() + { + return isEnabled; + } + + /** + * Reloads dictionary management configuration when compression parameters change. + * This method enables or disables dictionary compression based on the new parameters, + * and properly manages the lifecycle of training and refresh tasks. + * + * @param newParams the new compression parameters to apply + */ + public synchronized void maybeReloadFromSchema(CompressionParams newParams) + { + this.isEnabled = newParams.isDictionaryCompressionEnabled(); + scheduler.setEnabled(isEnabled); + if (isEnabled) + { + registerMbean(); + // Check if we need a new trainer due to compression parameter changes + boolean needsNewTrainer = shouldCreateNewTrainer(newParams); + + if (needsNewTrainer) + { + // Close existing trainer and create a new one + if (trainer != null) + { + try + { + trainer.close(); + } + catch (Exception e) + { + logger.warn("Failed to close existing trainer for {}.{}", keyspaceName, tableName, e); + } + } + + trainer = ICompressionDictionaryTrainer.create(keyspaceName, tableName, newParams, createTrainingConfig()); + trainer.setDictionaryTrainedListener(this::handleNewDictionary); + } + + scheduler.scheduleRefreshTask(); + + // Start trainer if it exists + if (trainer != null) + { + trainer.start(false); + } + return; + } + + // Clean up when dictionary compression is disabled + try + { + close(); + } + catch (Exception e) + { + logger.warn("Failed to close CompressionDictionaryManager on disabling " + + "dictionary-based compression for table {}.{}", keyspaceName, tableName); + } + } + + /** + * Adds a sample to the dictionary trainer for learning compression patterns. + * Samples are randomly selected to avoid bias and improve dictionary quality. + * + * @param sample the sample data to potentially add for training + */ + public void addSample(ByteBuffer sample) + { + ICompressionDictionaryTrainer dictionaryTrainer = trainer; + if (dictionaryTrainer != null && dictionaryTrainer.shouldSample()) + { + dictionaryTrainer.addSample(sample); + } + } + + @Nullable + @Override + public CompressionDictionary getCurrent() + { + return cache.getCurrent(); + } + + @Override + public CompressionDictionary get(CompressionDictionary.DictId dictId) + { + return cache.get(dictId); + } + + @Override + public void add(@Nullable CompressionDictionary compressionDictionary) + { + cache.add(compressionDictionary); + } + + @Override + public void onNewDictionaryTrained(CompressionDictionary.DictId dictionaryId) + { + eventHandler.onNewDictionaryTrained(dictionaryId); + } + + @Override + public void onNewDictionaryAvailable(CompressionDictionary.DictId dictionaryId) + { + eventHandler.onNewDictionaryAvailable(dictionaryId); + } + + @Override + public synchronized void train() + { + // Validate table supports dictionary compression + if (!isEnabled) + { + throw new UnsupportedOperationException("Table " + keyspaceName + '.' + tableName + " does not support dictionary compression"); + } + + if (trainer == null) + { + throw new IllegalStateException("Dictionary trainer is not available for table " + keyspaceName + '.' + tableName); + } + + // SSTable-based training: sample from existing SSTables + Set sstables = columnFamilyStore.getLiveSSTables(); + if (sstables.isEmpty()) + { + logger.info("No SSTables available for training in table {}.{}, flushing memtable first", + keyspaceName, tableName); + columnFamilyStore.forceBlockingFlush(ColumnFamilyStore.FlushReason.USER_FORCED); + sstables = columnFamilyStore.getLiveSSTables(); + + if (sstables.isEmpty()) + { + throw new IllegalStateException("No SSTables available for training in table " + keyspaceName + '.' + tableName + " after flush"); + } + } + + logger.info("Starting SSTable-based training for {}.{} with {} SSTables", + keyspaceName, tableName, sstables.size()); + + trainer.start(true); + scheduler.scheduleSSTableBasedTraining(trainer, sstables, createTrainingConfig()); + } + + @Override + public CompositeData getTrainingState() + { + ICompressionDictionaryTrainer dictionaryTrainer = trainer; + if (dictionaryTrainer == null) + { + return TrainingState.notStarted().toCompositeData(); + } + return dictionaryTrainer.getTrainingState().toCompositeData(); + } + + /** + * Close all the resources. The method can be called multiple times. + */ + @Override + public synchronized void close() + { + unregisterMbean(); + if (trainer != null) + { + closeQuitely(trainer, "CompressionDictionaryTrainer"); + trainer = null; + } + closeQuitely(cache, "CompressionDictionaryCache"); + closeQuitely(scheduler, "CompressionDictionaryScheduler"); + } + + private void handleNewDictionary(CompressionDictionary dictionary) + { + // sequence meatters; persist the new dictionary before broadcasting to others. + storeDictionary(dictionary); + onNewDictionaryTrained(dictionary.dictId()); + } + + private CompressionDictionaryTrainingConfig createTrainingConfig() + { + CompressionParams compressionParams = columnFamilyStore.metadata().params.compression; + return CompressionDictionaryTrainingConfig + .builder() + .maxDictionarySize(DatabaseDescriptor.getCompressionDictionaryTrainingMaxDictionarySize()) + .maxTotalSampleSize(DatabaseDescriptor.getCompressionDictionaryTrainingMaxTotalSampleSize()) + .samplingRate(DatabaseDescriptor.getCompressionDictionaryTrainingSamplingRate()) + .chunkSize(compressionParams.chunkLength()) + .build(); + } + + private void storeDictionary(CompressionDictionary dictionary) + { + if (!isEnabled) + { + return; + } + + SystemDistributedKeyspace.storeCompressionDictionary(keyspaceName, tableName, dictionary); + cache.add(dictionary); + } + + /** + * Determines if a new trainer should be created based on compression parameter changes. + * A new trainer is needed when no existing trainer exists or when the existing trainer + * is not compatible with the new compression parameters. + * + * The method is (and should be) only invoked inside {@link #maybeReloadFromSchema(CompressionParams)}, + * which is guarded by synchronized. + * + * @param newParams the new compression parameters + * @return true if a new trainer should be created + */ + private boolean shouldCreateNewTrainer(CompressionParams newParams) + { + if (trainer == null) + { + return true; + } + + return !trainer.isCompatibleWith(newParams); + } + + private void registerMbean() + { + if (!mbeanRegistered) + { + MBeanWrapper.instance.registerMBean(this, mbeanName(keyspaceName, tableName)); + mbeanRegistered = true; + } + } + + private void unregisterMbean() + { + if (mbeanRegistered) + { + MBeanWrapper.instance.unregisterMBean(mbeanName(keyspaceName, tableName), OnException.IGNORE); + mbeanRegistered = false; + } + } + + private void closeQuitely(AutoCloseable closeable, String objectName) + { + try + { + closeable.close(); + } + catch (Exception exception) + { + logger.warn("Failed closing {}", objectName, exception); + } + } + + @VisibleForTesting + boolean isReady() + { + return trainer != null && trainer.isReady(); + } + + @VisibleForTesting + ICompressionDictionaryTrainer trainer() + { + return trainer; + } +} diff --git a/src/java/org/apache/cassandra/db/compression/CompressionDictionaryManagerMBean.java b/src/java/org/apache/cassandra/db/compression/CompressionDictionaryManagerMBean.java new file mode 100644 index 000000000000..93f248619282 --- /dev/null +++ b/src/java/org/apache/cassandra/db/compression/CompressionDictionaryManagerMBean.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.db.compression; + +import javax.management.openmbean.CompositeData; + +public interface CompressionDictionaryManagerMBean +{ + String MBEAN_NAME = "org.apache.cassandra.db.compression:type=CompressionDictionaryManager"; + + /** + * Starts training from existing SSTables for this table. + * Samples chunks from all live SSTables and trains a compression dictionary. + * If no SSTables are available, automatically flushes the memtable first. + * This operation runs synchronously and blocks until training completes. + * + * @throws UnsupportedOperationException if table doesn't support dictionary compression + * @throws IllegalStateException if no SSTables available after flush + */ + void train(); + + /** + * Gets the current training state for this table. + * Returns a snapshot of {@link TrainingState} as JMX CompositeData. + * + * @return CompositeData representing {@link TrainingState} + */ + CompositeData getTrainingState(); +} diff --git a/src/java/org/apache/cassandra/db/compression/CompressionDictionaryScheduler.java b/src/java/org/apache/cassandra/db/compression/CompressionDictionaryScheduler.java new file mode 100644 index 000000000000..f2910b58fdc9 --- /dev/null +++ b/src/java/org/apache/cassandra/db/compression/CompressionDictionaryScheduler.java @@ -0,0 +1,268 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.db.compression; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.TimeUnit; + +import com.google.common.annotations.VisibleForTesting; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.cassandra.concurrent.ScheduledExecutors; +import org.apache.cassandra.config.DatabaseDescriptor; +import org.apache.cassandra.io.sstable.format.SSTableReader; +import org.apache.cassandra.schema.SystemDistributedKeyspace; +import org.apache.cassandra.utils.concurrent.Ref; + +/** + * Manages scheduled tasks for compression dictionary operations. + *

+ * This class handles: + * - Periodic refresh of dictionaries from system tables + * - Manual training task scheduling and monitoring + * - Cleanup of scheduled tasks + */ +public class CompressionDictionaryScheduler implements ICompressionDictionaryScheduler +{ + private static final Logger logger = LoggerFactory.getLogger(CompressionDictionaryScheduler.class); + + private final String keyspaceName; + private final String tableName; + private final ICompressionDictionaryCache cache; + + private volatile ScheduledFuture scheduledRefreshTask; + private volatile ScheduledFuture scheduledManualTrainingTask; + private volatile boolean isEnabled; + + public CompressionDictionaryScheduler(String keyspaceName, + String tableName, + ICompressionDictionaryCache cache, + boolean isEnabled) + { + this.keyspaceName = keyspaceName; + this.tableName = tableName; + this.cache = cache; + this.isEnabled = isEnabled; + } + + /** + * Schedules the periodic dictionary refresh task if not already scheduled. + */ + public void scheduleRefreshTask() + { + if (scheduledRefreshTask != null) + return; + + this.scheduledRefreshTask = ScheduledExecutors.scheduledTasks.scheduleWithFixedDelay( + this::refreshDictionaryFromSystemTable, + DatabaseDescriptor.getCompressionDictionaryRefreshInitialDelaySeconds(), + DatabaseDescriptor.getCompressionDictionaryRefreshIntervalSeconds(), + TimeUnit.SECONDS + ); + } + + @Override + public void scheduleSSTableBasedTraining(ICompressionDictionaryTrainer trainer, + Set sstables, + CompressionDictionaryTrainingConfig config) + { + if (scheduledManualTrainingTask != null) + { + throw new IllegalStateException("Training already in progress for table " + keyspaceName + '.' + tableName); + } + + logger.info("Starting SSTable-based dictionary training for {}.{} from {} SSTables", + keyspaceName, tableName, sstables.size()); + + // Run the SSTableSamplingTask asynchronously + // Use a dummy scheduled task to track that training is in progress + SSTableSamplingTask task = new SSTableSamplingTask(sstables, trainer, config); + ScheduledExecutors.nonPeriodicTasks.submit(task); + + // Set a placeholder task so status checks know training is in progress + scheduledManualTrainingTask = ScheduledExecutors.scheduledTasks.schedule(() -> {}, 1, TimeUnit.HOURS); + } + + /** + * Cancels the in-progress manual training task. + */ + private void cancelManualTraining() + { + ScheduledFuture future = scheduledManualTrainingTask; + if (future != null) + { + future.cancel(false); + } + scheduledManualTrainingTask = null; + } + + /** + * Sets the enabled state of the scheduler. When disabled, refresh tasks will not execute. + * + * @param enabled whether the scheduler should be enabled + */ + @Override + public void setEnabled(boolean enabled) + { + this.isEnabled = enabled; + } + + /** + * Refreshes dictionary from system table and updates the cache. + * This method is called periodically by the scheduled refresh task. + */ + private void refreshDictionaryFromSystemTable() + { + try + { + if (!isEnabled) + { + return; + } + + CompressionDictionary dictionary = SystemDistributedKeyspace.retrieveLatestCompressionDictionary(keyspaceName, tableName); + cache.add(dictionary); + } + catch (Exception e) + { + logger.warn("Failed to refresh compression dictionary for {}.{}", + keyspaceName, tableName, e); + } + } + + @Override + public void close() + { + if (scheduledRefreshTask != null) + { + scheduledRefreshTask.cancel(false); + scheduledRefreshTask = null; + } + + if (scheduledManualTrainingTask != null) + { + scheduledManualTrainingTask.cancel(false); + scheduledManualTrainingTask = null; + } + } + + /** + * Task that samples chunks from existing SSTables and triggers training. + * Acquires references to SSTables to prevent them from being deleted during sampling. + */ + private class SSTableSamplingTask implements Runnable + { + private final Set sstables; + private final ICompressionDictionaryTrainer trainer; + private final CompressionDictionaryTrainingConfig config; + private final List> sstableRefs; + + private SSTableSamplingTask(Set sstables, + ICompressionDictionaryTrainer trainer, + CompressionDictionaryTrainingConfig config) + { + this.trainer = trainer; + this.config = config; + + // Acquire references to all SSTables to prevent deletion during sampling + this.sstableRefs = new ArrayList<>(); + Set referencedSSTables = new HashSet<>(); + + for (SSTableReader sstable : sstables) + { + Ref ref = sstable.tryRef(); + if (ref != null) + { + sstableRefs.add(ref); + referencedSSTables.add(sstable); + } + else + { + logger.debug("Couldn't acquire reference to SSTable {}. It may have been removed.", + sstable.descriptor); + } + } + + this.sstables = referencedSSTables; + } + + @Override + public void run() + { + try + { + if (sstables.isEmpty()) + { + logger.warn("No SSTables available for sampling in {}.{}", keyspaceName, tableName); + cancelManualTraining(); + return; + } + + logger.info("Sampling chunks from {} SSTables for {}.{}", + sstables.size(), keyspaceName, tableName); + + // Sample chunks from SSTables and add to trainer + SSTableChunkSampler.sampleFromSSTables(sstables, trainer, config); + + logger.info("Completed sampling for {}.{}, now training dictionary", + keyspaceName, tableName); + + // force=true for manual training + trainer.trainDictionaryAsync(true) + .addCallback((dictionary, throwable) -> { + cancelManualTraining(); + if (throwable != null) + { + logger.error("SSTable-based dictionary training failed for {}.{}: {}", + keyspaceName, tableName, throwable.getMessage()); + } + else + { + logger.info("SSTable-based dictionary training completed for {}.{}", + keyspaceName, tableName); + } + }); + } + catch (Exception e) + { + logger.error("Failed to sample from SSTables for {}.{}", keyspaceName, tableName, e); + cancelManualTraining(); + } + finally + { + // Release all SSTable references + for (Ref ref : sstableRefs) + { + ref.release(); + } + } + } + } + + @VisibleForTesting + ScheduledFuture scheduledManualTrainingTask() + { + return scheduledManualTrainingTask; + } +} diff --git a/src/java/org/apache/cassandra/db/compression/CompressionDictionaryTrainingConfig.java b/src/java/org/apache/cassandra/db/compression/CompressionDictionaryTrainingConfig.java new file mode 100644 index 000000000000..f580b0acfa1b --- /dev/null +++ b/src/java/org/apache/cassandra/db/compression/CompressionDictionaryTrainingConfig.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.db.compression; + +import com.google.common.base.Preconditions; + +/** + * Configuration for dictionary training parameters. + */ +public class CompressionDictionaryTrainingConfig +{ + public final int maxDictionarySize; + public final int maxTotalSampleSize; + public final int acceptableTotalSampleSize; + public final int samplingRate; + public final int chunkSize; + + private CompressionDictionaryTrainingConfig(Builder builder) + { + this.maxDictionarySize = builder.maxDictionarySize; + this.maxTotalSampleSize = builder.maxTotalSampleSize; + this.acceptableTotalSampleSize = builder.maxTotalSampleSize / 10 * 8; + this.samplingRate = builder.samplingRate; + this.chunkSize = builder.chunkSize; + } + + public static Builder builder() + { + return new Builder(); + } + + public static class Builder + { + private int maxDictionarySize = 65536; // 64KB default + private int maxTotalSampleSize = 10 * 1024 * 1024; // 10MB total + private int samplingRate = 100; // Sampling 1% + private int chunkSize = 64 * 1024; // 64KB default + + public Builder maxDictionarySize(int size) + { + this.maxDictionarySize = size; + return this; + } + + public Builder maxTotalSampleSize(int size) + { + this.maxTotalSampleSize = size; + return this; + } + + public Builder samplingRate(float samplingRate) + { + this.samplingRate = Math.round(1 / samplingRate); + return this; + } + + public Builder chunkSize(int chunkSize) + { + this.chunkSize = chunkSize; + return this; + } + + public CompressionDictionaryTrainingConfig build() + { + Preconditions.checkArgument(maxDictionarySize > 0, "maxDictionarySize must be positive"); + Preconditions.checkArgument(maxTotalSampleSize > 0, "maxTotalSampleSize must be positive"); + Preconditions.checkArgument(samplingRate > 0, "samplingRate must be positive"); + Preconditions.checkArgument(chunkSize > 0, "chunkSize must be positive"); + return new CompressionDictionaryTrainingConfig(this); + } + } +} diff --git a/src/java/org/apache/cassandra/db/compression/CompressionDictionaryUpdateMessage.java b/src/java/org/apache/cassandra/db/compression/CompressionDictionaryUpdateMessage.java new file mode 100644 index 000000000000..a52afdd9aea0 --- /dev/null +++ b/src/java/org/apache/cassandra/db/compression/CompressionDictionaryUpdateMessage.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.db.compression; + +import java.io.IOException; + +import org.apache.cassandra.db.compression.CompressionDictionary.DictId; +import org.apache.cassandra.io.IVersionedSerializer; +import org.apache.cassandra.io.util.DataInputPlus; +import org.apache.cassandra.io.util.DataOutputPlus; +import org.apache.cassandra.schema.TableId; + +public class CompressionDictionaryUpdateMessage +{ + public static final IVersionedSerializer serializer = new DictionaryUpdateMessageSerializer(); + + public final TableId tableId; + public final DictId dictionaryId; + + public CompressionDictionaryUpdateMessage(TableId tableId, DictId dictionaryId) + { + this.tableId = tableId; + this.dictionaryId = dictionaryId; + } + + public static class DictionaryUpdateMessageSerializer implements IVersionedSerializer + { + @Override + public void serialize(CompressionDictionaryUpdateMessage message, DataOutputPlus out, int version) throws IOException + { + TableId.serializer.serialize(message.tableId, out, version); + out.writeByte(message.dictionaryId.kind.ordinal()); + out.writeLong(message.dictionaryId.id); + } + + @Override + public CompressionDictionaryUpdateMessage deserialize(DataInputPlus in, int version) throws IOException + { + TableId tableId = TableId.serializer.deserialize(in, version); + int kindOrdinal = in.readByte(); + long dictionaryId = in.readLong(); + DictId dictId = new DictId(CompressionDictionary.Kind.values()[kindOrdinal], dictionaryId); + return new CompressionDictionaryUpdateMessage(tableId, dictId); + } + + @Override + public long serializedSize(CompressionDictionaryUpdateMessage message, int version) + { + return TableId.serializer.serializedSize(message.tableId, version) + + 1 + // byte for kind ordinal + 8; // long for dictionaryId + } + } +} diff --git a/src/java/org/apache/cassandra/db/compression/CompressionDictionaryUpdateVerbHandler.java b/src/java/org/apache/cassandra/db/compression/CompressionDictionaryUpdateVerbHandler.java new file mode 100644 index 000000000000..f595b173024d --- /dev/null +++ b/src/java/org/apache/cassandra/db/compression/CompressionDictionaryUpdateVerbHandler.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.db.compression; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.cassandra.db.ColumnFamilyStore; +import org.apache.cassandra.net.IVerbHandler; +import org.apache.cassandra.net.Message; +import org.apache.cassandra.schema.Schema; + +public class CompressionDictionaryUpdateVerbHandler implements IVerbHandler +{ + private static final Logger logger = LoggerFactory.getLogger(CompressionDictionaryUpdateVerbHandler.class); + public static final CompressionDictionaryUpdateVerbHandler instance = new CompressionDictionaryUpdateVerbHandler(); + + private CompressionDictionaryUpdateVerbHandler() {} + + @Override + public void doVerb(Message message) + { + CompressionDictionaryUpdateMessage payload = message.payload; + + try + { + ColumnFamilyStore cfs = Schema.instance.getColumnFamilyStoreInstance(payload.tableId); + if (cfs == null) + { + logger.warn("Received dictionary update for unknown table with tableId {}", payload.tableId); + return; + } + + logger.debug("Received dictionary update notification for {}.{} with dictionaryId {}", + cfs.keyspace, cfs.name, payload.dictionaryId); + CompressionDictionaryManager manager = cfs.compressionDictionaryManager(); + manager.onNewDictionaryAvailable(payload.dictionaryId); + } + catch (Exception e) + { + logger.error("Failed to process dictionary update notification for tableId {}", + payload.tableId, e); + } + } +} diff --git a/src/java/org/apache/cassandra/db/compression/ICompressionDictionaryCache.java b/src/java/org/apache/cassandra/db/compression/ICompressionDictionaryCache.java new file mode 100644 index 000000000000..c2d12caeb43b --- /dev/null +++ b/src/java/org/apache/cassandra/db/compression/ICompressionDictionaryCache.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.db.compression; + +import javax.annotation.Nullable; + +/** + * Interface for managing compression dictionary caching and current dictionary state. + *

+ * Implementations handle: + * - Local caching of compression dictionaries with automatic cleanup + * - Managing the current active dictionary for write operations + * - Thread-safe access to cached dictionaries + */ +public interface ICompressionDictionaryCache extends AutoCloseable +{ + /** + * Gets the current active compression dictionary. + * + * @return the current compression dictionary, or null if no dictionary is available + */ + @Nullable + CompressionDictionary getCurrent(); + + /** + * Retrieves a specific compression dictionary by its identifier. + * + * @param dictId the dictionary identifier to look up + * @return the compression dictionary with the given identifier, or null if not found in cache + */ + @Nullable + CompressionDictionary get(CompressionDictionary.DictId dictId); + + /** + * Stores a compression dictionary in the local cache and updates the current dictionary if the new one is newer. + * + * @param compressionDictionary the compression dictionary to cache, may be null + */ + void add(@Nullable CompressionDictionary compressionDictionary); +} diff --git a/src/java/org/apache/cassandra/db/compression/ICompressionDictionaryEventHandler.java b/src/java/org/apache/cassandra/db/compression/ICompressionDictionaryEventHandler.java new file mode 100644 index 000000000000..4ed2f8f2ad07 --- /dev/null +++ b/src/java/org/apache/cassandra/db/compression/ICompressionDictionaryEventHandler.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.db.compression; + +public interface ICompressionDictionaryEventHandler +{ + /** + * Invoked when a new dictionary is trained + * @param dictionaryId dictionary id + */ + void onNewDictionaryTrained(CompressionDictionary.DictId dictionaryId); + + /** + * Invoked when {@link CompressionDictionaryUpdateMessage} is received indicating + * a dictionary is trained and local node should retrieve the specified dictionary + * @param dictionaryId dictionary id + */ + void onNewDictionaryAvailable(CompressionDictionary.DictId dictionaryId); +} diff --git a/src/java/org/apache/cassandra/db/compression/ICompressionDictionaryScheduler.java b/src/java/org/apache/cassandra/db/compression/ICompressionDictionaryScheduler.java new file mode 100644 index 000000000000..96b1ef06d4a6 --- /dev/null +++ b/src/java/org/apache/cassandra/db/compression/ICompressionDictionaryScheduler.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.db.compression; + +import java.util.Set; + +import org.apache.cassandra.io.sstable.format.SSTableReader; + +/** + * Interface for managing scheduled tasks for compression dictionary operations. + *

+ * Implementations handle: + * - Periodic refresh of dictionaries from system tables + * - Manual training task scheduling and monitoring + * - Cleanup of scheduled tasks + */ +public interface ICompressionDictionaryScheduler extends AutoCloseable +{ + /** + * Schedules the periodic dictionary refresh task if not already scheduled. + */ + void scheduleRefreshTask(); + + /** + * Schedules SSTable-based training that samples from existing SSTables. + * + * @param trainer the trainer to use + * @param sstables the set of SSTables to sample from + * @param config the training configuration + * @throws IllegalStateException if training is already in progress + */ + void scheduleSSTableBasedTraining(ICompressionDictionaryTrainer trainer, + Set sstables, + CompressionDictionaryTrainingConfig config); + + /** + * Sets the enabled state of the scheduler. When disabled, refresh tasks will not execute. + * + * @param enabled whether the scheduler should be enabled + */ + void setEnabled(boolean enabled); +} diff --git a/src/java/org/apache/cassandra/db/compression/ICompressionDictionaryTrainer.java b/src/java/org/apache/cassandra/db/compression/ICompressionDictionaryTrainer.java new file mode 100644 index 000000000000..64810711ff73 --- /dev/null +++ b/src/java/org/apache/cassandra/db/compression/ICompressionDictionaryTrainer.java @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.db.compression; + +import java.nio.ByteBuffer; +import java.util.function.Consumer; + +import org.apache.cassandra.concurrent.ScheduledExecutors; +import org.apache.cassandra.utils.concurrent.Future; +import org.apache.cassandra.io.compress.ICompressor; +import org.apache.cassandra.io.compress.IDictionaryCompressor; +import org.apache.cassandra.schema.CompressionParams; + +/** + * Interface for training compression dictionaries from sample data. + *

+ * Implementations handle: + * - Sample collection and management + * - Dictionary training lifecycle + * - Asynchronous training execution + * - Training status tracking + */ +public interface ICompressionDictionaryTrainer extends AutoCloseable +{ + /** + * Starts the trainer for collecting samples. + * + * @param manualTraining true if this is manual training, false for automatic + * @return true if the trainer is started; otherwise false. The trainer is started + * in any of those conditions: 1. trainer closed; 2. not requested for + * either manual or auto training; 3. failed to start + */ + boolean start(boolean manualTraining); + + /** + * @return true if the trainer is ready to take a new sample; otherwise, false + */ + boolean shouldSample(); + + /** + * Adds a sample to the training dataset. + * + * @param sample the sample data to add for training + */ + void addSample(ByteBuffer sample); + + /** + * Trains and produces a compression dictionary from collected samples synchronously. + * + * @param force force the dictionary training even if there are not enough samples; + * otherwise, dictionary training won't start if the trainer is not ready + * @return the trained compression dictionary + */ + CompressionDictionary trainDictionary(boolean force); + + /** + * Trains and produces a compression dictionary from collected samples asynchronously. + * + * @param force force the dictionary training even if there are not enough samples + * @return Future that completes when training is done + */ + default Future trainDictionaryAsync(boolean force) + { + return ScheduledExecutors.nonPeriodicTasks.submit(() -> trainDictionary(force)); + } + + /** + * @return true if enough samples have been collected for training + */ + boolean isReady(); + + /** + * Clears all collected samples and resets trainer state. + */ + void reset(); + + /** + * Gets the current training state including status, progress, and failure details. + * + * @return the current training state as an atomic snapshot + */ + TrainingState getTrainingState(); + + /** + * @return the compression algorithm kind this trainer supports + */ + CompressionDictionary.Kind kind(); + + /** + * Determines if this trainer is compatible with the given compression parameters. + * This method allows the trainer to decide whether it can continue operating + * with new compression parameters or if a new trainer instance is needed. + * + * @param newParams the new compression parameters to check compatibility against + * @return true if this trainer is compatible with the new parameters, false otherwise + */ + boolean isCompatibleWith(CompressionParams newParams); + + /** + * Sets the listener for dictionary training events. + * + * @param listener the listener to be notified when dictionaries are trained, null to remove listener + */ + void setDictionaryTrainedListener(Consumer listener); + + /** + * Updates the sampling rate for this trainer. + * + * @param newSamplingRate the new sampling rate. For exmaple, 1 = sample every time (100%), + * 2 = expect sample 1/2 of data (50%), n = expect sample 1/n of data + */ + void updateSamplingRate(int newSamplingRate); + + /** + * Factory method to create appropriate trainer based on compression parameters. + * + * @param keyspaceName the keyspace name for logging + * @param tableName the table name for logging + * @param params the compression parameters + * @param config the training configuration + * @return a dictionary trainer for the specified compression algorithm + * @throws IllegalArgumentException if no dictionary trainer is available for the compression algorithm + */ + static ICompressionDictionaryTrainer create(String keyspaceName, + String tableName, + CompressionParams params, + CompressionDictionaryTrainingConfig config) + { + ICompressor compressor = params.getSstableCompressor(); + if (!(compressor instanceof IDictionaryCompressor)) + { + throw new IllegalArgumentException("Compressor does not support dictionary training: " + params.getSstableCompressor()); + } + + IDictionaryCompressor dictionaryCompressor = (IDictionaryCompressor) compressor; + return dictionaryCompressor.acceptableDictionaryKind().createTrainer(keyspaceName, tableName, config, compressor); + } + + enum TrainingStatus + { + NOT_STARTED, + SAMPLING, + TRAINING, + COMPLETED, + FAILED; + } +} diff --git a/src/java/org/apache/cassandra/db/compression/SSTableChunkSampler.java b/src/java/org/apache/cassandra/db/compression/SSTableChunkSampler.java new file mode 100644 index 000000000000..d65070062c23 --- /dev/null +++ b/src/java/org/apache/cassandra/db/compression/SSTableChunkSampler.java @@ -0,0 +1,443 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.db.compression; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.concurrent.ThreadLocalRandom; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.cassandra.db.compression.ICompressionDictionaryTrainer.TrainingStatus; +import org.apache.cassandra.io.compress.CompressionMetadata; +import org.apache.cassandra.io.compress.ICompressor; +import org.apache.cassandra.io.sstable.format.SSTableReader; +import org.apache.cassandra.io.util.ChannelProxy; +import org.apache.cassandra.utils.ChecksumType; + +/** + * Samples uncompressed chunks from existing SSTables for dictionary training. + * Uses random sampling to locate the chunk offsets to avoid sequential scanning while ensuring representative samples. + * Supports both compressed and uncompressed SSTables. + */ +public class SSTableChunkSampler +{ + private static final Logger logger = LoggerFactory.getLogger(SSTableChunkSampler.class); + + /** + * Information about an SSTable and its chunks for sampling. + */ + static class SSTableChunkInfo + { + final SSTableReader sstable; + final CompressionMetadata metadata; // null for uncompressed + final long chunkCount; + final long dataLength; + final int chunkSize; + final boolean isCompressed; + + SSTableChunkInfo(SSTableReader sstable, CompressionDictionaryTrainingConfig config) + { + this.sstable = sstable; + this.isCompressed = sstable.compression; + + if (isCompressed) + { + this.metadata = sstable.getCompressionMetadata(); + this.dataLength = metadata.dataLength; + this.chunkSize = metadata.chunkLength(); + // Use the logical chunk count from metadata (each offset is 8 bytes) + this.chunkCount = metadata.chunkOffsetsSize >> 3; + } + else + { + this.metadata = null; + this.dataLength = sstable.uncompressedLength(); + this.chunkSize = config.chunkSize; + // Calculate number of chunks for uncompressed: dataLength divided by chunkSize, rounded up + this.chunkCount = (dataLength + chunkSize - 1) / chunkSize; + } + } + } + + /** + * Samples chunks from existing SSTables and adds them to the trainer. + * Uses two-level sampling to avoid memory issues with large datasets: + * 1. Select SSTables (potentially all, weighted by size) + * 2. For each SSTable, randomly select specific chunks to sample + * + * @param sstables the set of SSTables to sample from + * @param trainer the trainer to add samples to + * @param config the training configuration with sample size limits + */ + public static void sampleFromSSTables(Set sstables, + ICompressionDictionaryTrainer trainer, + CompressionDictionaryTrainingConfig config) throws IOException + { + if (sstables.isEmpty()) + { + throw new IllegalArgumentException("No SSTables provided for sampling"); + } + + TrainingStatus status = trainer.getTrainingState().status; + if (status != TrainingStatus.SAMPLING) + { + throw new IllegalStateException("Trainer is not ready to accept samples. Current status: " + status); + } + + // Build metadata for all SSTables + List sstableInfos = buildSSTableInfos(sstables, config); + long totalChunks = sstableInfos.stream().mapToLong(info -> info.chunkCount).sum(); + + // Calculate how many chunks to sample + long targetChunkCount = calculateTargetChunkCount(sstableInfos, totalChunks, config); + + logger.debug("Target chunk count for sampling: {} (max sample size: {} bytes)", + targetChunkCount, config.maxTotalSampleSize); + + // Sample chunks from each SSTable + SamplingStats stats = sampleChunksFromSSTables(sstableInfos, totalChunks, targetChunkCount, trainer, config); + + logger.info("Completed sampling: {} chunks, total size: {} bytes", stats.sampleCount, stats.totalSampleSize); + } + + /** + * Builds SSTableChunkInfo objects for all SSTables and logs statistics. + */ + static List buildSSTableInfos(Set sstables, + CompressionDictionaryTrainingConfig config) + { + List sstableInfos = new ArrayList<>(); + long totalChunks = 0; + int compressedCount = 0; + int uncompressedCount = 0; + + for (SSTableReader sstable : sstables) + { + SSTableChunkInfo info = new SSTableChunkInfo(sstable, config); + sstableInfos.add(info); + totalChunks += info.chunkCount; + + if (info.isCompressed) + compressedCount++; + else + uncompressedCount++; + } + + logger.info("Sampling from {} SSTables ({} compressed, {} uncompressed) with {} total chunks", + sstableInfos.size(), compressedCount, uncompressedCount, totalChunks); + + return sstableInfos; + } + + /** + * Calculates the target number of chunks to sample based on available data and constraints. + */ + static long calculateTargetChunkCount(List sstableInfos, + long totalChunks, + CompressionDictionaryTrainingConfig config) + { + long totalDataSize = sstableInfos.stream().mapToLong(info -> info.dataLength).sum(); + int averageChunkSize = totalDataSize > 0 ? (int) (totalDataSize / totalChunks) : config.chunkSize; + return config.maxTotalSampleSize / averageChunkSize; + } + + /** + * Result of sampling operation containing statistics. + */ + static class SamplingStats + { + final long sampleCount; + final long totalSampleSize; + + SamplingStats(long sampleCount, long totalSampleSize) + { + this.sampleCount = sampleCount; + this.totalSampleSize = totalSampleSize; + } + } + + /** + * Samples chunks from all SSTables proportionally to their size. + * Each SSTable contributes samples in proportion to its chunk count relative to the total. + * Stops early if either the target chunk count or max total sample size limit is reached. + *

+ * For example, + *

+     * Given:
+     *   - SSTable A: 40 chunks, chunkSize=64KB (40% of total)
+     *   - SSTable B: 60 chunks, chunkSize=64KB (60% of total)
+     *   - Target chunk count: 100
+     *   - Max total sample size: 5MB
+     *
+     * Result:
+     *   - Sample 32 chunks from A (5MiB / 64KiB * 0.4 = 32 chunks)
+     *   - Sample 48 chunks from B (5MiB / 64KiB * 0.6 = 48 chunks)
+     *   - Total sampled: 80 chunks (stopped due to size limit, not target)
+     * 
+ */ + static SamplingStats sampleChunksFromSSTables(List sstableInfos, + long totalChunks, + long targetChunkCount, + ICompressionDictionaryTrainer trainer, + CompressionDictionaryTrainingConfig config) throws IOException + { + long totalSampleSize = 0; + long sampleCount = 0; + + for (SSTableChunkInfo info : sstableInfos) + { + if (sampleCount >= targetChunkCount || totalSampleSize >= config.maxTotalSampleSize) + { + break; + } + + // Calculate how many chunks to sample from this SSTable (proportional to its size) + long remainingTarget = Math.min(targetChunkCount - sampleCount, (config.maxTotalSampleSize - totalSampleSize) / info.chunkSize); + long chunksFromThisSSTable = Math.min((targetChunkCount * info.chunkCount) / totalChunks, remainingTarget); + + if (chunksFromThisSSTable <= 0) + { + continue; + } + + // Sample chunks from this SSTable + SamplingStats sstableStats = sampleChunksFromSSTable(info, chunksFromThisSSTable, trainer, config); + totalSampleSize += sstableStats.totalSampleSize; + sampleCount += sstableStats.sampleCount; + + if (sampleCount % 100 == 0) + { + logger.debug("Sampled {} chunks, total size: {} bytes", sampleCount, totalSampleSize); + } + } + + return new SamplingStats(sampleCount, totalSampleSize); + } + + /** + * Samples a specified number of chunks from a single SSTable. + */ + static SamplingStats sampleChunksFromSSTable(SSTableChunkInfo info, + long chunksToSample, + ICompressionDictionaryTrainer trainer, + CompressionDictionaryTrainingConfig config) throws IOException + { + long totalSampleSize = 0; + long sampleCount = 0; + + // Generate random chunk indices for this SSTable (without building full list) + Set selectedIndices = selectRandomChunkIndices(info.chunkCount, chunksToSample); + + // Sample the selected chunks + for (long chunkIndex : selectedIndices) + { + if (totalSampleSize >= config.maxTotalSampleSize) + { + logger.debug("Reached max total sample size limit"); + break; + } + + long position = chunkIndex * info.chunkSize; + ByteBuffer chunk = readChunk(info, position); + + // Check if adding this sample would exceed the max total sample size + if (totalSampleSize + chunk.remaining() > config.maxTotalSampleSize) + { + logger.debug("Next chunk would exceed max total sample size limit"); + break; + } + + trainer.addSample(chunk); + totalSampleSize += chunk.remaining(); + sampleCount++; + } + + return new SamplingStats(sampleCount, totalSampleSize); + } + + /** + * Selects random chunk indices. + * + * @param totalChunks the total number of chunks available + * @param count the number of chunks to select + * @return set of randomly selected chunk indices + */ + static Set selectRandomChunkIndices(long totalChunks, long count) + { + // If we need to sample more than half, it's more efficient to select what to exclude + if (count > totalChunks / 2) + { + long excludeCount = totalChunks - count; + Set toExclude = floydRandomSampling(totalChunks, excludeCount); + Set selected = new HashSet<>(); + // Add all indices except those in toExclude + for (long i = 0; i < totalChunks; i++) + { + if (!toExclude.contains(i)) + { + selected.add(i); + } + } + return selected; + } + else + { + return floydRandomSampling(totalChunks, count); + } + } + + /** + * Floyd's algorithm for random sampling without replacement. + * Efficiently selects a random subset by iterating only through the sample size, not the total population. + * Guarantees no duplication. + * + * @param total the total number of items available + * @param samples the number of items to select + * @return set of randomly selected indices + * @see Floyd's Sampling Algorithm + */ + static Set floydRandomSampling(long total, long samples) + { + Set set = new HashSet<>(); + long requested = Math.min(total, samples); + for (long i = total - requested; i < total; i++) + { + long randomIndex = ThreadLocalRandom.current().nextLong(i + 1); + if (!set.add(randomIndex)) + { + set.add(i); + } + } + return set; + } + + /** + * Reads a chunk from an SSTable at the given position. + * Handles both compressed and uncompressed SSTables. + * + * @param sstableInfo the SSTable info + * @param position the position to read from + * @return the chunk data (uncompressed if source was compressed) + * @throws IOException if reading or decompression fails + */ + static ByteBuffer readChunk(SSTableChunkInfo sstableInfo, long position) throws IOException + { + if (sstableInfo.isCompressed) + { + return readAndDecompressChunk(sstableInfo, position); + } + else + { + return readUncompressedChunk(sstableInfo, position); + } + } + + /** + * Reads and decompresses a single chunk from a compressed SSTable. + * + * @param sstableInfo the SSTable info + * @param position the uncompressed position (will be mapped to chunk) + * @return the uncompressed chunk data + * @throws IOException if reading or decompression fails + */ + static ByteBuffer readAndDecompressChunk(SSTableChunkInfo sstableInfo, long position) throws IOException + { + CompressionMetadata metadata = sstableInfo.metadata; + CompressionMetadata.Chunk chunk = metadata.chunkFor(position); + + // Read the compressed chunk from disk + ChannelProxy channel = sstableInfo.sstable.getDataChannel(); + + // Allocate buffer for compressed data + checksum + int compressedLength = chunk.length; + ByteBuffer compressed = ByteBuffer.allocateDirect(compressedLength + Integer.BYTES); + + int read = channel.read(compressed, chunk.offset); + if (read != compressedLength + Integer.BYTES) + { + throw new IOException(String.format("Expected to read %d bytes but got %d", + compressedLength + Integer.BYTES, read)); + } + + compressed.flip(); + compressed.limit(compressedLength); + + // Verify checksum + int expectedChecksum = (int) ChecksumType.CRC32.of(compressed); + compressed.limit(compressedLength + Integer.BYTES); + int actualChecksum = compressed.getInt(compressedLength); + + if (expectedChecksum != actualChecksum) + { + throw new IOException(String.format("Checksum mismatch for chunk at position %d in SSTable %s (expected: %d, actual: %d)", + position, sstableInfo.sstable, expectedChecksum, actualChecksum)); + } + + // Reset for decompression + compressed.position(0).limit(compressedLength); + + // Decompress the chunk + ICompressor compressor = metadata.compressor(); + ByteBuffer uncompressed = ByteBuffer.allocateDirect(metadata.chunkLength()); + + compressor.uncompress(compressed, uncompressed); + uncompressed.flip(); + return uncompressed; + } + + /** + * Reads a chunk directly from an uncompressed SSTable. + * + * @param sstableInfo the SSTable info + * @param position the position to read from + * @return the chunk data + * @throws IOException if reading fails + */ + static ByteBuffer readUncompressedChunk(SSTableChunkInfo sstableInfo, long position) throws IOException + { + ChannelProxy channel = sstableInfo.sstable.getDataChannel(); + + // Calculate how much to read (might be less than chunkSize at end of file) + long remainingData = sstableInfo.dataLength - position; + int readSize = (int) Math.min(sstableInfo.chunkSize, remainingData); + + if (readSize <= 0) + { + throw new IOException(String.format("Invalid read size %d at position %d (dataLength: %d) for SSTable %s", + readSize, position, sstableInfo.dataLength, sstableInfo.sstable)); + } + + ByteBuffer buffer = ByteBuffer.allocateDirect(readSize); + int read = channel.read(buffer, position); + + if (read != readSize) + { + throw new IOException(String.format("Expected to read %d bytes but got %d", readSize, read)); + } + + buffer.flip(); + return buffer; + } +} diff --git a/src/java/org/apache/cassandra/db/compression/TrainingState.java b/src/java/org/apache/cassandra/db/compression/TrainingState.java new file mode 100644 index 000000000000..279683c4f2f8 --- /dev/null +++ b/src/java/org/apache/cassandra/db/compression/TrainingState.java @@ -0,0 +1,218 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.db.compression; + +import java.util.HashMap; +import java.util.Map; +import javax.management.openmbean.CompositeData; +import javax.management.openmbean.CompositeDataSupport; +import javax.management.openmbean.CompositeType; +import javax.management.openmbean.OpenDataException; +import javax.management.openmbean.OpenType; +import javax.management.openmbean.SimpleType; + +import org.apache.cassandra.db.compression.ICompressionDictionaryTrainer.TrainingStatus; + +/** + * Represents the current state of compression dictionary training. + * This class encapsulates training status, progress information, and failure details in a single snapshot. + */ +public class TrainingState +{ + public final TrainingStatus status; + public final String failureMessage; // null unless status is FAILED + public final long sampleCount; + public final long totalSampleSize; + + // JMX CompositeData support + private static final String[] ITEM_NAMES = new String[]{ "status", + "failure_message", + "sample_count", + "total_sample_size" }; + + private static final String[] ITEM_DESC = new String[]{ "current training status", + "failure message if training failed, null otherwise", + "number of samples collected", + "total size of samples collected in bytes" }; + + private static final OpenType[] ITEM_TYPES; + + public static final CompositeType COMPOSITE_TYPE; + + static + { + try + { + ITEM_TYPES = new OpenType[]{ SimpleType.STRING, + SimpleType.STRING, + SimpleType.LONG, + SimpleType.LONG }; + + COMPOSITE_TYPE = new CompositeType(TrainingState.class.getName(), + "TrainingState", + ITEM_NAMES, + ITEM_DESC, + ITEM_TYPES); + } + catch (OpenDataException e) + { + throw new RuntimeException(e); + } + } + + private TrainingState(TrainingStatus status, String failureMessage, long sampleCount, long totalSampleSize) + { + this.status = status; + this.failureMessage = failureMessage; + this.sampleCount = sampleCount; + this.totalSampleSize = totalSampleSize; + } + + // Factory methods for clarity and type safety + public static TrainingState notStarted() + { + return new TrainingState(TrainingStatus.NOT_STARTED, null, 0, 0); + } + + public static TrainingState sampling(long samples, long totalSize) + { + return new TrainingState(TrainingStatus.SAMPLING, null, samples, totalSize); + } + + public static TrainingState training(long samples, long totalSize) + { + return new TrainingState(TrainingStatus.TRAINING, null, samples, totalSize); + } + + public static TrainingState completed(long samples, long totalSize) + { + return new TrainingState(TrainingStatus.COMPLETED, null, samples, totalSize); + } + + public static TrainingState failed(String message, long samples, long totalSize) + { + return new TrainingState(TrainingStatus.FAILED, message, samples, totalSize); + } + + public boolean isFailed() + { + return status == TrainingStatus.FAILED; + } + + public boolean isCompleted() + { + return status == TrainingStatus.COMPLETED; + } + + public TrainingStatus getStatus() + { + return status; + } + + public long getSampleCount() + { + return sampleCount; + } + + public long getTotalSampleSize() + { + return totalSampleSize; + } + + // returns null unless status is FAILED + public String getFailureMessage() + { + return isFailed() ? failureMessage : null; + } + + // JMX CompositeData conversion methods + + /** + * Converts this TrainingState to JMX CompositeData format. + * + * @return CompositeData representation of this training state + */ + public CompositeData toCompositeData() + { + Map valueMap = new HashMap<>(); + valueMap.put(ITEM_NAMES[0], status.toString()); + valueMap.put(ITEM_NAMES[1], getFailureMessage()); + valueMap.put(ITEM_NAMES[2], sampleCount); + valueMap.put(ITEM_NAMES[3], totalSampleSize); + + try + { + return new CompositeDataSupport(COMPOSITE_TYPE, valueMap); + } + catch (final OpenDataException e) + { + throw new RuntimeException(e); + } + } + + /** + * Converts JMX CompositeData back to TrainingState. + * + * @param data the CompositeData to convert + * @return TrainingState reconstructed from the CompositeData + */ + public static TrainingState fromCompositeData(final CompositeData data) + { + assert data.getCompositeType().equals(COMPOSITE_TYPE); + + final Object[] values = data.getAll(ITEM_NAMES); + + TrainingStatus status = TrainingStatus.valueOf((String) values[0]); + String failureMessage = (String) values[1]; + long sampleCount = (Long) values[2]; + long totalSampleSize = (Long) values[3]; + + // Reconstruct TrainingState based on status + switch (status) + { + case NOT_STARTED: + return TrainingState.notStarted(); + case SAMPLING: + return TrainingState.sampling(sampleCount, totalSampleSize); + case TRAINING: + return TrainingState.training(sampleCount, totalSampleSize); + case COMPLETED: + return TrainingState.completed(sampleCount, totalSampleSize); + case FAILED: + return TrainingState.failed(failureMessage, sampleCount, totalSampleSize); + default: + throw new IllegalStateException("Unknown training status: " + status); + } + } + + @Override + public String toString() + { + StringBuilder sb = new StringBuilder("TrainingState{status="); + sb.append(status); + sb.append(", samples=").append(sampleCount); + sb.append(", totalSize=").append(totalSampleSize); + if (isFailed() && failureMessage != null) + { + sb.append(", failure='").append(failureMessage).append('\''); + } + sb.append('}'); + return sb.toString(); + } +} diff --git a/src/java/org/apache/cassandra/db/compression/ZstdCompressionDictionary.java b/src/java/org/apache/cassandra/db/compression/ZstdCompressionDictionary.java new file mode 100644 index 000000000000..a30d011ade96 --- /dev/null +++ b/src/java/org/apache/cassandra/db/compression/ZstdCompressionDictionary.java @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.db.compression; + +import java.util.Objects; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicBoolean; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.github.luben.zstd.ZstdDictCompress; +import com.github.luben.zstd.ZstdDictDecompress; +import org.apache.cassandra.io.compress.ZstdCompressorBase; +import org.apache.cassandra.utils.concurrent.Ref; +import org.apache.cassandra.utils.concurrent.RefCounted; +import org.apache.cassandra.utils.concurrent.SelfRefCounted; + +public class ZstdCompressionDictionary implements CompressionDictionary, SelfRefCounted +{ + private static final Logger logger = LoggerFactory.getLogger(ZstdCompressionDictionary.class); + + private final DictId dictId; + private final byte[] rawDictionary; + // One ZstdDictDecompress and multiple ZstdDictCompress (per level) can be derived from the same raw dictionary content + private final ConcurrentHashMap zstdDictCompressPerLevel = new ConcurrentHashMap<>(); + private volatile ZstdDictDecompress dictDecompress; + private final AtomicBoolean closed = new AtomicBoolean(false); + private final Ref selfRef; + + public ZstdCompressionDictionary(DictId dictId, byte[] rawDictionary) + { + this.dictId = dictId; + this.rawDictionary = rawDictionary; + this.selfRef = new Ref<>(this, new Tidy(zstdDictCompressPerLevel, dictDecompress)); + } + + @Override + public DictId dictId() + { + return dictId; + } + + @Override + public Kind kind() + { + return Kind.ZSTD; + } + + @Override + public byte[] rawDictionary() + { + return rawDictionary; + } + + @Override + public boolean equals(Object o) + { + if (!(o instanceof ZstdCompressionDictionary)) return false; + ZstdCompressionDictionary that = (ZstdCompressionDictionary) o; + return Objects.equals(dictId, that.dictId); + } + + @Override + public int hashCode() + { + return dictId.hashCode(); + } + + /** + * Get a pre-processed compression tables that is optimized for compression. + * It is derived/computed from dictionary bytes. + * The internal data structure is different from the tables for decompression. + * + * @param compressionLevel compression level to create the compression table + * @return ZstdDictCompress + */ + public ZstdDictCompress dictionaryForCompression(int compressionLevel) + { + if (closed.get()) + throw new IllegalStateException("Dictionary has been closed. " + dictId); + + ZstdCompressorBase.validateCompressionLevel(compressionLevel); + + return zstdDictCompressPerLevel.computeIfAbsent(compressionLevel, level -> { + if (closed.get()) + throw new IllegalStateException("Dictionary has been closed"); + return new ZstdDictCompress(rawDictionary, level); + }); + } + + /** + * Get a pre-processed decompression tables that is optimized for decompression. + * It is derived/computed from dictionary bytes. + * The internal data structure is different from the tables for compression. + * + * @return ZstdDictDecompress + */ + public ZstdDictDecompress dictionaryForDecompression() + { + if (closed.get()) + throw new IllegalStateException("Dictionary has been closed"); + + ZstdDictDecompress result = dictDecompress; + if (result != null) + return result; + + synchronized (this) + { + if (closed.get()) + throw new IllegalStateException("Dictionary has been closed"); + + result = dictDecompress; + if (result == null) + { + result = new ZstdDictDecompress(rawDictionary); + dictDecompress = result; + } + return result; + } + } + + @Override + public Ref tryRef() + { + return selfRef.tryRef(); + } + + @Override + public Ref selfRef() + { + return selfRef; + } + + @Override + public Ref ref() + { + return selfRef.ref(); + } + + @Override + public void close() + { + if (closed.compareAndSet(false, true)) + { + selfRef.release(); + } + } + + private static class Tidy implements RefCounted.Tidy + { + private final ConcurrentHashMap zstdDictCompressPerLevel; + private volatile ZstdDictDecompress dictDecompress; + + Tidy(ConcurrentHashMap zstdDictCompressPerLevel, ZstdDictDecompress dictDecompress) + { + this.zstdDictCompressPerLevel = zstdDictCompressPerLevel; + this.dictDecompress = dictDecompress; + } + + @Override + public void tidy() + { + // Close all compression dictionaries + for (ZstdDictCompress compressDict : zstdDictCompressPerLevel.values()) + { + try + { + compressDict.close(); + } + catch (Exception e) + { + // Log but don't fail - continue closing other resources + logger.warn("Failed to close ZstdDictCompress", e); + } + } + zstdDictCompressPerLevel.clear(); + + // Close decompression dictionary + ZstdDictDecompress decompressDict = dictDecompress; + if (decompressDict != null) + { + try + { + decompressDict.close(); + } + catch (Exception e) + { + logger.warn("Failed to close ZstdDictDecompress", e); + } + dictDecompress = null; + } + } + + @Override + public String name() + { + return ZstdCompressionDictionary.class.getSimpleName(); + } + } +} diff --git a/src/java/org/apache/cassandra/db/compression/ZstdDictionaryTrainer.java b/src/java/org/apache/cassandra/db/compression/ZstdDictionaryTrainer.java new file mode 100644 index 000000000000..2b74903f80aa --- /dev/null +++ b/src/java/org/apache/cassandra/db/compression/ZstdDictionaryTrainer.java @@ -0,0 +1,352 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.db.compression; + +import java.nio.ByteBuffer; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Consumer; + +import com.google.common.annotations.VisibleForTesting; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.github.luben.zstd.Zstd; +import com.github.luben.zstd.ZstdDictTrainer; +import org.apache.cassandra.config.DatabaseDescriptor; +import org.apache.cassandra.db.compression.CompressionDictionary.DictId; +import org.apache.cassandra.db.compression.CompressionDictionary.Kind; +import org.apache.cassandra.io.compress.IDictionaryCompressor; +import org.apache.cassandra.io.compress.ZstdDictionaryCompressor; +import org.apache.cassandra.schema.CompressionParams; +import org.apache.cassandra.utils.Clock; + +/** + * Zstd implementation of dictionary trainer with lifecycle management. + */ +public class ZstdDictionaryTrainer implements ICompressionDictionaryTrainer +{ + private static final Logger logger = LoggerFactory.getLogger(ZstdDictionaryTrainer.class); + + private final String keyspaceName; + private final String tableName; + private final CompressionDictionaryTrainingConfig config; + private final AtomicLong totalSampleSize; + private final AtomicLong sampleCount; + private final int compressionLevel; // optimal if using the same level for training as when compressing. + + // Sampling rate can be updated during training + private volatile int samplingRate; + + // Minimum number of samples required by ZSTD library + private static final int MIN_SAMPLES_REQUIRED = 10; + + private volatile Consumer dictionaryTrainedListener; + // TODO: manage the samples in this class for auto-train (follow-up). The ZstdDictTrainer cannot be re-used for multiple training runs. + private ZstdDictTrainer zstdTrainer; + private volatile boolean closed = false; + private volatile TrainingStatus currentTrainingStatus; + private volatile String failureMessage; + + public ZstdDictionaryTrainer(String keyspaceName, String tableName, + CompressionDictionaryTrainingConfig config, + int compressionLevel) + { + this.keyspaceName = keyspaceName; + this.tableName = tableName; + this.config = config; + this.totalSampleSize = new AtomicLong(0); + this.sampleCount = new AtomicLong(0); + this.compressionLevel = compressionLevel; + this.samplingRate = config.samplingRate; + this.currentTrainingStatus = TrainingStatus.NOT_STARTED; + } + + @Override + public boolean shouldSample() + { + return zstdTrainer != null && ThreadLocalRandom.current().nextInt(samplingRate) == 0; + } + + @Override + public void addSample(ByteBuffer sample) + { + if (closed || sample == null || !sample.hasRemaining() || zstdTrainer == null) + return; + + byte[] sampleBytes = new byte[sample.remaining()]; + sample.duplicate().get(sampleBytes); + + if (zstdTrainer.addSample(sampleBytes)) + { + // Update the totalSampleSize and sampleCount if the sample is added + totalSampleSize.addAndGet(sampleBytes.length); + sampleCount.incrementAndGet(); + } + } + + @Override + public CompressionDictionary trainDictionary(boolean force) + { + boolean isReady = isReady(); + if (!force && !isReady) + { + failureMessage = "Trainer is not ready"; + currentTrainingStatus = TrainingStatus.FAILED; + throw new IllegalStateException(failureMessage); + } + + long currentSampleCount = sampleCount.get(); + if (currentSampleCount < MIN_SAMPLES_REQUIRED) // minimum samples should be required even if force training + { + failureMessage = String.format("Insufficient samples for training: %d (minimum required: %d)", + currentSampleCount, MIN_SAMPLES_REQUIRED); + currentTrainingStatus = TrainingStatus.FAILED; + throw new IllegalStateException(failureMessage); + } + + currentTrainingStatus = TrainingStatus.TRAINING; + failureMessage = null; // Clear any previous failure message + try + { + logger.debug("Training with sample count: {}, sample size: {}, isReady: {}", + currentSampleCount, totalSampleSize.get(), isReady); + byte[] dictBytes = zstdTrainer.trainSamples(); + long zstdDictId = Zstd.getDictIdFromDict(dictBytes); + DictId dictId = new DictId(Kind.ZSTD, makeDictionaryId(Clock.Global.currentTimeMillis(), zstdDictId)); + currentTrainingStatus = TrainingStatus.COMPLETED; + logger.debug("New dictionary is trained with {}", dictId); + CompressionDictionary dictionary = new ZstdCompressionDictionary(dictId, dictBytes); + notifyDictionaryTrainedListener(dictionary); + return dictionary; + } + catch (Exception e) + { + failureMessage = "Failed to train Zstd dictionary: " + e.getMessage(); + currentTrainingStatus = TrainingStatus.FAILED; + throw new RuntimeException(failureMessage, e); + } + } + + @Override + public boolean isReady() + { + return currentTrainingStatus != TrainingStatus.TRAINING + && !closed + && zstdTrainer != null + && totalSampleSize.get() >= config.acceptableTotalSampleSize + && sampleCount.get() > MIN_SAMPLES_REQUIRED; + } + + @Override + public TrainingState getTrainingState() + { + long currentSampleCount = sampleCount.get(); + long currentTotalSampleSize = totalSampleSize.get(); + + switch (currentTrainingStatus) + { + case NOT_STARTED: + return TrainingState.notStarted(); + case SAMPLING: + return TrainingState.sampling(currentSampleCount, currentTotalSampleSize); + case TRAINING: + return TrainingState.training(currentSampleCount, currentTotalSampleSize); + case COMPLETED: + return TrainingState.completed(currentSampleCount, currentTotalSampleSize); + case FAILED: + return TrainingState.failed(failureMessage, currentSampleCount, currentTotalSampleSize); + default: + throw new IllegalStateException("Unknown training status: " + currentTrainingStatus); + } + } + + @Override + public boolean start(boolean manualTraining) + { + if (closed || !(manualTraining || shouldAutoStartTraining())) + return false; + + try + { + // reset on starting; a new zstdTrainer instance is created during reset + reset(); + logger.info("Started dictionary training for {}.{}", keyspaceName, tableName); + currentTrainingStatus = TrainingStatus.SAMPLING; + failureMessage = null; // Clear any previous failure message + return true; + } + catch (Exception e) + { + logger.warn("Failed to create ZstdDictTrainer for {}.{}", keyspaceName, tableName, e); + failureMessage = "Failed to create ZstdDictTrainer: " + e.getMessage(); + currentTrainingStatus = TrainingStatus.FAILED; + } + return false; + } + + /** + * Determines if training should auto-start based on configuration. + */ + private boolean shouldAutoStartTraining() + { + return DatabaseDescriptor.getCompressionDictionaryTrainingAutoTrainEnabled(); + } + + @Override + public void reset() + { + if (closed) + { + return; + } + + currentTrainingStatus = TrainingStatus.NOT_STARTED; + synchronized (this) + { + totalSampleSize.set(0); + sampleCount.set(0); + zstdTrainer = new ZstdDictTrainer(config.maxTotalSampleSize, config.maxDictionarySize, compressionLevel); + } + } + + @Override + public Kind kind() + { + return Kind.ZSTD; + } + + @Override + public void setDictionaryTrainedListener(Consumer listener) + { + this.dictionaryTrainedListener = listener; + } + + @Override + public void updateSamplingRate(int newSamplingRate) + { + if (newSamplingRate <= 0) + { + throw new IllegalArgumentException("Sampling rate must be positive, got: " + newSamplingRate); + } + this.samplingRate = newSamplingRate; + logger.debug("Updated sampling rate to {} for {}.{}", newSamplingRate, keyspaceName, tableName); + } + + /** + * Notifies the registered listener that a dictionary has been trained. + * + * @param dictionary the newly trained dictionary + */ + private void notifyDictionaryTrainedListener(CompressionDictionary dictionary) + { + Consumer listener = this.dictionaryTrainedListener; + if (listener != null) + { + try + { + listener.accept(dictionary); + } + catch (Exception e) + { + logger.warn("Error notifying dictionary trained listener for {}.{}", keyspaceName, tableName, e); + } + } + } + + @Override + public boolean isCompatibleWith(CompressionParams newParams) + { + if (!newParams.isDictionaryCompressionEnabled()) + { + return false; + } + + IDictionaryCompressor newCompressor = (IDictionaryCompressor) newParams.getSstableCompressor(); + + // Check if the compressor type is compatible with this trainer + if (newCompressor.acceptableDictionaryKind() != Kind.ZSTD) + { + return false; + } + + ZstdDictionaryCompressor zstdDictionaryCompressor = (ZstdDictionaryCompressor) newCompressor; + // For Zstd compressors, check if compression level matches + return this.compressionLevel == zstdDictionaryCompressor.compressionLevel(); + } + + @Override + public void close() + { + if (closed) + return; + + closed = true; + currentTrainingStatus = TrainingStatus.NOT_STARTED; + + synchronized (this) + { + // Permanent shutdown: clear all state and prevent restart + totalSampleSize.set(0); + sampleCount.set(0); + zstdTrainer = null; + } + + logger.info("Permanently closed dictionary trainer for {}.{}", keyspaceName, tableName); + } + + /** + * Creates a monotonically increasing dictionary ID by combining timestamp and dictionary ID. + *

+ * The resulting dictionary ID has the following structure: + * - Upper 32 bits: timestamp in minutes (signed int) + * - Lower 32 bits: Zstd dictionary ID (unsigned int, passed as long due to Java limitations) + *

+ * This ensures dictionary IDs are monotonically increasing over time, which helps to identify + * the latest dictionary. + *

+ * The implementation assumes that dictionary training frequency is significantly larger than + * every minute, which a healthy system should do. In the scenario when multiple dictionaries + * are trained in the same minute (only possible using manual training), there should not be + * correctness concerns since the dictionary is attached to the SSTables, but leads to performance + * hit from having too many dictionary. Therefore, such scenario should be avoided at the best. + * + * @param currentTimeMillis the current time in milliseconds + * @param dictId dictionary ID (unsigned 32-bit value represented as long) + * @return combined dictionary ID that is monotonically increasing over time + */ + static long makeDictionaryId(long currentTimeMillis, long dictId) + { + // timestamp in minutes since Unix epoch. Good until year 6053 + long timestampMinutes = currentTimeMillis / 1000 / 60; + // Convert timestamp to long and shift to upper 32 bits + long combined = timestampMinutes << 32; + + // Add the unsigned int (already as long) to lower 32 bits + combined |= (dictId & 0xFFFFFFFFL); + + return combined; + } + + @VisibleForTesting + Object trainer() + { + return zstdTrainer; + } +} diff --git a/src/java/org/apache/cassandra/io/compress/CompressedSequentialWriter.java b/src/java/org/apache/cassandra/io/compress/CompressedSequentialWriter.java index ea6448f8182c..1f4cd517b021 100644 --- a/src/java/org/apache/cassandra/io/compress/CompressedSequentialWriter.java +++ b/src/java/org/apache/cassandra/io/compress/CompressedSequentialWriter.java @@ -24,7 +24,10 @@ import java.nio.channels.Channels; import java.util.Optional; import java.util.zip.CRC32; +import javax.annotation.Nullable; +import org.apache.cassandra.db.compression.CompressionDictionary; +import org.apache.cassandra.db.compression.CompressionDictionaryManager; import org.apache.cassandra.io.FSReadError; import org.apache.cassandra.io.FSWriteError; import org.apache.cassandra.io.sstable.CorruptSSTableException; @@ -61,11 +64,24 @@ public class CompressedSequentialWriter extends SequentialWriter private long uncompressedSize = 0, compressedSize = 0; private final MetadataCollector sstableMetadataCollector; + private final CompressionDictionaryManager compressionDictionaryManager; private final ByteBuffer crcCheckBuffer = ByteBuffer.allocate(4); private final Optional digestFile; private final int maxCompressedLength; + private final boolean isDictionaryEnabled; + + public CompressedSequentialWriter(File file, + File offsetsFile, + File digestFile, + SequentialWriterOption option, + CompressionParams parameters, + MetadataCollector sstableMetadataCollector) + { + this(file, offsetsFile, digestFile, option, parameters, sstableMetadataCollector, null); + } + /** * Create CompressedSequentialWriter without digest file. @@ -74,15 +90,17 @@ public class CompressedSequentialWriter extends SequentialWriter * @param offsetsFile File to write compression metadata * @param digestFile File to write digest * @param option Write option (buffer size and type will be set the same as compression params) - * @param parameters Compression mparameters + * @param parameters Compression parameters * @param sstableMetadataCollector Metadata collector + * @param compressionDictionaryManager manages compression dictionary; null if absent */ public CompressedSequentialWriter(File file, File offsetsFile, File digestFile, SequentialWriterOption option, CompressionParams parameters, - MetadataCollector sstableMetadataCollector) + MetadataCollector sstableMetadataCollector, + @Nullable CompressionDictionaryManager compressionDictionaryManager) { super(file, SequentialWriterOption.newBuilder() .bufferSize(option.bufferSize()) @@ -91,7 +109,7 @@ public CompressedSequentialWriter(File file, .bufferType(parameters.getSstableCompressor().preferredBufferType()) .finishOnClose(option.finishOnClose()) .build()); - this.compressor = parameters.getSstableCompressor(); + ICompressor compressor = parameters.getSstableCompressor(); this.digestFile = Optional.ofNullable(digestFile); // buffer for compression should be the same size as buffer itself @@ -99,8 +117,28 @@ public CompressedSequentialWriter(File file, maxCompressedLength = parameters.maxCompressedLength(); + // Note that we cannot rely on the compressor type to tell whether dictionary compression is enabled. + // Because the `CompressionParams` for this method is updated at the callsite, `DataComponent.buildWriter`. + // See CASSANDRA-15379 for details regarding the optimization. + // Meanwhile, as long as dictionary-based compression is enabled, we want to collect samples. + this.isDictionaryEnabled = compressionDictionaryManager != null && compressionDictionaryManager.isEnabled(); + + CompressionDictionary compressionDictionary = compressionDictionaryManager == null ? null : compressionDictionaryManager.getCurrent(); + if (compressionDictionary != null && compressor instanceof IDictionaryCompressor) + { + compressor = ((IDictionaryCompressor) compressor).getOrCopyWithDictionary(compressionDictionary); + } + else + { + // It is likely on the sstable flushing path and LZ4 compressor or something else is picked. + // In this case, we disable the compression dictionary, i.e. do not attach the dictionary + // bytes to the CompressionInfo component. + compressionDictionary = null; + } + this.compressor = compressor; + this.compressionDictionaryManager = compressionDictionaryManager; /* Index File (-CompressionInfo.db component) and it's header */ - metadataWriter = CompressionMetadata.Writer.open(parameters, offsetsFile); + metadataWriter = CompressionMetadata.Writer.open(parameters, offsetsFile, compressionDictionary); this.sstableMetadataCollector = sstableMetadataCollector; crcMetadata = new ChecksumWriter(new DataOutputStream(Channels.newOutputStream(channel))); @@ -145,6 +183,13 @@ protected void flushData() { // compressing data with buffer re-use buffer.flip(); + + // Collect sample for dictionary training before compression + if (isDictionaryEnabled) + { + compressionDictionaryManager.addSample(buffer.duplicate()); + } + compressed.clear(); compressor.compress(buffer, compressed); } @@ -440,4 +485,4 @@ public CompressedFileWriterMark(long chunkOffset, long uncDataOffset, int validB this.nextChunkIndex = nextChunkIndex; } } -} \ No newline at end of file +} diff --git a/src/java/org/apache/cassandra/io/compress/CompressionMetadata.java b/src/java/org/apache/cassandra/io/compress/CompressionMetadata.java index d5f5f05655e9..f49734a6597e 100644 --- a/src/java/org/apache/cassandra/io/compress/CompressionMetadata.java +++ b/src/java/org/apache/cassandra/io/compress/CompressionMetadata.java @@ -27,11 +27,14 @@ import java.util.Map; import java.util.SortedSet; import java.util.TreeSet; +import javax.annotation.Nullable; import com.google.common.annotations.VisibleForTesting; import com.google.common.primitives.Longs; import org.apache.cassandra.db.TypeSizes; +import org.apache.cassandra.db.compression.CompressionDictionary; +import org.apache.cassandra.db.compression.CompressionDictionaryManager; import org.apache.cassandra.exceptions.ConfigurationException; import org.apache.cassandra.io.FSReadError; import org.apache.cassandra.io.FSWriteError; @@ -62,16 +65,31 @@ public class CompressionMetadata extends WrappedSharedCloseable public final long dataLength; public final long compressedFileLength; private final Memory chunkOffsets; - private final long chunkOffsetsSize; + public final long chunkOffsetsSize; public final File chunksIndexFile; public final CompressionParams parameters; + @Nullable // null when no dictionary + private final CompressionDictionary compressionDictionary; + private volatile ICompressor resolvedCompressor; @VisibleForTesting - public static CompressionMetadata open(File chunksIndexFile, long compressedLength, boolean hasMaxCompressedSize) + public static CompressionMetadata open(File chunksIndexFile, + long compressedLength, + boolean hasMaxCompressedSize) + { + return open(chunksIndexFile, compressedLength, hasMaxCompressedSize, null); + } + + @VisibleForTesting + public static CompressionMetadata open(File chunksIndexFile, + long compressedLength, + boolean hasMaxCompressedSize, + @Nullable CompressionDictionaryManager compressionDictionaryManager) { CompressionParams parameters; long dataLength; Memory chunkOffsets; + CompressionDictionary compressionDictionary; try (FileInputStreamPlus stream = chunksIndexFile.newInputStream()) { @@ -99,6 +117,7 @@ public static CompressionMetadata open(File chunksIndexFile, long compressedLeng dataLength = stream.readLong(); chunkOffsets = readChunkOffsets(stream); + compressionDictionary = CompressionDictionary.deserialize(stream, compressionDictionaryManager); } catch (FileNotFoundException | NoSuchFileException e) { @@ -109,7 +128,9 @@ public static CompressionMetadata open(File chunksIndexFile, long compressedLeng throw new CorruptSSTableException(e, chunksIndexFile); } - return new CompressionMetadata(chunksIndexFile, parameters, chunkOffsets, chunkOffsets.size(), dataLength, compressedLength); + return new CompressionMetadata(chunksIndexFile, parameters, + chunkOffsets, chunkOffsets.size(), dataLength, + compressedLength, compressionDictionary); } // do not call this constructor directly, unless used in testing @@ -119,7 +140,8 @@ public CompressionMetadata(File chunksIndexFile, Memory chunkOffsets, long chunkOffsetsSize, long dataLength, - long compressedFileLength) + long compressedFileLength, + CompressionDictionary compressionDictionary) { super(chunkOffsets); this.chunksIndexFile = chunksIndexFile; @@ -128,6 +150,7 @@ public CompressionMetadata(File chunksIndexFile, this.compressedFileLength = compressedFileLength; this.chunkOffsets = chunkOffsets; this.chunkOffsetsSize = chunkOffsetsSize; + this.compressionDictionary = compressionDictionary; } private CompressionMetadata(CompressionMetadata copy) @@ -139,11 +162,46 @@ private CompressionMetadata(CompressionMetadata copy) this.compressedFileLength = copy.compressedFileLength; this.chunkOffsets = copy.chunkOffsets; this.chunkOffsetsSize = copy.chunkOffsetsSize; + this.compressionDictionary = copy.compressionDictionary; } public ICompressor compressor() { - return parameters.getSstableCompressor(); + // classic double-checked locking to call resolveCompressor method just once per CompressionMetadata object + ICompressor result = resolvedCompressor; + if (result != null) + return result; + + synchronized (this) + { + result = resolvedCompressor; + if (result == null) + { + result = resolveCompressor(parameters.getSstableCompressor(), compressionDictionary); + resolvedCompressor = result; + } + return result; + } + } + + static ICompressor resolveCompressor(ICompressor compressor, CompressionDictionary dictionary) + { + if (dictionary == null) + return compressor; + + // When the attached dictionary can be consumed by the current dictionary compressor + if (compressor instanceof IDictionaryCompressor) + { + IDictionaryCompressor dictionaryCompressor = (IDictionaryCompressor) compressor; + if (dictionaryCompressor.canConsumeDictionary(dictionary)) + return dictionaryCompressor.getOrCopyWithDictionary(dictionary); + } + + // When the current compressor is not compatible with the dictionary. It could happen in the read path when: + // 1. The current compressor is not a dictionary compressor, but there is dictionary attached + // 2. The current dictionary compressor is a different type, e.g. table schema is changed + // In those cases, we should get the compatible dictionary compressor based on the dictionary + return dictionary.kind().createCompressor(dictionary); } public int chunkLength() @@ -349,16 +407,21 @@ public static class Writer extends Transactional.AbstractTransactional implement // provided by user when setDescriptor private long dataLength, chunkCount; + @Nullable + private CompressionDictionary compressionDictionary; - private Writer(CompressionParams parameters, File file) + private Writer(CompressionParams parameters, File file, CompressionDictionary compressionDictionary) { this.parameters = parameters; this.file = file; + this.compressionDictionary = compressionDictionary; } - public static Writer open(CompressionParams parameters, File file) + public static Writer open(CompressionParams parameters, + File file, + CompressionDictionary compressionDictionary) { - return new Writer(parameters, file); + return new Writer(parameters, file, compressionDictionary); } public void addOffset(long offset) @@ -397,6 +460,21 @@ private void writeHeader(DataOutput out, long dataLength, int chunks) } } + private void writeCompressionDictionary(DataOutput out) + { + if (compressionDictionary == null) + return; + + try + { + compressionDictionary.serialize(out); + } + catch (IOException e) + { + throw new FSWriteError(e, file); + } + } + // we've written everything; wire up some final metadata state public Writer finalizeLength(long dataLength, int chunkCount) { @@ -426,6 +504,7 @@ public void doPrepare() for (int i = 0; i < count; i++) out.writeLong(offsets.getLong(i * 8L)); + writeCompressionDictionary(out); out.flush(); out.sync(); } @@ -453,7 +532,9 @@ public CompressionMetadata open(long dataLength, long compressedLength) if (tCount < this.count) compressedLength = tOffsets.getLong(tCount * 8L); - return new CompressionMetadata(file, parameters, tOffsets, tCount * 8L, dataLength, compressedLength); + return new CompressionMetadata(file, parameters, + tOffsets, tCount * 8L, dataLength, + compressedLength, compressionDictionary); } /** diff --git a/src/java/org/apache/cassandra/io/compress/ICompressor.java b/src/java/org/apache/cassandra/io/compress/ICompressor.java index fd6a104431b3..950ae03e3de4 100644 --- a/src/java/org/apache/cassandra/io/compress/ICompressor.java +++ b/src/java/org/apache/cassandra/io/compress/ICompressor.java @@ -37,6 +37,11 @@ enum Uses { FAST_COMPRESSION } + /** + * Get the maximum compressed size in the worst case scenario + * @param chunkLength input data (chunk) size + * @return compressed size upper bound in the worse case + */ public int initialCompressedBufferLength(int chunkLength); public int uncompress(byte[] input, int inputOffset, int inputLength, byte[] output, int outputOffset) throws IOException; diff --git a/src/java/org/apache/cassandra/io/compress/IDictionaryCompressor.java b/src/java/org/apache/cassandra/io/compress/IDictionaryCompressor.java new file mode 100644 index 000000000000..f6f7681f74ab --- /dev/null +++ b/src/java/org/apache/cassandra/io/compress/IDictionaryCompressor.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.io.compress; + +import org.apache.cassandra.db.compression.CompressionDictionary; + +/** + * Interface for compressors that support dictionary-based compression. + *
+ * Dictionary compressors can use pre-trained compression dictionaries to achieve + * better compression ratios, especially for small data chunks that are similar + * to the training data used to create the dictionary. + * + * @param the specific type of compression dictionary this compressor supports + */ +public interface IDictionaryCompressor +{ + /** + * Returns a compressor instance configured with the specified compression dictionary. + *
+ * This method may return the same instance if it already uses the given dictionary, + * or create a new instance configured with the dictionary. The implementation should + * be efficient and avoid unnecessary object creation when possible. + * + * @param compressionDictionary the dictionary to use for compression/decompression + * @return a compressor instance that will use the specified dictionary + */ + ICompressor getOrCopyWithDictionary(T compressionDictionary); + + /** + * Returns the kind of compression dictionary that this compressor can accept. + *
+ * This is used to validate dictionary compatibility before attempting to use + * a dictionary with this compressor. Only dictionaries of the returned kind + * should be passed to {@link #getOrCopyWithDictionary(CompressionDictionary)}. + * + * @return the compression dictionary kind supported by this compressor + */ + CompressionDictionary.Kind acceptableDictionaryKind(); + + /** + * Checks whether this compressor can use the given compression dictionary. + *
+ * The default implementation compares the dictionary's kind with the kind + * returned by {@link #acceptableDictionaryKind()}. Compressor implementations + * may override this method to provide more sophisticated compatibility checks. + * + * @param dictionary the compression dictionary to check for compatibility + * @return true if this compressor can use the dictionary, false otherwise + */ + default boolean canConsumeDictionary(CompressionDictionary dictionary) + { + return dictionary.kind() == acceptableDictionaryKind(); + } +} diff --git a/src/java/org/apache/cassandra/io/compress/ZstdCompressor.java b/src/java/org/apache/cassandra/io/compress/ZstdCompressor.java index c86db26c8621..9327de4c139c 100644 --- a/src/java/org/apache/cassandra/io/compress/ZstdCompressor.java +++ b/src/java/org/apache/cassandra/io/compress/ZstdCompressor.java @@ -18,72 +18,31 @@ package org.apache.cassandra.io.compress; -import java.io.IOException; -import java.nio.ByteBuffer; import java.util.Collections; -import java.util.HashSet; import java.util.Map; -import java.util.Set; import java.util.concurrent.ConcurrentHashMap; -import com.google.common.annotations.VisibleForTesting; -import com.google.common.collect.ImmutableSet; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import com.github.luben.zstd.Zstd; - /** * ZSTD Compressor */ -public class ZstdCompressor implements ICompressor +public class ZstdCompressor extends ZstdCompressorBase implements ICompressor { - private static final Logger logger = LoggerFactory.getLogger(ZstdCompressor.class); - - // These might change with the version of Zstd we're using - public static final int FAST_COMPRESSION_LEVEL = Zstd.minCompressionLevel(); - public static final int BEST_COMPRESSION_LEVEL = Zstd.maxCompressionLevel(); - - // Compressor Defaults - public static final int DEFAULT_COMPRESSION_LEVEL = 3; - private static final boolean ENABLE_CHECKSUM_FLAG = true; - - @VisibleForTesting - public static final String COMPRESSION_LEVEL_OPTION_NAME = "compression_level"; - private static final ConcurrentHashMap instances = new ConcurrentHashMap<>(); - private final int compressionLevel; - private final Set recommendedUses; - /** * Create a Zstd compressor with the given options + * Invoked by {@link org.apache.cassandra.schema.CompressionParams#createCompressor} via reflection * - * @param options - * @return + * @param options compression options + * @return ZstdCompressor */ public static ZstdCompressor create(Map options) { int level = getOrDefaultCompressionLevel(options); - - if (!isValid(level)) - throw new IllegalArgumentException(String.format("%s=%d is invalid", COMPRESSION_LEVEL_OPTION_NAME, level)); - + validateCompressionLevel(level); return getOrCreate(level); } - /** - * Private constructor - * - * @param compressionLevel - */ - private ZstdCompressor(int compressionLevel) - { - this.compressionLevel = compressionLevel; - this.recommendedUses = ImmutableSet.of(Uses.GENERAL); - logger.trace("Creating Zstd Compressor with compression level={}", compressionLevel); - } - /** * Get a cached instance or return a new one * @@ -92,157 +51,16 @@ private ZstdCompressor(int compressionLevel) */ public static ZstdCompressor getOrCreate(int level) { - return instances.computeIfAbsent(level, l -> new ZstdCompressor(level)); - } - - /** - * Get initial compressed buffer length - * - * @param chunkLength - * @return - */ - @Override - public int initialCompressedBufferLength(int chunkLength) - { - return (int) Zstd.compressBound(chunkLength); - } - - /** - * Decompress data using arrays - * - * @param input - * @param inputOffset - * @param inputLength - * @param output - * @param outputOffset - * @return - * @throws IOException - */ - @Override - public int uncompress(byte[] input, int inputOffset, int inputLength, byte[] output, int outputOffset) - throws IOException - { - long dsz = Zstd.decompressByteArray(output, outputOffset, output.length - outputOffset, - input, inputOffset, inputLength); - - if (Zstd.isError(dsz)) - throw new IOException(String.format("Decompression failed due to %s", Zstd.getErrorName(dsz))); - - return (int) dsz; - } - - /** - * Decompress data via ByteBuffers - * - * @param input - * @param output - * @throws IOException - */ - @Override - public void uncompress(ByteBuffer input, ByteBuffer output) throws IOException - { - try - { - Zstd.decompress(output, input); - } catch (Exception e) - { - throw new IOException("Decompression failed", e); - } - } - - /** - * Compress using ByteBuffers - * - * @param input - * @param output - * @throws IOException - */ - @Override - public void compress(ByteBuffer input, ByteBuffer output) throws IOException - { - try - { - Zstd.compress(output, input, compressionLevel, ENABLE_CHECKSUM_FLAG); - } catch (Exception e) - { - throw new IOException("Compression failed", e); - } - } - - /** - * Check if the given compression level is valid. This can be a negative value as well. - * - * @param level - * @return - */ - private static boolean isValid(int level) - { - return (level >= FAST_COMPRESSION_LEVEL && level <= BEST_COMPRESSION_LEVEL); + return instances.computeIfAbsent(level, ZstdCompressor::new); } /** - * Parse the compression options - * - * @param options - * @return - */ - private static int getOrDefaultCompressionLevel(Map options) - { - if (options == null) - return DEFAULT_COMPRESSION_LEVEL; - - String val = options.get(COMPRESSION_LEVEL_OPTION_NAME); - - if (val == null) - return DEFAULT_COMPRESSION_LEVEL; - - return Integer.valueOf(val); - } - - /** - * Return the preferred BufferType - * - * @return - */ - @Override - public BufferType preferredBufferType() - { - return BufferType.OFF_HEAP; - } - - /** - * Check whether the given BufferType is supported - * - * @param bufferType - * @return - */ - @Override - public boolean supports(BufferType bufferType) - { - return bufferType == BufferType.OFF_HEAP; - } - - /** - * Lists the supported options by this compressor + * Private constructor * - * @return + * @param compressionLevel */ - @Override - public Set supportedOptions() - { - return new HashSet<>(Collections.singletonList(COMPRESSION_LEVEL_OPTION_NAME)); - } - - - @VisibleForTesting - public int getCompressionLevel() - { - return compressionLevel; - } - - @Override - public Set recommendedUses() + private ZstdCompressor(int compressionLevel) { - return recommendedUses; + super(compressionLevel, Collections.singleton(COMPRESSION_LEVEL_OPTION_NAME)); } } diff --git a/src/java/org/apache/cassandra/io/compress/ZstdCompressorBase.java b/src/java/org/apache/cassandra/io/compress/ZstdCompressorBase.java new file mode 100644 index 000000000000..bbe278cbf43c --- /dev/null +++ b/src/java/org/apache/cassandra/io/compress/ZstdCompressorBase.java @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.io.compress; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Collections; +import java.util.Map; +import java.util.Set; + +import com.google.common.annotations.VisibleForTesting; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.github.luben.zstd.Zstd; + +public abstract class ZstdCompressorBase implements ICompressor +{ + // These might change with the version of Zstd we're using + public static final int FAST_COMPRESSION_LEVEL = Zstd.minCompressionLevel(); + public static final int BEST_COMPRESSION_LEVEL = Zstd.maxCompressionLevel(); + + // Compressor Defaults + public static final int DEFAULT_COMPRESSION_LEVEL = 3; + public static final boolean ENABLE_CHECKSUM_FLAG = true; + + // Compressor option names + public static final String COMPRESSION_LEVEL_OPTION_NAME = "compression_level"; + + protected final Logger logger = LoggerFactory.getLogger(getClass()); + + private final int compressionLevel; + private final Set recommendedUses; + private final Set supportedOptions; + + protected ZstdCompressorBase(int compressionLevel, Set supportedOptions) + { + this.compressionLevel = compressionLevel; + this.supportedOptions = Collections.unmodifiableSet(supportedOptions); + this.recommendedUses = Set.of(ICompressor.Uses.GENERAL); + logger.trace("Creating Zstd Compressor with compression level={}", compressionLevel); + } + + @Override + public int initialCompressedBufferLength(int chunkLength) + { + return (int) Zstd.compressBound(chunkLength); + } + + @Override + public BufferType preferredBufferType() + { + return BufferType.OFF_HEAP; + } + + @Override + public boolean supports(BufferType bufferType) + { + return bufferType == BufferType.OFF_HEAP; + } + + @Override + public Set recommendedUses() + { + return recommendedUses; + } + + @VisibleForTesting + public int compressionLevel() + { + return compressionLevel; + } + + @Override + public Set supportedOptions() + { + return supportedOptions; + } + + /** + * Decompress data using arrays + * + * @param input + * @param inputOffset + * @param inputLength + * @param output + * @param outputOffset + * @return + * @throws IOException + */ + @Override + public int uncompress(byte[] input, int inputOffset, int inputLength, byte[] output, int outputOffset) + throws IOException + { + long dsz; + try + { + dsz = Zstd.decompressByteArray(output, outputOffset, output.length - outputOffset, + input, inputOffset, inputLength); + } + catch (Exception e) + { + throw new IOException("Decompression failed", e); + } + + if (Zstd.isError(dsz)) + throw new IOException("Decompression failed due to " + Zstd.getErrorName(dsz)); + + return (int) dsz; + } + + /** + * Decompress data via ByteBuffers + * + * @param input + * @param output + * @throws IOException + */ + @Override + public void uncompress(ByteBuffer input, ByteBuffer output) throws IOException + { + try + { + Zstd.decompress(output, input); + } catch (Exception e) + { + throw new IOException("Decompression failed", e); + } + } + + /** + * Compress using ByteBuffers + * + * @param input + * @param output + * @throws IOException + */ + @Override + public void compress(ByteBuffer input, ByteBuffer output) throws IOException + { + try + { + Zstd.compress(output, input, compressionLevel(), ENABLE_CHECKSUM_FLAG); + } catch (Exception e) + { + throw new IOException("Compression failed", e); + } + } + + /** + * Check if the given compression level is valid. This can be a negative value as well. + * + * @param level compression level + */ + public static void validateCompressionLevel(int level) + { + if (level < FAST_COMPRESSION_LEVEL || level > BEST_COMPRESSION_LEVEL) + { + throw new IllegalArgumentException(String.format("%s=%d is invalid", COMPRESSION_LEVEL_OPTION_NAME, level)); + } + } + + /** + * Get the supplied compression level; otherwise, use the default + * + * @param options compression options + * @return compression level + */ + public static int getOrDefaultCompressionLevel(Map options) + { + if (options == null) + return DEFAULT_COMPRESSION_LEVEL; + + String val = options.get(COMPRESSION_LEVEL_OPTION_NAME); + + if (val == null) + return DEFAULT_COMPRESSION_LEVEL; + + return Integer.parseInt(val); + } +} diff --git a/src/java/org/apache/cassandra/io/compress/ZstdDictionaryCompressor.java b/src/java/org/apache/cassandra/io/compress/ZstdDictionaryCompressor.java new file mode 100644 index 000000000000..3ba5841aaa54 --- /dev/null +++ b/src/java/org/apache/cassandra/io/compress/ZstdDictionaryCompressor.java @@ -0,0 +1,216 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.io.compress; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.time.Duration; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; + +import com.github.benmanes.caffeine.cache.Cache; +import com.github.benmanes.caffeine.cache.Caffeine; +import com.github.benmanes.caffeine.cache.RemovalCause; +import com.github.luben.zstd.Zstd; + +import com.google.common.annotations.VisibleForTesting; + +import org.apache.cassandra.config.DatabaseDescriptor; +import org.apache.cassandra.db.compression.ZstdCompressionDictionary; +import org.apache.cassandra.db.compression.CompressionDictionary.Kind; +import org.apache.cassandra.utils.concurrent.Ref; + +import javax.annotation.Nullable; + +public class ZstdDictionaryCompressor extends ZstdCompressorBase implements ICompressor, IDictionaryCompressor +{ + private static final ConcurrentHashMap instancesPerLevel = new ConcurrentHashMap<>(); + private static final Cache instancePerDict = + Caffeine.newBuilder() + .maximumSize(DatabaseDescriptor.getCompressionDictionaryCacheSize()) + .expireAfterAccess(Duration.ofSeconds(DatabaseDescriptor.getCompressionDictionaryCacheExpireSeconds())) + .removalListener((ZstdCompressionDictionary dictionary, + ZstdDictionaryCompressor compressor, + RemovalCause cause) -> { + // Release dictionary reference when compressor is evicted from cache + if (compressor != null && compressor.dictionaryRef != null) + { + compressor.dictionaryRef.release(); + } + }) + .build(); + + // dictionary and its ref are null, when they are absent. + // In this case, the compressor falls back to be the same as ZstdCompressor + @Nullable + private final ZstdCompressionDictionary dictionary; + @Nullable + private final Ref dictionaryRef; + + /** + * Create a ZstdDictionaryCompressor with the given options + * Invoked by {@link org.apache.cassandra.schema.CompressionParams#createCompressor} via reflection + * + * @param options compression options + * @return ZstdDictionaryCompressor + */ + public static ZstdDictionaryCompressor create(Map options) + { + int level = getOrDefaultCompressionLevel(options); + validateCompressionLevel(level); + return getOrCreate(level, null); + } + + // Constructor used to create the compressor for reading the sstable; the compression level is not relevant + public static ZstdDictionaryCompressor create(ZstdCompressionDictionary dictionary) + { + return getOrCreate(DEFAULT_COMPRESSION_LEVEL, dictionary); + } + + private static ZstdDictionaryCompressor getOrCreate(int level, ZstdCompressionDictionary dictionary) + { + if (dictionary == null) + { + return instancesPerLevel.computeIfAbsent(level, ZstdDictionaryCompressor::new); + } + + return instancePerDict.get(dictionary, dict -> { + // Get a reference to the dictionary when creating new compressor + Ref ref = dict != null ? dict.tryRef() : null; + if (ref == null && dict != null) + { + // Dictionary is being closed, cannot create compressor + throw new IllegalStateException("Dictionary is being closed"); + } + return new ZstdDictionaryCompressor(level, dictionary, ref); + }); + } + + private ZstdDictionaryCompressor(int level) + { + this(level, null, null); + } + + private ZstdDictionaryCompressor(int level, ZstdCompressionDictionary dictionary, Ref dictionaryRef) + { + super(level, Set.of(COMPRESSION_LEVEL_OPTION_NAME)); + this.dictionary = dictionary; + this.dictionaryRef = dictionaryRef; + } + + @Override + public ZstdDictionaryCompressor getOrCopyWithDictionary(ZstdCompressionDictionary compressionDictionary) + { + return getOrCreate(compressionLevel(), compressionDictionary); + } + + @Override + public Kind acceptableDictionaryKind() + { + return Kind.ZSTD; + } + + @Override + public int uncompress(byte[] input, int inputOffset, int inputLength, byte[] output, int outputOffset) throws IOException + { + // fallback to non-dict zstd compressor + if (dictionary == null) + { + return super.uncompress(input, inputOffset, inputLength, output, outputOffset); + } + + int dsz; + try + { + dsz = (int) Zstd.decompressFastDict(output, outputOffset, + input, inputOffset, inputLength, + dictionary.dictionaryForDecompression()); + } + catch (Exception e) + { + throw new IOException("Decompression failed", e); + } + + if (Zstd.isError(dsz)) + throw new IOException("Decompression failed due to " + Zstd.getErrorName(dsz)); + + return dsz; + } + + @Override + public void uncompress(ByteBuffer input, ByteBuffer output) throws IOException + { + if (dictionary == null) + { + super.uncompress(input, output); + return; + } + + try + { + // Zstd compressors expect only direct bytebuffer. See ZstdCompressorBase.preferredBufferType and supports + int decompressedSize = (int) Zstd.decompressDirectByteBufferFastDict(output, output.position(), output.limit() - output.position(), + input, input.position(), input.limit() - input.position(), + dictionary.dictionaryForDecompression()); + output.position(output.position() + decompressedSize); + input.position(input.limit()); + } + catch (Exception e) + { + throw new IOException("Decompression failed", e); + } + } + + @Override + public void compress(ByteBuffer input, ByteBuffer output) throws IOException + { + if (dictionary == null) + { + super.compress(input, output); + return; + } + + try + { + // Zstd compressors expect only direct bytebuffer. See ZstdCompressorBase.preferredBufferType and supports + int compressedSize = (int) Zstd.compressDirectByteBufferFastDict(output, output.position(), output.limit() - output.position(), + input, input.position(), input.limit() - input.position(), + dictionary.dictionaryForCompression(compressionLevel())); + output.position(output.position() + compressedSize); + input.position(input.limit()); + } + catch (Exception e) + { + throw new IOException("Compression failed", e); + } + } + + @VisibleForTesting + ZstdCompressionDictionary dictionary() + { + return dictionary; + } + + @VisibleForTesting + public static void invalidateCache() + { + instancePerDict.invalidateAll(); + } +} diff --git a/src/java/org/apache/cassandra/io/sstable/SSTable.java b/src/java/org/apache/cassandra/io/sstable/SSTable.java index 14c7af6cd5c8..0ee4de1897ec 100644 --- a/src/java/org/apache/cassandra/io/sstable/SSTable.java +++ b/src/java/org/apache/cassandra/io/sstable/SSTable.java @@ -43,6 +43,7 @@ import org.apache.cassandra.dht.AbstractBounds; import org.apache.cassandra.dht.IPartitioner; import org.apache.cassandra.dht.Token; +import org.apache.cassandra.db.compression.CompressionDictionaryManager; import org.apache.cassandra.io.sstable.format.SSTableFormat; import org.apache.cassandra.io.sstable.format.SSTableFormat.Components; import org.apache.cassandra.io.sstable.format.TOCComponent; @@ -369,6 +370,8 @@ public interface Owner OpOrder.Barrier newReadOrderingBarrier(); TableMetrics getMetrics(); + + CompressionDictionaryManager compressionDictionaryManager(); } /** diff --git a/src/java/org/apache/cassandra/io/sstable/SimpleSSTableMultiWriter.java b/src/java/org/apache/cassandra/io/sstable/SimpleSSTableMultiWriter.java index 6c7b7b6d7fe2..fa541d075c1d 100644 --- a/src/java/org/apache/cassandra/io/sstable/SimpleSSTableMultiWriter.java +++ b/src/java/org/apache/cassandra/io/sstable/SimpleSSTableMultiWriter.java @@ -23,6 +23,7 @@ import org.apache.cassandra.db.SerializationHeader; import org.apache.cassandra.db.commitlog.CommitLogPosition; import org.apache.cassandra.db.commitlog.IntervalSet; +import org.apache.cassandra.db.compression.CompressionDictionaryManager; import org.apache.cassandra.db.lifecycle.ILifecycleTransaction; import org.apache.cassandra.db.rows.UnfilteredRowIterator; import org.apache.cassandra.index.Index; @@ -122,17 +123,21 @@ public static SSTableMultiWriter create(Descriptor descriptor, MetadataCollector metadataCollector = new MetadataCollector(metadata.get().comparator) .commitLogIntervals(commitLogPositions != null ? commitLogPositions : IntervalSet.empty()) .sstableLevel(sstableLevel); - SSTableWriter writer = descriptor.getFormat().getWriterFactory().builder(descriptor) - .setKeyCount(keyCount) - .setRepairedAt(repairedAt) - .setPendingRepair(pendingRepair) - .setTransientSSTable(isTransient) - .setTableMetadataRef(metadata) - .setMetadataCollector(metadataCollector) - .setSerializationHeader(header) - .addDefaultComponents(indexGroups) - .setSecondaryIndexGroups(indexGroups) - .build(txn, owner); + CompressionDictionaryManager compressionDictionaryManager = owner == null ? null : owner.compressionDictionaryManager(); + SSTableWriter writer = descriptor.getFormat() + .getWriterFactory() + .builder(descriptor) + .setKeyCount(keyCount) + .setRepairedAt(repairedAt) + .setPendingRepair(pendingRepair) + .setTransientSSTable(isTransient) + .setTableMetadataRef(metadata) + .setMetadataCollector(metadataCollector) + .setSerializationHeader(header) + .addDefaultComponents(indexGroups) + .setSecondaryIndexGroups(indexGroups) + .setCompressionDictionaryManager(compressionDictionaryManager) + .build(txn, owner); return new SimpleSSTableMultiWriter(writer, txn); } } diff --git a/src/java/org/apache/cassandra/io/sstable/format/CompressionInfoComponent.java b/src/java/org/apache/cassandra/io/sstable/format/CompressionInfoComponent.java index 0e24fa991d72..abb0658c9d33 100644 --- a/src/java/org/apache/cassandra/io/sstable/format/CompressionInfoComponent.java +++ b/src/java/org/apache/cassandra/io/sstable/format/CompressionInfoComponent.java @@ -22,6 +22,9 @@ import java.nio.file.NoSuchFileException; import java.util.Set; +import javax.annotation.Nullable; + +import org.apache.cassandra.db.compression.CompressionDictionaryManager; import org.apache.cassandra.io.FSReadError; import org.apache.cassandra.io.compress.CompressionMetadata; import org.apache.cassandra.io.sstable.Component; @@ -32,27 +35,31 @@ public class CompressionInfoComponent { - public static CompressionMetadata maybeLoad(Descriptor descriptor, Set components) + public static CompressionMetadata maybeLoad(Descriptor descriptor, Set components, + @Nullable CompressionDictionaryManager compressionDictionaryManager) { if (components.contains(Components.COMPRESSION_INFO)) - return load(descriptor); + return load(descriptor, compressionDictionaryManager); return null; } - public static CompressionMetadata loadIfExists(Descriptor descriptor) + public static CompressionMetadata loadIfExists(Descriptor descriptor, + @Nullable CompressionDictionaryManager compressionDictionaryManager) { if (descriptor.fileFor(Components.COMPRESSION_INFO).exists()) - return load(descriptor); + return load(descriptor, compressionDictionaryManager); return null; } - public static CompressionMetadata load(Descriptor descriptor) + public static CompressionMetadata load(Descriptor descriptor, + @Nullable CompressionDictionaryManager compressionDictionaryManager) { return CompressionMetadata.open(descriptor.fileFor(Components.COMPRESSION_INFO), descriptor.fileFor(Components.DATA).length(), - descriptor.version.hasMaxCompressedLength()); + descriptor.version.hasMaxCompressedLength(), + compressionDictionaryManager); } /** diff --git a/src/java/org/apache/cassandra/io/sstable/format/DataComponent.java b/src/java/org/apache/cassandra/io/sstable/format/DataComponent.java index 9367cb444d80..69528628d348 100644 --- a/src/java/org/apache/cassandra/io/sstable/format/DataComponent.java +++ b/src/java/org/apache/cassandra/io/sstable/format/DataComponent.java @@ -20,6 +20,7 @@ import org.apache.cassandra.config.Config.FlushCompression; import org.apache.cassandra.db.compaction.OperationType; +import org.apache.cassandra.db.compression.CompressionDictionaryManager; import org.apache.cassandra.io.compress.CompressedSequentialWriter; import org.apache.cassandra.io.compress.ICompressor; import org.apache.cassandra.io.sstable.Descriptor; @@ -38,7 +39,8 @@ public static SequentialWriter buildWriter(Descriptor descriptor, SequentialWriterOption options, MetadataCollector metadataCollector, OperationType operationType, - FlushCompression flushCompression) + FlushCompression flushCompression, + CompressionDictionaryManager compressionDictionaryManager) { if (metadata.params.compression.isEnabled()) { @@ -49,7 +51,8 @@ public static SequentialWriter buildWriter(Descriptor descriptor, descriptor.fileFor(Components.DIGEST), options, compressionParams, - metadataCollector); + metadataCollector, + compressionDictionaryManager); } else { diff --git a/src/java/org/apache/cassandra/io/sstable/format/SSTableWriter.java b/src/java/org/apache/cassandra/io/sstable/format/SSTableWriter.java index 073936f62961..c09459b37d40 100644 --- a/src/java/org/apache/cassandra/io/sstable/format/SSTableWriter.java +++ b/src/java/org/apache/cassandra/io/sstable/format/SSTableWriter.java @@ -28,6 +28,7 @@ import java.util.Set; import java.util.function.Consumer; import java.util.function.Supplier; +import javax.annotation.Nullable; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; @@ -37,6 +38,7 @@ import org.apache.cassandra.db.DecoratedKey; import org.apache.cassandra.db.SerializationHeader; +import org.apache.cassandra.db.compression.CompressionDictionaryManager; import org.apache.cassandra.db.lifecycle.ILifecycleTransaction; import org.apache.cassandra.db.rows.UnfilteredRowIterator; import org.apache.cassandra.dht.AbstractBounds; @@ -443,6 +445,8 @@ public abstract static class Builder indexGroups; + @Nullable + private CompressionDictionaryManager compressionDictionaryManager; public B setMetadataCollector(MetadataCollector metadataCollector) { @@ -521,6 +525,18 @@ public B setSecondaryIndexGroups(Collection indexGroups) return (B) this; } + public B setCompressionDictionaryManager(CompressionDictionaryManager compressionDictionaryManager) + { + this.compressionDictionaryManager = compressionDictionaryManager; + return (B) this; + } + + @Nullable + public CompressionDictionaryManager getCompressionDictionaryManager() + { + return compressionDictionaryManager; + } + public MetadataCollector getMetadataCollector() { return metadataCollector; diff --git a/src/java/org/apache/cassandra/io/sstable/format/big/BigFormat.java b/src/java/org/apache/cassandra/io/sstable/format/big/BigFormat.java index e987506c40f8..e6b60c2a0656 100644 --- a/src/java/org/apache/cassandra/io/sstable/format/big/BigFormat.java +++ b/src/java/org/apache/cassandra/io/sstable/format/big/BigFormat.java @@ -433,7 +433,8 @@ public BigTableWriter.Builder builder(Descriptor descriptor) static class BigVersion extends Version { - public static final String current_version = DatabaseDescriptor.getStorageCompatibilityMode().isBefore(5) ? "nb" : "oa"; + public static final String current_version = DatabaseDescriptor.getStorageCompatibilityMode().isBefore(5) ? "nb" : + DatabaseDescriptor.getStorageCompatibilityMode().isBefore(6) ? "oa" : "pa"; public static final String earliest_supported_version = "ma"; // ma (3.0.0): swap bf hash order @@ -448,6 +449,7 @@ static class BigVersion extends Version // oa (5.0): improved min/max, partition level deletion presence marker, key range (CASSANDRA-18134) // Long deletionTime to prevent TTL overflow // token space coverage + // pa (6.0): compression dictionary metadata in CompressionInfo component // // NOTE: When adding a new version: // - Please add it to LegacySSTableTest diff --git a/src/java/org/apache/cassandra/io/sstable/format/big/BigSSTableReaderLoadingBuilder.java b/src/java/org/apache/cassandra/io/sstable/format/big/BigSSTableReaderLoadingBuilder.java index 84e02217d565..3557b0d80227 100644 --- a/src/java/org/apache/cassandra/io/sstable/format/big/BigSSTableReaderLoadingBuilder.java +++ b/src/java/org/apache/cassandra/io/sstable/format/big/BigSSTableReaderLoadingBuilder.java @@ -26,6 +26,7 @@ import org.apache.cassandra.db.DecoratedKey; import org.apache.cassandra.db.SerializationHeader; +import org.apache.cassandra.db.compression.CompressionDictionaryManager; import org.apache.cassandra.io.compress.CompressionMetadata; import org.apache.cassandra.io.sstable.Downsampling; import org.apache.cassandra.io.sstable.KeyReader; @@ -137,7 +138,8 @@ protected void openComponents(BigTableReader.Builder builder, SSTable.Owner owne } } - try (CompressionMetadata compressionMetadata = CompressionInfoComponent.maybeLoad(descriptor, components)) + CompressionDictionaryManager compressionDictionaryManager = owner == null ? null : owner.compressionDictionaryManager(); + try (CompressionMetadata compressionMetadata = CompressionInfoComponent.maybeLoad(descriptor, components, compressionDictionaryManager)) { builder.setDataFile(dataFileBuilder(builder.getStatsMetadata()) .withCompressionMetadata(compressionMetadata) diff --git a/src/java/org/apache/cassandra/io/sstable/format/big/BigTableWriter.java b/src/java/org/apache/cassandra/io/sstable/format/big/BigTableWriter.java index 3233ca4c0633..6cf01e7fbcb6 100644 --- a/src/java/org/apache/cassandra/io/sstable/format/big/BigTableWriter.java +++ b/src/java/org/apache/cassandra/io/sstable/format/big/BigTableWriter.java @@ -389,7 +389,8 @@ protected SequentialWriter openDataWriter() getIOOptions().writerOptions, getMetadataCollector(), ensuringInBuildInternalContext(operationType), - getIOOptions().flushCompression); + getIOOptions().flushCompression, + getCompressionDictionaryManager()); this.dataWriterOpened = true; return dataWriter; } diff --git a/src/java/org/apache/cassandra/io/sstable/format/bti/BtiFormat.java b/src/java/org/apache/cassandra/io/sstable/format/bti/BtiFormat.java index e7703c6a0612..8c6dbf1f1f62 100644 --- a/src/java/org/apache/cassandra/io/sstable/format/bti/BtiFormat.java +++ b/src/java/org/apache/cassandra/io/sstable/format/bti/BtiFormat.java @@ -286,11 +286,12 @@ public long estimateSize(SSTableWriter.SSTableSizeParameters parameters) static class BtiVersion extends Version { - public static final String current_version = "da"; + public static final String current_version = "ea"; public static final String earliest_supported_version = "da"; // versions aa-cz are not supported in OSS - // da (5.0): initial version of the BIT format + // da (5.0): initial version of the BTI format + // ea (6.0): compression dictionary metadata in CompressionInfo component // NOTE: when adding a new version, please add that to LegacySSTableTest, too. private final boolean isLatestVersion; diff --git a/src/java/org/apache/cassandra/io/sstable/format/bti/BtiTableReaderLoadingBuilder.java b/src/java/org/apache/cassandra/io/sstable/format/bti/BtiTableReaderLoadingBuilder.java index fa408adc5d0e..8f47e89a9a96 100644 --- a/src/java/org/apache/cassandra/io/sstable/format/bti/BtiTableReaderLoadingBuilder.java +++ b/src/java/org/apache/cassandra/io/sstable/format/bti/BtiTableReaderLoadingBuilder.java @@ -23,7 +23,9 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.cassandra.db.ColumnFamilyStore; import org.apache.cassandra.db.DecoratedKey; +import org.apache.cassandra.db.compression.CompressionDictionaryManager; import org.apache.cassandra.dht.IPartitioner; import org.apache.cassandra.io.compress.CompressionMetadata; import org.apache.cassandra.io.sstable.KeyReader; @@ -38,6 +40,7 @@ import org.apache.cassandra.io.sstable.metadata.ValidationMetadata; import org.apache.cassandra.io.util.FileHandle; import org.apache.cassandra.metrics.TableMetrics; +import org.apache.cassandra.schema.Schema; import org.apache.cassandra.utils.FilterFactory; import org.apache.cassandra.utils.IFilter; import org.apache.cassandra.utils.Throwables; @@ -68,8 +71,15 @@ private KeyReader createKeyReader(StatsMetadata statsMetadata) throws IOExceptio { checkNotNull(statsMetadata); + ColumnFamilyStore cfs = Schema.instance.getColumnFamilyStoreInstance(tableMetadataRef.id); + CompressionDictionaryManager compressionDictionaryManager = null; + if (cfs != null) + { + compressionDictionaryManager = cfs.compressionDictionaryManager(); + } + try (PartitionIndex index = PartitionIndex.load(partitionIndexFileBuilder(), tableMetadataRef.getLocal().partitioner, false); - CompressionMetadata compressionMetadata = CompressionInfoComponent.maybeLoad(descriptor, components); + CompressionMetadata compressionMetadata = CompressionInfoComponent.maybeLoad(descriptor, components, compressionDictionaryManager); FileHandle dFile = dataFileBuilder(statsMetadata).withCompressionMetadata(compressionMetadata) .withCrcCheckChance(() -> tableMetadataRef.getLocal().params.crcCheckChance) .complete(); @@ -131,7 +141,7 @@ protected void openComponents(BtiTableReader.Builder builder, SSTable.Owner owne } } - try (CompressionMetadata compressionMetadata = CompressionInfoComponent.maybeLoad(descriptor, components)) + try (CompressionMetadata compressionMetadata = CompressionInfoComponent.maybeLoad(descriptor, components, owner == null ? null : owner.compressionDictionaryManager())) { builder.setDataFile(dataFileBuilder(builder.getStatsMetadata()) .withCompressionMetadata(compressionMetadata) diff --git a/src/java/org/apache/cassandra/io/sstable/format/bti/BtiTableWriter.java b/src/java/org/apache/cassandra/io/sstable/format/bti/BtiTableWriter.java index 074c5c17085c..0179497f0988 100644 --- a/src/java/org/apache/cassandra/io/sstable/format/bti/BtiTableWriter.java +++ b/src/java/org/apache/cassandra/io/sstable/format/bti/BtiTableWriter.java @@ -334,7 +334,8 @@ protected SequentialWriter openDataWriter() getIOOptions().writerOptions, getMetadataCollector(), ensuringInBuildInternalContext(operationType), - getIOOptions().flushCompression); + getIOOptions().flushCompression, + getCompressionDictionaryManager()); } @Override diff --git a/src/java/org/apache/cassandra/net/MessagingService.java b/src/java/org/apache/cassandra/net/MessagingService.java index a636005beef5..129d5114dd2b 100644 --- a/src/java/org/apache/cassandra/net/MessagingService.java +++ b/src/java/org/apache/cassandra/net/MessagingService.java @@ -478,21 +478,28 @@ public void respond(V response, Message message) public Future sendWithResponse(InetAddressAndPort to, Message msg) { Promise future = AsyncPromise.uncancellable(); - MessagingService.instance().sendWithCallback(msg, to, - new RequestCallback() - { - @Override - public void onResponse(Message msg) - { - future.setSuccess(msg.payload); - } - - @Override - public void onFailure(InetAddressAndPort from, RequestFailure failure) - { - future.setFailure(new RuntimeException(failure.toString())); - } - }); + RequestCallback callback = new RequestCallback() + { + @Override + public void onResponse(Message msg) + { + future.setSuccess(msg.payload); + } + + @Override + public void onFailure(InetAddressAndPort from, RequestFailure failure) + { + future.setFailure(new RuntimeException(failure.toString())); + } + }; + try + { + MessagingService.instance().sendWithCallback(msg, to, callback); + } + catch (Throwable e) // catch any exception during sending the message and wrap it inside feture to have unified exception handling + { + future.setFailure(e); + } return future; } diff --git a/src/java/org/apache/cassandra/net/Verb.java b/src/java/org/apache/cassandra/net/Verb.java index d24c9e64adff..14b78405a51d 100644 --- a/src/java/org/apache/cassandra/net/Verb.java +++ b/src/java/org/apache/cassandra/net/Verb.java @@ -44,6 +44,8 @@ import org.apache.cassandra.db.TruncateRequest; import org.apache.cassandra.db.TruncateResponse; import org.apache.cassandra.db.TruncateVerbHandler; +import org.apache.cassandra.db.compression.CompressionDictionaryUpdateMessage; +import org.apache.cassandra.db.compression.CompressionDictionaryUpdateVerbHandler; import org.apache.cassandra.exceptions.RequestFailure; import org.apache.cassandra.gms.GossipDigestAck; import org.apache.cassandra.gms.GossipDigestAck2; @@ -372,6 +374,9 @@ public enum Verb ACCORD_FETCH_TOPOLOGY_RSP (169, P0, shortTimeout, FETCH_METADATA, () -> accordEmbedded(FetchTopologies.responseSerializer), RESPONSE_HANDLER), ACCORD_FETCH_TOPOLOGY_REQ (170, P0, shortTimeout, FETCH_METADATA, () -> accordEmbedded(FetchTopologies.serializer), () -> FetchTopologies.handler, ACCORD_FETCH_TOPOLOGY_RSP), + DICTIONARY_UPDATE_RSP (171, P1, rpcTimeout, REQUEST_RESPONSE, () -> NoPayload.serializer, RESPONSE_HANDLER ), + DICTIONARY_UPDATE_REQ (172, P1, rpcTimeout, MISC, () -> CompressionDictionaryUpdateMessage.serializer, () -> CompressionDictionaryUpdateVerbHandler.instance, DICTIONARY_UPDATE_RSP ), + // generic failure response FAILURE_RSP (99, P0, noTimeout, REQUEST_RESPONSE, () -> RequestFailure.serializer, RESPONSE_HANDLER ), @@ -679,4 +684,4 @@ class VerbTimeouts class ResponseHandlerSupplier { static final Supplier> RESPONSE_HANDLER = () -> ResponseVerbHandler.instance; -} \ No newline at end of file +} diff --git a/src/java/org/apache/cassandra/schema/CompressionParams.java b/src/java/org/apache/cassandra/schema/CompressionParams.java index 0e7c3da13ab0..fdf184e94d96 100644 --- a/src/java/org/apache/cassandra/schema/CompressionParams.java +++ b/src/java/org/apache/cassandra/schema/CompressionParams.java @@ -160,15 +160,31 @@ public static CompressionParams lz4(int chunkLength, int maxCompressedLength) return new CompressionParams(LZ4Compressor.create(Collections.emptyMap()), chunkLength, maxCompressedLength, calcMinCompressRatio(chunkLength, maxCompressedLength), Collections.emptyMap()); } + @VisibleForTesting public static CompressionParams zstd() { - return zstd(DEFAULT_CHUNK_LENGTH); + return zstd(DEFAULT_CHUNK_LENGTH, false); } + @VisibleForTesting public static CompressionParams zstd(Integer chunkLength) { - ZstdCompressor compressor = ZstdCompressor.create(Collections.emptyMap()); - return new CompressionParams(compressor, chunkLength, Integer.MAX_VALUE, DEFAULT_MIN_COMPRESS_RATIO, Collections.emptyMap()); + return zstd(chunkLength, false); + } + + @VisibleForTesting + public static CompressionParams zstd(Integer chunkLength, boolean useDictionary) + { + return zstd(chunkLength, useDictionary, Collections.emptyMap()); + } + + @VisibleForTesting + public static CompressionParams zstd(Integer chunkLength, boolean useDictionary, Map options) + { + ICompressor compressor = useDictionary + ? ZstdDictionaryCompressor.create(options) + : ZstdCompressor.create(options); + return new CompressionParams(compressor, chunkLength, Integer.MAX_VALUE, DEFAULT_MIN_COMPRESS_RATIO, options); } @VisibleForTesting @@ -223,6 +239,18 @@ public boolean isEnabled() return sstableCompressor != null; } + /** + * Checks if dictionary compression is enabled for this configuration. + * Dictionary compression is enabled when both compression is enabled and + * the compressor supports dictionary-based compression. + * + * @return {@code true} if dictionary compression is enabled, {@code false} otherwise. + */ + public boolean isDictionaryCompressionEnabled() + { + return isEnabled() && sstableCompressor instanceof IDictionaryCompressor; + } + /** * Returns the SSTable compressor. * @return the SSTable compressor or {@code null} if compression is disabled. diff --git a/src/java/org/apache/cassandra/schema/SystemDistributedKeyspace.java b/src/java/org/apache/cassandra/schema/SystemDistributedKeyspace.java index d50621a3a15c..f67715ee5e0f 100644 --- a/src/java/org/apache/cassandra/schema/SystemDistributedKeyspace.java +++ b/src/java/org/apache/cassandra/schema/SystemDistributedKeyspace.java @@ -30,6 +30,8 @@ import java.util.UUID; import java.util.concurrent.TimeUnit; +import javax.annotation.Nullable; + import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Joiner; import com.google.common.collect.ImmutableMap; @@ -49,6 +51,7 @@ import org.apache.cassandra.db.Keyspace; import org.apache.cassandra.dht.Range; import org.apache.cassandra.dht.Token; +import org.apache.cassandra.db.compression.CompressionDictionary; import org.apache.cassandra.locator.InetAddressAndPort; import org.apache.cassandra.repair.CommonRange; import org.apache.cassandra.repair.messages.RepairOption; @@ -56,7 +59,6 @@ import org.apache.cassandra.utils.TimeUUID; import static java.lang.String.format; - import static org.apache.cassandra.utils.ByteBufferUtil.bytes; public final class SystemDistributedKeyspace @@ -83,10 +85,11 @@ private SystemDistributedKeyspace() * gen 5: add ttl and TWCS to repair_history tables * gen 6: add denylist table * gen 7: add auto_repair_history and auto_repair_priority tables for AutoRepair feature + * gen 8: add compression_dictionaries for dictionary-based compression algorithms (e.g. zstd) * * // TODO: TCM - how do we evolve these tables? */ - public static final long GENERATION = 7; + public static final long GENERATION = 8; public static final String REPAIR_HISTORY = "repair_history"; @@ -100,7 +103,12 @@ private SystemDistributedKeyspace() public static final String AUTO_REPAIR_PRIORITY = "auto_repair_priority"; - public static final Set TABLE_NAMES = ImmutableSet.of(REPAIR_HISTORY, PARENT_REPAIR_HISTORY, VIEW_BUILD_STATUS, PARTITION_DENYLIST_TABLE, AUTO_REPAIR_HISTORY, AUTO_REPAIR_PRIORITY); + public static final String COMPRESSION_DICTIONARIES = "compression_dictionaries"; + + public static final Set TABLE_NAMES = ImmutableSet.of(REPAIR_HISTORY, PARENT_REPAIR_HISTORY, + VIEW_BUILD_STATUS, PARTITION_DENYLIST_TABLE, + AUTO_REPAIR_HISTORY, AUTO_REPAIR_PRIORITY, + COMPRESSION_DICTIONARIES); public static final String REPAIR_HISTORY_CQL = "CREATE TABLE IF NOT EXISTS %s (" + "keyspace_name text," @@ -185,6 +193,18 @@ private SystemDistributedKeyspace() private static final TableMetadata AutoRepairPriorityTable = parse(AUTO_REPAIR_PRIORITY, "Auto repair priority for each group", AUTO_REPAIR_PRIORITY_CQL).build(); + public static final String COMPRESSION_DICTIONARIES_CQL = "CREATE TABLE IF NOT EXISTS %s (" + + "keyspace_name text," + + "table_name text," + + "kind text," + + "dict_id bigint," + + "dict blob," + + "PRIMARY KEY ((keyspace_name, table_name), dict_id)) " + + "WITH CLUSTERING ORDER BY (dict_id DESC)"; // in order to retrieve the latest dictionary; the contract is the newer the dictionary the larger the dict_id + + private static final TableMetadata CompressionDictionariesTable = + parse(COMPRESSION_DICTIONARIES, "Compression dictionaries for applicable tables", COMPRESSION_DICTIONARIES_CQL).build(); + private static TableMetadata.Builder parse(String table, String description, String cql) { return CreateTableStatement.parse(format(cql, table), SchemaConstants.DISTRIBUTED_KEYSPACE_NAME) @@ -197,7 +217,10 @@ public static KeyspaceMetadata metadata() { return KeyspaceMetadata.create(SchemaConstants.DISTRIBUTED_KEYSPACE_NAME, KeyspaceParams.simple(Math.max(DEFAULT_RF, DatabaseDescriptor.getDefaultKeyspaceRF())), - Tables.of(RepairHistory, ParentRepairHistory, ViewBuildStatus, PartitionDenylistTable, AutoRepairHistoryTable, AutoRepairPriorityTable)); + Tables.of(RepairHistory, ParentRepairHistory, + ViewBuildStatus, PartitionDenylistTable, + AutoRepairHistoryTable, AutoRepairPriorityTable, + CompressionDictionariesTable)); } public static void startParentRepair(TimeUUID parent_id, String keyspaceName, String[] cfnames, RepairOption options) @@ -382,20 +405,97 @@ public static void setViewRemoved(String keyspaceName, String viewName) forceBlockingFlush(VIEW_BUILD_STATUS, ColumnFamilyStore.FlushReason.INTERNALLY_FORCED); } - private static void processSilent(String fmtQry, String... values) + /** + * Stores a compression dictionary for a given keyspace and table in the distributed system keyspace. + * + * @param keyspaceName the keyspace name to associate with the dictionary + * @param tableName the table name to associate with the dictionary + * @param dictionary the compression dictionary to store + */ + public static void storeCompressionDictionary(String keyspaceName, String tableName, CompressionDictionary dictionary) + { + String query = "INSERT INTO %s.%s (keyspace_name, table_name, kind, dict_id, dict) VALUES ('%s', '%s', '%s', %s, ?)"; + String fmtQuery = format(query, + SchemaConstants.DISTRIBUTED_KEYSPACE_NAME, + COMPRESSION_DICTIONARIES, + keyspaceName, + tableName, + dictionary.kind(), + dictionary.dictId().id); + noThrow(fmtQuery, + () -> QueryProcessor.process(fmtQuery, ConsistencyLevel.ONE, + Collections.singletonList(ByteBuffer.wrap(dictionary.rawDictionary())))); + } + + /** + * Retrieves the latest compression dictionary for a given keyspace and table. + * + * @param keyspaceName the keyspace name to retrieve the dictionary for + * @param tableName the table name to retrieve the dictionary for + * @return the latest compression dictionary for the specified keyspace and table, + * or null if no dictionary exists or if an error occurs during retrieval + */ + @Nullable + public static CompressionDictionary retrieveLatestCompressionDictionary(String keyspaceName, String tableName) + { + String query = "SELECT kind, dict_id, dict FROM %s.%s WHERE keyspace_name='%s' AND table_name='%s' LIMIT 1"; + String fmtQuery = format(query, SchemaConstants.DISTRIBUTED_KEYSPACE_NAME, COMPRESSION_DICTIONARIES, keyspaceName, tableName); + try + { + UntypedResultSet.Row row = QueryProcessor.execute(fmtQuery, ConsistencyLevel.ONE).one(); + return CompressionDictionary.createFromRow(row); + } + catch (Exception e) + { + return null; + } + } + + /** + * Retrieves a specific compression dictionary for a given keyspace and table. + * + * @param keyspaceName the keyspace name to retrieve the dictionary for + * @param tableName the table name to retrieve the dictionary for + * @param dictionaryId the dictionary id to retrieve the dictionary for + * @return the compression dictionary identified by the specified keyspace, table and dictionaryId, + * or null if no dictionary exists or if an error occurs during retrieval + */ + public static CompressionDictionary retrieveCompressionDictionary(String keyspaceName, String tableName, CompressionDictionary.DictId dictionaryId) { + String query = "SELECT kind, dict_id, dict FROM %s.%s WHERE keyspace_name='%s' AND table_name='%s' AND dict_id=%s"; + String fmtQuery = format(query, SchemaConstants.DISTRIBUTED_KEYSPACE_NAME, COMPRESSION_DICTIONARIES, keyspaceName, tableName, dictionaryId.id); try { + UntypedResultSet.Row row = QueryProcessor.execute(fmtQuery, ConsistencyLevel.ONE).one(); + return CompressionDictionary.createFromRow(row); + } + catch (Exception e) + { + return null; + } + } + + private static void processSilent(String fmtQry, String... values) + { + noThrow(fmtQry, () -> { List valueList = new ArrayList<>(values.length); for (String v : values) { valueList.add(bytes(v)); } QueryProcessor.process(fmtQry, ConsistencyLevel.ANY, valueList); + }); + } + + private static void noThrow(String fmtQry, Runnable queryExec) + { + try + { + queryExec.run(); } catch (Throwable t) { - logger.error("Error executing query "+fmtQry, t); + logger.error("Error executing query " + fmtQry, t); } } diff --git a/src/java/org/apache/cassandra/tools/NodeProbe.java b/src/java/org/apache/cassandra/tools/NodeProbe.java index b59655134aaa..67680391d076 100644 --- a/src/java/org/apache/cassandra/tools/NodeProbe.java +++ b/src/java/org/apache/cassandra/tools/NodeProbe.java @@ -43,6 +43,7 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import javax.annotation.Nullable; +import javax.management.InstanceNotFoundException; import javax.management.JMX; import javax.management.MBeanServerConnection; import javax.management.MalformedObjectNameException; @@ -88,10 +89,12 @@ import org.apache.cassandra.db.ColumnFamilyStoreMBean; import org.apache.cassandra.db.compaction.CompactionManager; import org.apache.cassandra.db.compaction.CompactionManagerMBean; -import org.apache.cassandra.db.virtual.CIDRFilteringMetricsTable; -import org.apache.cassandra.db.virtual.CIDRFilteringMetricsTableMBean; +import org.apache.cassandra.db.compression.CompressionDictionaryManagerMBean; +import org.apache.cassandra.db.compression.TrainingState; import org.apache.cassandra.db.guardrails.Guardrails; import org.apache.cassandra.db.guardrails.GuardrailsMBean; +import org.apache.cassandra.db.virtual.CIDRFilteringMetricsTable; +import org.apache.cassandra.db.virtual.CIDRFilteringMetricsTableMBean; import org.apache.cassandra.fql.FullQueryLoggerOptions; import org.apache.cassandra.fql.FullQueryLoggerOptionsCompositeData; import org.apache.cassandra.gms.FailureDetector; @@ -2682,6 +2685,67 @@ public void setMixedMajorVersionRepairEnabled(boolean enabled) { autoRepairProxy.setMixedMajorVersionRepairEnabled(enabled); } + + /** + * Triggers compression dictionary training for the specified table. + * Samples chunks from existing SSTables and trains a dictionary. + * + * @param keyspace the keyspace name + * @param table the table name + * @throws IOException if there's an error accessing the MBean + * @throws IllegalArgumentException if table doesn't support dictionary compression + */ + public void trainCompressionDictionary(String keyspace, String table) throws IOException + { + CompressionDictionaryManagerMBean proxy = getDictionaryManagerProxy(keyspace, table); + try + { + proxy.train(); + } + catch (Exception e) + { + if (e.getCause() instanceof InstanceNotFoundException) + { + String message = String.format("Table %s.%s does not exist or does not support dictionary compression", + keyspace, table); + throw new IOException(message); + } + else + { + throw new IOException(e.getMessage()); + } + } + } + + /** + * Gets the compression dictionary training state for the specified table. + * Returns an atomic snapshot of training status, progress, and failure details. + * + * @param keyspace the keyspace name + * @param table the table name + * @return the current training state + * @throws IOException if there's an error accessing the MBean + */ + public TrainingState getCompressionDictionaryTrainingState(String keyspace, String table) throws IOException + { + CompositeData compositeData = getDictionaryManagerProxy(keyspace, table).getTrainingState(); + return TrainingState.fromCompositeData(compositeData); + } + + private CompressionDictionaryManagerMBean getDictionaryManagerProxy(String keyspace, String table) throws IOException + { + // Construct table-specific MBean name + String mbeanName = CompressionDictionaryManagerMBean.MBEAN_NAME + ",keyspace=" + keyspace + ",table=" + table; + try + { + ObjectName objectName = new ObjectName(mbeanName); + return JMX.newMBeanProxy(mbeanServerConn, objectName, CompressionDictionaryManagerMBean.class); + } + catch (MalformedObjectNameException e) + { + throw new IOException("Invalid keyspace or table name", e); + } + } } class ColumnFamilyStoreMBeanIterator implements Iterator> diff --git a/src/java/org/apache/cassandra/tools/SSTableMetadataViewer.java b/src/java/org/apache/cassandra/tools/SSTableMetadataViewer.java index 256c80d26903..11df06158d7d 100644 --- a/src/java/org/apache/cassandra/tools/SSTableMetadataViewer.java +++ b/src/java/org/apache/cassandra/tools/SSTableMetadataViewer.java @@ -322,7 +322,7 @@ private void printSStableMetadata(File file, boolean scan) throws IOException CompactionMetadata compaction = statsComponent.compactionMetadata(); SerializationHeader.Component header = statsComponent.serializationHeader(); Class compressorClass = null; - try (CompressionMetadata compression = CompressionInfoComponent.loadIfExists(descriptor)) + try (CompressionMetadata compression = CompressionInfoComponent.loadIfExists(descriptor, null)) { compressorClass = compression != null ? compression.compressor().getClass() : null; } diff --git a/src/java/org/apache/cassandra/tools/nodetool/NodetoolCommand.java b/src/java/org/apache/cassandra/tools/nodetool/NodetoolCommand.java index c20c6936fc4d..3415c52311a7 100644 --- a/src/java/org/apache/cassandra/tools/nodetool/NodetoolCommand.java +++ b/src/java/org/apache/cassandra/tools/nodetool/NodetoolCommand.java @@ -207,6 +207,7 @@ TableStats.class, TopPartitions.class, TpStats.class, + TrainCompressionDictionary.class, TruncateHints.class, UpdateCIDRGroup.class, UpgradeSSTable.class, diff --git a/src/java/org/apache/cassandra/tools/nodetool/TrainCompressionDictionary.java b/src/java/org/apache/cassandra/tools/nodetool/TrainCompressionDictionary.java new file mode 100644 index 000000000000..e8d358677a85 --- /dev/null +++ b/src/java/org/apache/cassandra/tools/nodetool/TrainCompressionDictionary.java @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.tools.nodetool; + +import java.io.PrintStream; +import java.util.concurrent.TimeUnit; + +import com.google.common.util.concurrent.Uninterruptibles; + +import org.apache.cassandra.db.compression.ICompressionDictionaryTrainer.TrainingStatus; +import org.apache.cassandra.db.compression.TrainingState; +import org.apache.cassandra.tools.NodeProbe; +import org.apache.cassandra.utils.Clock; +import picocli.CommandLine.Command; +import picocli.CommandLine.Parameters; + +@Command(name = "traincompressiondictionary", +description = "Manually trigger compression dictionary training for a table. If no SSTables are available, the memtable will be flushed first.") +public class TrainCompressionDictionary extends AbstractCommand +{ + @Parameters(index = "0", description = "The keyspace name", arity = "1") + private String keyspace; + + @Parameters(index = "1", description = "The table name", arity = "1") + private String table; + + @Override + public void execute(NodeProbe probe) + { + PrintStream out = probe.output().out; + PrintStream err = probe.output().err; + + try + { + out.printf("Starting compression dictionary training for %s.%s...%n", keyspace, table); + out.printf("Training from existing SSTables (flushing first if needed)%n"); + + probe.trainCompressionDictionary(keyspace, table); + + // Wait for training completion (10 minutes timeout for SSTable-based training) + out.println("Sampling from existing SSTables and training."); + long maxWaitMillis = TimeUnit.MINUTES.toMillis(10); + long startTime = Clock.Global.currentTimeMillis(); + + while (Clock.Global.currentTimeMillis() - startTime < maxWaitMillis) + { + TrainingState trainingState = probe.getCompressionDictionaryTrainingState(keyspace, table); + TrainingStatus status = trainingState.getStatus(); + displayProgress(trainingState, startTime, out, status); + if (TrainingStatus.COMPLETED == status) + { + out.printf("%nTraining completed successfully for %s.%s%n", keyspace, table); + return; + } + else if (TrainingStatus.FAILED == status) + { + err.printf("%nTraining failed for %s.%s%n", keyspace, table); + try + { + String failureMessage = trainingState.getFailureMessage(); + if (failureMessage != null && !failureMessage.isEmpty()) + { + err.printf("Reason: %s%n", failureMessage); + } + } + catch (Exception e) + { + // If we can't get the failure message, just continue without it + } + System.exit(1); + } + + Uninterruptibles.sleepUninterruptibly(1, TimeUnit.SECONDS); + } + + err.printf("%nTraining did not complete within expected timeframe (10 minutes).%n"); + System.exit(1); + } + catch (Exception e) + { + err.printf("Failed to trigger training: %s%n", e.getMessage()); + System.exit(1); + } + } + + private static void displayProgress(TrainingState trainingState, long startTime, PrintStream out, TrainingStatus status) + { + // Display meaningful statistics + long sampleCount = trainingState.getSampleCount(); + long totalSampleSize = trainingState.getTotalSampleSize(); + long elapsedSeconds = (Clock.Global.currentTimeMillis() - startTime) / 1000; + double sampleSizeMB = totalSampleSize / (1024.0 * 1024.0); + + out.printf("\rStatus: %s | Samples: %d | Size: %.2f MiB | Elapsed: %ds", + status, sampleCount, sampleSizeMB, elapsedSeconds); + } +} diff --git a/src/java/org/apache/cassandra/utils/StorageCompatibilityMode.java b/src/java/org/apache/cassandra/utils/StorageCompatibilityMode.java index 2969597c2381..ac0e7121bfdd 100644 --- a/src/java/org/apache/cassandra/utils/StorageCompatibilityMode.java +++ b/src/java/org/apache/cassandra/utils/StorageCompatibilityMode.java @@ -35,6 +35,13 @@ public enum StorageCompatibilityMode */ CASSANDRA_4(4), + /** + * Similar to CASSANDRA_4. + * The new features in 6.0 are + * - ZSTD dictionary-based compression. Once SSTables are compressed with dictionary, they cannot be rolled back. + */ + CASSANDRA_5(5), + /** * Use the storage formats of the current version, but disabling features that are not compatible with any * not-upgraded nodes in the cluster. Use this during rolling upgrades to a new major Cassandra version. Once all diff --git a/test/microbench/org/apache/cassandra/test/microbench/ZstdDictionaryCompressorBenchBase.java b/test/microbench/org/apache/cassandra/test/microbench/ZstdDictionaryCompressorBenchBase.java new file mode 100644 index 000000000000..12036d6d77dc --- /dev/null +++ b/test/microbench/org/apache/cassandra/test/microbench/ZstdDictionaryCompressorBenchBase.java @@ -0,0 +1,231 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.test.microbench; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Map; +import java.util.Random; +import java.util.UUID; + +import com.github.luben.zstd.ZstdDictTrainer; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; + +import org.apache.cassandra.config.DatabaseDescriptor; +import org.apache.cassandra.db.compression.CompressionDictionary.DictId; +import org.apache.cassandra.db.compression.CompressionDictionary.Kind; +import org.apache.cassandra.db.compression.ZstdCompressionDictionary; +import org.apache.cassandra.io.compress.ZstdDictionaryCompressor; + +// The bench takes over 20 minutes to finish +@State(Scope.Benchmark) +public abstract class ZstdDictionaryCompressorBenchBase +{ + @Param({"4096", "16384", "65536"}) + protected int dataSize; + + @Param({"CASSANDRA_LIKE", "COMPRESSIBLE", "MIXED"}) + protected DataType dataType; + + @Param({"0", "65536"}) + protected int dictionarySize; + + @Param({"3", "5", "7"}) + protected int compressionLevel; + + protected byte[] inputData; + protected ByteBuffer inputBuffer; + protected ByteBuffer compressedBuffer; + protected ByteBuffer decompressedBuffer; + protected ZstdDictionaryCompressor compressor; + protected ZstdDictionaryCompressor noDictCompressor; + protected ZstdCompressionDictionary dictionary; + + public enum DataType + { + CASSANDRA_LIKE, COMPRESSIBLE, MIXED + } + + @Setup(Level.Trial) + public void setupTrial() + { + DatabaseDescriptor.daemonInitialization(); + } + + @Setup(Level.Iteration) + public void setupIteration() throws IOException + { + Random random = new Random(42); + + // Generate test data based on type + inputData = generateTestData(dataType, dataSize, random); + + // Create direct ByteBuffers (required by ZSTD) + inputBuffer = ByteBuffer.allocateDirect(dataSize); + inputBuffer.put(inputData); + inputBuffer.flip(); + + // Allocate buffers with extra space for compression overhead + int maxCompressedSize = dataSize + 1024; + compressedBuffer = ByteBuffer.allocateDirect(maxCompressedSize); + decompressedBuffer = ByteBuffer.allocateDirect(dataSize); + + // Create dictionary if needed + if (dictionarySize != 0) + { + dictionary = createDictionary(dataType, dictionarySize, random); + Map options = Map.of("compression_level", String.valueOf(compressionLevel)); + compressor = ZstdDictionaryCompressor.create(options).getOrCopyWithDictionary(dictionary); + } + else + { + Map options = Map.of("compression_level", String.valueOf(compressionLevel)); + compressor = ZstdDictionaryCompressor.create(options); + } + + // Always create a no-dictionary compressor for comparison + Map options = Map.of("compression_level", String.valueOf(compressionLevel)); + noDictCompressor = ZstdDictionaryCompressor.create(options); + } + + @TearDown(Level.Iteration) + public void tearDown() + { + if (dictionary != null) + { + dictionary.close(); + dictionary = null; + } + ZstdDictionaryCompressor.invalidateCache(); + } + + protected byte[] generateTestData(DataType type, int size, Random random) + { + byte[] data = new byte[size]; + + switch (type) + { + case CASSANDRA_LIKE: + generateCassandraLikeData(data, random); + break; + + case COMPRESSIBLE: + generateCompressibleData(data, random); + break; + + case MIXED: + generateMixedData(data, random); + break; + } + + return data; + } + + private void generateCassandraLikeData(byte[] data, Random random) + { + StringBuilder sb = new StringBuilder(); + String[] patterns = { + "user_id_", "timestamp_", "session_", "event_type_", + "metadata_", "value_", "status_", "location_" + }; + + while (sb.length() < data.length) + { + String pattern = patterns[random.nextInt(patterns.length)]; + sb.append(pattern).append(UUID.randomUUID().toString()).append("|"); + sb.append("timestamp:").append(System.currentTimeMillis() + random.nextInt(86400000)).append("|"); + sb.append("value:").append(random.nextDouble()).append("|"); + sb.append("count:").append(random.nextInt(1000)).append("\n"); + } + + byte[] generated = sb.substring(0, Math.min(data.length, sb.length())).getBytes(); + System.arraycopy(generated, 0, data, 0, generated.length); + + // Fill remaining space with random data if needed + if (generated.length < data.length) + { + byte[] remaining = new byte[data.length - generated.length]; + random.nextBytes(remaining); + System.arraycopy(remaining, 0, data, generated.length, remaining.length); + } + } + + private void generateCompressibleData(byte[] data, Random random) + { + String pattern = "The quick brown fox jumps over the lazy dog. This is a highly compressible pattern that repeats. "; + byte[] patternBytes = pattern.getBytes(); + + for (int i = 0; i < data.length; i++) + { + data[i] = patternBytes[i % patternBytes.length]; + } + + // Add some randomness (10%) + for (int i = 0; i < data.length / 10; i++) + { + data[random.nextInt(data.length)] = (byte) random.nextInt(256); + } + } + + private void generateMixedData(byte[] data, Random random) + { + int quarter = data.length / 4; + + // 25% random + random.nextBytes(data); + + // 25% compressible + byte[] compressible = new byte[quarter]; + generateCompressibleData(compressible, random); + System.arraycopy(compressible, 0, data, quarter, quarter); + + // 50% Cassandra-like + byte[] cassandraLike = new byte[data.length - 2 * quarter]; + generateCassandraLikeData(cassandraLike, random); + System.arraycopy(cassandraLike, 0, data, 2 * quarter, cassandraLike.length); + } + + private ZstdCompressionDictionary createDictionary(DataType dataType, int dictSize, Random random) + { + // Generate training samples + byte[][] samples = new byte[100][]; + int totalSampleSize = 0; + for (int i = 0; i < samples.length; i++) + { + samples[i] = generateTestData(dataType, Math.min(1024, dataSize), random); + totalSampleSize += samples[i].length; + } + + // Train dictionary + ZstdDictTrainer trainer = new ZstdDictTrainer(totalSampleSize, dictSize); + for (byte[] sample : samples) + { + trainer.addSample(sample); + } + + byte[] dictData = trainer.trainSamples(); + DictId dictId = new DictId(Kind.ZSTD, 0); + return new ZstdCompressionDictionary(dictId, dictData); + } +} diff --git a/test/microbench/org/apache/cassandra/test/microbench/ZstdDictionaryCompressorRatioBench.java b/test/microbench/org/apache/cassandra/test/microbench/ZstdDictionaryCompressorRatioBench.java new file mode 100644 index 000000000000..95e208369292 --- /dev/null +++ b/test/microbench/org/apache/cassandra/test/microbench/ZstdDictionaryCompressorRatioBench.java @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.test.microbench; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import org.apache.cassandra.config.DatabaseDescriptor; + +// This is not really a bench, but share common utilties from the base class. +// Running the class output compression ratio and dictionary effectiveness of different configurations +public class ZstdDictionaryCompressorRatioBench extends ZstdDictionaryCompressorBenchBase +{ + private static class CompressionResult + { + final String configuration; + final double compressionRatio; + final double dictionaryEffectiveness; + + CompressionResult(String configuration, double compressionRatio, double dictionaryEffectiveness) + { + this.configuration = configuration; + this.compressionRatio = compressionRatio; + this.dictionaryEffectiveness = dictionaryEffectiveness; + } + } + + private CompressionResult measureCompressionRatio() throws IOException + { + // Compress with current compressor + inputBuffer.rewind(); + compressedBuffer.clear(); + compressor.compress(inputBuffer, compressedBuffer); + int compressedSize = compressedBuffer.position(); + + // Compress without dictionary for comparison + inputBuffer.rewind(); + compressedBuffer.clear(); + noDictCompressor.compress(inputBuffer, compressedBuffer); + int noDictCompressedSize = compressedBuffer.position(); + + // Calculate ratios + double compressionRatio = (double) inputBuffer.limit() / compressedSize; + double dictionaryEffectiveness = (double) noDictCompressedSize / compressedSize; + + // Create configuration string + String config = String.format("%s_%s_L%d_Chunk%dKiB", + dataType, + dictionarySize == 0 ? "NoDict" : "WithDict", + compressionLevel, + dataSize / 1024); + + return new CompressionResult(config, compressionRatio, dictionaryEffectiveness); + } + + public static void main(String[] args) throws Exception + { + DatabaseDescriptor.daemonInitialization(); + + List allResults = new ArrayList<>(); + + // Define test parameters + DataType[] dataTypes = {DataType.CASSANDRA_LIKE, DataType.COMPRESSIBLE, DataType.MIXED}; + int[] dictionarySizes = {0, 65536}; + int[] compressionLevels = {3, 5, 7}; + int[] dataSizes = {4096, 16384, 65536}; + + System.out.println("Running ZSTD Dictionary Compressor Ratio Measurements..."); + System.out.println("Total configurations: " + (dataTypes.length * dictionarySizes.length * compressionLevels.length * dataSizes.length)); + + int configCount = 0; + for (DataType dataType : dataTypes) + { + for (int dictionarySize : dictionarySizes) + { + for (int compressionLevel : compressionLevels) + { + for (int dataSize : dataSizes) + { + configCount++; + ZstdDictionaryCompressorRatioBench bench = new ZstdDictionaryCompressorRatioBench(); + bench.dataType = dataType; + bench.dictionarySize = dictionarySize; + bench.compressionLevel = compressionLevel; + bench.dataSize = dataSize; + + try + { + bench.setupIteration(); + CompressionResult result = bench.measureCompressionRatio(); + allResults.add(result); + bench.tearDown(); + } + catch (Exception e) + { + System.err.println("Failed to process configuration: " + e.getMessage()); + e.printStackTrace(); + } + } + } + } + } + + // Print consolidated results + printConsolidatedResults(allResults); + } + + private static void printConsolidatedResults(List results) + { + StringBuilder report = new StringBuilder(); + report.append("\n").append("=".repeat(100)).append("\n"); + report.append("ZSTD DICTIONARY COMPRESSOR RATIO RESULTS").append("\n"); + report.append("=".repeat(100)).append("\n"); + report.append(String.format("%-50s %-20s %-20s%n", "Configuration", "Compression Ratio", "Dictionary Effectiveness")); + report.append("-".repeat(100)).append("\n"); + + for (CompressionResult entry : results) + { + report.append(String.format("%-50s %-20.3f %-20.3f%n", + entry.configuration, entry.compressionRatio, entry.dictionaryEffectiveness)); + } + + report.append("=".repeat(100)).append("\n"); + report.append("Compression Ratio: Original Size / Compressed Size (higher is better)").append("\n"); + report.append("Dictionary Effectiveness: Non-Dict Size / Dict Size (higher is better)").append("\n"); + report.append("=".repeat(100)).append("\n"); + + System.out.print(report.toString()); + } +} diff --git a/test/microbench/org/apache/cassandra/test/microbench/ZstdDictionaryCompressorThroughputBench.java b/test/microbench/org/apache/cassandra/test/microbench/ZstdDictionaryCompressorThroughputBench.java new file mode 100644 index 000000000000..601ad53ebd56 --- /dev/null +++ b/test/microbench/org/apache/cassandra/test/microbench/ZstdDictionaryCompressorThroughputBench.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.test.microbench; + +import java.io.IOException; +import java.util.concurrent.TimeUnit; + +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; + +@BenchmarkMode({Mode.Throughput, Mode.AverageTime}) +@OutputTimeUnit(TimeUnit.MILLISECONDS) +@Warmup(iterations = 1, time = 1, timeUnit = TimeUnit.SECONDS) +@Measurement(iterations = 2, time = 2, timeUnit = TimeUnit.SECONDS) +@Fork(value = 1, jvmArgsAppend = "-Xmx1G") +public class ZstdDictionaryCompressorThroughputBench extends ZstdDictionaryCompressorBenchBase +{ + @Benchmark + public void compressionThroughput(Blackhole bh) throws IOException + { + inputBuffer.rewind(); + compressedBuffer.clear(); + + compressor.compress(inputBuffer, compressedBuffer); + bh.consume(compressedBuffer.position()); + } + + @Benchmark + public void decompressionThroughput(Blackhole bh) throws IOException + { + // First compress the data + inputBuffer.rewind(); + compressedBuffer.clear(); + compressor.compress(inputBuffer, compressedBuffer); + + // Then decompress it + compressedBuffer.flip(); + decompressedBuffer.clear(); + compressor.uncompress(compressedBuffer, decompressedBuffer); + + bh.consume(decompressedBuffer.position()); + } +} diff --git a/test/resources/nodetool/help/nodetool b/test/resources/nodetool/help/nodetool index 043008305617..e636f15e2d1d 100644 --- a/test/resources/nodetool/help/nodetool +++ b/test/resources/nodetool/help/nodetool @@ -155,6 +155,7 @@ The most commonly used nodetool commands are: tablestats Print statistics on tables toppartitions Sample and print the most active partitions tpstats Print usage statistics of thread pools + traincompressiondictionary Manually trigger compression dictionary training for a table. If no SSTables are available, the memtable will be flushed first. truncatehints Truncate all hints on the local node, or truncate hints for the endpoint(s) specified. updatecidrgroup Insert/Update a cidr group upgradesstables Rewrite sstables (for the requested tables) that are not on the current version (thus upgrading them to said current version) diff --git a/test/resources/nodetool/help/traincompressiondictionary b/test/resources/nodetool/help/traincompressiondictionary new file mode 100644 index 000000000000..5c715f81b2f3 --- /dev/null +++ b/test/resources/nodetool/help/traincompressiondictionary @@ -0,0 +1,41 @@ +NAME + nodetool traincompressiondictionary - Manually trigger compression + dictionary training for a table. If no SSTables are available, the + memtable will be flushed first. + +SYNOPSIS + nodetool [(-h | --host )] [(-p | --port )] + [(-pp | --print-port)] [(-pw | --password )] + [(-pwf | --password-file )] + [(-u | --username )] traincompressiondictionary + [--]

+ +OPTIONS + -h , --host + Node hostname or ip address + + -p , --port + Remote jmx agent port number + + -pp, --print-port + Operate in 4.0 mode with hosts disambiguated by port number + + -pw , --password + Remote jmx agent password + + -pwf , --password-file + Path to the JMX password file + + -u , --username + Remote jmx agent username + + -- + This option can be used to separate command-line options from the + list of argument, (useful when arguments might be mistaken for + command-line options + + + The keyspace name + +
+ The table name diff --git a/test/unit/org/apache/cassandra/db/compression/CompressionDictionaryCacheTest.java b/test/unit/org/apache/cassandra/db/compression/CompressionDictionaryCacheTest.java new file mode 100644 index 000000000000..878780982d62 --- /dev/null +++ b/test/unit/org/apache/cassandra/db/compression/CompressionDictionaryCacheTest.java @@ -0,0 +1,438 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.db.compression; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; + +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +import com.github.luben.zstd.ZstdDictTrainer; +import org.apache.cassandra.config.DatabaseDescriptor; + +import static org.apache.cassandra.db.compression.CompressionDictionary.DictId; +import static org.apache.cassandra.db.compression.CompressionDictionary.Kind; +import static org.assertj.core.api.Assertions.assertThat; + +public class CompressionDictionaryCacheTest +{ + private static final String TEST_PATTERN = "The quick brown fox jumps over the lazy dog. "; + + private CompressionDictionaryCache cache; + private ZstdCompressionDictionary testDict1; + private ZstdCompressionDictionary testDict2; + private ZstdCompressionDictionary testDict3; + + @BeforeClass + public static void setUpClass() + { + DatabaseDescriptor.daemonInitialization(); + } + + @Before + public void setUp() + { + cache = new CompressionDictionaryCache(); + testDict1 = createTestDictionary(1); + testDict2 = createTestDictionary(2); + testDict3 = createTestDictionary(3); + } + + @After + public void tearDown() + { + if (cache != null) + { + cache.close(); + } + + // Close dictionaries if not already closed + closeQuietly(testDict1); + closeQuietly(testDict2); + closeQuietly(testDict3); + } + + // Basic cache operations tests + + @Test + public void testGetCurrentInitiallyNull() + { + assertThat(cache.getCurrent()) + .as("Current dictionary should be null initially") + .isNull(); + } + + @Test + public void testAddAndGet() + { + cache.add(testDict1); + + CompressionDictionary retrieved = cache.get(testDict1.dictId()); + assertThat(retrieved) + .as("Should retrieve the same dictionary instance") + .isSameAs(testDict1); + } + + @Test + public void testGetNonExistentDictionary() + { + DictId nonExistentId = new DictId(Kind.ZSTD, 999); + assertThat(cache.get(nonExistentId)) + .as("Should return null for non-existent dictionary") + .isNull(); + } + + @Test + public void testAddMultipleDictionaries() + { + cache.add(testDict1); + cache.add(testDict2); + cache.add(testDict3); + + assertThat(cache.get(testDict1.dictId())).isSameAs(testDict1); + assertThat(cache.get(testDict2.dictId())).isSameAs(testDict2); + assertThat(cache.get(testDict3.dictId())).isSameAs(testDict3); + } + + @Test + public void testSetCurrentWithNewerDictionary() + { + cache.add(testDict1); + cache.add(testDict2); + + assertThat(cache.getCurrent()) + .as("Should update to newer dictionary") + .isSameAs(testDict2); + + // Both should be in cache + assertThat(cache.get(testDict1.dictId())).isSameAs(testDict1); + assertThat(cache.get(testDict2.dictId())).isSameAs(testDict2); + } + + @Test + public void testSetCurrentWithOlderDictionary() + { + cache.add(testDict2); + cache.add(testDict1); // older dictionary + + assertThat(cache.getCurrent()) + .as("Should keep newer dictionary as current") + .isSameAs(testDict2); + + // Both should be in cache + assertThat(cache.get(testDict1.dictId())).isSameAs(testDict1); + assertThat(cache.get(testDict2.dictId())).isSameAs(testDict2); + } + + @Test + public void testSetCurrentWithSameIdDictionary() + { + ZstdCompressionDictionary sameDictCopy = createTestDictionary(2); + + cache.add(testDict2); + cache.add(sameDictCopy); + + // Should not update since ID is the same (not newer) + assertThat(cache.getCurrent()) + .as("Should keep original dictionary as current") + .isSameAs(testDict2); + + sameDictCopy.close(); + } + + @Test + public void testSetCurrentWithNull() + { + cache.add(testDict1); + cache.add(null); + + // Should not change current dictionary + assertThat(cache.getCurrent()) + .as("Should keep existing dictionary as current") + .isSameAs(testDict1); + } + + @Test + public void testCacheClose() + { + cache.add(testDict1); + cache.add(testDict2); + + assertThat(cache.getCurrent()) + .as("Current should not be null before close") + .isNotNull(); + assertThat(cache.get(testDict1.dictId())) + .as("Cache should contain dict1 before close") + .isNotNull(); + + cache.close(); + + assertThat(cache.getCurrent()) + .as("Current should be null after close") + .isNull(); + assertThat(cache.get(testDict1.dictId())) + .as("Cache should not contain dict1 after close") + .isNull(); + assertThat(cache.get(testDict2.dictId())) + .as("Cache should not contain dict2 after close") + .isNull(); + } + + @Test + public void testCloseIdempotent() + { + cache.add(testDict1); + + // Close multiple times should not cause issues + cache.close(); + cache.close(); + cache.close(); + + assertThat(cache.getCurrent()) + .as("Current should remain null") + .isNull(); + assertThat(cache.get(testDict1.dictId())) + .as("Cache should remain empty") + .isNull(); + } + + @Test + public void testConcurrentAccess() throws InterruptedException + { + int threadCount = 10; + int operationsPerThread = 100; + ExecutorService executor = Executors.newFixedThreadPool(threadCount); + CountDownLatch startLatch = new CountDownLatch(1); + CountDownLatch doneLatch = new CountDownLatch(threadCount); + AtomicInteger successCount = new AtomicInteger(0); + AtomicReference errorRef = new AtomicReference<>(); + + // Pre-populate cache + cache.add(testDict1); + cache.add(testDict2); + + for (int i = 0; i < threadCount; i++) + { + executor.submit(() -> { + try + { + startLatch.await(); + + for (int j = 0; j < operationsPerThread; j++) + { + // Mix of read operations + CompressionDictionary current = cache.getCurrent(); + cache.get(testDict1.dictId()); + cache.get(testDict2.dictId()); + + // Verify consistency + if (current != null && current.dictId().equals(testDict2.dictId())) + { + successCount.incrementAndGet(); + } + } + } + catch (Exception e) + { + errorRef.set(e); + } + finally + { + doneLatch.countDown(); + } + }); + } + + startLatch.countDown(); // Start all threads + assertThat(doneLatch.await(10, TimeUnit.SECONDS)) + .as("Threads should complete within timeout") + .isTrue(); + + executor.shutdown(); + + assertThat(errorRef.get()) + .as("No errors should occur during concurrent access") + .isNull(); + assertThat(successCount.get()) + .as("Should have successful read operations") + .isGreaterThan(0); + } + + @Test + public void testConcurrentSetCurrent() throws InterruptedException + { + int threadCount = 5; + ExecutorService executor = Executors.newFixedThreadPool(threadCount); + CountDownLatch startLatch = new CountDownLatch(1); + CountDownLatch doneLatch = new CountDownLatch(threadCount); + + // Create multiple dictionaries with different IDs + ZstdCompressionDictionary[] dicts = new ZstdCompressionDictionary[threadCount]; + for (int i = 0; i < threadCount; i++) + { + dicts[i] = createTestDictionary(100 + i); // High IDs to ensure newer + } + + for (int i = 0; i < threadCount; i++) + { + ZstdCompressionDictionary dict = dicts[i]; + executor.submit(() -> { + try + { + startLatch.await(); + cache.add(dict); + } + catch (Exception e) + { + // Ignore - testing thread safety + } + finally + { + doneLatch.countDown(); + } + }); + } + + startLatch.countDown(); + assertThat(doneLatch.await(5, TimeUnit.SECONDS)) + .as("Threads should complete within timeout") + .isTrue(); + + executor.shutdown(); + + // Verify that a current dictionary was set and it's one of our test dictionaries + CompressionDictionary current = cache.getCurrent(); + assertThat(current) + .as("A current dictionary should be set") + .isNotNull(); + assertThat(current.dictId().id) + .as("Current dictionary should be one of the test dictionaries") + .isBetween(100L, 100L + threadCount); + + // Clean up + for (ZstdCompressionDictionary dict : dicts) + { + closeQuietly(dict); + } + } + + @Test + public void testGetCurrentRefreshesCacheEntry() throws InterruptedException + { + int expireSeconds = 1; + try (CompressionDictionaryCache shortLivedCache = new CompressionDictionaryCache(10, expireSeconds)) + { + shortLivedCache.add(testDict1); + assertThat(shortLivedCache.getCurrent()) + .as("Current dictionary should be set") + .isSameAs(testDict1); + assertThat(shortLivedCache.get(testDict1.dictId())) + .as("Dictionary should be in cache") + .isNotNull(); + + // Access getCurrent() repeatedly for slightly longer than the expiration time + int iterations = expireSeconds * 2 + 1; + for (int i = 0; i < iterations; i++) + { + Thread.sleep(900); // Sleep a bit less than expireSeconds between accesses + + CompressionDictionary current = shortLivedCache.getCurrent(); + assertThat(current) + .as("Current dictionary should remain accessible after %d seconds", i + 1) + .isSameAs(testDict1); + + assertThat(shortLivedCache.get(testDict1.dictId())) + .as("Dictionary should still be in cache after %d seconds", i + 1) + .isNotNull(); + } + + assertThat(shortLivedCache.getCurrent()) + .as("Current dictionary should remain accessible due to getCurrent() refreshing the cache") + .isSameAs(testDict1); + } + } + + @Test + public void testGetCurrentReturnsNullAfterExpiration() throws InterruptedException + { + int expireSeconds = 1; + try (CompressionDictionaryCache shortLivedCache = new CompressionDictionaryCache(10, expireSeconds)) + { + shortLivedCache.add(testDict1); + assertThat(shortLivedCache.getCurrent()) + .as("Current dictionary should be set") + .isSameAs(testDict1); + + // Wait for expiration without accessing getCurrent() + Thread.sleep((expireSeconds + 1) * 1000); + + // Current should now return null since the entry expired and we only store the DictId + assertThat(shortLivedCache.getCurrent()) + .as("Current dictionary should be null after expiration when not accessed") + .isNull(); + } + } + + private static ZstdCompressionDictionary createTestDictionary(long id) + { + try + { + // Create simple dictionary + ZstdDictTrainer trainer = new ZstdDictTrainer(10 * 1024, 1024, 3); + + // Add samples + byte[] sample = TEST_PATTERN.getBytes(); + for (int i = 0; i < 100; i++) + { + trainer.addSample(sample); + } + + byte[] dictBytes = trainer.trainSamples(); + DictId dictId = new DictId(Kind.ZSTD, id); + + return new ZstdCompressionDictionary(dictId, dictBytes); + } + catch (Exception e) + { + throw new RuntimeException("Failed to create test dictionary", e); + } + } + + private static void closeQuietly(AutoCloseable resource) + { + if (resource != null) + { + try + { + resource.close(); + } + catch (Exception e) + { + // Ignore + } + } + } +} diff --git a/test/unit/org/apache/cassandra/db/compression/CompressionDictionaryEventHandlerTest.java b/test/unit/org/apache/cassandra/db/compression/CompressionDictionaryEventHandlerTest.java new file mode 100644 index 000000000000..143f18cab6bb --- /dev/null +++ b/test/unit/org/apache/cassandra/db/compression/CompressionDictionaryEventHandlerTest.java @@ -0,0 +1,237 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.db.compression; + +import java.util.Collections; +import java.util.Map; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +import org.apache.cassandra.SchemaLoader; +import org.apache.cassandra.ServerTestUtils; +import org.apache.cassandra.db.ColumnFamilyStore; +import org.apache.cassandra.db.DecoratedKey; +import org.apache.cassandra.db.Keyspace; +import org.apache.cassandra.db.compression.CompressionDictionary.DictId; +import org.apache.cassandra.db.compression.CompressionDictionary.Kind; +import org.apache.cassandra.locator.InetAddressAndPort; +import org.apache.cassandra.net.MessagingService; +import org.apache.cassandra.net.Verb; +import org.apache.cassandra.schema.CompressionParams; +import org.apache.cassandra.schema.KeyspaceParams; +import org.apache.cassandra.schema.Schema; +import org.apache.cassandra.schema.TableId; +import org.apache.cassandra.schema.TableMetadata; +import org.apache.cassandra.tcm.membership.NodeAddresses; +import org.apache.cassandra.tcm.membership.NodeId; +import org.apache.cassandra.tcm.transformations.Register; +import org.apache.cassandra.tcm.transformations.UnsafeJoin; +import org.apache.cassandra.utils.ByteBufferUtil; +import org.apache.cassandra.utils.FBUtilities; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatNoException; + +public class CompressionDictionaryEventHandlerTest +{ + private static final String TEST_NAME = "compression_dict_event_handler_test_"; + private static final String KEYSPACE = TEST_NAME + "keyspace"; + private static final String TABLE = "test_table"; + private static final DictId TEST_DICTIONARY_ID = new DictId(Kind.ZSTD, 12345L); + + private static TableMetadata tableMetadata; + private static ColumnFamilyStore cfs; + + private CompressionDictionaryEventHandler eventHandler; + private ZstdCompressionDictionary testDictionary; + + @BeforeClass + public static void setUpClass() throws Exception + { + ServerTestUtils.prepareServerNoRegister(); + + // Create a table with dictionary compression enabled + CompressionParams compressionParams = CompressionParams.zstd(CompressionParams.DEFAULT_CHUNK_LENGTH, true, + Map.of("compression_level", "3")); + + TableMetadata.Builder tableBuilder = TableMetadata.builder(KEYSPACE, TABLE) + .addPartitionKeyColumn("pk", org.apache.cassandra.db.marshal.UTF8Type.instance) + .addRegularColumn("data", org.apache.cassandra.db.marshal.UTF8Type.instance) + .compression(compressionParams); + + SchemaLoader.createKeyspace(KEYSPACE, + KeyspaceParams.simple(1), + tableBuilder); + + tableMetadata = Schema.instance.getTableMetadata(KEYSPACE, TABLE); + cfs = Keyspace.open(KEYSPACE).getColumnFamilyStore(TABLE); + + // Register some nodes for cluster testing + InetAddressAndPort ep1 = InetAddressAndPort.getByName("127.0.0.2:9042"); + InetAddressAndPort ep2 = InetAddressAndPort.getByName("127.0.0.3:9042"); + InetAddressAndPort ep3 = FBUtilities.getBroadcastAddressAndPort(); + + NodeId node1 = Register.register(new NodeAddresses(UUID.randomUUID(), ep1, ep1, ep1)); + NodeId node2 = Register.register(new NodeAddresses(UUID.randomUUID(), ep2, ep2, ep2)); + NodeId node3 = Register.register(new NodeAddresses(UUID.randomUUID(), ep3, ep3, ep3)); + + // Simple token distribution for testing + UnsafeJoin.unsafeJoin(node1, Collections.singleton(key(tableMetadata, 1).getToken())); + UnsafeJoin.unsafeJoin(node2, Collections.singleton(key(tableMetadata, 2).getToken())); + UnsafeJoin.unsafeJoin(node3, Collections.singleton(key(tableMetadata, 3).getToken())); + } + + @Before + public void setUp() + { + MessagingService.instance().inboundSink.clear(); + MessagingService.instance().outboundSink.clear(); + testDictionary = createTestDictionary(); + eventHandler = new CompressionDictionaryEventHandler(cfs, new CompressionDictionaryCache()); + } + + @After + public void tearDown() + { + if (testDictionary != null) + { + testDictionary.close(); + } + MessagingService.instance().inboundSink.clear(); + MessagingService.instance().outboundSink.clear(); + } + + @Test + public void testOnNewDictionaryTrained() throws InterruptedException + { + // Expect messages to 2 other nodes (excluding self) + CountDownLatch messageSentLatch = new CountDownLatch(2); + Set receivers = ConcurrentHashMap.newKeySet(2); + AtomicReference capturedMessage = new AtomicReference<>(); + + // Capture outbound messages + MessagingService.instance().outboundSink.add((message, to) -> { + if (message.verb() == Verb.DICTIONARY_UPDATE_REQ) + { + capturedMessage.set((CompressionDictionaryUpdateMessage) message.payload); + receivers.add(to); + messageSentLatch.countDown(); + } + return false; // Don't actually send + }); + + eventHandler.onNewDictionaryTrained(TEST_DICTIONARY_ID); + + // Wait for message to be processed + assertThat(messageSentLatch.await(5, TimeUnit.SECONDS)) + .as("Dictionary update notification should be sent") + .isTrue(); + + assertThat(receivers) + .as("Should not send notification to self") + .hasSize(2) + .doesNotContain(FBUtilities.getBroadcastAddressAndPort()); + + CompressionDictionaryUpdateMessage message = capturedMessage.get(); + assertThat(message) + .as("Message should be captured") + .isNotNull(); + assertThat(message.tableId) + .as("Message should contain correct table ID") + .isEqualTo(tableMetadata.id); + assertThat(message.dictionaryId) + .as("Message should contain correct dictionary ID") + .isEqualTo(TEST_DICTIONARY_ID); + } + + @Test + public void testMessageSerialization() + { + TableId testTableId = tableMetadata.id; + CompressionDictionaryUpdateMessage message = new CompressionDictionaryUpdateMessage(testTableId, TEST_DICTIONARY_ID); + + assertThat(message.tableId) + .as("Message should contain correct table ID") + .isEqualTo(testTableId); + assertThat(message.dictionaryId) + .as("Message should contain correct dictionary ID") + .isEqualTo(TEST_DICTIONARY_ID); + assertThat(CompressionDictionaryUpdateMessage.serializer) + .as("Message should have serializer") + .isNotNull(); + } + + @Test + public void testMessageSerializationRoundTrip() throws Exception + { + TableId testTableId = tableMetadata.id; + CompressionDictionaryUpdateMessage originalMessage = new CompressionDictionaryUpdateMessage(testTableId, TEST_DICTIONARY_ID); + + // Serialize + org.apache.cassandra.io.util.DataOutputBuffer out = new org.apache.cassandra.io.util.DataOutputBuffer(); + CompressionDictionaryUpdateMessage.serializer.serialize(originalMessage, out, MessagingService.current_version); + + // Deserialize + org.apache.cassandra.io.util.DataInputBuffer in = new org.apache.cassandra.io.util.DataInputBuffer(out.getData()); + CompressionDictionaryUpdateMessage deserializedMessage = + CompressionDictionaryUpdateMessage.serializer.deserialize(in, MessagingService.current_version); + + assertThat(deserializedMessage.tableId) + .as("Deserialized table ID should match") + .isEqualTo(originalMessage.tableId); + assertThat(deserializedMessage.dictionaryId) + .as("Deserialized dictionary ID should match") + .isEqualTo(originalMessage.dictionaryId); + } + + @Test + public void testSendNotificationRobustness() + { + // Test that sending notifications doesn't throw even if messaging fails + MessagingService.instance().outboundSink.add((message, to) -> { + if (message.verb() == Verb.DICTIONARY_UPDATE_REQ) + { + throw new RuntimeException("Simulated messaging failure"); + } + return false; + }); + + assertThatNoException().isThrownBy(() -> eventHandler.onNewDictionaryTrained(TEST_DICTIONARY_ID)); + } + + private static ZstdCompressionDictionary createTestDictionary() + { + byte[] dictBytes = "test dictionary data for event handler testing".getBytes(); + return new ZstdCompressionDictionary(TEST_DICTIONARY_ID, dictBytes); + } + + private static DecoratedKey key(TableMetadata metadata, int key) + { + return metadata.partitioner.decorateKey(ByteBufferUtil.bytes(key)); + } +} diff --git a/test/unit/org/apache/cassandra/db/compression/CompressionDictionaryIntegrationTest.java b/test/unit/org/apache/cassandra/db/compression/CompressionDictionaryIntegrationTest.java new file mode 100644 index 000000000000..4828e46f11d1 --- /dev/null +++ b/test/unit/org/apache/cassandra/db/compression/CompressionDictionaryIntegrationTest.java @@ -0,0 +1,223 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.db.compression; + +import java.util.Collections; + +import org.junit.Before; +import org.junit.Test; + +import org.apache.cassandra.config.Config; +import org.apache.cassandra.config.DatabaseDescriptor; +import org.apache.cassandra.config.DataStorageSpec; +import org.apache.cassandra.cql3.CQLTester; +import org.apache.cassandra.db.ColumnFamilyStore; +import org.apache.cassandra.db.Keyspace; +import org.apache.cassandra.db.compression.CompressionDictionary.DictId; +import org.apache.cassandra.db.compression.CompressionDictionary.Kind; +import org.apache.cassandra.db.compression.ICompressionDictionaryTrainer.TrainingStatus; +import org.apache.cassandra.schema.CompressionParams; +import org.apache.cassandra.utils.Clock; + +import static org.apache.cassandra.Util.spinUntilTrue; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatNoException; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +public class CompressionDictionaryIntegrationTest extends CQLTester +{ + private static final String REPEATED_DATA = "The quick brown fox jumps over the lazy dog. This text repeats for better compression. "; + + @Before + public void configureDatabaseDescriptor() + { + Config config = DatabaseDescriptor.getRawConfig(); + config.compression_dictionary_training_sampling_rate = 1.0f; + config.compression_dictionary_training_max_total_sample_size = new DataStorageSpec.IntKibibytesBound("128KiB"); + config.compression_dictionary_training_max_dictionary_size = new DataStorageSpec.IntKibibytesBound("10KiB"); + config.flush_compression = Config.FlushCompression.table; + DatabaseDescriptor.setConfig(config); + } + + @Test + public void testEnableDisableDictionaryCompression() + { + String table = createTable("CREATE TABLE %s (id int PRIMARY KEY, data text) WITH compression = {'class': 'ZstdDictionaryCompressor'}"); + ColumnFamilyStore cfs = Keyspace.open(keyspace()).getColumnFamilyStore(table); + CompressionDictionaryManager manager = cfs.compressionDictionaryManager(); + + // Insert data and flush to create SSTables + for (int i = 0; i < 100; i++) + { + execute("INSERT INTO %s (id, data) VALUES (?, ?)", i, REPEATED_DATA + " " + i); + } + flush(); + + assertThatNoException() + .as("Should allow manual training") + .isThrownBy(() -> manager.train()); + + // Disable dictionary compression + CompressionParams nonDictParams = CompressionParams.lz4(); + manager.maybeReloadFromSchema(nonDictParams); + + assertThatThrownBy(() -> manager.train()) + .as("Should disallow manual training when using lz4") + .isInstanceOf(UnsupportedOperationException.class) + .hasMessageContaining("does not support dictionary compression"); + + // Re-enable dictionary compression + CompressionParams dictParams = CompressionParams.zstd(CompressionParams.DEFAULT_CHUNK_LENGTH, true, + Collections.singletonMap("compression_level", "3")); + manager.maybeReloadFromSchema(dictParams); + + // Insert more data for the re-enabled compression + for (int i = 100; i < 200; i++) + { + execute("INSERT INTO %s (id, data) VALUES (?, ?)", i, REPEATED_DATA + " " + i); + } + flush(); + + assertThatNoException() + .as("Should allow manual training after switching back to dictionary compression") + .isThrownBy(() -> manager.train()); + } + + @Test + public void testCompressionParameterChanges() + { + String table = createTable("CREATE TABLE %s (id int PRIMARY KEY, data text) WITH compression = {'class': 'ZstdDictionaryCompressor'}"); + ColumnFamilyStore cfs = Keyspace.open(keyspace()).getColumnFamilyStore(table); + CompressionDictionaryManager manager = cfs.compressionDictionaryManager(); + ICompressionDictionaryTrainer trainer = manager.trainer(); + assertThat(trainer).isNotNull(); + assertThat(trainer.kind()).isEqualTo(Kind.ZSTD); + + // Change compression level - should create new trainer + CompressionParams newParams = CompressionParams.zstd(CompressionParams.DEFAULT_CHUNK_LENGTH, true, + Collections.singletonMap("compression_level", "5")); + manager.maybeReloadFromSchema(newParams); + ICompressionDictionaryTrainer newTrainer = manager.trainer(); + assertThat(newTrainer.kind()).isEqualTo(Kind.ZSTD); + assertThat(newTrainer) + .as("Should create a different trainer instance when compression level is changed") + .isNotSameAs(trainer); + } + + @Test + public void testResourceCleanupOnClose() throws Exception + { + createTable("CREATE TABLE %s (id int PRIMARY KEY, data text) WITH compression = {'class': 'ZstdDictionaryCompressor'}"); + ColumnFamilyStore cfs = getCurrentColumnFamilyStore(); + CompressionDictionaryManager manager = cfs.compressionDictionaryManager(); + + // Add test dictionary + ZstdCompressionDictionary testDict = createTestDictionary(); + manager.add(testDict); + + assertThat(testDict.selfRef().globalCount()) + .as("Dictionary's reference count should be 1 after adding to cache") + .isOne(); + + assertThat(manager.getCurrent()) + .as("Should have current dictionary before close") + .isNotNull(); + + manager.close(); + + assertThat(manager.trainer()).isNull(); + assertThat(testDict.selfRef().globalCount()) + .as("Dictionary's reference count should be 0 after closing manager") + .isZero(); + assertThat(testDict.rawDictionary()) + .as("The raw dictionary bytes should still be accessible") + .isNotNull(); + } + + private static ZstdCompressionDictionary createTestDictionary() + { + byte[] dictBytes = (REPEATED_DATA + " dictionary training data").getBytes(); + DictId dictId = new DictId(Kind.ZSTD, Clock.Global.currentTimeMillis()); + return new ZstdCompressionDictionary(dictId, dictBytes); + } + + @Test + public void testSSTableBasedTraining() + { + DatabaseDescriptor.setFlushCompression(Config.FlushCompression.table); + String table = createTable("CREATE TABLE %s (pk text PRIMARY KEY, data text) " + + "WITH compression = {'class': 'ZstdDictionaryCompressor', 'chunk_length_in_kb' : 4}"); + ColumnFamilyStore cfs = Keyspace.open(keyspace()).getColumnFamilyStore(table); + CompressionDictionaryManager manager = cfs.compressionDictionaryManager(); + + // Insert compressible data and flush to create SSTables + for (int i = 0; i < 1000; i++) + { + execute("INSERT INTO %s (pk, data) VALUES (?, ?)", + "key" + i, + REPEATED_DATA + " row " + i); + if (i % 200 == 0) + flush(); + } + flush(); + + // Verify we have SSTables + assertThat(cfs.getLiveSSTables()) + .as("Should have created SSTables") + .hasSizeGreaterThan(0); + + // Train from existing SSTables + manager.train(); + + // Training should complete quickly since we're reading from existing SSTables + spinUntilTrue(() -> TrainingState.fromCompositeData(manager.getTrainingState()).status == TrainingStatus.COMPLETED, 10); + + // Verify dictionary was trained and is available + spinUntilTrue(() -> manager.getCurrent() != null, 2); + + CompressionDictionary currentDict = manager.getCurrent(); + assertThat(currentDict).isNotNull(); + assertThat(currentDict.kind()) + .as("Dictionary should be ZSTD type") + .isEqualTo(Kind.ZSTD); + + assertThat(currentDict.rawDictionary().length) + .as("Dictionary should have content") + .isGreaterThan(0); + + // Verify we can still read the data + assertRows(execute("SELECT pk, data FROM %s WHERE pk = ?", "key0"), + row("key0", REPEATED_DATA + " row 0")); + } + + @Test + public void testSSTableBasedTrainingWithoutSSTables() + { + String table = createTable("CREATE TABLE %s (pk text PRIMARY KEY, data text) " + + "WITH compression = {'class': 'ZstdDictionaryCompressor'}"); + ColumnFamilyStore cfs = Keyspace.open(keyspace()).getColumnFamilyStore(table); + CompressionDictionaryManager manager = cfs.compressionDictionaryManager(); + + // Try to train without any SSTables + assertThatThrownBy(() -> manager.train()) + .as("Should fail when no SSTables are available") + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("No SSTables available for training"); + } +} diff --git a/test/unit/org/apache/cassandra/db/compression/CompressionDictionaryManagerMBeanTest.java b/test/unit/org/apache/cassandra/db/compression/CompressionDictionaryManagerMBeanTest.java new file mode 100644 index 000000000000..3e007b3f4fe6 --- /dev/null +++ b/test/unit/org/apache/cassandra/db/compression/CompressionDictionaryManagerMBeanTest.java @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.db.compression; + +import java.util.Map; + +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +import org.apache.cassandra.SchemaLoader; +import org.apache.cassandra.ServerTestUtils; +import org.apache.cassandra.db.ColumnFamilyStore; +import org.apache.cassandra.db.Keyspace; +import org.apache.cassandra.schema.CompressionParams; +import org.apache.cassandra.schema.KeyspaceParams; +import org.apache.cassandra.schema.TableMetadata; +import org.apache.cassandra.utils.MBeanWrapper; +import org.apache.cassandra.utils.MBeanWrapper.OnException; + +import static org.assertj.core.api.Assertions.assertThat; + +public class CompressionDictionaryManagerMBeanTest +{ + private static final String KEYSPACE_WITH_DICT = "keyspace_mbean_test"; + private static final String TABLE = "test_table"; + + private static ColumnFamilyStore cfsWithDict; + + @BeforeClass + public static void setUpClass() throws Exception + { + ServerTestUtils.prepareServer(); + CompressionParams compressionParams = CompressionParams.zstd(CompressionParams.DEFAULT_CHUNK_LENGTH, true, + Map.of("compression_level", "3")); + TableMetadata.Builder tableBuilder = TableMetadata.builder(KEYSPACE_WITH_DICT, TABLE) + .addPartitionKeyColumn("pk", org.apache.cassandra.db.marshal.UTF8Type.instance) + .addRegularColumn("data", org.apache.cassandra.db.marshal.UTF8Type.instance) + .compression(compressionParams); + SchemaLoader.createKeyspace(KEYSPACE_WITH_DICT, + KeyspaceParams.simple(1), + tableBuilder); + cfsWithDict = Keyspace.open(KEYSPACE_WITH_DICT).getColumnFamilyStore(TABLE); + } + + // Ensure no mbean is registered at the begining of the test + @Before + public void cleanup() + { + String mbeanName = CompressionDictionaryManager.mbeanName(KEYSPACE_WITH_DICT, TABLE); + MBeanWrapper.instance.unregisterMBean(mbeanName, OnException.IGNORE); + } + + @Test + public void testMBeanRegisteredWhenBookkeepingEnabled() + { + String mbeanName = CompressionDictionaryManager.mbeanName(KEYSPACE_WITH_DICT, TABLE); + // Create manager with bookkeeping enabled + try (CompressionDictionaryManager manager = new CompressionDictionaryManager(cfsWithDict, true)) + { + // Verify MBean is registered + assertThat(MBeanWrapper.instance.isRegistered(mbeanName)) + .as("MBean should be registered when bookkeeping is enabled") + .isTrue(); + } + // Closing manager should unregister the mbean; Verify it is unregistered + assertThat(MBeanWrapper.instance.isRegistered(mbeanName)) + .as("MBean should be unregistered after unregisterMbean() call") + .isFalse(); + } + + @Test + public void testMBeanNotRegisteredWhenBookkeepingDisabled() + { + // Create manager with bookkeeping disabled + try (CompressionDictionaryManager manager = new CompressionDictionaryManager(cfsWithDict, false)) + { + // Verify MBean is NOT registered + String mbeanName = CompressionDictionaryManager.mbeanName(KEYSPACE_WITH_DICT, TABLE);; + assertThat(MBeanWrapper.instance.isRegistered(mbeanName)) + .as("MBean should not be registered when bookkeeping is disabled") + .isFalse(); + } + // Closing manager should not throw due to mbean not registered + } + + @Test + public void testMBeanUnregisteredOnCFSInvalidation() + { + String testKeyspace = "test_invalidation_mbean_ks"; + String testTable = "test_invalidation_mbean_table"; + + CompressionParams compressionParams = CompressionParams.zstd(CompressionParams.DEFAULT_CHUNK_LENGTH, true, + Map.of("compression_level", "3")); + + TableMetadata.Builder tableBuilder = TableMetadata.builder(testKeyspace, testTable) + .addPartitionKeyColumn("pk", org.apache.cassandra.db.marshal.UTF8Type.instance) + .addRegularColumn("data", org.apache.cassandra.db.marshal.UTF8Type.instance) + .compression(compressionParams); + + SchemaLoader.createKeyspace(testKeyspace, + KeyspaceParams.simple(1), + tableBuilder); + + ColumnFamilyStore cfs = Keyspace.open(testKeyspace).getColumnFamilyStore(testTable); + + String mbeanName = CompressionDictionaryManager.mbeanName(testKeyspace, testTable); + + // Verify MBean is registered (CFS registers it during creation) + assertThat(MBeanWrapper.instance.isRegistered(mbeanName)) + .as("MBean should be registered after CFS creation") + .isTrue(); + + // Invalidate the CFS (which should unregister the MBean) + cfs.invalidate(true, true); + + // Verify MBean is unregistered + assertThat(MBeanWrapper.instance.isRegistered(mbeanName)) + .as("MBean should be unregistered after CFS invalidation") + .isFalse(); + } + + @Test + public void testMBeanStatisticsMethods() + { + // Create manager with bookkeeping enabled + try (CompressionDictionaryManager manager = new CompressionDictionaryManager(cfsWithDict, true)) + { + TrainingState state = TrainingState.fromCompositeData(manager.getTrainingState()); + // Test statistics methods directly on the manager (which implements the MBean interface) + assertThat(state.getSampleCount()) + .as("Sample count should be non-negative") + .isGreaterThanOrEqualTo(0); + + assertThat(state.getTotalSampleSize()) + .as("Total sample size should be non-negative") + .isGreaterThanOrEqualTo(0); + } + } +} diff --git a/test/unit/org/apache/cassandra/db/compression/CompressionDictionaryManagerTest.java b/test/unit/org/apache/cassandra/db/compression/CompressionDictionaryManagerTest.java new file mode 100644 index 000000000000..18245cf51b21 --- /dev/null +++ b/test/unit/org/apache/cassandra/db/compression/CompressionDictionaryManagerTest.java @@ -0,0 +1,258 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.db.compression; + +import java.nio.ByteBuffer; +import java.util.Map; + +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +import org.apache.cassandra.SchemaLoader; +import org.apache.cassandra.ServerTestUtils; +import org.apache.cassandra.config.CassandraRelevantProperties; +import org.apache.cassandra.db.ColumnFamilyStore; +import org.apache.cassandra.db.Keyspace; +import org.apache.cassandra.db.compression.ICompressionDictionaryTrainer.TrainingStatus; +import org.apache.cassandra.schema.CompressionParams; +import org.apache.cassandra.schema.KeyspaceParams; +import org.apache.cassandra.schema.TableMetadata; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatNoException; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +public class CompressionDictionaryManagerTest +{ + private static final String KEYSPACE_WITH_DICT = "keyspace_with_dict"; + private static final String KEYSPACE_WITHOUT_DICT = "keyspace_without_dict"; + private static final String TABLE = "test_table"; + + private static ColumnFamilyStore cfsWithDict; + private static ColumnFamilyStore cfsWithoutDict; + + private CompressionDictionaryManager managerWithDict; + private CompressionDictionaryManager managerWithoutDict; + + @BeforeClass + public static void setUpClass() throws Exception + { + CassandraRelevantProperties.ORG_APACHE_CASSANDRA_DISABLE_MBEAN_REGISTRATION.setBoolean(true); + ServerTestUtils.prepareServerNoRegister(); + + // Create table with dictionary compression enabled + CompressionParams compressionParamsWithDict = CompressionParams.zstd(CompressionParams.DEFAULT_CHUNK_LENGTH, true, + Map.of("compression_level", "3")); + + TableMetadata.Builder tableBuilderWithDict = TableMetadata.builder(KEYSPACE_WITH_DICT, TABLE) + .addPartitionKeyColumn("pk", org.apache.cassandra.db.marshal.UTF8Type.instance) + .addRegularColumn("data", org.apache.cassandra.db.marshal.UTF8Type.instance) + .compression(compressionParamsWithDict); + + // Create table without dictionary compression + CompressionParams compressionParamsWithoutDict = CompressionParams.lz4(); + + TableMetadata.Builder tableBuilderWithoutDict = TableMetadata.builder(KEYSPACE_WITHOUT_DICT, TABLE) + .addPartitionKeyColumn("pk", org.apache.cassandra.db.marshal.UTF8Type.instance) + .addRegularColumn("data", org.apache.cassandra.db.marshal.UTF8Type.instance) + .compression(compressionParamsWithoutDict); + + SchemaLoader.createKeyspace(KEYSPACE_WITH_DICT, + KeyspaceParams.simple(1), + tableBuilderWithDict); + + SchemaLoader.createKeyspace(KEYSPACE_WITHOUT_DICT, + KeyspaceParams.simple(1), + tableBuilderWithoutDict); + + cfsWithDict = Keyspace.open(KEYSPACE_WITH_DICT).getColumnFamilyStore(TABLE); + cfsWithoutDict = Keyspace.open(KEYSPACE_WITHOUT_DICT).getColumnFamilyStore(TABLE); + } + + @Before + public void setUp() + { + managerWithDict = new CompressionDictionaryManager(cfsWithDict, true); + managerWithoutDict = new CompressionDictionaryManager(cfsWithoutDict, true); + } + + @After + public void tearDown() throws Exception + { + if (managerWithDict != null) + { + managerWithDict.close(); + } + if (managerWithoutDict != null) + { + managerWithoutDict.close(); + } + } + + @Test + public void testManagerInitializationWithDictionaryCompression() + { + assertThat(managerWithDict) + .as("Manager should be created successfully for dictionary-enabled table") + .isNotNull(); + + // Manager should start in a valid state + TrainingState trainingState = TrainingState.fromCompositeData(managerWithDict.getTrainingState()); + assertThat(trainingState.getStatus()) + .as("Training status should be valid") + .isEqualTo(TrainingStatus.NOT_STARTED); + } + + @Test + public void testManagerInitializationWithoutDictionaryCompression() + { + assertThat(managerWithoutDict) + .as("Manager should be created successfully for non-dictionary table") + .isNotNull(); + + // Should report NOT_STARTED since no trainer is created + TrainingState trainingState = TrainingState.fromCompositeData(managerWithoutDict.getTrainingState()); + assertThat(trainingState.getStatus()) + .as("Should report NOT_STARTED for non-dictionary tables") + .isEqualTo(TrainingStatus.NOT_STARTED); + } + + @Test + public void testMaybeReloadFromSchemaEnableDictionaryCompression() + { + // Start with manager for non-dictionary table + TrainingState initialTrainingState = TrainingState.fromCompositeData(managerWithoutDict.getTrainingState()); + assertThat(initialTrainingState.getStatus()) + .as("Initially should not be training") + .isEqualTo(TrainingStatus.NOT_STARTED); + + // Enable dictionary compression by switching to dict params + CompressionParams dictParams = CompressionParams.zstd(CompressionParams.DEFAULT_CHUNK_LENGTH, true, + Map.of("compression_level", "3")); + + managerWithoutDict.maybeReloadFromSchema(dictParams); + + // Should now have a trainer + assertThat(managerWithoutDict.trainer()) + .as("Should have a trainer after enabling dictionary compression") + .isNotNull(); + } + + @Test + public void testMaybeReloadFromSchemaDisableDictionaryCompression() + { + // Verify we have a trainer initially + assertThat(managerWithDict.trainer()).isNotNull(); + + // Disable dictionary compression + CompressionParams nonDictParams = CompressionParams.lz4(); + managerWithDict.maybeReloadFromSchema(nonDictParams); + + // Should disable training + assertThat(managerWithDict.trainer()) + .as("Should not have trainer when dictionary compression is disabled") + .isNull(); + } + + @Test + public void testTrainerCompatibilityCheck() + { + ICompressionDictionaryTrainer initialTrainer = managerWithDict.trainer(); + assertThat(initialTrainer).isNotNull(); + + // Change compression level - should create new trainer + CompressionParams differentLevelParams = CompressionParams.zstd(CompressionParams.DEFAULT_CHUNK_LENGTH, true, + Map.of("compression_level", "5")); + managerWithDict.maybeReloadFromSchema(differentLevelParams); + ICompressionDictionaryTrainer newTrainer = managerWithDict.trainer(); + + // Should have a different trainer instance + assertThat(newTrainer) + .as("Should create new trainer when compression level changes") + .isNotSameAs(initialTrainer); + } + + @Test + public void testAddSample() + { + ByteBuffer sample = ByteBuffer.wrap("test sample data".getBytes()); + ByteBuffer emptyBuffer = ByteBuffer.allocate(0); + + // Should not throw for dictionary-enabled table + assertThatNoException().isThrownBy(() -> managerWithDict.addSample(sample)); + assertThatNoException().isThrownBy(() -> managerWithDict.addSample(null)); + assertThatNoException().isThrownBy(() -> managerWithDict.addSample(emptyBuffer)); + + // Should not throw for non-dictionary table (graceful handling) + assertThatNoException().isThrownBy(() -> managerWithoutDict.addSample(sample)); + assertThatNoException().isThrownBy(() -> managerWithoutDict.addSample(null)); + assertThatNoException().isThrownBy(() -> managerWithoutDict.addSample(emptyBuffer)); + } + + @Test + public void testTrainManualWithNonDictionaryTable() + { + assertThatThrownBy(() -> managerWithoutDict.train()) + .isInstanceOf(UnsupportedOperationException.class) + .hasMessageContaining("does not support dictionary compression"); + } + + @Test + public void testTrainManualWithDictionaryTable() + { + // Should throw because no SSTables exist + assertThatThrownBy(() -> managerWithDict.train()) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("No SSTables available for training"); + } + + @Test + public void testSchemaChangeWorkflow() + { + // Start with non-dictionary table + TrainingState initialTrainingState = TrainingState.fromCompositeData(managerWithoutDict.getTrainingState()); + assertThat(initialTrainingState.getStatus()).isEqualTo(TrainingStatus.NOT_STARTED); + assertThat(managerWithoutDict.trainer()).isNull(); + + // Enable dictionary compression + CompressionParams dictParams = CompressionParams.zstd(CompressionParams.DEFAULT_CHUNK_LENGTH, true, + Map.of("compression_level", "3")); + managerWithoutDict.maybeReloadFromSchema(dictParams); + + // Should now support training + assertThat(managerWithoutDict.trainer()).isNotNull(); + + // Change compression level + CompressionParams newDictParams = CompressionParams.zstd(CompressionParams.DEFAULT_CHUNK_LENGTH, true, + Map.of("compression_level", "5")); + managerWithoutDict.maybeReloadFromSchema(newDictParams); + + // Should still support training with new parameters + assertThat(managerWithoutDict.trainer()).isNotNull(); + + // Disable dictionary compression + CompressionParams nonDictParams = CompressionParams.lz4(); + managerWithoutDict.maybeReloadFromSchema(nonDictParams); + + // Should disable training + assertThat(managerWithoutDict.trainer()).isNull(); + } +} diff --git a/test/unit/org/apache/cassandra/db/compression/CompressionDictionarySchedulerTest.java b/test/unit/org/apache/cassandra/db/compression/CompressionDictionarySchedulerTest.java new file mode 100644 index 000000000000..4ada039ece7e --- /dev/null +++ b/test/unit/org/apache/cassandra/db/compression/CompressionDictionarySchedulerTest.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.db.compression; + +import java.util.HashSet; +import java.util.Set; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.apache.cassandra.config.DatabaseDescriptor; +import org.apache.cassandra.cql3.CQLTester; +import org.apache.cassandra.db.ColumnFamilyStore; +import org.apache.cassandra.db.Keyspace; +import org.apache.cassandra.io.sstable.format.SSTableReader; + +import static org.apache.cassandra.Util.spinUntilTrue; +import static org.assertj.core.api.Assertions.assertThat; + +public class CompressionDictionarySchedulerTest extends CQLTester +{ + private static final String KEYSPACE = "scheduler_test_ks"; + private static final String TABLE = "scheduler_test_table"; + + private CompressionDictionaryScheduler scheduler; + private ICompressionDictionaryCache cache; + + @Before + public void setUp() + { + cache = new CompressionDictionaryCache(); + scheduler = new CompressionDictionaryScheduler(KEYSPACE, TABLE, cache, true); + } + + @After + public void tearDown() + { + if (scheduler != null) + { + scheduler.close(); + } + } + + @Test + public void testScheduleSSTableBasedTrainingWithNoSSTables() + { + String table = createTable("CREATE TABLE %s (id int PRIMARY KEY, data text) " + + "WITH compression = {'class': 'ZstdDictionaryCompressor'}"); + ColumnFamilyStore cfs = Keyspace.open(keyspace()).getColumnFamilyStore(table); + CompressionDictionaryManager manager = cfs.compressionDictionaryManager(); + + Set sstables = new HashSet<>(); + CompressionDictionaryTrainingConfig config = createSampleAllTrainingConfig(cfs); + + // Should not throw, but task will complete quickly with no SSTables + scheduler.scheduleSSTableBasedTraining(manager.trainer(), sstables, config); + spinUntilTrue(() -> scheduler.scheduledManualTrainingTask() == null); + assertThat(manager.getCurrent()).isNull(); + } + + @Test + public void testScheduleSSTableBasedTrainingWithSSTables() + { + String table = createTable("CREATE TABLE %s (id int PRIMARY KEY, data text) " + + "WITH compression = {'class': 'ZstdDictionaryCompressor', 'chunk_length_in_kb': '4'}"); + ColumnFamilyStore cfs = Keyspace.open(keyspace()).getColumnFamilyStore(table); + CompressionDictionaryManager manager = cfs.compressionDictionaryManager(); + + createSSTables(); + + Set sstables = cfs.getLiveSSTables(); + assertThat(sstables).isNotEmpty(); + + CompressionDictionaryTrainingConfig config = createSampleAllTrainingConfig(cfs); + manager.trainer().start(true); + + assertThat(manager.getCurrent()).as("There should be no dictionary at this step").isNull(); + scheduler.scheduleSSTableBasedTraining(manager.trainer(), sstables, config); + + // Task should be scheduled + assertThat((Object) scheduler.scheduledManualTrainingTask()).isNotNull(); + // A dictionary should be trained + spinUntilTrue(() -> manager.getCurrent() != null); + } + + private void createSSTables() + { + for (int file = 0; file < 10; file++) + { + int batchSize = 1000; + for (int i = 0; i < batchSize; i++) + { + int index = i + file * batchSize; + execute("INSERT INTO %s (id, data) VALUES (?, ?)", index, "test data " + index); + } + flush(); + } + } + + private static CompressionDictionaryTrainingConfig createSampleAllTrainingConfig(ColumnFamilyStore cfs) { + return CompressionDictionaryTrainingConfig + .builder() + .maxDictionarySize(DatabaseDescriptor.getCompressionDictionaryTrainingMaxDictionarySize()) + .maxTotalSampleSize(DatabaseDescriptor.getCompressionDictionaryTrainingMaxTotalSampleSize()) + .samplingRate(1.0f) + .chunkSize(cfs.metadata().params.compression.chunkLength()) + .build(); + } +} diff --git a/test/unit/org/apache/cassandra/db/compression/CompressionDictionaryTrainingConfigTest.java b/test/unit/org/apache/cassandra/db/compression/CompressionDictionaryTrainingConfigTest.java new file mode 100644 index 000000000000..8f3d746d035b --- /dev/null +++ b/test/unit/org/apache/cassandra/db/compression/CompressionDictionaryTrainingConfigTest.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.db.compression; + +import org.junit.Test; + +import static org.assertj.core.api.Assertions.assertThat; + +public class CompressionDictionaryTrainingConfigTest +{ + @Test + public void testBuilderDefaults() + { + CompressionDictionaryTrainingConfig config = CompressionDictionaryTrainingConfig.builder().build(); + + assertThat(config.maxDictionarySize) + .as("Default max dictionary size should be 64KB") + .isEqualTo(65536); + assertThat(config.maxTotalSampleSize) + .as("Default max total sample size should be 10MB") + .isEqualTo(10 * 1024 * 1024); + assertThat(config.samplingRate) + .as("Default sampling rate should be 100 (1%)") + .isEqualTo(100); + } + + @Test + public void testCalculatedThresholds() + { + int dictSize = 16 * 1024; // 16KB + int sampleSize = 2 * 1024 * 1024; // 2MB + float samplingRate = 0.005f; // 0.5% + + CompressionDictionaryTrainingConfig config = CompressionDictionaryTrainingConfig.builder() + .maxDictionarySize(dictSize) + .maxTotalSampleSize(sampleSize) + .samplingRate(samplingRate) + .build(); + + // Verify all calculated values are consistent + assertThat(config.maxDictionarySize).isEqualTo(dictSize); + assertThat(config.maxTotalSampleSize).isEqualTo(sampleSize); + assertThat(config.acceptableTotalSampleSize).isEqualTo(sampleSize / 10 * 8); + assertThat(config.samplingRate).isEqualTo(Math.round(1 / samplingRate)); + + // Verify relationship between max and acceptable sample sizes + assertThat(config.acceptableTotalSampleSize) + .as("Acceptable sample size should be less than or equal to max") + .isLessThanOrEqualTo(config.maxTotalSampleSize); + } +} diff --git a/test/unit/org/apache/cassandra/db/compression/SSTableChunkSamplerTest.java b/test/unit/org/apache/cassandra/db/compression/SSTableChunkSamplerTest.java new file mode 100644 index 000000000000..22a4b081d588 --- /dev/null +++ b/test/unit/org/apache/cassandra/db/compression/SSTableChunkSamplerTest.java @@ -0,0 +1,236 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.db.compression; + +import java.util.List; +import java.util.Set; + +import org.junit.Test; + +import org.apache.cassandra.cql3.CQLTester; +import org.apache.cassandra.db.ColumnFamilyStore; +import org.apache.cassandra.db.Keyspace; +import org.apache.cassandra.db.compression.SSTableChunkSampler.SSTableChunkInfo; +import org.apache.cassandra.io.sstable.format.SSTableReader; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.RETURNS_DEEP_STUBS; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class SSTableChunkSamplerTest extends CQLTester +{ + @Test + public void testSSTableChunkInfoForCompressedSSTable() + { + String table = createTable("CREATE TABLE %s (id int PRIMARY KEY, data text) WITH compression = {'class': 'LZ4Compressor', 'chunk_length_in_kb': '64'}"); + ColumnFamilyStore cfs = Keyspace.open(keyspace()).getColumnFamilyStore(table); + + // Insert data and flush to create an SSTable + for (int i = 0; i < 100; i++) + { + execute("INSERT INTO %s (id, data) VALUES (?, ?)", i, "test data " + i); + } + flush(); + + Set sstables = cfs.getLiveSSTables(); + assertThat(sstables).isNotEmpty(); + + SSTableReader sstable = sstables.iterator().next(); + CompressionDictionaryTrainingConfig config = CompressionDictionaryTrainingConfig.builder() + .chunkSize(64 * 1024) + .build(); + + SSTableChunkInfo info = new SSTableChunkInfo(sstable, config); + + assertThat(info.isCompressed).isTrue(); + assertThat(info.chunkCount).isGreaterThan(0); + assertThat(info.dataLength).isGreaterThan(0); + assertThat(info.chunkSize).isEqualTo(64 * 1024); + assertThat(info.metadata).isNotNull(); + } + + @Test + public void testSSTableChunkInfoForUncompressedSSTable() + { + String table = createTable("CREATE TABLE %s (id int PRIMARY KEY, data text) WITH compression = {'enabled': 'false'}"); + ColumnFamilyStore cfs = Keyspace.open(keyspace()).getColumnFamilyStore(table); + + // Insert data and flush to create an uncompressed SSTable + for (int i = 0; i < 100; i++) + { + execute("INSERT INTO %s (id, data) VALUES (?, ?)", i, "test data " + i); + } + flush(); + + Set sstables = cfs.getLiveSSTables(); + assertThat(sstables).isNotEmpty(); + + SSTableReader sstable = sstables.iterator().next(); + CompressionDictionaryTrainingConfig config = CompressionDictionaryTrainingConfig.builder() + .chunkSize(64 * 1024) + .build(); + + SSTableChunkInfo info = new SSTableChunkInfo(sstable, config); + + assertThat(info.isCompressed).isFalse(); + assertThat(info.chunkCount).isGreaterThan(0); + assertThat(info.dataLength).isGreaterThan(0); + assertThat(info.chunkSize).isEqualTo(64 * 1024); + assertThat(info.metadata).isNull(); + } + + @Test + public void testCalculateTargetChunkCount() + { + String table = createTable("CREATE TABLE %s (id int PRIMARY KEY, data text) WITH compression = {'enabled': 'false'}"); + ColumnFamilyStore cfs = Keyspace.open(keyspace()).getColumnFamilyStore(table); + + // Create multiple SSTables + for (int batch = 0; batch < 3; batch++) + { + for (int i = 0; i < 100; i++) + { + execute("INSERT INTO %s (id, data) VALUES (?, ?)", batch * 100 + i, "test data " + i); + } + flush(); + } + + CompressionDictionaryTrainingConfig config = CompressionDictionaryTrainingConfig.builder() + .maxTotalSampleSize(10 * 1024 * 1024) // 10MB + .chunkSize(64 * 1024) + .build(); + + Set sstables = cfs.getLiveSSTables(); + assertThat(sstables).hasSizeGreaterThanOrEqualTo(3); + + List sstableInfos = SSTableChunkSampler.buildSSTableInfos(sstables, config); + long totalChunks = sstableInfos.stream().mapToLong(info -> info.chunkCount).sum(); + long targetChunkCount = SSTableChunkSampler.calculateTargetChunkCount(sstableInfos, totalChunks, config); + + // Target should be based on maxTotalSampleSize divided by average chunk size + assertThat(targetChunkCount).isGreaterThan(0); + long totalDataSize = sstableInfos.stream().mapToLong(info -> info.dataLength).sum(); + int averageChunkSize = (int) (totalDataSize / totalChunks); + long expectedTarget = config.maxTotalSampleSize / averageChunkSize; + assertThat(targetChunkCount).isEqualTo(expectedTarget); + } + + @Test + public void testSelectRandomChunkIndices() + { + // test scenarios: select small portion, large portion and all + for (int expectedChunkCount : List.of(10, 80, 100)) + { + Set selected = SSTableChunkSampler.selectRandomChunkIndices(100, expectedChunkCount); + + assertThat(selected).hasSize(expectedChunkCount); + assertThat(selected).allMatch(idx -> idx >= 0 && idx < 100); + } + } + + @Test + public void testSelectRandomChunkIndicesDistribution() + { + // Test that selection is reasonably distributed + int totalChunks = 100; + int runs = 1000; + int[] hitCount = new int[totalChunks]; + + // Run many selections and count how often each chunk is selected + for (int i = 0; i < runs; i++) + { + Set selected = SSTableChunkSampler.selectRandomChunkIndices(totalChunks, 10); + for (long idx : selected) + { + hitCount[(int) idx]++; + } + } + + // Each chunk should be selected approximately 10% of the time (10 out of 100) + // So in 1000 runs, expect ~100 hits per chunk + // Allow for variance - between 50 and 150 hits + for (int count : hitCount) + { + assertThat(count).isBetween(50, 150); + } + } + + @Test + public void testSampleFromSSTablesWithTrainerNotReady() + { + String table = createTable("CREATE TABLE %s (id int PRIMARY KEY, data text) WITH compression = {'class': 'LZ4Compressor'}"); + ColumnFamilyStore cfs = Keyspace.open(keyspace()).getColumnFamilyStore(table); + + // Insert data and flush to create an SSTable + for (int i = 0; i < 100; i++) + { + execute("INSERT INTO %s (id, data) VALUES (?, ?)", i, "test data " + i); + } + flush(); + + Set sstables = cfs.getLiveSSTables(); + assertThat(sstables).isNotEmpty(); + + CompressionDictionaryTrainingConfig config = CompressionDictionaryTrainingConfig.builder() + .chunkSize(64 * 1024) + .build(); + + // Create a mock trainer that is not ready to sample + ICompressionDictionaryTrainer trainer = mock(ICompressionDictionaryTrainer.class, RETURNS_DEEP_STUBS); + when(trainer.shouldSample()).thenReturn(false); + when(trainer.getTrainingState().getStatus()).thenReturn(ICompressionDictionaryTrainer.TrainingStatus.NOT_STARTED); + + // Should throw IllegalStateException when trainer is not ready + assertThatThrownBy(() -> SSTableChunkSampler.sampleFromSSTables(sstables, trainer, config)) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Trainer is not ready to accept samples"); + } + + @Test + public void testReadChunkThrowsOnInvalidPosition() + { + String table = createTable("CREATE TABLE %s (id int PRIMARY KEY, data text) WITH compression = {'enabled': 'false'}"); + ColumnFamilyStore cfs = Keyspace.open(keyspace()).getColumnFamilyStore(table); + + // Insert data and flush to create an uncompressed SSTable + for (int i = 0; i < 100; i++) + { + execute("INSERT INTO %s (id, data) VALUES (?, ?)", i, "test data " + i); + } + flush(); + + Set sstables = cfs.getLiveSSTables(); + assertThat(sstables).isNotEmpty(); + + SSTableReader sstable = sstables.iterator().next(); + CompressionDictionaryTrainingConfig config = CompressionDictionaryTrainingConfig.builder() + .chunkSize(64 * 1024) + .build(); + + SSTableChunkInfo info = new SSTableChunkInfo(sstable, config); + + // Try to read at a position beyond the data length - should throw IOException + long invalidPosition = info.dataLength + 1000; + assertThatThrownBy(() -> SSTableChunkSampler.readUncompressedChunk(info, invalidPosition)) + .isInstanceOf(java.io.IOException.class) + .hasMessageContaining("Invalid read size"); + } +} diff --git a/test/unit/org/apache/cassandra/db/compression/ZstdCompressionDictionaryTest.java b/test/unit/org/apache/cassandra/db/compression/ZstdCompressionDictionaryTest.java new file mode 100644 index 000000000000..e054776f96b2 --- /dev/null +++ b/test/unit/org/apache/cassandra/db/compression/ZstdCompressionDictionaryTest.java @@ -0,0 +1,392 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.db.compression; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +import com.github.luben.zstd.ZstdDictCompress; +import com.github.luben.zstd.ZstdDictDecompress; +import org.apache.cassandra.config.DatabaseDescriptor; +import org.apache.cassandra.db.compression.CompressionDictionary.DictId; +import org.apache.cassandra.db.compression.CompressionDictionary.Kind; +import org.apache.cassandra.io.compress.ZstdCompressorBase; +import org.apache.cassandra.utils.concurrent.Ref; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +public class ZstdCompressionDictionaryTest +{ + private static final byte[] SAMPLE_DICT_DATA = createSampleDictionaryData(); + private static final DictId SAMPLE_DICT_ID = new DictId(Kind.ZSTD, 123456789L); + + private ZstdCompressionDictionary dictionary; + + @BeforeClass + public static void setUpClass() + { + DatabaseDescriptor.daemonInitialization(); + } + + @Before + public void setUp() + { + dictionary = new ZstdCompressionDictionary(SAMPLE_DICT_ID, SAMPLE_DICT_DATA); + } + + @Test + public void testEqualsAndHashCode() + { + ZstdCompressionDictionary dictionary2 = new ZstdCompressionDictionary(SAMPLE_DICT_ID, SAMPLE_DICT_DATA); + ZstdCompressionDictionary differentIdDict = new ZstdCompressionDictionary( + new DictId(Kind.ZSTD, 987654321L), SAMPLE_DICT_DATA); + + assertThat(dictionary) + .as("Dictionaries with same ID should be equal") + .isEqualTo(dictionary2); + + assertThat(dictionary.hashCode()) + .as("Hash codes should be equal for same ID") + .isEqualTo(dictionary2.hashCode()); + + assertThat(dictionary) + .as("Dictionaries with different IDs should not be equal") + .isNotEqualTo(differentIdDict); + + dictionary2.close(); + differentIdDict.close(); + } + + @Test + public void testDictionaryForCompression() + { + int compressionLevel = 3; + ZstdDictCompress compressDict = dictionary.dictionaryForCompression(compressionLevel); + + assertThat(compressDict) + .as("Compression dictionary should not be null") + .isNotNull(); + + // Calling again should return the same cached instance + ZstdDictCompress compressDict2 = dictionary.dictionaryForCompression(compressionLevel); + assertThat(compressDict2) + .as("Second call should return cached instance") + .isSameAs(compressDict); + } + + @Test + public void testDictionaryForCompressionMultipleLevels() + { + ZstdDictCompress level1 = dictionary.dictionaryForCompression(1); + ZstdDictCompress level3 = dictionary.dictionaryForCompression(3); + ZstdDictCompress level6 = dictionary.dictionaryForCompression(6); + + assertThat(level1) + .as("Level 1 compression dictionary should not be null") + .isNotNull(); + + assertThat(level3) + .as("Level 3 compression dictionary should not be null") + .isNotNull(); + + assertThat(level6) + .as("Level 6 compression dictionary should not be null") + .isNotNull(); + + assertThat(level1) + .as("Different compression levels should have different instances") + .isNotSameAs(level3); + + assertThat(level3) + .as("Different compression levels should have different instances") + .isNotSameAs(level6); + } + + @Test + public void testDictionaryForDecompression() + { + ZstdDictDecompress decompressDict = dictionary.dictionaryForDecompression(); + + assertThat(decompressDict) + .as("Decompression dictionary should not be null") + .isNotNull(); + + ZstdDictDecompress decompressDict2 = dictionary.dictionaryForDecompression(); + assertThat(decompressDict2) + .as("Second call should return cached instance") + .isSameAs(decompressDict); + } + + @Test + public void testInvalidCompressionLevel() + { + // Test with various invalid compression levels + assertThatThrownBy(() -> dictionary.dictionaryForCompression(ZstdCompressorBase.FAST_COMPRESSION_LEVEL - 1)) + .as("Negative compression level should throw exception") + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("is invalid"); + + assertThatThrownBy(() -> dictionary.dictionaryForCompression(100)) + .as("Too high compression level should throw exception") + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("is invalid"); + } + + @Test + public void testDictionaryClose() + { + // Access some dictionaries first + dictionary.dictionaryForCompression(3); + dictionary.dictionaryForDecompression(); + + dictionary.close(); + + assertThatThrownBy(() -> dictionary.dictionaryForCompression(3)) + .as("Should throw exception when accessing closed dictionary") + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Dictionary has been closed"); + + assertThatThrownBy(() -> dictionary.dictionaryForDecompression()) + .as("Should throw exception when accessing closed dictionary") + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Dictionary has been closed"); + } + + @Test + public void testTryRef() + { + Ref ref = dictionary.tryRef(); + + assertThat(ref) + .as("tryRef should return non-null reference") + .isNotNull(); + + assertThat(ref.get()) + .as("Reference should point to same dictionary") + .isSameAs(dictionary); + + ref.release(); + } + + @Test + public void testMultipleReferences() + { + Ref ref1 = dictionary.ref(); + Ref ref2 = dictionary.ref(); + Ref ref3 = dictionary.tryRef(); + + assertThat(ref1.get()) + .as("All references should point to same dictionary") + .isSameAs(dictionary); + + assertThat(ref2.get()) + .as("All references should point to same dictionary") + .isSameAs(dictionary); + + assertThat(ref3.get()) + .as("All references should point to same dictionary") + .isSameAs(dictionary); + + // Dictionary should still be accessible + assertThat(dictionary.dictionaryForCompression(3)) + .as("Dictionary should still be accessible with multiple refs") + .isNotNull(); + + ref1.release(); + ref2.release(); + ref3.release(); + } + + @Test + public void testReferenceAfterClose() + { + dictionary.close(); + + assertThatThrownBy(() -> dictionary.ref()) + .as("Should not be able to get reference after close") + .isInstanceOf(AssertionError.class); + + Ref tryRef = dictionary.tryRef(); + assertThat(tryRef) + .as("tryRef should return null after close") + .isNull(); + } + + @Test + public void testConcurrentAccess() throws Exception + { + ExecutorService executor = Executors.newFixedThreadPool(4); + AtomicInteger successCount = new AtomicInteger(0); + int numTasks = 100; + + try + { + Future[] futures = new Future[numTasks]; + + for (int i = 0; i < numTasks; i++) + { + final int level = (i % 6) + 1; // Compression levels 1-6 + futures[i] = executor.submit(() -> { + try + { + Ref ref = dictionary.ref(); + ZstdDictCompress compressDict = ref.get().dictionaryForCompression(level); + ZstdDictDecompress decompressDict = ref.get().dictionaryForDecompression(); + + assertThat(compressDict).isNotNull(); + assertThat(decompressDict).isNotNull(); + + successCount.incrementAndGet(); + ref.release(); + } + catch (Exception e) + { + throw new RuntimeException(e); + } + }); + } + + // Wait for all tasks to complete + for (Future future : futures) + { + future.get(5, TimeUnit.SECONDS); + } + + assertThat(successCount.get()) + .as("All concurrent accesses should succeed") + .isEqualTo(numTasks); + } + finally + { + executor.shutdown(); + executor.awaitTermination(5, TimeUnit.SECONDS); + } + } + + @Test + public void testSerializeDeserialize() throws IOException + { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + DataOutputStream dos = new DataOutputStream(baos); + + dictionary.serialize(dos); + dos.flush(); + + byte[] serializedData = baos.toByteArray(); + assertThat(serializedData.length) + .as("Serialized data should not be empty") + .isGreaterThan(0); + + // Deserialize + ByteArrayInputStream bais = new ByteArrayInputStream(serializedData); + DataInputStream dis = new DataInputStream(bais); + + CompressionDictionary deserializedDict = CompressionDictionary.deserialize(dis, null); + + assertThat(deserializedDict) + .as("Deserialized dictionary should not be null") + .isNotNull(); + + assertThat(deserializedDict.dictId()) + .as("Deserialized dictionary ID should match") + .isEqualTo(dictionary.dictId()); + + assertThat(deserializedDict.kind()) + .as("Deserialized dictionary kind should match") + .isEqualTo(dictionary.kind()); + + assertThat(deserializedDict.rawDictionary()) + .as("Deserialized dictionary data should match") + .isEqualTo(dictionary.rawDictionary()); + } + + @Test + public void testSerializeDeserializeWithManager() throws Exception + { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + DataOutputStream dos = new DataOutputStream(baos); + + dictionary.serialize(dos); + dos.flush(); + + byte[] serializedData = baos.toByteArray(); + + // First deserialization should create and cache the dictionary + ByteArrayInputStream bais1 = new ByteArrayInputStream(serializedData); + DataInputStream dis1 = new DataInputStream(bais1); + CompressionDictionary dict1 = CompressionDictionary.deserialize(dis1, null); + + // Second deserialization should return cached instance + ByteArrayInputStream bais2 = new ByteArrayInputStream(serializedData); + DataInputStream dis2 = new DataInputStream(bais2); + CompressionDictionary dict2 = CompressionDictionary.deserialize(dis2, null); + + assertThat(dict1) + .as("Both deserializations should return identical dictionary") + .isNotNull() + .isEqualTo(dict2); + + dict1.close(); + dict2.close(); + } + + @Test + public void testDeserializeCorruptedData() throws IOException + { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + DataOutputStream dos = new DataOutputStream(baos); + + // Write corrupted data (wrong checksum) + dos.writeByte(Kind.ZSTD.ordinal()); + dos.writeLong(SAMPLE_DICT_ID.id); + dos.writeInt(SAMPLE_DICT_DATA.length); + dos.write(SAMPLE_DICT_DATA); + dos.writeInt(0xDEADBEEF); // Wrong checksum + dos.flush(); + + byte[] corruptedData = baos.toByteArray(); + ByteArrayInputStream bais = new ByteArrayInputStream(corruptedData); + DataInputStream dis = new DataInputStream(bais); + + assertThatThrownBy(() -> CompressionDictionary.deserialize(dis, null)) + .as("Should throw exception for corrupted data") + .isInstanceOf(IOException.class) + .hasMessageContaining("checksum does not match"); + } + + private static byte[] createSampleDictionaryData() + { + // Create sample dictionary data that could be used for compression + String sampleText = "The quick brown fox jumps over the lazy dog. "; + return sampleText.repeat(100).getBytes(); + } +} diff --git a/test/unit/org/apache/cassandra/db/compression/ZstdDictionaryTrainerTest.java b/test/unit/org/apache/cassandra/db/compression/ZstdDictionaryTrainerTest.java new file mode 100644 index 000000000000..6c43e7e95695 --- /dev/null +++ b/test/unit/org/apache/cassandra/db/compression/ZstdDictionaryTrainerTest.java @@ -0,0 +1,685 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.db.compression; + +import java.nio.ByteBuffer; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; + +import org.apache.cassandra.utils.concurrent.Future; + +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +import org.apache.cassandra.config.DatabaseDescriptor; +import org.apache.cassandra.db.compression.ICompressionDictionaryTrainer.TrainingStatus; +import org.apache.cassandra.schema.CompressionParams; +import org.apache.cassandra.utils.Clock; + +import static org.apache.cassandra.db.compression.CompressionDictionary.Kind; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +public class ZstdDictionaryTrainerTest +{ + private static final String TEST_KEYSPACE = "test_ks"; + private static final String TEST_TABLE = "test_table"; + private static final String SAMPLE_DATA = "The quick brown fox jumps over the lazy dog. "; + private static final int COMPRESSION_LEVEL = 3; + + private CompressionDictionaryTrainingConfig testConfig; + private ZstdDictionaryTrainer trainer; + private Consumer mockCallback; + private AtomicReference callbackResult; + + @BeforeClass + public static void setUpClass() + { + DatabaseDescriptor.daemonInitialization(); + } + + @Before + public void setUp() + { + testConfig = CompressionDictionaryTrainingConfig.builder() + .maxDictionarySize(1024) // Small for testing + .maxTotalSampleSize(10 * 1024) // 10KB total + .samplingRate(1) // 100% sampling for predictable tests + .build(); + + callbackResult = new AtomicReference<>(); + mockCallback = callbackResult::set; + + trainer = new ZstdDictionaryTrainer(TEST_KEYSPACE, TEST_TABLE, testConfig, COMPRESSION_LEVEL); + trainer.setDictionaryTrainedListener(mockCallback); + } + + @After + public void tearDown() throws Exception + { + if (trainer != null) + { + trainer.close(); + } + + // Clean up any dictionary created in callback + CompressionDictionary dict = callbackResult.get(); + if (dict != null) + { + dict.close(); + callbackResult.set(null); + } + } + + @Test + public void testTrainerInitialState() + { + assertThat(trainer.getTrainingState().getStatus()) + .as("Initial status should be NOT_STARTED") + .isEqualTo(TrainingStatus.NOT_STARTED); + assertThat(trainer.isReady()) + .as("Should not be ready initially") + .isFalse(); + assertThat(trainer.kind()) + .as("Should return ZSTD kind") + .isEqualTo(Kind.ZSTD); + } + + @Test + public void testTrainerStart() + { + // Auto start depends on configuration - test both scenarios + boolean started = trainer.start(false); + if (started) + { + assertThat(trainer.getTrainingState().getStatus()) + .as("Status should be SAMPLING if auto-start enabled") + .isEqualTo(TrainingStatus.SAMPLING); + } + else + { + assertThat(trainer.getTrainingState().getStatus()) + .as("Status should remain NOT_STARTED if auto-start disabled") + .isEqualTo(TrainingStatus.NOT_STARTED); + } + } + + @Test + public void testTrainerStartManual() + { + assertThat(trainer.start(true)) + .as("Manual training should start successfully") + .isTrue(); + assertThat(trainer.getTrainingState().getStatus()) + .as("Status should be SAMPLING after start") + .isEqualTo(TrainingStatus.SAMPLING); + assertThat(trainer.isReady()) + .as("Should not be ready immediately after start") + .isFalse(); + } + + @Test + public void testTrainerStartMultipleTimes() + { + assertThat(trainer.start(true)) + .as("First start (manual training) should succeed") + .isTrue(); + Object firstTrainer = trainer.trainer(); + assertThat(firstTrainer).isNotNull(); + assertThat(trainer.start(true)) + .as("Second start (manual training) should suceed and reset") + .isTrue(); + Object secondTrainer = trainer.trainer(); + assertThat(secondTrainer).isNotNull().isNotSameAs(firstTrainer); + assertThat(trainer.start(false)) + .as("Third start (not manual training) should fail") + .isFalse(); + } + + @Test + public void testTrainerCloseIdempotent() + { + trainer.start(true); + trainer.close(); + trainer.close(); // Should not throw + trainer.close(); // Should not throw + + assertThat(trainer.getTrainingState().getStatus()) + .as("Status should remain NOT_STARTED after multiple closes") + .isEqualTo(TrainingStatus.NOT_STARTED); + } + + @Test + public void testTrainerReset() + { + trainer.start(true); + addSampleData(1000); // Add some samples + + assertThat(trainer.getTrainingState().getSampleCount()) + .as("Should have samples before reset") + .isGreaterThan(0); + + trainer.reset(); + assertThat(trainer.getTrainingState().getStatus()) + .as("Status should be NOT_STARTED after reset") + .isEqualTo(TrainingStatus.NOT_STARTED); + assertThat(trainer.getTrainingState().getSampleCount()) + .as("Sample count should be 0 after reset") + .isEqualTo(0); + assertThat(trainer.isReady()) + .as("Should not be ready after reset") + .isFalse(); + } + + @Test + public void testStartAfterClose() + { + trainer.start(true); + trainer.close(); + + assertThat(trainer.start(true)) + .as("Should not start after close") + .isFalse(); + assertThat(trainer.getTrainingState().getStatus()) + .as("Status should remain NOT_STARTED") + .isEqualTo(TrainingStatus.NOT_STARTED); + } + + @Test + public void testShouldSample() + { + trainer.start(true); + // With sampling rate 1 (100%), should always return true + for (int i = 0; i < 10; i++) + { + assertThat(trainer.shouldSample()) + .as("Should sample with rate 1") + .isTrue(); + } + } + + @Test + public void testShouldSampleWithLowRate() + { + // Test with lower sampling rate + CompressionDictionaryTrainingConfig lowSamplingConfig = + CompressionDictionaryTrainingConfig.builder() + .maxDictionarySize(1024) + .maxTotalSampleSize(10 * 1024) + .samplingRate(0.001f) // 0.1% sampling + .build(); + + try (ZstdDictionaryTrainer lowSamplingTrainer = new ZstdDictionaryTrainer(TEST_KEYSPACE, TEST_TABLE, + lowSamplingConfig, COMPRESSION_LEVEL)) + { + lowSamplingTrainer.setDictionaryTrainedListener(mockCallback); + // With very low sampling rate, should mostly return false + int sampleCount = 0; + int iterations = 1000; + for (int i = 0; i < iterations; i++) + { + if (lowSamplingTrainer.shouldSample()) + { + sampleCount++; + } + } + + // Should be roughly 0.1% (1 out of 1000), allow some variance + assertThat(sampleCount) + .as("Sample rate should be low") + .isLessThan(iterations / 10); + } + } + + @Test + public void testAddSample() + { + trainer.start(true); + + assertThat(trainer.getTrainingState().getSampleCount()) + .as("Initial sample count should be 0") + .isEqualTo(0); + + ByteBuffer sample = ByteBuffer.wrap(SAMPLE_DATA.getBytes()); + trainer.addSample(sample); + + assertThat(trainer.getTrainingState().getSampleCount()) + .as("Sample count should be 1 after adding one sample") + .isEqualTo(1); + assertThat(trainer.getTrainingState().getStatus()) + .as("Status should be SAMPLING") + .isEqualTo(TrainingStatus.SAMPLING); + assertThat(trainer.isReady()) + .as("Should not be ready with single small sample") + .isFalse(); + } + + @Test + public void testAddSampleBeforeStart() + { + // Should not accept samples before start + ByteBuffer sample = ByteBuffer.wrap(SAMPLE_DATA.getBytes()); + trainer.addSample(sample); + + assertThat(trainer.getTrainingState().getStatus()) + .as("Status should remain NOT_STARTED") + .isEqualTo(TrainingStatus.NOT_STARTED); + assertThat(trainer.isReady()) + .as("Should not be ready") + .isFalse(); + } + + @Test + public void testAddSampleAfterClose() + { + trainer.start(true); + trainer.close(); + + ByteBuffer sample = ByteBuffer.wrap(SAMPLE_DATA.getBytes()); + trainer.addSample(sample); + + assertThat(trainer.getTrainingState().getStatus()) + .as("Status should remain NOT_STARTED after close") + .isEqualTo(TrainingStatus.NOT_STARTED); + assertThat(trainer.isReady()) + .as("Should not be ready after close") + .isFalse(); + } + + @Test + public void testAddNullSample() + { + trainer.start(true); + trainer.addSample(null); // Should not throw + + assertThat(trainer.getTrainingState().getStatus()) + .as("Status should remain SAMPLING") + .isEqualTo(TrainingStatus.SAMPLING); + assertThat(trainer.isReady()) + .as("Should not be ready with null sample") + .isFalse(); + } + + @Test + public void testAddEmptySample() + { + trainer.start(true); + ByteBuffer empty = ByteBuffer.allocate(0); + trainer.addSample(empty); // Should not throw + + assertThat(trainer.getTrainingState().getStatus()) + .as("Status should remain SAMPLING") + .isEqualTo(TrainingStatus.SAMPLING); + assertThat(trainer.isReady()) + .as("Should not be ready with empty sample") + .isFalse(); + } + + @Test + public void testIsReady() + { + trainer.start(true); + assertThat(trainer.isReady()) + .as("Should not be ready initially") + .isFalse(); + + addSampleData(testConfig.acceptableTotalSampleSize / 2); + assertThat(trainer.isReady()) + .as("Should not be ready with insufficient samples") + .isFalse(); + + addSampleData(testConfig.acceptableTotalSampleSize); + assertThat(trainer.isReady()) + .as("Should be ready after enough samples") + .isTrue(); + + trainer.close(); + + assertThat(trainer.isReady()) + .as("Should not be ready when closed") + .isFalse(); + } + + @Test + public void testTrainDictionaryWithInsufficientSampleCount() + { + trainer.start(true); + + // Add sufficient data size but only 5 samples (less than minimum 10) + for (int i = 0; i < 5; i++) + { + ByteBuffer largeSample = ByteBuffer.wrap(new byte[testConfig.acceptableTotalSampleSize / 5]); + trainer.addSample(largeSample); + } + + assertThat(trainer.getTrainingState().getSampleCount()) + .as("Should have 5 samples") + .isEqualTo(5); + assertThat(trainer.isReady()) + .as("Should not be ready with insufficient sample count") + .isFalse(); + + // Trying to train without force should fail + assertThatThrownBy(() -> trainer.trainDictionary(false)) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Trainer is not ready"); + + // Force training should fail with insufficient samples + assertThatThrownBy(() -> trainer.trainDictionary(true)) + .isInstanceOf(RuntimeException.class) + .hasMessageContaining("Insufficient samples for training: 5 (minimum required: 10)"); + } + + @Test + public void testTrainDictionaryWithSufficientSampleCount() + { + trainer.start(true); + + // Add 15 samples with sufficient total size + for (int i = 0; i < 15; i++) + { + ByteBuffer sample = ByteBuffer.wrap(new byte[testConfig.acceptableTotalSampleSize / 15 + 1]); + trainer.addSample(sample); + } + + assertThat(trainer.getTrainingState().getSampleCount()).isEqualTo(15); + assertThat(trainer.isReady()).isTrue(); + + // Training should succeed + CompressionDictionary dictionary = trainer.trainDictionary(false); + assertThat(dictionary).as("Dictionary should be created").isNotNull(); + assertThat(trainer.getTrainingState().getStatus()).isEqualTo(TrainingStatus.COMPLETED); + } + + @Test + public void testTrainDictionaryAsync() throws Exception + { + Future future = startTraining(true, false, testConfig.acceptableTotalSampleSize); + CompressionDictionary dictionary = future.get(5, TimeUnit.SECONDS); + + assertThat(dictionary).as("Dictionary should not be null").isNotNull(); + assertThat(trainer.getTrainingState().getStatus()).as("Status should be COMPLETED").isEqualTo(TrainingStatus.COMPLETED); + + // Verify callback was called + assertThat(callbackResult.get()).as("Callback should have been called").isNotNull(); + assertThat(callbackResult.get().dictId()).as("Callback should receive same dictionary").isEqualTo(dictionary.dictId()); + } + + @Test + public void testTrainDictionaryAsyncForce() throws Exception + { + // Don't add enough samples + Future future = startTraining(true, true, 512); + CompressionDictionary dictionary = future.get(1, TimeUnit.SECONDS); + assertThat(dictionary) + .as("Forced async training should produce dictionary") + .isNotNull(); + } + + @Test + public void testTrainDictionaryAsyncForceFailsWithNoData() throws Exception + { + AtomicReference dictRef = new AtomicReference<>(); + Future result = startTraining(true, true, 0) + .addCallback((dict, t) -> dictRef.set(dict)); + + assertThat(result.isDone() && result.cause() != null) + .as("Result should be completed exceptionally") + .isTrue(); + assertThat(trainer.getTrainingState().getStatus()) + .as("Status should be FAILED") + .isEqualTo(TrainingStatus.FAILED); + assertThat(dictRef.get()) + .as("Dictionary reference should be null") + .isNull(); + } + + @Test + public void testDictionaryTrainedListener() + { + trainer.start(true); + addSampleData(testConfig.acceptableTotalSampleSize); + + // Train dictionary synchronously - callback should be called + CompressionDictionary dictionary = trainer.trainDictionary(false); + + // Verify callback was invoked with the dictionary + assertThat(callbackResult.get()).as("Callback should have been called").isNotNull(); + assertThat(callbackResult.get().dictId().id) + .as("Callback should receive correct dictionary ID") + .isEqualTo(dictionary.dictId().id); + assertThat(callbackResult.get().kind()) + .as("Callback should receive correct dictionary kind") + .isEqualTo(dictionary.kind()); + } + + @Test + public void testMonotonicDictionaryIds() + { + long now = Clock.Global.currentTimeMillis(); + long id1 = ZstdDictionaryTrainer.makeDictionaryId(now, 100L); + long hourLater= now + TimeUnit.HOURS.toMillis(1); + long id2 = ZstdDictionaryTrainer.makeDictionaryId(hourLater, 200L); + long id3 = ZstdDictionaryTrainer.makeDictionaryId(now, 200L); + + assertThat(id2) + .as("Dictionary IDs should be monotonic over time") + .isGreaterThan(id1) + .isGreaterThan(id3); + + assertThat(id3).isNotEqualTo(id1).isNotEqualTo(id2); + } + + @Test + public void testIsCompatibleWith() + { + CompressionParams compatibleParams = CompressionParams.zstd(CompressionParams.DEFAULT_CHUNK_LENGTH, true, + Map.of("compression_level", "3")); + + assertThat(trainer.isCompatibleWith(compatibleParams)) + .as("Should be compatible with same compression level") + .isTrue(); + + + CompressionParams incompatibleParams = CompressionParams.lz4(); + + assertThat(trainer.isCompatibleWith(incompatibleParams)) + .as("Should not be compatible with different compressor") + .isFalse(); + + CompressionParams differentLevelParams = CompressionParams.zstd(CompressionParams.DEFAULT_CHUNK_LENGTH, true, + Map.of("compression_level", "4")); + + assertThat(trainer.isCompatibleWith(differentLevelParams)) + .as("Should not be compatible with different compression level") + .isFalse(); + + CompressionParams disabledParams = CompressionParams.noCompression(); + + assertThat(trainer.isCompatibleWith(disabledParams)) + .as("Should not be compatible with disabled compression") + .isFalse(); + } + + @Test + public void testUpdateSamplingRate() + { + trainer.start(true); + + // Test updating to different valid sampling rates + trainer.updateSamplingRate(10); + + // With sampling rate 10 (10%), should mostly return false + int sampleCount = 0; + int iterations = 1000; + for (int i = 0; i < iterations; i++) + { + if (trainer.shouldSample()) + { + sampleCount++; + } + } + + // Should be roughly 10% (1 out of 10), allow some variance + assertThat(sampleCount) + .as("Sample rate should be approximately 10%") + .isGreaterThan(iterations / 20) // at least 5% + .isLessThan(iterations / 5); // at most 20% + + // Test updating to 100% sampling + trainer.updateSamplingRate(1); + + // Should always sample now + for (int i = 0; i < 10; i++) + { + assertThat(trainer.shouldSample()) + .as("Should always sample with rate 1") + .isTrue(); + } + } + + @Test + public void testUpdateSamplingRateValidation() + { + trainer.start(true); + + // Test invalid sampling rates + assertThatThrownBy(() -> trainer.updateSamplingRate(0)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Sampling rate must be positive"); + + assertThatThrownBy(() -> trainer.updateSamplingRate(-1)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Sampling rate must be positive"); + + assertThatThrownBy(() -> trainer.updateSamplingRate(-100)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Sampling rate must be positive"); + } + + @Test + public void testUpdateSamplingRateBeforeStart() + { + // Should be able to update sampling rate even before start + trainer.updateSamplingRate(5); + + trainer.start(true); + + // Verify the updated rate is used after start + int sampleCount = 0; + int iterations = 1000; + for (int i = 0; i < iterations; i++) + { + if (trainer.shouldSample()) + { + sampleCount++; + } + } + + // Should be roughly 20% (1 out of 5), allow some variance + assertThat(sampleCount) + .as("Sample rate should be approximately 20%") + .isGreaterThan(iterations / 10) // at least 10% + .isLessThan(iterations / 2); // at most 50% + } + + private Future startTraining(boolean manualTraining, boolean forceTrain, int sampleSize) throws Exception + { + trainer.start(manualTraining); + if (sampleSize > 0) + { + addSampleData(sampleSize); + } + + if (forceTrain) + { + assertThat(trainer.isReady()) + .as("Trainer should not be ready to train due to lack of samples") + .isFalse(); + } + + CountDownLatch latch = new CountDownLatch(1); + Future future = trainer.trainDictionaryAsync(forceTrain) + .addCallback((dict, throwable) -> latch.countDown()); + assertThat(latch.await(10, TimeUnit.SECONDS)) + .as("Training should complete within timeout") + .isTrue(); + return future; + } + + private void addSampleData(int totalSize) + { + byte[] sampleBytes = SAMPLE_DATA.getBytes(); + int samplesNeeded = (totalSize + sampleBytes.length - 1) / sampleBytes.length; // Round up + + for (int i = 0; i < samplesNeeded; i++) + { + ByteBuffer sample = ByteBuffer.wrap(sampleBytes); + trainer.addSample(sample); + } + } + + @Test + public void testStatisticsMethods() + { + assertThat(trainer.getTrainingState().getSampleCount()) + .as("Initial sample count should be 0") + .isEqualTo(0); + + assertThat(trainer.getTrainingState().getTotalSampleSize()) + .as("Initial total sample size should be 0") + .isEqualTo(0); + + // Start training + trainer.start(true); + + // Add some samples + byte[] sampleBytes = SAMPLE_DATA.getBytes(); + int sampleSize = sampleBytes.length; + int numSamples = 5; + + for (int i = 0; i < numSamples; i++) + { + trainer.addSample(ByteBuffer.wrap(sampleBytes)); + } + + assertThat(trainer.getTrainingState().getSampleCount()) + .as("Sample count should be updated after adding samples") + .isEqualTo(numSamples); + + assertThat(trainer.getTrainingState().getTotalSampleSize()) + .as("Total sample size should match number of samples times sample size") + .isEqualTo((long) numSamples * sampleSize); + + trainer.reset(); + + assertThat(trainer.getTrainingState().getSampleCount()) + .as("Sample count should be 0 after reset") + .isEqualTo(0); + + assertThat(trainer.getTrainingState().getTotalSampleSize()) + .as("Total sample size should be 0 after reset") + .isEqualTo(0); + } +} diff --git a/test/unit/org/apache/cassandra/io/compress/CQLCompressionTest.java b/test/unit/org/apache/cassandra/io/compress/CQLCompressionTest.java index fbc17c272992..f2a56c33640c 100644 --- a/test/unit/org/apache/cassandra/io/compress/CQLCompressionTest.java +++ b/test/unit/org/apache/cassandra/io/compress/CQLCompressionTest.java @@ -81,11 +81,11 @@ public void zstdParamsTest() { createTable("create table %s (id int primary key, uh text) with compression = {'class':'ZstdCompressor', 'compression_level':-22}"); assertTrue(((ZstdCompressor)getCurrentColumnFamilyStore().metadata().params.compression.getSstableCompressor()).getClass().equals(ZstdCompressor.class)); - assertEquals(((ZstdCompressor)getCurrentColumnFamilyStore().metadata().params.compression.getSstableCompressor()).getCompressionLevel(), -22); + assertEquals(((ZstdCompressor)getCurrentColumnFamilyStore().metadata().params.compression.getSstableCompressor()).compressionLevel(), -22); createTable("create table %s (id int primary key, uh text) with compression = {'class':'ZstdCompressor'}"); assertTrue(((ZstdCompressor)getCurrentColumnFamilyStore().metadata().params.compression.getSstableCompressor()).getClass().equals(ZstdCompressor.class)); - assertEquals(((ZstdCompressor)getCurrentColumnFamilyStore().metadata().params.compression.getSstableCompressor()).getCompressionLevel(), ZstdCompressor.DEFAULT_COMPRESSION_LEVEL); + assertEquals(((ZstdCompressor)getCurrentColumnFamilyStore().metadata().params.compression.getSstableCompressor()).compressionLevel(), ZstdCompressor.DEFAULT_COMPRESSION_LEVEL); } @Test(expected = ConfigurationException.class) diff --git a/test/unit/org/apache/cassandra/io/compress/CompressedRandomAccessReaderTest.java b/test/unit/org/apache/cassandra/io/compress/CompressedRandomAccessReaderTest.java index 2bf127f0790d..056761b5602c 100644 --- a/test/unit/org/apache/cassandra/io/compress/CompressedRandomAccessReaderTest.java +++ b/test/unit/org/apache/cassandra/io/compress/CompressedRandomAccessReaderTest.java @@ -314,4 +314,4 @@ private static void updateChecksum(RandomAccessFile file, long checksumOffset, b file.write(checksum); SyncUtil.sync(file.getFD()); } -} \ No newline at end of file +} diff --git a/test/unit/org/apache/cassandra/io/compress/CompressedSequentialWriterTest.java b/test/unit/org/apache/cassandra/io/compress/CompressedSequentialWriterTest.java index afa469c48772..ada098e6d057 100644 --- a/test/unit/org/apache/cassandra/io/compress/CompressedSequentialWriterTest.java +++ b/test/unit/org/apache/cassandra/io/compress/CompressedSequentialWriterTest.java @@ -395,4 +395,4 @@ void cleanup() } } -} \ No newline at end of file +} diff --git a/test/unit/org/apache/cassandra/io/compress/CompressionMetadataTest.java b/test/unit/org/apache/cassandra/io/compress/CompressionMetadataTest.java index 321fe5735606..560a0af63520 100644 --- a/test/unit/org/apache/cassandra/io/compress/CompressionMetadataTest.java +++ b/test/unit/org/apache/cassandra/io/compress/CompressionMetadataTest.java @@ -42,7 +42,8 @@ private CompressionMetadata newCompressionMetadata(Memory memory) memory, memory.size(), dataLength, - compressedFileLength); + compressedFileLength, + null); } @Test @@ -75,4 +76,4 @@ public void testMemoryIsShared() assertThat(copy.isCleanedUp()).isTrue(); assertThatExceptionOfType(AssertionError.class).isThrownBy(memory::size); } -} \ No newline at end of file +} diff --git a/test/unit/org/apache/cassandra/io/compress/ZstdCompressorTest.java b/test/unit/org/apache/cassandra/io/compress/ZstdCompressorTest.java index 70e32ad4d252..b360280b9e32 100644 --- a/test/unit/org/apache/cassandra/io/compress/ZstdCompressorTest.java +++ b/test/unit/org/apache/cassandra/io/compress/ZstdCompressorTest.java @@ -36,7 +36,7 @@ public class ZstdCompressorTest public void emptyConfigurationUsesDefaultCompressionLevel() { ZstdCompressor compressor = ZstdCompressor.create(Collections.emptyMap()); - assertEquals(ZstdCompressor.DEFAULT_COMPRESSION_LEVEL, compressor.getCompressionLevel()); + assertEquals(ZstdCompressor.DEFAULT_COMPRESSION_LEVEL, compressor.compressionLevel()); } @Test(expected = IllegalArgumentException.class) diff --git a/test/unit/org/apache/cassandra/io/compress/ZstdDictionaryCompressorTest.java b/test/unit/org/apache/cassandra/io/compress/ZstdDictionaryCompressorTest.java new file mode 100644 index 000000000000..ef5e3d11d8d7 --- /dev/null +++ b/test/unit/org/apache/cassandra/io/compress/ZstdDictionaryCompressorTest.java @@ -0,0 +1,392 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.io.compress; + +import com.github.luben.zstd.Zstd; +import com.github.luben.zstd.ZstdDictTrainer; +import org.apache.cassandra.config.DatabaseDescriptor; +import org.apache.cassandra.db.compression.ZstdCompressionDictionary; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Collections; +import java.util.Map; +import java.util.Random; + +import static org.apache.cassandra.db.compression.CompressionDictionary.DictId; +import static org.apache.cassandra.db.compression.CompressionDictionary.Kind; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.junit.Assert.fail; + +public class ZstdDictionaryCompressorTest +{ + private static final int TEST_DATA_SIZE = 1024; + private static final String REPEATED_PATTERN = "The quick brown fox jumps over the lazy dog. "; + + private static byte[] testData; + private static byte[] compressibleData; + private static ZstdCompressionDictionary testDictionary; + + @BeforeClass + public static void setup() + { + DatabaseDescriptor.daemonInitialization(); + testData = new byte[TEST_DATA_SIZE]; + new Random(42).nextBytes(testData); + + // Generate compressible data + StringBuilder sb = new StringBuilder(); + while (sb.length() < TEST_DATA_SIZE) + { + sb.append(REPEATED_PATTERN); + } + compressibleData = sb.substring(0, TEST_DATA_SIZE).getBytes(); + testDictionary = createTestDictionary(); + } + + @AfterClass + public static void tearDown() + { + if (testDictionary != null) + { + testDictionary.close(); + } + } + + @Test + public void testCreateWithOptions() + { + Map options = Map.of(ZstdCompressor.COMPRESSION_LEVEL_OPTION_NAME, "5"); + + ZstdDictionaryCompressor compressor = ZstdDictionaryCompressor.create(options); + assertThat(compressor).isNotNull(); + assertThat(compressor.compressionLevel()).isEqualTo(5); + assertThat(compressor.dictionary()).isNull(); // No dictionary should be set + } + + @Test + public void testCreateWithEmptyOptions() + { + ZstdDictionaryCompressor compressor = ZstdDictionaryCompressor.create(Collections.emptyMap()); + assertThat(compressor).isNotNull(); + assertThat(compressor.compressionLevel()).isEqualTo(ZstdCompressor.DEFAULT_COMPRESSION_LEVEL); + } + + @Test + public void testCreateWithDictionary() + { + ZstdDictionaryCompressor compressor = ZstdDictionaryCompressor.create(testDictionary); + assertThat(compressor).isNotNull(); + assertThat(compressor.compressionLevel()).isEqualTo(ZstdCompressor.DEFAULT_COMPRESSION_LEVEL); + assertThat(compressor.dictionary()).isSameAs(testDictionary); + } + + @Test + public void testCreateWithInvalidCompressionLevel() + { + String invalidLevel = String.valueOf(Zstd.maxCompressionLevel() + 1); + Map options = Map.of(ZstdCompressor.COMPRESSION_LEVEL_OPTION_NAME, invalidLevel); + + assertThatThrownBy(() -> ZstdDictionaryCompressor.create(options)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage(ZstdCompressor.COMPRESSION_LEVEL_OPTION_NAME + '=' + invalidLevel + " is invalid"); + } + + @Test + public void testCompressDecompressWithDictionary() throws IOException + { + ZstdDictionaryCompressor compressor = ZstdDictionaryCompressor.create(testDictionary); + + ByteBuffer input = ByteBuffer.allocateDirect(compressibleData.length); + input.put(compressibleData); + input.flip(); + + ByteBuffer compressed = ByteBuffer.allocateDirect(compressor.initialCompressedBufferLength(compressibleData.length)); + + // Compress + compressor.compress(input, compressed); + compressed.flip(); + + assertThat(compressed.remaining()) + .as("Data should be compressed") + .isLessThan(compressibleData.length); + + // Decompress + ByteBuffer decompressed = ByteBuffer.allocateDirect(compressibleData.length); + compressed.rewind(); + compressor.uncompress(compressed, decompressed); + decompressed.flip(); + + // Verify roundtrip + byte[] result = new byte[decompressed.remaining()]; + decompressed.get(result); + assertThat(result).isEqualTo(compressibleData); + } + + @Test + public void testCompressDecompressWithoutDictionary() throws IOException + { + // Test fallback behavior when no dictionary is provided + ZstdDictionaryCompressor compressor = ZstdDictionaryCompressor.create(Collections.emptyMap()); + + ByteBuffer input = ByteBuffer.allocateDirect(testData.length); + input.put(testData); + input.flip(); + + ByteBuffer compressed = ByteBuffer.allocateDirect(compressor.initialCompressedBufferLength(testData.length)); + + // Compress + compressor.compress(input, compressed); + compressed.flip(); + + // Decompress + ByteBuffer decompressed = ByteBuffer.allocateDirect(testData.length); + compressed.rewind(); + compressor.uncompress(compressed, decompressed); + decompressed.flip(); + + // Verify roundtrip + byte[] result = new byte[decompressed.remaining()]; + decompressed.get(result); + assertThat(result).isEqualTo(testData); + } + + @Test + public void testCompressDecompressByteArray() throws IOException + { + ZstdDictionaryCompressor compressor = ZstdDictionaryCompressor.create(testDictionary); + + // Test byte array compression/decompression using direct buffers + ByteBuffer input = ByteBuffer.allocateDirect(compressibleData.length); + input.put(compressibleData); + input.flip(); + + ByteBuffer output = ByteBuffer.allocateDirect(compressor.initialCompressedBufferLength(compressibleData.length)); + + compressor.compress(input, output); + int compressedLength = output.position(); + + // Extract compressed data to byte array for array-based decompression test + byte[] compressed = new byte[compressedLength]; + output.flip(); + output.get(compressed); + + // Decompress using byte array method + byte[] decompressed = new byte[compressibleData.length]; + int decompressedLength = compressor.uncompress(compressed, 0, compressedLength, decompressed, 0); + + assertThat(decompressedLength).isEqualTo(compressibleData.length); + assertThat(decompressed).isEqualTo(compressibleData); + } + + @Test + public void testDictionaryCompressionImprovement() + { + // Test that dictionary compression provides better compression ratio + ZstdDictionaryCompressor dictCompressor = ZstdDictionaryCompressor.create(testDictionary); + ZstdDictionaryCompressor noDictCompressor = ZstdDictionaryCompressor.create(Collections.emptyMap()); + + ByteBuffer input1 = ByteBuffer.allocateDirect(compressibleData.length); + input1.put(compressibleData); + input1.flip(); + + ByteBuffer input2 = ByteBuffer.allocateDirect(compressibleData.length); + input2.put(compressibleData); + input2.flip(); + + ByteBuffer dictCompressed = ByteBuffer.allocateDirect(dictCompressor.initialCompressedBufferLength(compressibleData.length)); + ByteBuffer noDictCompressed = ByteBuffer.allocateDirect(noDictCompressor.initialCompressedBufferLength(compressibleData.length)); + + try + { + dictCompressor.compress(input1, dictCompressed); + noDictCompressor.compress(input2, noDictCompressed); + + dictCompressed.flip(); + noDictCompressed.flip(); + + // Dictionary compression should achieve better compression ratio for repetitive data + assertThat(dictCompressed.remaining()) + .as("Dictionary compression should achieve better compression ratio") + .isLessThanOrEqualTo(noDictCompressed.remaining()); + } + catch (IOException e) + { + fail("Compression should not fail: " + e.getMessage()); + } + } + + @Test + public void testCompressorCaching() + { + // Test that same dictionary returns same compressor instance + ZstdDictionaryCompressor compressor1 = ZstdDictionaryCompressor.create(testDictionary); + ZstdDictionaryCompressor compressor2 = ZstdDictionaryCompressor.create(testDictionary); + + assertThat(compressor1) + .as("Same dictionary should return cached compressor instance") + .isSameAs(compressor2); + } + + @Test + public void testGetOrCopyWithDictionary() + { + ZstdDictionaryCompressor originalCompressor = ZstdDictionaryCompressor.create(Collections.emptyMap()); + ZstdDictionaryCompressor dictCompressor = originalCompressor.getOrCopyWithDictionary(testDictionary); + + assertThat(dictCompressor) + .as("Should return different compressor instance") + .isNotSameAs(originalCompressor); + assertThat(dictCompressor.dictionary()) + .as("Should have the provided dictionary") + .isSameAs(testDictionary); + assertThat(dictCompressor.compressionLevel()) + .as("Should preserve compression level") + .isEqualTo(originalCompressor.compressionLevel()); + } + + @Test + public void testGetOrCopyWithSameDictionary() + { + ZstdDictionaryCompressor originalCompressor = ZstdDictionaryCompressor.create(testDictionary); + ZstdDictionaryCompressor sameCompressor = originalCompressor.getOrCopyWithDictionary(testDictionary); + + assertThat(sameCompressor) + .as("Same dictionary should return same compressor") + .isSameAs(originalCompressor); + } + + @Test + public void testClosedDictionaryHandling() + { + ZstdDictionaryCompressor.invalidateCache(); + ZstdCompressionDictionary closedDict = createTestDictionary(); + closedDict.close(); + + // This should throw IllegalStateException + assertThatThrownBy(() -> ZstdDictionaryCompressor.create(closedDict)) + .isInstanceOf(IllegalStateException.class); + } + + @Test + public void testCompressionWithNullDictionary() throws IOException + { + // Test that null dictionary falls back to standard compression + ZstdDictionaryCompressor compressor = ZstdDictionaryCompressor.create((ZstdCompressionDictionary) null); + + ByteBuffer input = ByteBuffer.allocateDirect(testData.length); + input.put(testData); + input.flip(); + + ByteBuffer compressed = ByteBuffer.allocateDirect(compressor.initialCompressedBufferLength(testData.length)); + + // Should not throw exception, should fall back to standard Zstd + compressor.compress(input, compressed); + compressed.flip(); + + ByteBuffer decompressed = ByteBuffer.allocateDirect(testData.length); + compressed.rewind(); + compressor.uncompress(compressed, decompressed); + decompressed.flip(); + + byte[] result = new byte[decompressed.remaining()]; + decompressed.get(result); + assertThat(result) + .as("Null dictionary should fall back to standard compression") + .isEqualTo(testData); + } + + @Test + public void testDecompressionFailureHandling() + { + ZstdDictionaryCompressor compressor = ZstdDictionaryCompressor.create(testDictionary); + + // Create invalid compressed data + byte[] invalidData = new byte[10]; + new Random().nextBytes(invalidData); + + byte[] output = new byte[100]; + + assertThatThrownBy(() -> compressor.uncompress(invalidData, 0, invalidData.length, output, 0)) + .isInstanceOf(IOException.class) + .hasMessageContaining("Decompression failed"); + } + + @Test + public void testAcceptableDictionaryKind() + { + ZstdDictionaryCompressor compressor = ZstdDictionaryCompressor.create(Collections.emptyMap()); + assertThat(compressor.acceptableDictionaryKind()) + .as("Should accept ZSTD dictionary kind") + .isEqualTo(Kind.ZSTD); + } + + @Test + public void testEmptyDataCompression() throws IOException + { + ZstdDictionaryCompressor compressor = ZstdDictionaryCompressor.create(testDictionary); + + byte[] emptyData = new byte[0]; + ByteBuffer input = ByteBuffer.allocateDirect(emptyData.length + 1); // Allocate at least 1 byte for direct buffer + input.put(emptyData); + input.flip(); + + ByteBuffer compressed = ByteBuffer.allocateDirect(Math.max(1, compressor.initialCompressedBufferLength(0))); + + compressor.compress(input, compressed); + compressed.flip(); + + ByteBuffer decompressed = ByteBuffer.allocateDirect(1); // Allocate at least 1 byte for direct buffer + compressed.rewind(); + compressor.uncompress(compressed, decompressed); + + assertThat(decompressed.position()) + .as("Should have written nothing for empty data") + .isEqualTo(0); + } + + private static ZstdCompressionDictionary createTestDictionary() + { + try + { + int sampleSize = 100 * 1024; + int dictSize = 6 * 1024; + // Create a simple dictionary from repetitive data + ZstdDictTrainer trainer = new ZstdDictTrainer(sampleSize, dictSize, 3); + + for (int i = 0; i < 1000; i++) + { + trainer.addSample(compressibleData); + } + + byte[] dictBytes = trainer.trainSamples(); + DictId dictId = new DictId(Kind.ZSTD, 1); + + return new ZstdCompressionDictionary(dictId, dictBytes); + } + catch (Exception e) + { + throw new RuntimeException("Failed to create test dictionary", e); + } + } +} diff --git a/test/unit/org/apache/cassandra/io/sstable/ScrubTest.java b/test/unit/org/apache/cassandra/io/sstable/ScrubTest.java index bd8f846572ff..7696d5d9d59e 100644 --- a/test/unit/org/apache/cassandra/io/sstable/ScrubTest.java +++ b/test/unit/org/apache/cassandra/io/sstable/ScrubTest.java @@ -559,7 +559,7 @@ public static void overrideWithGarbage(SSTableReader sstable, ByteBuffer key1, B if (compression) { // overwrite with garbage the compression chunks from key1 to key2 - CompressionMetadata compData = CompressionInfoComponent.load(sstable.descriptor); + CompressionMetadata compData = CompressionInfoComponent.load(sstable.descriptor, null); CompressionMetadata.Chunk chunk1 = compData.chunkFor( sstable.getPosition(PartitionPosition.ForKey.get(key1, sstable.getPartitioner()), SSTableReader.Operator.EQ)); diff --git a/test/unit/org/apache/cassandra/schema/CompressionParamsTest.java b/test/unit/org/apache/cassandra/schema/CompressionParamsTest.java new file mode 100644 index 000000000000..70f1fbbc1938 --- /dev/null +++ b/test/unit/org/apache/cassandra/schema/CompressionParamsTest.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.schema; + +import org.junit.BeforeClass; +import org.junit.Test; + +import org.apache.cassandra.config.DatabaseDescriptor; + +import static org.assertj.core.api.Assertions.assertThat; + +public class CompressionParamsTest +{ + @BeforeClass + public static void beforeClass() + { + DatabaseDescriptor.daemonInitialization(); + } + + @Test + public void testIsDictionaryCompressionEnabled() + { + CompressionParams noCompression = CompressionParams.noCompression(); + assertThat(noCompression.isDictionaryCompressionEnabled()) + .as("No compression should not enable dictionary compression") + .isFalse(); + + CompressionParams regularZstd = CompressionParams.zstd(); + assertThat(regularZstd.isDictionaryCompressionEnabled()) + .as("Regular Zstd compression should not enable dictionary compression") + .isFalse(); + + CompressionParams zstdDictionary = CompressionParams.zstd(CompressionParams.DEFAULT_CHUNK_LENGTH, true); + assertThat(zstdDictionary.isDictionaryCompressionEnabled()) + .as("Zstd dictionary compression should enable dictionary compression") + .isTrue(); + + CompressionParams lz4 = CompressionParams.lz4(); + assertThat(lz4.isDictionaryCompressionEnabled()) + .as("LZ4 compression should not enable dictionary compression") + .isFalse(); + + CompressionParams snappy = CompressionParams.snappy(); + assertThat(snappy.isDictionaryCompressionEnabled()) + .as("Snappy compression should not enable dictionary compression") + .isFalse(); + + CompressionParams deflate = CompressionParams.deflate(); + assertThat(deflate.isDictionaryCompressionEnabled()) + .as("Deflate compression should not enable dictionary compression") + .isFalse(); + + CompressionParams noop = CompressionParams.noop(); + assertThat(noop.isDictionaryCompressionEnabled()) + .as("Noop compression should not enable dictionary compression") + .isFalse(); + } +} diff --git a/test/unit/org/apache/cassandra/schema/SystemDistributedKeyspaceCompressionDictionaryTest.java b/test/unit/org/apache/cassandra/schema/SystemDistributedKeyspaceCompressionDictionaryTest.java new file mode 100644 index 000000000000..94c0140149a1 --- /dev/null +++ b/test/unit/org/apache/cassandra/schema/SystemDistributedKeyspaceCompressionDictionaryTest.java @@ -0,0 +1,195 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.schema; + +import java.util.List; +import java.util.Set; + +import org.junit.Before; +import org.junit.Test; + +import org.apache.cassandra.cql3.CQLTester; +import org.apache.cassandra.cql3.QueryProcessor; +import org.apache.cassandra.db.compression.CompressionDictionary; +import org.apache.cassandra.db.compression.CompressionDictionary.DictId; +import org.apache.cassandra.db.compression.CompressionDictionary.Kind; +import org.apache.cassandra.db.compression.ZstdCompressionDictionary; + +import static org.assertj.core.api.Assertions.assertThat; + +public class SystemDistributedKeyspaceCompressionDictionaryTest extends CQLTester +{ + private static final String TEST_KEYSPACE = "test_keyspace"; + private static final String TEST_TABLE = "test_table"; + private static final String OTHER_TABLE = "other_table"; + + private CompressionDictionary testDictionary1; + private CompressionDictionary testDictionary2; + + @Before + public void setUp() + { + DictId dictId1 = new DictId(Kind.ZSTD, 100L); + DictId dictId2 = new DictId(Kind.ZSTD, 200L); + + byte[] dictData1 = "test dictionary data 1".getBytes(); + byte[] dictData2 = "test dictionary data 2".getBytes(); + + testDictionary1 = new ZstdCompressionDictionary(dictId1, dictData1); + testDictionary2 = new ZstdCompressionDictionary(dictId2, dictData2); + + clearCompressionDictionaries(); + } + + @Test + public void testCompressionDictionariesTableExists() + { + Set tableNames = SystemDistributedKeyspace.TABLE_NAMES; + + assertThat(tableNames) + .as("TABLE_NAMES should contain compression_dictionaries") + .contains(SystemDistributedKeyspace.COMPRESSION_DICTIONARIES); + + // Verify the table exists in the schema + KeyspaceMetadata systemDistributedKs = SystemDistributedKeyspace.metadata(); + TableMetadata compressionDictTable = systemDistributedKs + .getTableOrViewNullable(SystemDistributedKeyspace.COMPRESSION_DICTIONARIES); + + assertThat(compressionDictTable) + .as("compression_dictionaries table should exist in schema") + .isNotNull(); + } + + @Test + public void testStoreCompressionDictionary() throws Exception + { + // Store a dictionary + SystemDistributedKeyspace.storeCompressionDictionary(TEST_KEYSPACE, TEST_TABLE, testDictionary1); + + // Verify it was stored + CompressionDictionary retrieved = SystemDistributedKeyspace.retrieveLatestCompressionDictionary( + TEST_KEYSPACE, TEST_TABLE); + + assertThat(retrieved) + .as("Retrieved dictionary should not be null") + .isNotNull(); + + assertThat(retrieved.dictId()) + .as("Retrieved dictionary ID should match stored") + .isEqualTo(testDictionary1.dictId()); + + assertThat(retrieved.kind()) + .as("Retrieved dictionary kind should match stored") + .isEqualTo(testDictionary1.kind()); + + assertThat(retrieved.rawDictionary()) + .as("Retrieved dictionary data should match stored") + .isEqualTo(testDictionary1.rawDictionary()); + + retrieved.close(); + } + + @Test + public void testStoreMultipleDictionaries() throws Exception + { + // Store multiple dictionaries for the same table + SystemDistributedKeyspace.storeCompressionDictionary(TEST_KEYSPACE, TEST_TABLE, testDictionary1); + SystemDistributedKeyspace.storeCompressionDictionary(TEST_KEYSPACE, TEST_TABLE, testDictionary2); + + // Should retrieve the latest one (higher ID due to clustering order) + CompressionDictionary latest = SystemDistributedKeyspace.retrieveLatestCompressionDictionary( + TEST_KEYSPACE, TEST_TABLE); + + assertThat(latest) + .as("Should retrieve the latest dictionary") + .isNotNull(); + + assertThat(latest.dictId()) + .as("Should retrieve dictionary with higher ID") + .isEqualTo(testDictionary2.dictId()); + + latest.close(); + } + + @Test + public void testRetrieveSpecificDictionary() throws Exception + { + // Store both dictionaries + SystemDistributedKeyspace.storeCompressionDictionary(TEST_KEYSPACE, TEST_TABLE, testDictionary1); + SystemDistributedKeyspace.storeCompressionDictionary(TEST_KEYSPACE, TEST_TABLE, testDictionary2); + + // Retrieve specific dictionary by ID + CompressionDictionary dict1 = SystemDistributedKeyspace.retrieveCompressionDictionary( + TEST_KEYSPACE, TEST_TABLE, new DictId(Kind.ZSTD, 100L)); + CompressionDictionary dict2 = SystemDistributedKeyspace.retrieveCompressionDictionary( + TEST_KEYSPACE, TEST_TABLE, new DictId(Kind.ZSTD, 200L)); + + assertThat(dict1) + .as("Should retrieve dictionary 1") + .isNotNull(); + + assertThat(dict1.dictId()) + .as("Should retrieve correct dictionary by ID") + .isEqualTo(testDictionary1.dictId()); + + assertThat(dict2) + .as("Should retrieve dictionary 2") + .isNotNull(); + + assertThat(dict2.dictId()) + .as("Should retrieve correct dictionary by ID") + .isEqualTo(testDictionary2.dictId()); + + dict1.close(); + dict2.close(); + } + + @Test + public void testRetrieveNonExistentDictionary() + { + // Try to retrieve dictionary that doesn't exist + CompressionDictionary nonExistent = SystemDistributedKeyspace.retrieveLatestCompressionDictionary( + "nonexistent_keyspace", "nonexistent_table"); + + assertThat(nonExistent) + .as("Should return null for non-existent dictionary") + .isNull(); + + // Try to retrieve specific dictionary that doesn't exist + CompressionDictionary nonExistentById = SystemDistributedKeyspace.retrieveCompressionDictionary( + TEST_KEYSPACE, TEST_TABLE, new DictId(Kind.ZSTD, 999L)); + + assertThat(nonExistentById) + .as("Should return null for non-existent dictionary ID") + .isNull(); + } + + private void clearCompressionDictionaries() + { + for (String table : List.of(TEST_TABLE, OTHER_TABLE)) + { + String deleteQuery = String.format("DELETE FROM %s.%s WHERE keyspace_name = '%s' AND table_name = '%s'", + SchemaConstants.DISTRIBUTED_KEYSPACE_NAME, + SystemDistributedKeyspace.COMPRESSION_DICTIONARIES, + TEST_KEYSPACE, + table); + QueryProcessor.executeInternal(deleteQuery); + } + } +} diff --git a/test/unit/org/apache/cassandra/streaming/compression/CompressedInputStreamTest.java b/test/unit/org/apache/cassandra/streaming/compression/CompressedInputStreamTest.java index 391d58972106..44e5bc2719e8 100644 --- a/test/unit/org/apache/cassandra/streaming/compression/CompressedInputStreamTest.java +++ b/test/unit/org/apache/cassandra/streaming/compression/CompressedInputStreamTest.java @@ -142,7 +142,7 @@ private void testCompressedReadWith(long[] valuesToCheck, boolean testTruncate, writer.finish(); } - CompressionMetadata comp = CompressionInfoComponent.load(desc); + CompressionMetadata comp = CompressionInfoComponent.load(desc, null); List sections = new ArrayList<>(); for (long l : valuesToCheck) { diff --git a/test/unit/org/apache/cassandra/tools/ToolRunner.java b/test/unit/org/apache/cassandra/tools/ToolRunner.java index 3d4d4588fdd7..143b7cfbd693 100644 --- a/test/unit/org/apache/cassandra/tools/ToolRunner.java +++ b/test/unit/org/apache/cassandra/tools/ToolRunner.java @@ -666,6 +666,13 @@ public AssertHelp errorContainsAny(String... messages) return this; } + public AssertHelp stdoutContains(String message) + { + assertThat(message).hasSizeGreaterThan(0); + assertThat(stdout).isNotNull().contains(message); + return this; + } + private void fail(String msg) { StringBuilder sb = new StringBuilder(); diff --git a/test/unit/org/apache/cassandra/tools/nodetool/TrainCompressionDictionaryTest.java b/test/unit/org/apache/cassandra/tools/nodetool/TrainCompressionDictionaryTest.java new file mode 100644 index 000000000000..b2afa3bc27b4 --- /dev/null +++ b/test/unit/org/apache/cassandra/tools/nodetool/TrainCompressionDictionaryTest.java @@ -0,0 +1,221 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.tools.nodetool; + +import org.junit.BeforeClass; +import org.junit.Test; + +import org.apache.cassandra.cql3.CQLTester; +import org.apache.cassandra.tools.ToolRunner; + +import static org.apache.cassandra.tools.ToolRunner.invokeNodetool; +import static org.assertj.core.api.Assertions.assertThat; + +public class TrainCompressionDictionaryTest extends CQLTester +{ + @BeforeClass + public static void setup() throws Throwable + { + requireNetwork(); + startJMXServer(); + } + + @Test + public void testTrainCommandSuccess() + { + // Create a table with dictionary compression enabled + String table = createTable("CREATE TABLE %s (id int PRIMARY KEY, data text) WITH compression = {'class': 'ZstdDictionaryCompressor'}"); + + createSSTables(true); + + // Test training command + ToolRunner.ToolResult result = invokeNodetool("traincompressiondictionary", keyspace(), table); + result.assertOnCleanExit(); + + assertThat(result.getStdout()) + .as("Should indicate training completed") + .contains("Training completed successfully") + .contains(keyspace()) + .contains(table); + } + + @Test + public void testTrainCommandWithDataButNoSSTables() + { + String table = createTable("CREATE TABLE %s (id int PRIMARY KEY, data text) WITH compression = {'class': 'ZstdDictionaryCompressor'}"); + + // Add test data but don't flush - memtable should be flushed automatically + createSSTables(false); + + // Test training, the command should run flush before sampling + ToolRunner.ToolResult result = invokeNodetool("traincompressiondictionary", + keyspace(), + table); + result.assertOnCleanExit(); + + assertThat(result.getStdout()) + .as("Should flush automatically when no SSTables available") + .contains("Training completed successfully"); + } + + @Test + public void testTrainCommandWithNoSSTables() + { + String table = createTable("CREATE TABLE %s (id int PRIMARY KEY, data text) WITH compression = {'class': 'ZstdDictionaryCompressor'}"); + ToolRunner.ToolResult result = invokeNodetool("traincompressiondictionary", + keyspace(), + table); + assertThat(result.getStderr()) + .contains("Failed to trigger training: No SSTables available for training", "after flush"); + } + + @Test + public void testInvalidKeyspace() + { + ToolRunner.ToolResult result = invokeNodetool("traincompressiondictionary", + "nonexistent_keyspace", + "nonexistent_table"); + result.asserts() + .failure() + .errorContains("Failed to trigger training"); + } + + @Test + public void testInvalidTable() + { + ToolRunner.ToolResult result = invokeNodetool("traincompressiondictionary", + keyspace(), + "nonexistent_table"); + result.asserts() + .failure() + .errorContains("Failed to trigger training") + .errorContains("does not exist or does not support dictionary compression"); + } + + @Test + public void testTrainingOnNonDictionaryTable() + { + // Create table without dictionary compression + String table = createTable("CREATE TABLE %s (id int PRIMARY KEY, data text) WITH compression = {'class': 'LZ4Compressor'}"); + + ToolRunner.ToolResult result = invokeNodetool("traincompressiondictionary", + keyspace(), + table); + result.asserts() + .failure() + .errorContains("does not support dictionary compression"); + } + + @Test + public void testTrainingWithoutDictionaryCompressionEnabled() + { + // Create table with Zstd but without dictionary compression + String table = createTable("CREATE TABLE %s (id int PRIMARY KEY, data text) WITH compression = {'class': 'ZstdCompressor'}"); + + ToolRunner.ToolResult result = invokeNodetool("traincompressiondictionary", + keyspace(), + table); + result.asserts() + .failure() + .errorContains("does not support dictionary compression"); + } + + + @Test + public void testAlterCompressionToZstdDictionary() + { + // Create table with LZ4 compression + String table = createTable("CREATE TABLE %s (id int PRIMARY KEY, data text) WITH compression = {'class': 'LZ4Compressor'}"); + + // Training should fail on LZ4 table + ToolRunner.ToolResult result = invokeNodetool("traincompressiondictionary", keyspace(), table); + result.asserts() + .failure() + .errorContains("Failed to trigger training") + .errorContains("does not exist or does not support dictionary compression"); + + // Alter table to use ZstdDictionaryCompressor + execute("ALTER TABLE %s WITH compression = {'class': 'ZstdDictionaryCompressor'}"); + + // Training should fail with no sstables + result = invokeNodetool("traincompressiondictionary", keyspace(), table); + assertThat(result.getStderr()) + .contains("Failed to trigger training: No SSTables available for training", "after flush"); + + // Write sstables + createSSTables(true); + + // Training should now succeed + result = invokeNodetool("traincompressiondictionary", keyspace(), table); + result.assertOnCleanExit(); + + assertThat(result.getStdout()) + .as("Should indicate training completed with new dictionary") + .contains("Training completed successfully") + .contains(keyspace()) + .contains(table); + } + + @Test + public void testHelpOutput() + { + ToolRunner.ToolResult result = invokeNodetool("help", "traincompressiondictionary"); + result.assertOnCleanExit(); + + assertThat(result.getStdout()) + .as("Should show command help") + .contains("nodetool traincompressiondictionary - Manually trigger compression") + .contains("dictionary training for a table") + .contains("keyspace name") + .contains("table name"); + } + + @Test + public void testCommandLineArgumentParsing() + { + // Test missing required arguments + ToolRunner.ToolResult result = invokeNodetool("traincompressiondictionary"); + result.asserts() + .failure() + .stdoutContains("Missing required parameter"); + + // Test missing table argument + result = invokeNodetool("traincompressiondictionary", keyspace()); + result.asserts() + .failure() + .stdoutContains("Missing required parameter"); + } + + private void createSSTables(boolean flush) + { + for (int file = 0; file < 10; file++) + { + int batchSize = 1000; + for (int i = 0; i < batchSize; i++) + { + int index = i + file * batchSize; + execute("INSERT INTO %s (id, data) VALUES (?, ?)", index, "test data " + index); + } + if (flush) + { + flush(); + } + } + } +} diff --git a/test/unit/org/apache/cassandra/utils/StorageCompatibilityModeTest.java b/test/unit/org/apache/cassandra/utils/StorageCompatibilityModeTest.java index f8684edd6250..0bc7186891ff 100644 --- a/test/unit/org/apache/cassandra/utils/StorageCompatibilityModeTest.java +++ b/test/unit/org/apache/cassandra/utils/StorageCompatibilityModeTest.java @@ -40,6 +40,7 @@ public void testBtiFormatAndStorageCompatibilityMode() { case UPGRADING: case NONE: + case CASSANDRA_5: mode.validateSstableFormat(big); mode.validateSstableFormat(trie); break;