Skip to content

Commit 9394014

Browse files
committed
implement ReadStateBytes
1 parent dc163e4 commit 9394014

File tree

4 files changed

+100
-1
lines changed

4 files changed

+100
-1
lines changed

tfprotov6/internal/fromproto/state_store.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,17 @@ func ConfigureStateStoreRequest(in *tfplugin6.ConfigureStateStore_Request) *tfpr
3030
}
3131
}
3232

33+
func ReadStateBytesRequest(in *tfplugin6.ReadStateBytes_Request) *tfprotov6.ReadStateBytesRequest {
34+
if in == nil {
35+
return nil
36+
}
37+
38+
return &tfprotov6.ReadStateBytesRequest{
39+
TypeName: in.TypeName,
40+
StateId: in.StateId,
41+
}
42+
}
43+
3344
func GetStatesRequest(in *tfplugin6.GetStates_Request) *tfprotov6.GetStatesRequest {
3445
if in == nil {
3546
return nil

tfprotov6/internal/toproto/state_store.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,22 @@ func ConfigureStateStore_Response(in *tfprotov6.ConfigureStateStoreResponse) *tf
2828
}
2929
}
3030

31+
func ReadStateBytes_ResponseChunk(in *tfprotov6.StateByteChunk) *tfplugin6.ReadStateBytes_ResponseChunk {
32+
if in == nil {
33+
return nil
34+
}
35+
36+
return &tfplugin6.ReadStateBytes_ResponseChunk{
37+
Diagnostics: Diagnostics(in.Diagnostics),
38+
Bytes: in.Bytes,
39+
TotalLength: in.TotalLength,
40+
Range: &tfplugin6.StateRange{
41+
Start: in.Range.Start,
42+
End: in.Range.End,
43+
},
44+
}
45+
}
46+
3147
func GetStates_Response(in *tfprotov6.GetStatesResponse) *tfplugin6.GetStates_Response {
3248
if in == nil {
3349
return nil

tfprotov6/state_store.go

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@
33

44
package tfprotov6
55

6-
import "context"
6+
import (
7+
"context"
8+
"iter"
9+
)
710

811
// StateStoreServer is an interface containing the methods an list resource
912
// implementation needs to fill.
@@ -14,6 +17,9 @@ type StateStoreServer interface {
1417
// ConfigureStateStore configures the state store, such as S3 connection in the context of already configured provider
1518
ConfigureStateStore(context.Context, *ConfigureStateStoreRequest) (*ConfigureStateStoreResponse, error)
1619

20+
// ReadStateBytes streams byte chunks of a given state file from a state store
21+
ReadStateBytes(context.Context, *ReadStateBytesRequest) (*ReadStateBytesStream, error)
22+
1723
// GetStates returns a list of all states (i.e. CE workspaces) managed by a given state store
1824
GetStates(context.Context, *GetStatesRequest) (*GetStatesResponse, error)
1925

@@ -39,6 +45,26 @@ type ConfigureStateStoreResponse struct {
3945
Diagnostics []*Diagnostic
4046
}
4147

48+
type ReadStateBytesRequest struct {
49+
TypeName string
50+
StateId string
51+
}
52+
53+
type ReadStateBytesStream struct {
54+
Chunks iter.Seq[StateByteChunk]
55+
}
56+
57+
type StateByteChunk struct {
58+
Bytes []byte
59+
Diagnostics []*Diagnostic
60+
TotalLength int64
61+
Range StateByteRange
62+
}
63+
64+
type StateByteRange struct {
65+
Start, End int64
66+
}
67+
4268
type GetStatesRequest struct {
4369
TypeName string
4470
}

tfprotov6/tf6server/server.go

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1598,6 +1598,52 @@ func (s *server) ConfigureStateStore(ctx context.Context, protoReq *tfplugin6.Co
15981598
return protoResp, nil
15991599
}
16001600

1601+
func (s *server) ReadStateBytes(protoReq *tfplugin6.ReadStateBytes_Request, protoStream grpc.ServerStreamingServer[tfplugin6.ReadStateBytes_ResponseChunk]) error {
1602+
rpc := "ReadStateBytes"
1603+
ctx := protoStream.Context()
1604+
ctx = s.loggingContext(ctx)
1605+
ctx = logging.RpcContext(ctx, rpc)
1606+
ctx = logging.StateStoreContext(ctx, protoReq.TypeName)
1607+
ctx = s.stoppableContext(ctx)
1608+
logging.ProtocolTrace(ctx, "Received request")
1609+
defer logging.ProtocolTrace(ctx, "Served request")
1610+
1611+
req := fromproto.ReadStateBytesRequest(protoReq)
1612+
logging.ProtocolData(ctx, s.protocolDataDir, rpc, "Request", "StateId", req.StateId)
1613+
1614+
ctx = tf6serverlogging.DownstreamRequest(ctx)
1615+
1616+
server, ok := s.downstream.(tfprotov6.StateStoreServer)
1617+
if !ok {
1618+
err := status.Error(codes.Unimplemented, "ProviderServer does not implement ReadStateBytes")
1619+
logging.ProtocolError(ctx, err.Error())
1620+
return err
1621+
}
1622+
1623+
stream, err := server.ReadStateBytes(ctx, req)
1624+
if err != nil {
1625+
logging.ProtocolError(ctx, "Error from downstream", map[string]interface{}{logging.KeyError: err})
1626+
return err
1627+
}
1628+
1629+
for chunk := range stream.Chunks {
1630+
select {
1631+
case <-ctx.Done():
1632+
logging.ProtocolTrace(ctx, "all chunks sent")
1633+
return nil
1634+
1635+
default:
1636+
protoChunk := toproto.ReadStateBytes_ResponseChunk(&chunk)
1637+
if err := protoStream.Send(protoChunk); err != nil {
1638+
logging.ProtocolError(ctx, "Error sending chunk", map[string]any{logging.KeyError: err})
1639+
return err
1640+
}
1641+
}
1642+
}
1643+
1644+
return nil
1645+
}
1646+
16011647
func (s *server) GetStates(ctx context.Context, protoReq *tfplugin6.GetStates_Request) (*tfplugin6.GetStates_Response, error) {
16021648
rpc := "GetStates"
16031649
ctx = s.loggingContext(ctx)

0 commit comments

Comments
 (0)