Skip to content

Commit 5c186b3

Browse files
committed
WIP - writing
1 parent d621dcd commit 5c186b3

File tree

3 files changed

+89
-6
lines changed

3 files changed

+89
-6
lines changed

tfprotov6/internal/toproto/state_store.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,12 @@ func ConfigureStateStore_Response(in *tfprotov6.ConfigureStateStoreResponse) *tf
2828
}
2929
}
3030

31-
func ReadStateBytes_ResponseChunk(in *tfprotov6.StateByteChunk) *tfplugin6.ReadStateBytes_ResponseChunk {
31+
func ReadStateBytes_Response(in *tfprotov6.ReadStateByteChunk) *tfplugin6.ReadStateBytes_Response {
3232
if in == nil {
3333
return nil
3434
}
3535

36-
return &tfplugin6.ReadStateBytes_ResponseChunk{
36+
return &tfplugin6.ReadStateBytes_Response{
3737
Diagnostics: Diagnostics(in.Diagnostics),
3838
Bytes: in.Bytes,
3939
TotalLength: in.TotalLength,

tfprotov6/state_store.go

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ type StateStoreServer interface {
2020
// ReadStateBytes streams byte chunks of a given state file from a state store
2121
ReadStateBytes(context.Context, *ReadStateBytesRequest) (*ReadStateBytesStream, error)
2222

23+
WriteStateBytes(context.Context, *WriteStateBytesStream) (*WriteStateBytesResponse, error)
24+
2325
// GetStates returns a list of all states (i.e. CE workspaces) managed by a given state store
2426
GetStates(context.Context, *GetStatesRequest) (*GetStatesResponse, error)
2527

@@ -51,12 +53,28 @@ type ReadStateBytesRequest struct {
5153
}
5254

5355
type ReadStateBytesStream struct {
54-
Chunks iter.Seq[StateByteChunk]
56+
Chunks iter.Seq[ReadStateByteChunk]
57+
}
58+
59+
// type ChunkIterator func(StateByteChunkRequest) StateByteChunk
60+
61+
type WriteStateBytesStream struct {
62+
Chunks iter.Seq[WriteStateByteChunk]
63+
}
64+
65+
type WriteStateBytesResponse struct {
66+
Diagnostics []*Diagnostic
67+
}
68+
69+
type WriteStateByteChunk = StateByteChunk
70+
71+
type ReadStateByteChunk struct {
72+
StateByteChunk
73+
Diagnostics []*Diagnostic
5574
}
5675

5776
type StateByteChunk struct {
5877
Bytes []byte
59-
Diagnostics []*Diagnostic
6078
TotalLength int64
6179
Range StateByteRange
6280
}

tfprotov6/tf6server/server.go

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"encoding/json"
99
"errors"
1010
"fmt"
11+
"io"
1112
"os"
1213
"os/signal"
1314
"regexp"
@@ -1598,7 +1599,7 @@ func (s *server) ConfigureStateStore(ctx context.Context, protoReq *tfplugin6.Co
15981599
return protoResp, nil
15991600
}
16001601

1601-
func (s *server) ReadStateBytes(protoReq *tfplugin6.ReadStateBytes_Request, protoStream grpc.ServerStreamingServer[tfplugin6.ReadStateBytes_ResponseChunk]) error {
1602+
func (s *server) ReadStateBytes(protoReq *tfplugin6.ReadStateBytes_Request, protoStream grpc.ServerStreamingServer[tfplugin6.ReadStateBytes_Response]) error {
16021603
rpc := "ReadStateBytes"
16031604
ctx := protoStream.Context()
16041605
ctx = s.loggingContext(ctx)
@@ -1628,12 +1629,13 @@ func (s *server) ReadStateBytes(protoReq *tfplugin6.ReadStateBytes_Request, prot
16281629

16291630
for chunk := range stream.Chunks {
16301631
select {
1632+
// TODO: check how interruptions are handled
16311633
case <-ctx.Done():
16321634
logging.ProtocolTrace(ctx, "all chunks sent")
16331635
return nil
16341636

16351637
default:
1636-
protoChunk := toproto.ReadStateBytes_ResponseChunk(&chunk)
1638+
protoChunk := toproto.ReadStateBytes_Response(&chunk)
16371639
if err := protoStream.Send(protoChunk); err != nil {
16381640
logging.ProtocolError(ctx, "Error sending chunk", map[string]any{logging.KeyError: err})
16391641
return err
@@ -1644,6 +1646,69 @@ func (s *server) ReadStateBytes(protoReq *tfplugin6.ReadStateBytes_Request, prot
16441646
return nil
16451647
}
16461648

1649+
func (s *server) WriteStateBytes(srv grpc.ClientStreamingServer[tfplugin6.WriteStateBytes_RequestChunk, tfplugin6.WriteStateBytes_Response]) error {
1650+
rpc := "WriteStateBytes"
1651+
ctx := srv.Context()
1652+
ctx = s.loggingContext(ctx)
1653+
ctx = logging.RpcContext(ctx, rpc)
1654+
// ctx = logging.StateStoreContext(ctx, protoReq.TypeName)
1655+
ctx = s.stoppableContext(ctx)
1656+
// logging.ProtocolTrace(ctx, "Received request")
1657+
// defer logging.ProtocolTrace(ctx, "Served request")
1658+
1659+
ctx = tf6serverlogging.DownstreamRequest(ctx)
1660+
1661+
server, ok := s.downstream.(tfprotov6.StateStoreServer)
1662+
if !ok {
1663+
err := status.Error(codes.Unimplemented, "ProviderServer does not implement WriteStateBytes")
1664+
logging.ProtocolError(ctx, err.Error())
1665+
return err
1666+
}
1667+
1668+
var iteratorErr error
1669+
1670+
// TODO: what about error handling per chunk and providers having the ability to do cleanup on interruption?
1671+
1672+
iterator := func(yield func(tfprotov6.WriteStateByteChunk) bool) {
1673+
for {
1674+
chunk, err := srv.Recv()
1675+
if err == io.EOF {
1676+
break
1677+
}
1678+
if err != nil {
1679+
iteratorErr = err
1680+
srv.SendMsg(&tfplugin6.WriteStateBytes_Response{
1681+
// Diagnostics: ,
1682+
})
1683+
return
1684+
}
1685+
1686+
yield(tfprotov6.WriteStateByteChunk{
1687+
Bytes: chunk.Bytes,
1688+
TotalLength: chunk.TotalLength,
1689+
Range: tfprotov6.StateByteRange{
1690+
Start: chunk.Range.Start,
1691+
End: chunk.Range.End,
1692+
},
1693+
})
1694+
1695+
}
1696+
}
1697+
1698+
resp, err := server.WriteStateBytes(ctx, &tfprotov6.WriteStateBytesStream{
1699+
Chunks: iterator,
1700+
})
1701+
if err != nil {
1702+
return err
1703+
}
1704+
1705+
err = srv.SendAndClose(&tfplugin6.WriteStateBytes_Response{
1706+
// Diagnostics: resp.Diagnostics,
1707+
})
1708+
1709+
return nil
1710+
}
1711+
16471712
func (s *server) GetStates(ctx context.Context, protoReq *tfplugin6.GetStates_Request) (*tfplugin6.GetStates_Response, error) {
16481713
rpc := "GetStates"
16491714
ctx = s.loggingContext(ctx)

0 commit comments

Comments
 (0)