|
| 1 | +package grpcclients |
| 2 | + |
| 3 | +import ( |
| 4 | + "context" |
| 5 | + "io" |
| 6 | + "slices" |
| 7 | + "sync" |
| 8 | + |
| 9 | + remoteexecution "github.com/bazelbuild/remote-apis/build/bazel/remote/execution/v2" |
| 10 | + "github.com/buildbarn/bb-storage/pkg/blobstore" |
| 11 | + "github.com/buildbarn/bb-storage/pkg/blobstore/buffer" |
| 12 | + "github.com/buildbarn/bb-storage/pkg/blobstore/slicing" |
| 13 | + "github.com/buildbarn/bb-storage/pkg/digest" |
| 14 | + "github.com/buildbarn/bb-storage/pkg/util" |
| 15 | + "github.com/google/uuid" |
| 16 | + "github.com/klauspost/compress/zstd" |
| 17 | + |
| 18 | + "google.golang.org/genproto/googleapis/bytestream" |
| 19 | + "google.golang.org/grpc" |
| 20 | + "google.golang.org/grpc/codes" |
| 21 | + "google.golang.org/grpc/metadata" |
| 22 | + "google.golang.org/grpc/status" |
| 23 | +) |
| 24 | + |
| 25 | +type casWithZstdBlobAccess struct { |
| 26 | + byteStreamClient bytestream.ByteStreamClient |
| 27 | + contentAddressableStorageClient remoteexecution.ContentAddressableStorageClient |
| 28 | + capabilitiesClient remoteexecution.CapabilitiesClient |
| 29 | + uuidGenerator util.UUIDGenerator |
| 30 | + readChunkSize int |
| 31 | +} |
| 32 | + |
| 33 | +// NewCASWithZstdBlobAccess creates a BlobAccess handle that relays any requests |
| 34 | +// to a GRPC service that implements the bytestream.ByteStream and |
| 35 | +// remoteexecution.ContentAddressableStorage services with ZSTD compression support. |
| 36 | +func NewCASWithZstdBlobAccess(client grpc.ClientConnInterface, uuidGenerator util.UUIDGenerator, readChunkSize int) blobstore.BlobAccess { |
| 37 | + return &casWithZstdBlobAccess{ |
| 38 | + byteStreamClient: bytestream.NewByteStreamClient(client), |
| 39 | + contentAddressableStorageClient: remoteexecution.NewContentAddressableStorageClient(client), |
| 40 | + capabilitiesClient: remoteexecution.NewCapabilitiesClient(client), |
| 41 | + uuidGenerator: uuidGenerator, |
| 42 | + readChunkSize: readChunkSize, |
| 43 | + } |
| 44 | +} |
| 45 | + |
| 46 | +type zstdByteStreamChunkReader struct { |
| 47 | + client bytestream.ByteStream_ReadClient |
| 48 | + cancel context.CancelFunc |
| 49 | + zstdReader io.ReadCloser |
| 50 | + readChunkSize int |
| 51 | + wg sync.WaitGroup |
| 52 | +} |
| 53 | + |
| 54 | +// Read opens a pipe which allows us to process the compressed stream into the ZSTD |
| 55 | +// reader without blocking or keeping chunks in memory. |
| 56 | +// |
| 57 | +// Unlike the non-compressed version that can return gRPC chunks directly, compression |
| 58 | +// requires bridging two incompatible interfaces: |
| 59 | +// - gRPC pushes chunks to us via client.Recv(). |
| 60 | +// - ZSTD expects to pull data from us via Read(). |
| 61 | +// |
| 62 | +// We work around this by using a goroutine that receives gRPC chunks and writes them to a pipe, |
| 63 | +// while the ZSTD decoder reads from the other end of the pipe. This creates a streaming |
| 64 | +// pipeline: gRPC -> goroutine -> pipe -> ZSTD -> CAS. |
| 65 | +func (r *zstdByteStreamChunkReader) Read() ([]byte, error) { |
| 66 | + if r.zstdReader == nil { |
| 67 | + pr, pw := io.Pipe() |
| 68 | + |
| 69 | + r.wg.Add(1) |
| 70 | + go func() { |
| 71 | + defer r.wg.Done() |
| 72 | + defer pw.Close() |
| 73 | + for { |
| 74 | + chunk, err := r.client.Recv() |
| 75 | + if err != nil { |
| 76 | + if err != io.EOF { |
| 77 | + pw.CloseWithError(err) |
| 78 | + } |
| 79 | + return |
| 80 | + } |
| 81 | + if _, writeErr := pw.Write(chunk.Data); writeErr != nil { |
| 82 | + pw.CloseWithError(writeErr) |
| 83 | + return |
| 84 | + } |
| 85 | + } |
| 86 | + }() |
| 87 | + |
| 88 | + var err error |
| 89 | + r.zstdReader, err = util.NewZstdReadCloser(pr, zstd.WithDecoderConcurrency(1)) |
| 90 | + if err != nil { |
| 91 | + pr.Close() |
| 92 | + return nil, err |
| 93 | + } |
| 94 | + } |
| 95 | + |
| 96 | + buf := make([]byte, r.readChunkSize) |
| 97 | + n, err := r.zstdReader.Read(buf) |
| 98 | + if n > 0 { |
| 99 | + if err != nil && err != io.EOF { |
| 100 | + err = nil |
| 101 | + } |
| 102 | + return buf[:n], err |
| 103 | + } |
| 104 | + return nil, err |
| 105 | +} |
| 106 | + |
| 107 | +func (r *zstdByteStreamChunkReader) Close() { |
| 108 | + if r.zstdReader != nil { |
| 109 | + r.zstdReader.Close() |
| 110 | + } |
| 111 | + r.cancel() |
| 112 | + |
| 113 | + // Drain the gRPC stream. |
| 114 | + for { |
| 115 | + if _, err := r.client.Recv(); err != nil { |
| 116 | + break |
| 117 | + } |
| 118 | + } |
| 119 | + r.wg.Wait() |
| 120 | +} |
| 121 | + |
| 122 | +func (ba *casWithZstdBlobAccess) Get(ctx context.Context, digest digest.Digest) buffer.Buffer { |
| 123 | + ctxWithCancel, cancel := context.WithCancel(ctx) |
| 124 | + resourceName := digest.GetByteStreamReadPath(remoteexecution.Compressor_ZSTD) |
| 125 | + client, err := ba.byteStreamClient.Read( |
| 126 | + metadata.AppendToOutgoingContext(ctxWithCancel, resourceNameHeader, resourceName), |
| 127 | + &bytestream.ReadRequest{ |
| 128 | + ResourceName: resourceName, |
| 129 | + }, |
| 130 | + ) |
| 131 | + if err != nil { |
| 132 | + cancel() |
| 133 | + return buffer.NewBufferFromError(err) |
| 134 | + } |
| 135 | + return buffer.NewCASBufferFromChunkReader(digest, &zstdByteStreamChunkReader{ |
| 136 | + client: client, |
| 137 | + cancel: cancel, |
| 138 | + readChunkSize: ba.readChunkSize, |
| 139 | + }, buffer.BackendProvided(buffer.Irreparable(digest))) |
| 140 | +} |
| 141 | + |
| 142 | +func (ba *casWithZstdBlobAccess) GetFromComposite(ctx context.Context, parentDigest, childDigest digest.Digest, slicer slicing.BlobSlicer) buffer.Buffer { |
| 143 | + return buffer.NewBufferFromError(status.Error(codes.Unimplemented, "GetFromComposite is not supported with ZSTD compression")) |
| 144 | +} |
| 145 | + |
| 146 | +type zstdByteStreamWriter struct { |
| 147 | + client bytestream.ByteStream_WriteClient |
| 148 | + resourceName string |
| 149 | + writeOffset int64 |
| 150 | + cancel context.CancelFunc |
| 151 | +} |
| 152 | + |
| 153 | +func (w *zstdByteStreamWriter) Write(p []byte) (int, error) { |
| 154 | + if err := w.client.Send(&bytestream.WriteRequest{ |
| 155 | + ResourceName: w.resourceName, |
| 156 | + WriteOffset: w.writeOffset, |
| 157 | + Data: p, |
| 158 | + }); err != nil { |
| 159 | + return 0, err |
| 160 | + } |
| 161 | + w.writeOffset += int64(len(p)) |
| 162 | + w.resourceName = "" |
| 163 | + return len(p), nil |
| 164 | +} |
| 165 | + |
| 166 | +func (w *zstdByteStreamWriter) Close() error { |
| 167 | + if err := w.client.Send(&bytestream.WriteRequest{ |
| 168 | + ResourceName: w.resourceName, |
| 169 | + WriteOffset: w.writeOffset, |
| 170 | + FinishWrite: true, |
| 171 | + }); err != nil { |
| 172 | + w.cancel() |
| 173 | + w.client.CloseAndRecv() |
| 174 | + return err |
| 175 | + } |
| 176 | + _, err := w.client.CloseAndRecv() |
| 177 | + w.cancel() |
| 178 | + return err |
| 179 | +} |
| 180 | + |
| 181 | +func (ba *casWithZstdBlobAccess) Put(ctx context.Context, digest digest.Digest, b buffer.Buffer) error { |
| 182 | + ctxWithCancel, cancel := context.WithCancel(ctx) |
| 183 | + resourceName := digest.GetByteStreamWritePath(uuid.Must(ba.uuidGenerator()), remoteexecution.Compressor_ZSTD) |
| 184 | + client, err := ba.byteStreamClient.Write( |
| 185 | + metadata.AppendToOutgoingContext(ctxWithCancel, resourceNameHeader, resourceName), |
| 186 | + ) |
| 187 | + if err != nil { |
| 188 | + cancel() |
| 189 | + return err |
| 190 | + } |
| 191 | + |
| 192 | + byteStreamWriter := &zstdByteStreamWriter{ |
| 193 | + client: client, |
| 194 | + resourceName: resourceName, |
| 195 | + writeOffset: 0, |
| 196 | + cancel: cancel, |
| 197 | + } |
| 198 | + |
| 199 | + zstdWriter, err := zstd.NewWriter(byteStreamWriter, zstd.WithEncoderConcurrency(1)) |
| 200 | + if err != nil { |
| 201 | + cancel() |
| 202 | + client.CloseAndRecv() |
| 203 | + return status.Errorf(codes.Internal, "Failed to create zstd writer: %v", err) |
| 204 | + } |
| 205 | + |
| 206 | + if err := b.IntoWriter(zstdWriter); err != nil { |
| 207 | + zstdWriter.Close() |
| 208 | + byteStreamWriter.Close() |
| 209 | + return err |
| 210 | + } |
| 211 | + |
| 212 | + if err := zstdWriter.Close(); err != nil { |
| 213 | + byteStreamWriter.Close() |
| 214 | + return err |
| 215 | + } |
| 216 | + |
| 217 | + return byteStreamWriter.Close() |
| 218 | +} |
| 219 | + |
| 220 | +func (ba *casWithZstdBlobAccess) FindMissing(ctx context.Context, digests digest.Set) (digest.Set, error) { |
| 221 | + return findMissingBlobsInternal(ctx, digests, ba.contentAddressableStorageClient) |
| 222 | +} |
| 223 | + |
| 224 | +func (ba *casWithZstdBlobAccess) GetCapabilities(ctx context.Context, instanceName digest.InstanceName) (*remoteexecution.ServerCapabilities, error) { |
| 225 | + cacheCapabilities, err := getCacheCapabilities(ctx, ba.capabilitiesClient, instanceName) |
| 226 | + if err != nil { |
| 227 | + return nil, err |
| 228 | + } |
| 229 | + |
| 230 | + if !slices.Contains(cacheCapabilities.SupportedCompressors, remoteexecution.Compressor_ZSTD) { |
| 231 | + return nil, status.Error(codes.FailedPrecondition, "Server does not support ZSTD compression") |
| 232 | + } |
| 233 | + |
| 234 | + // Only return fields that pertain to the Content Addressable |
| 235 | + // Storage. Include compression support information. |
| 236 | + return &remoteexecution.ServerCapabilities{ |
| 237 | + CacheCapabilities: &remoteexecution.CacheCapabilities{ |
| 238 | + DigestFunctions: digest.RemoveUnsupportedDigestFunctions(cacheCapabilities.DigestFunctions), |
| 239 | + SupportedCompressors: cacheCapabilities.SupportedCompressors, |
| 240 | + }, |
| 241 | + }, nil |
| 242 | +} |
0 commit comments