Skip to content

Commit 6778a2b

Browse files
author
tac0turtle
committed
add reusing of compression to avoid allocations
1 parent 9eb8af7 commit 6778a2b

File tree

1 file changed

+170
-14
lines changed

1 file changed

+170
-14
lines changed

da/compression/compression.go

Lines changed: 170 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"encoding/binary"
66
"errors"
77
"fmt"
8+
"sync"
89

910
"github.com/evstack/ev-node/core/da"
1011
"github.com/klauspost/compress/zstd"
@@ -54,6 +55,101 @@ func DefaultConfig() Config {
5455
}
5556
}
5657

58+
// Global sync.Pools for encoder/decoder reuse
59+
var (
60+
encoderPools map[int]*sync.Pool
61+
decoderPool *sync.Pool
62+
poolsOnce sync.Once
63+
)
64+
65+
// initPools initializes the encoder and decoder pools
66+
func initPools() {
67+
poolsOnce.Do(func() {
68+
// Create encoder pools for different compression levels
69+
encoderPools = make(map[int]*sync.Pool)
70+
71+
// Pre-create pools for common compression levels (1-9)
72+
for level := 1; level <= 9; level++ {
73+
lvl := level // Capture loop variable
74+
encoderPools[lvl] = &sync.Pool{
75+
New: func() interface{} {
76+
encoder, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(zstd.EncoderLevelFromZstd(lvl)))
77+
if err != nil {
78+
// This should not happen with valid levels
79+
panic(fmt.Sprintf("failed to create zstd encoder with level %d: %v", lvl, err))
80+
}
81+
return encoder
82+
},
83+
}
84+
}
85+
86+
// Create decoder pool
87+
decoderPool = &sync.Pool{
88+
New: func() interface{} {
89+
decoder, err := zstd.NewReader(nil)
90+
if err != nil {
91+
// This should not happen
92+
panic(fmt.Sprintf("failed to create zstd decoder: %v", err))
93+
}
94+
return decoder
95+
},
96+
}
97+
})
98+
}
99+
100+
// getEncoder retrieves an encoder from the pool for the specified compression level
101+
func getEncoder(level int) *zstd.Encoder {
102+
initPools()
103+
104+
pool, exists := encoderPools[level]
105+
if !exists {
106+
// Create a new pool for this level if it doesn't exist
107+
pool = &sync.Pool{
108+
New: func() interface{} {
109+
encoder, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(zstd.EncoderLevelFromZstd(level)))
110+
if err != nil {
111+
panic(fmt.Sprintf("failed to create zstd encoder with level %d: %v", level, err))
112+
}
113+
return encoder
114+
},
115+
}
116+
encoderPools[level] = pool
117+
}
118+
119+
return pool.Get().(*zstd.Encoder)
120+
}
121+
122+
// putEncoder returns an encoder to the pool
123+
func putEncoder(encoder *zstd.Encoder, level int) {
124+
if encoder == nil {
125+
return
126+
}
127+
128+
// Reset the encoder for reuse
129+
encoder.Reset(nil)
130+
131+
if pool, exists := encoderPools[level]; exists {
132+
pool.Put(encoder)
133+
}
134+
}
135+
136+
// getDecoder retrieves a decoder from the pool
137+
func getDecoder() *zstd.Decoder {
138+
initPools()
139+
return decoderPool.Get().(*zstd.Decoder)
140+
}
141+
142+
// putDecoder returns a decoder to the pool
143+
func putDecoder(decoder *zstd.Decoder) {
144+
if decoder == nil {
145+
return
146+
}
147+
148+
// Reset the decoder for reuse
149+
decoder.Reset(nil)
150+
decoderPool.Put(decoder)
151+
}
152+
57153
// CompressibleDA wraps a DA implementation to add transparent compression support
58154
type CompressibleDA struct {
59155
baseDA da.DA
@@ -64,9 +160,8 @@ type CompressibleDA struct {
64160

65161
// NewCompressibleDA creates a new CompressibleDA wrapper
66162
func NewCompressibleDA(baseDA da.DA, config Config) (*CompressibleDA, error) {
67-
if baseDA == nil {
68-
return nil, errors.New("base DA cannot be nil")
69-
}
163+
// Allow nil baseDA for testing purposes (when only using compression functions)
164+
// The baseDA will only be used when calling Submit, Get, GetIDs methods
70165

71166
var encoder *zstd.Encoder
72167
var decoder *zstd.Decoder
@@ -277,25 +372,86 @@ func (c *CompressibleDA) GasMultiplier(ctx context.Context) (float64, error) {
277372
// CompressBlob compresses a blob using the default zstd level 3 configuration
278373
func CompressBlob(blob da.Blob) (da.Blob, error) {
279374
config := DefaultConfig()
280-
compressor, err := NewCompressibleDA(nil, config)
281-
if err != nil {
282-
return nil, err
375+
376+
if !config.Enabled || len(blob) == 0 {
377+
// Return with uncompressed header
378+
return addCompressionHeaderStandalone(blob, FlagUncompressed, uint64(len(blob))), nil
283379
}
284-
defer compressor.Close()
285-
286-
return compressor.compressBlob(blob)
380+
381+
// Get encoder from pool
382+
encoder := getEncoder(config.ZstdLevel)
383+
defer putEncoder(encoder, config.ZstdLevel)
384+
385+
// Compress the blob
386+
compressed := encoder.EncodeAll(blob, make([]byte, 0, len(blob)))
387+
388+
// Check if compression is beneficial
389+
compressionRatio := float64(len(compressed)) / float64(len(blob))
390+
if compressionRatio > (1.0 - config.MinCompressionRatio) {
391+
// Compression not beneficial, store uncompressed
392+
return addCompressionHeaderStandalone(blob, FlagUncompressed, uint64(len(blob))), nil
393+
}
394+
395+
return addCompressionHeaderStandalone(compressed, FlagZstd, uint64(len(blob))), nil
287396
}
288397

289398
// DecompressBlob decompresses a blob
290399
func DecompressBlob(compressedBlob da.Blob) (da.Blob, error) {
291-
config := DefaultConfig()
292-
compressor, err := NewCompressibleDA(nil, config)
400+
if len(compressedBlob) < CompressionHeaderSize {
401+
// Assume legacy uncompressed blob
402+
return compressedBlob, nil
403+
}
404+
405+
flag, originalSize, payload, err := parseCompressionHeaderStandalone(compressedBlob)
293406
if err != nil {
294-
return nil, err
407+
// Assume legacy uncompressed blob
408+
return compressedBlob, nil
295409
}
296-
defer compressor.Close()
410+
411+
switch flag {
412+
case FlagUncompressed:
413+
return payload, nil
414+
case FlagZstd:
415+
// Get decoder from pool
416+
decoder := getDecoder()
417+
defer putDecoder(decoder)
418+
419+
decompressed, err := decoder.DecodeAll(payload, make([]byte, 0, originalSize))
420+
if err != nil {
421+
return nil, fmt.Errorf("%w: %v", ErrDecompressionFailed, err)
422+
}
423+
424+
if uint64(len(decompressed)) != originalSize {
425+
return nil, fmt.Errorf("decompressed size mismatch: expected %d, got %d", originalSize, len(decompressed))
426+
}
427+
428+
return decompressed, nil
429+
default:
430+
return nil, fmt.Errorf("unsupported compression flag: %d", flag)
431+
}
432+
}
433+
434+
// Standalone helper functions for use without CompressibleDA instance
297435

298-
return compressor.decompressBlob(compressedBlob)
436+
// addCompressionHeaderStandalone adds compression metadata header to data
437+
func addCompressionHeaderStandalone(data []byte, flag uint8, originalSize uint64) []byte {
438+
header := make([]byte, CompressionHeaderSize)
439+
header[0] = flag
440+
binary.BigEndian.PutUint64(header[1:], originalSize)
441+
return append(header, data...)
442+
}
443+
444+
// parseCompressionHeaderStandalone parses compression metadata from blob
445+
func parseCompressionHeaderStandalone(blob []byte) (flag uint8, originalSize uint64, payload []byte, err error) {
446+
if len(blob) < CompressionHeaderSize {
447+
return 0, 0, nil, errors.New("blob too small for compression header")
448+
}
449+
450+
flag = blob[0]
451+
originalSize = binary.BigEndian.Uint64(blob[1:9])
452+
payload = blob[CompressionHeaderSize:]
453+
454+
return flag, originalSize, payload, nil
299455
}
300456

301457
// CompressionInfo provides information about a blob's compression

0 commit comments

Comments
 (0)