88 "encoding/json"
99 "errors"
1010 "fmt"
11+ "io"
1112 "os"
1213 "os/signal"
1314 "regexp"
@@ -1598,7 +1599,7 @@ func (s *server) ConfigureStateStore(ctx context.Context, protoReq *tfplugin6.Co
15981599 return protoResp , nil
15991600}
16001601
1601- func (s * server ) ReadStateBytes (protoReq * tfplugin6.ReadStateBytes_Request , protoStream grpc.ServerStreamingServer [tfplugin6.ReadStateBytes_ResponseChunk ]) error {
1602+ func (s * server ) ReadStateBytes (protoReq * tfplugin6.ReadStateBytes_Request , protoStream grpc.ServerStreamingServer [tfplugin6.ReadStateBytes_Response ]) error {
16021603 rpc := "ReadStateBytes"
16031604 ctx := protoStream .Context ()
16041605 ctx = s .loggingContext (ctx )
@@ -1628,12 +1629,13 @@ func (s *server) ReadStateBytes(protoReq *tfplugin6.ReadStateBytes_Request, prot
16281629
16291630 for chunk := range stream .Chunks {
16301631 select {
1632+ // TODO: check how interruptions are handled
16311633 case <- ctx .Done ():
16321634 logging .ProtocolTrace (ctx , "all chunks sent" )
16331635 return nil
16341636
16351637 default :
1636- protoChunk := toproto .ReadStateBytes_ResponseChunk (& chunk )
1638+ protoChunk := toproto .ReadStateBytes_Response (& chunk )
16371639 if err := protoStream .Send (protoChunk ); err != nil {
16381640 logging .ProtocolError (ctx , "Error sending chunk" , map [string ]any {logging .KeyError : err })
16391641 return err
@@ -1644,6 +1646,69 @@ func (s *server) ReadStateBytes(protoReq *tfplugin6.ReadStateBytes_Request, prot
16441646 return nil
16451647}
16461648
1649+ func (s * server ) WriteStateBytes (srv grpc.ClientStreamingServer [tfplugin6.WriteStateBytes_RequestChunk , tfplugin6.WriteStateBytes_Response ]) error {
1650+ rpc := "WriteStateBytes"
1651+ ctx := srv .Context ()
1652+ ctx = s .loggingContext (ctx )
1653+ ctx = logging .RpcContext (ctx , rpc )
1654+ // ctx = logging.StateStoreContext(ctx, protoReq.TypeName)
1655+ ctx = s .stoppableContext (ctx )
1656+ // logging.ProtocolTrace(ctx, "Received request")
1657+ // defer logging.ProtocolTrace(ctx, "Served request")
1658+
1659+ ctx = tf6serverlogging .DownstreamRequest (ctx )
1660+
1661+ server , ok := s .downstream .(tfprotov6.StateStoreServer )
1662+ if ! ok {
1663+ err := status .Error (codes .Unimplemented , "ProviderServer does not implement WriteStateBytes" )
1664+ logging .ProtocolError (ctx , err .Error ())
1665+ return err
1666+ }
1667+
1668+ var iteratorErr error
1669+
1670+ // TODO: what about error handling per chunk and providers having the ability to do cleanup on interruption?
1671+
1672+ iterator := func (yield func (tfprotov6.WriteStateByteChunk ) bool ) {
1673+ for {
1674+ chunk , err := srv .Recv ()
1675+ if err == io .EOF {
1676+ break
1677+ }
1678+ if err != nil {
1679+ iteratorErr = err
1680+ srv .SendMsg (& tfplugin6.WriteStateBytes_Response {
1681+ // Diagnostics: ,
1682+ })
1683+ return
1684+ }
1685+
1686+ yield (tfprotov6.WriteStateByteChunk {
1687+ Bytes : chunk .Bytes ,
1688+ TotalLength : chunk .TotalLength ,
1689+ Range : tfprotov6.StateByteRange {
1690+ Start : chunk .Range .Start ,
1691+ End : chunk .Range .End ,
1692+ },
1693+ })
1694+
1695+ }
1696+ }
1697+
1698+ resp , err := server .WriteStateBytes (ctx , & tfprotov6.WriteStateBytesStream {
1699+ Chunks : iterator ,
1700+ })
1701+ if err != nil {
1702+ return err
1703+ }
1704+
1705+ err = srv .SendAndClose (& tfplugin6.WriteStateBytes_Response {
1706+ // Diagnostics: resp.Diagnostics,
1707+ })
1708+
1709+ return nil
1710+ }
1711+
16471712func (s * server ) GetStates (ctx context.Context , protoReq * tfplugin6.GetStates_Request ) (* tfplugin6.GetStates_Response , error ) {
16481713 rpc := "GetStates"
16491714 ctx = s .loggingContext (ctx )
0 commit comments