@@ -3,6 +3,8 @@ package grpcclients
33import (
44 "context"
55 "io"
6+ "slices"
7+ "sync"
68
79 remoteexecution "github.com/bazelbuild/remote-apis/build/bazel/remote/execution/v2"
810 "github.com/buildbarn/bb-storage/pkg/blobstore"
@@ -11,10 +13,13 @@ import (
1113 "github.com/buildbarn/bb-storage/pkg/digest"
1214 "github.com/buildbarn/bb-storage/pkg/util"
1315 "github.com/google/uuid"
16+ "github.com/klauspost/compress/zstd"
1417
1518 "google.golang.org/genproto/googleapis/bytestream"
1619 "google.golang.org/grpc"
20+ "google.golang.org/grpc/codes"
1721 "google.golang.org/grpc/metadata"
22+ "google.golang.org/grpc/status"
1823)
1924
2025type casBlobAccess struct {
@@ -23,20 +28,29 @@ type casBlobAccess struct {
2328 capabilitiesClient remoteexecution.CapabilitiesClient
2429 uuidGenerator util.UUIDGenerator
2530 readChunkSize int
31+ compressionThresholdBytes int64
32+ supportedCompressors []remoteexecution.Compressor_Value
33+ supportedCompressorsMutex sync.RWMutex
34+ capabilitiesCalled bool
2635}
2736
2837// NewCASBlobAccess creates a BlobAccess handle that relays any requests
2938// to a GRPC service that implements the bytestream.ByteStream and
3039// remoteexecution.ContentAddressableStorage services. Those are the
3140// services that Bazel uses to access blobs stored in the Content
3241// Addressable Storage.
33- func NewCASBlobAccess (client grpc.ClientConnInterface , uuidGenerator util.UUIDGenerator , readChunkSize int ) blobstore.BlobAccess {
42+ //
43+ // If compressionThresholdBytes is > 0, the client will attempt to use
44+ // ZSTD compression for blobs larger than this threshold. The server's
45+ // supported compressors will be checked via GetCapabilities().
46+ func NewCASBlobAccess (client grpc.ClientConnInterface , uuidGenerator util.UUIDGenerator , readChunkSize int , compressionThresholdBytes int64 ) blobstore.BlobAccess {
3447 return & casBlobAccess {
3548 byteStreamClient : bytestream .NewByteStreamClient (client ),
3649 contentAddressableStorageClient : remoteexecution .NewContentAddressableStorageClient (client ),
3750 capabilitiesClient : remoteexecution .NewCapabilitiesClient (client ),
3851 uuidGenerator : uuidGenerator ,
3952 readChunkSize : readChunkSize ,
53+ compressionThresholdBytes : compressionThresholdBytes ,
4054 }
4155}
4256
@@ -62,11 +76,147 @@ func (r *byteStreamChunkReader) Close() {
6276 }
6377}
6478
79+ type zstdByteStreamChunkReader struct {
80+ client bytestream.ByteStream_ReadClient
81+ cancel context.CancelFunc
82+ zstdReader io.ReadCloser
83+ readChunkSize int
84+ wg sync.WaitGroup
85+ }
86+
87+ func (r * zstdByteStreamChunkReader ) Read () ([]byte , error ) {
88+ if r .zstdReader == nil {
89+ pr , pw := io .Pipe ()
90+
91+ r .wg .Add (1 )
92+ go func () {
93+ defer r .wg .Done ()
94+ defer pw .Close ()
95+ for {
96+ chunk , err := r .client .Recv ()
97+ if err != nil {
98+ if err != io .EOF {
99+ pw .CloseWithError (err )
100+ }
101+ return
102+ }
103+ if _ , writeErr := pw .Write (chunk .Data ); writeErr != nil {
104+ pw .CloseWithError (writeErr )
105+ return
106+ }
107+ }
108+ }()
109+
110+ var err error
111+ r .zstdReader , err = util .NewZstdReadCloser (pr , zstd .WithDecoderConcurrency (1 ))
112+ if err != nil {
113+ pr .Close ()
114+ return nil , err
115+ }
116+ }
117+
118+ buf := make ([]byte , r .readChunkSize )
119+ n , err := r .zstdReader .Read (buf )
120+ if n > 0 {
121+ if err != nil && err != io .EOF {
122+ err = nil
123+ }
124+ return buf [:n ], err
125+ }
126+ return nil , err
127+ }
128+
129+ func (r * zstdByteStreamChunkReader ) Close () {
130+ if r .zstdReader != nil {
131+ r .zstdReader .Close ()
132+ }
133+ r .cancel ()
134+
135+ // Drain the gRPC stream.
136+ for {
137+ if _ , err := r .client .Recv (); err != nil {
138+ break
139+ }
140+ }
141+ r .wg .Wait ()
142+ }
143+
144+ type zstdByteStreamWriter struct {
145+ client bytestream.ByteStream_WriteClient
146+ resourceName string
147+ writeOffset int64
148+ cancel context.CancelFunc
149+ }
150+
151+ func (w * zstdByteStreamWriter ) Write (p []byte ) (int , error ) {
152+ if err := w .client .Send (& bytestream.WriteRequest {
153+ ResourceName : w .resourceName ,
154+ WriteOffset : w .writeOffset ,
155+ Data : p ,
156+ }); err != nil {
157+ return 0 , err
158+ }
159+ w .writeOffset += int64 (len (p ))
160+ w .resourceName = ""
161+ return len (p ), nil
162+ }
163+
164+ func (w * zstdByteStreamWriter ) Close () error {
165+ if err := w .client .Send (& bytestream.WriteRequest {
166+ ResourceName : w .resourceName ,
167+ WriteOffset : w .writeOffset ,
168+ FinishWrite : true ,
169+ }); err != nil {
170+ w .cancel ()
171+ w .client .CloseAndRecv ()
172+ return err
173+ }
174+ _ , err := w .client .CloseAndRecv ()
175+ w .cancel ()
176+ return err
177+ }
178+
65179const resourceNameHeader = "build.bazel.remote.execution.v2.resource-name"
66180
181+ // shouldUseCompression checks if compression should be used for a blob of the given size.
182+ // It also ensures GetCapabilities has been called to negotiate compression support.
183+ func (ba * casBlobAccess ) shouldUseCompression (ctx context.Context , digest digest.Digest ) (bool , error ) {
184+ if ba .compressionThresholdBytes <= 0 || digest .GetSizeBytes () < ba .compressionThresholdBytes {
185+ return false , nil
186+ }
187+
188+ ba .supportedCompressorsMutex .RLock ()
189+ capabilitiesCalled := ba .capabilitiesCalled
190+ supportedCompressors := ba .supportedCompressors
191+ ba .supportedCompressorsMutex .RUnlock ()
192+
193+ if ! capabilitiesCalled {
194+ // Call GetCapabilities to check server support
195+ _ , err := ba .GetCapabilities (ctx , digest .GetDigestFunction ().GetInstanceName ())
196+ if err != nil {
197+ return false , err
198+ }
199+ ba .supportedCompressorsMutex .RLock ()
200+ supportedCompressors = ba .supportedCompressors
201+ ba .supportedCompressorsMutex .RUnlock ()
202+ }
203+
204+ return slices .Contains (supportedCompressors , remoteexecution .Compressor_ZSTD ), nil
205+ }
206+
67207func (ba * casBlobAccess ) Get (ctx context.Context , digest digest.Digest ) buffer.Buffer {
208+ useCompression , err := ba .shouldUseCompression (ctx , digest )
209+ if err != nil {
210+ return buffer .NewBufferFromError (err )
211+ }
212+
213+ compressor := remoteexecution .Compressor_IDENTITY
214+ if useCompression {
215+ compressor = remoteexecution .Compressor_ZSTD
216+ }
217+
68218 ctxWithCancel , cancel := context .WithCancel (ctx )
69- resourceName := digest .GetByteStreamReadPath (remoteexecution . Compressor_IDENTITY )
219+ resourceName := digest .GetByteStreamReadPath (compressor )
70220 client , err := ba .byteStreamClient .Read (
71221 metadata .AppendToOutgoingContext (ctxWithCancel , resourceNameHeader , resourceName ),
72222 & bytestream.ReadRequest {
@@ -77,6 +227,15 @@ func (ba *casBlobAccess) Get(ctx context.Context, digest digest.Digest) buffer.B
77227 cancel ()
78228 return buffer .NewBufferFromError (err )
79229 }
230+
231+ if useCompression {
232+ return buffer .NewCASBufferFromChunkReader (digest , & zstdByteStreamChunkReader {
233+ client : client ,
234+ cancel : cancel ,
235+ readChunkSize : ba .readChunkSize ,
236+ }, buffer .BackendProvided (buffer .Irreparable (digest )))
237+ }
238+
80239 return buffer .NewCASBufferFromChunkReader (digest , & byteStreamChunkReader {
81240 client : client ,
82241 cancel : cancel ,
@@ -89,19 +248,61 @@ func (ba *casBlobAccess) GetFromComposite(ctx context.Context, parentDigest, chi
89248}
90249
91250func (ba * casBlobAccess ) Put (ctx context.Context , digest digest.Digest , b buffer.Buffer ) error {
92- r := b .ToChunkReader (0 , ba .readChunkSize )
93- defer r .Close ()
251+ useCompression , err := ba .shouldUseCompression (ctx , digest )
252+ if err != nil {
253+ b .Discard ()
254+ return err
255+ }
256+
257+ compressor := remoteexecution .Compressor_IDENTITY
258+ if useCompression {
259+ compressor = remoteexecution .Compressor_ZSTD
260+ }
94261
95262 ctxWithCancel , cancel := context .WithCancel (ctx )
96- resourceName := digest .GetByteStreamWritePath (uuid .Must (ba .uuidGenerator ()), remoteexecution . Compressor_IDENTITY )
263+ resourceName := digest .GetByteStreamWritePath (uuid .Must (ba .uuidGenerator ()), compressor )
97264 client , err := ba .byteStreamClient .Write (
98265 metadata .AppendToOutgoingContext (ctxWithCancel , resourceNameHeader , resourceName ),
99266 )
100267 if err != nil {
101268 cancel ()
269+ b .Discard ()
102270 return err
103271 }
104272
273+ if useCompression {
274+ byteStreamWriter := & zstdByteStreamWriter {
275+ client : client ,
276+ resourceName : resourceName ,
277+ writeOffset : 0 ,
278+ cancel : cancel ,
279+ }
280+
281+ zstdWriter , err := zstd .NewWriter (byteStreamWriter , zstd .WithEncoderConcurrency (1 ))
282+ if err != nil {
283+ cancel ()
284+ client .CloseAndRecv ()
285+ return status .Errorf (codes .Internal , "Failed to create zstd writer: %v" , err )
286+ }
287+
288+ if err := b .IntoWriter (zstdWriter ); err != nil {
289+ zstdWriter .Close ()
290+ byteStreamWriter .Close ()
291+ return err
292+ }
293+
294+ if err := zstdWriter .Close (); err != nil {
295+ byteStreamWriter .Close ()
296+ return err
297+ }
298+
299+ return byteStreamWriter .Close ()
300+ }
301+
302+ // Non-compressed path
303+ r := b .ToChunkReader (0 , ba .readChunkSize )
304+ defer r .Close ()
305+
105306 writeOffset := int64 (0 )
106307 for {
107308 if data , err := r .Read (); err == nil {
@@ -140,6 +341,10 @@ func (ba *casBlobAccess) Put(ctx context.Context, digest digest.Digest, b buffer
140341}
141342
142343func (ba * casBlobAccess ) FindMissing (ctx context.Context , digests digest.Set ) (digest.Set , error ) {
344+ return findMissingBlobsInternal (ctx , digests , ba .contentAddressableStorageClient )
345+ }
346+
347+ func findMissingBlobsInternal (ctx context.Context , digests digest.Set , cas remoteexecution.ContentAddressableStorageClient ) (digest.Set , error ) {
143348 // Partition all digests by digest function, as the
144349 // FindMissingBlobs() RPC can only process digests for a single
145350 // instance name and digest function.
@@ -157,7 +362,7 @@ func (ba *casBlobAccess) FindMissing(ctx context.Context, digests digest.Set) (d
157362 BlobDigests : blobDigests ,
158363 DigestFunction : digestFunction .GetEnumValue (),
159364 }
160- response , err := ba . contentAddressableStorageClient .FindMissingBlobs (ctx , & request )
365+ response , err := cas .FindMissingBlobs (ctx , & request )
161366 if err != nil {
162367 return digest .EmptySet , err
163368 }
@@ -180,11 +385,17 @@ func (ba *casBlobAccess) GetCapabilities(ctx context.Context, instanceName diges
180385 return nil , err
181386 }
182387
388+ cacheCapabilities := serverCapabilities .CacheCapabilities
389+
390+ // Store supported compressors for compression negotiation
391+ ba .supportedCompressorsMutex .Lock ()
392+ ba .supportedCompressors = cacheCapabilities .SupportedCompressors
393+ ba .capabilitiesCalled = true
394+ ba .supportedCompressorsMutex .Unlock ()
395+
183396 // Only return fields that pertain to the Content Addressable
184397 // Storage. Don't set 'max_batch_total_size_bytes', as we don't
185- // issue batch operations. The same holds for fields related to
186- // compression support.
187- cacheCapabilities := serverCapabilities .CacheCapabilities
398+ // issue batch operations.
188399 return & remoteexecution.ServerCapabilities {
189400 CacheCapabilities : & remoteexecution.CacheCapabilities {
190401 DigestFunctions : digest .RemoveUnsupportedDigestFunctions (cacheCapabilities .DigestFunctions ),
0 commit comments