@@ -3,6 +3,9 @@ package grpcclients
33import (
44 "context"
55 "io"
6+ "slices"
7+ "sync"
8+ "sync/atomic"
69
710 remoteexecution "github.com/bazelbuild/remote-apis/build/bazel/remote/execution/v2"
811 "github.com/buildbarn/bb-storage/pkg/blobstore"
@@ -11,10 +14,13 @@ import (
1114 "github.com/buildbarn/bb-storage/pkg/digest"
1215 "github.com/buildbarn/bb-storage/pkg/util"
1316 "github.com/google/uuid"
17+ "github.com/klauspost/compress/zstd"
1418
1519 "google.golang.org/genproto/googleapis/bytestream"
1620 "google.golang.org/grpc"
21+ "google.golang.org/grpc/codes"
1722 "google.golang.org/grpc/metadata"
23+ "google.golang.org/grpc/status"
1824)
1925
2026type casBlobAccess struct {
@@ -23,20 +29,26 @@ type casBlobAccess struct {
2329 capabilitiesClient remoteexecution.CapabilitiesClient
2430 uuidGenerator util.UUIDGenerator
2531 readChunkSize int
32+ enableZSTDCompression bool
33+ supportedCompressors atomic.Pointer [[]remoteexecution.Compressor_Value ]
2634}
2735
2836// NewCASBlobAccess creates a BlobAccess handle that relays any requests
29- // to a GRPC service that implements the bytestream.ByteStream and
37+ // to a gRPC service that implements the bytestream.ByteStream and
3038// remoteexecution.ContentAddressableStorage services. Those are the
3139// services that Bazel uses to access blobs stored in the Content
3240// Addressable Storage.
33- func NewCASBlobAccess (client grpc.ClientConnInterface , uuidGenerator util.UUIDGenerator , readChunkSize int ) blobstore.BlobAccess {
41+ //
42+ // If enableZSTDCompression is true, the client will use ZSTD compression
43+ // for ByteStream operations if the server supports it.
44+ func NewCASBlobAccess (client grpc.ClientConnInterface , uuidGenerator util.UUIDGenerator , readChunkSize int , enableZSTDCompression bool ) blobstore.BlobAccess {
3445 return & casBlobAccess {
3546 byteStreamClient : bytestream .NewByteStreamClient (client ),
3647 contentAddressableStorageClient : remoteexecution .NewContentAddressableStorageClient (client ),
3748 capabilitiesClient : remoteexecution .NewCapabilitiesClient (client ),
3849 uuidGenerator : uuidGenerator ,
3950 readChunkSize : readChunkSize ,
51+ enableZSTDCompression : enableZSTDCompression ,
4052 }
4153}
4254
@@ -62,11 +74,140 @@ func (r *byteStreamChunkReader) Close() {
6274 }
6375}
6476
77+ type zstdByteStreamChunkReader struct {
78+ client bytestream.ByteStream_ReadClient
79+ cancel context.CancelFunc
80+ zstdReader io.ReadCloser
81+ readChunkSize int
82+ wg sync.WaitGroup
83+ }
84+
85+ func (r * zstdByteStreamChunkReader ) Read () ([]byte , error ) {
86+ if r .zstdReader == nil {
87+ pr , pw := io .Pipe ()
88+
89+ r .wg .Add (1 )
90+ go func () {
91+ defer r .wg .Done ()
92+ defer pw .Close ()
93+ for {
94+ chunk , err := r .client .Recv ()
95+ if err != nil {
96+ if err != io .EOF {
97+ pw .CloseWithError (err )
98+ }
99+ return
100+ }
101+ if _ , writeErr := pw .Write (chunk .Data ); writeErr != nil {
102+ pw .CloseWithError (writeErr )
103+ return
104+ }
105+ }
106+ }()
107+
108+ var err error
109+ r .zstdReader , err = util .NewZstdReadCloser (pr , zstd .WithDecoderConcurrency (1 ))
110+ if err != nil {
111+ pr .Close ()
112+ return nil , err
113+ }
114+ }
115+
116+ buf := make ([]byte , r .readChunkSize )
117+ n , err := r .zstdReader .Read (buf )
118+ if n > 0 {
119+ if err != nil && err != io .EOF {
120+ err = nil
121+ }
122+ return buf [:n ], err
123+ }
124+ return nil , err
125+ }
126+
127+ func (r * zstdByteStreamChunkReader ) Close () {
128+ if r .zstdReader != nil {
129+ r .zstdReader .Close ()
130+ }
131+ r .cancel ()
132+
133+ // Drain the gRPC stream.
134+ for {
135+ if _ , err := r .client .Recv (); err != nil {
136+ break
137+ }
138+ }
139+ r .wg .Wait ()
140+ }
141+
142+ type zstdByteStreamWriter struct {
143+ client bytestream.ByteStream_WriteClient
144+ resourceName string
145+ writeOffset int64
146+ cancel context.CancelFunc
147+ }
148+
149+ func (w * zstdByteStreamWriter ) Write (p []byte ) (int , error ) {
150+ if err := w .client .Send (& bytestream.WriteRequest {
151+ ResourceName : w .resourceName ,
152+ WriteOffset : w .writeOffset ,
153+ Data : p ,
154+ }); err != nil {
155+ return 0 , err
156+ }
157+ w .writeOffset += int64 (len (p ))
158+ w .resourceName = ""
159+ return len (p ), nil
160+ }
161+
162+ func (w * zstdByteStreamWriter ) Close () error {
163+ if err := w .client .Send (& bytestream.WriteRequest {
164+ ResourceName : w .resourceName ,
165+ WriteOffset : w .writeOffset ,
166+ FinishWrite : true ,
167+ }); err != nil {
168+ w .cancel ()
169+ w .client .CloseAndRecv ()
170+ return err
171+ }
172+ _ , err := w .client .CloseAndRecv ()
173+ w .cancel ()
174+ return err
175+ }
176+
65177const resourceNameHeader = "build.bazel.remote.execution.v2.resource-name"
66178
179+ // shouldUseZSTDCompression checks if ZSTD compression should be used.
180+ // It ensures GetCapabilities has been called to negotiate compression support.
181+ func (ba * casBlobAccess ) shouldUseZSTDCompression (ctx context.Context , digest digest.Digest ) (bool , error ) {
182+ if ! ba .enableZSTDCompression {
183+ return false , nil
184+ }
185+
186+ supportedCompressors := ba .supportedCompressors .Load ()
187+ if supportedCompressors == nil {
188+ // Call GetCapabilities to check server support.
189+ if _ , err := ba .GetCapabilities (ctx , digest .GetDigestFunction ().GetInstanceName ()); err != nil {
190+ return false , err
191+ }
192+ supportedCompressors = ba .supportedCompressors .Load ()
193+ }
194+
195+ return slices .Contains (* supportedCompressors , remoteexecution .Compressor_ZSTD ), nil
196+ }
197+
67198func (ba * casBlobAccess ) Get (ctx context.Context , digest digest.Digest ) buffer.Buffer {
199+ useCompression , err := ba .shouldUseZSTDCompression (ctx , digest )
200+ if err != nil {
201+ return buffer .NewBufferFromError (err )
202+ }
203+
204+ compressor := remoteexecution .Compressor_IDENTITY
205+ if useCompression {
206+ compressor = remoteexecution .Compressor_ZSTD
207+ }
208+
68209 ctxWithCancel , cancel := context .WithCancel (ctx )
69- resourceName := digest .GetByteStreamReadPath (remoteexecution . Compressor_IDENTITY )
210+ resourceName := digest .GetByteStreamReadPath (compressor )
70211 client , err := ba .byteStreamClient .Read (
71212 metadata .AppendToOutgoingContext (ctxWithCancel , resourceNameHeader , resourceName ),
72213 & bytestream.ReadRequest {
@@ -77,6 +218,15 @@ func (ba *casBlobAccess) Get(ctx context.Context, digest digest.Digest) buffer.B
77218 cancel ()
78219 return buffer .NewBufferFromError (err )
79220 }
221+
222+ if useCompression {
223+ return buffer .NewCASBufferFromChunkReader (digest , & zstdByteStreamChunkReader {
224+ client : client ,
225+ cancel : cancel ,
226+ readChunkSize : ba .readChunkSize ,
227+ }, buffer .BackendProvided (buffer .Irreparable (digest )))
228+ }
229+
80230 return buffer .NewCASBufferFromChunkReader (digest , & byteStreamChunkReader {
81231 client : client ,
82232 cancel : cancel ,
@@ -89,19 +239,61 @@ func (ba *casBlobAccess) GetFromComposite(ctx context.Context, parentDigest, chi
89239}
90240
91241func (ba * casBlobAccess ) Put (ctx context.Context , digest digest.Digest , b buffer.Buffer ) error {
92- r := b .ToChunkReader (0 , ba .readChunkSize )
93- defer r .Close ()
242+ useCompression , err := ba .shouldUseZSTDCompression (ctx , digest )
243+ if err != nil {
244+ b .Discard ()
245+ return err
246+ }
247+
248+ compressor := remoteexecution .Compressor_IDENTITY
249+ if useCompression {
250+ compressor = remoteexecution .Compressor_ZSTD
251+ }
94252
95253 ctxWithCancel , cancel := context .WithCancel (ctx )
96- resourceName := digest .GetByteStreamWritePath (uuid .Must (ba .uuidGenerator ()), remoteexecution . Compressor_IDENTITY )
254+ resourceName := digest .GetByteStreamWritePath (uuid .Must (ba .uuidGenerator ()), compressor )
97255 client , err := ba .byteStreamClient .Write (
98256 metadata .AppendToOutgoingContext (ctxWithCancel , resourceNameHeader , resourceName ),
99257 )
100258 if err != nil {
101259 cancel ()
260+ b .Discard ()
102261 return err
103262 }
104263
264+ if useCompression {
265+ byteStreamWriter := & zstdByteStreamWriter {
266+ client : client ,
267+ resourceName : resourceName ,
268+ writeOffset : 0 ,
269+ cancel : cancel ,
270+ }
271+
272+ zstdWriter , err := zstd .NewWriter (byteStreamWriter , zstd .WithEncoderConcurrency (1 ))
273+ if err != nil {
274+ cancel ()
275+ client .CloseAndRecv ()
276+ return status .Errorf (codes .Internal , "Failed to create zstd writer: %v" , err )
277+ }
278+
279+ if err := b .IntoWriter (zstdWriter ); err != nil {
280+ zstdWriter .Close ()
281+ byteStreamWriter .Close ()
282+ return err
283+ }
284+
285+ if err := zstdWriter .Close (); err != nil {
286+ byteStreamWriter .Close ()
287+ return err
288+ }
289+
290+ return byteStreamWriter .Close ()
291+ }
292+
293+ // Non-compressed path
294+ r := b .ToChunkReader (0 , ba .readChunkSize )
295+ defer r .Close ()
296+
105297 writeOffset := int64 (0 )
106298 for {
107299 if data , err := r .Read (); err == nil {
@@ -140,6 +332,10 @@ func (ba *casBlobAccess) Put(ctx context.Context, digest digest.Digest, b buffer
140332}
141333
142334func (ba * casBlobAccess ) FindMissing (ctx context.Context , digests digest.Set ) (digest.Set , error ) {
335+ return findMissingBlobsInternal (ctx , digests , ba .contentAddressableStorageClient )
336+ }
337+
338+ func findMissingBlobsInternal (ctx context.Context , digests digest.Set , cas remoteexecution.ContentAddressableStorageClient ) (digest.Set , error ) {
143339 // Partition all digests by digest function, as the
144340 // FindMissingBlobs() RPC can only process digests for a single
145341 // instance name and digest function.
@@ -157,7 +353,7 @@ func (ba *casBlobAccess) FindMissing(ctx context.Context, digests digest.Set) (d
157353 BlobDigests : blobDigests ,
158354 DigestFunction : digestFunction .GetEnumValue (),
159355 }
160- response , err := ba . contentAddressableStorageClient .FindMissingBlobs (ctx , & request )
356+ response , err := cas .FindMissingBlobs (ctx , & request )
161357 if err != nil {
162358 return digest .EmptySet , err
163359 }
@@ -180,11 +376,14 @@ func (ba *casBlobAccess) GetCapabilities(ctx context.Context, instanceName diges
180376 return nil , err
181377 }
182378
379+ cacheCapabilities := serverCapabilities .CacheCapabilities
380+
381+ // Store supported compressors for compression negotiation.
382+ ba .supportedCompressors .Store (& cacheCapabilities .SupportedCompressors )
383+
183384 // Only return fields that pertain to the Content Addressable
184385 // Storage. Don't set 'max_batch_total_size_bytes', as we don't
185- // issue batch operations. The same holds for fields related to
186- // compression support.
187- cacheCapabilities := serverCapabilities .CacheCapabilities
386+ // issue batch operations.
188387 return & remoteexecution.ServerCapabilities {
189388 CacheCapabilities : & remoteexecution.CacheCapabilities {
190389 DigestFunctions : digest .RemoveUnsupportedDigestFunctions (cacheCapabilities .DigestFunctions ),
0 commit comments