@@ -9,20 +9,23 @@ import (
99 "github.com/aspect-build/aspect-gazelle/runner/pkg/socket"
1010)
1111
12- type IncrementalClient interface {
12+ type WatchClient interface {
1313 Connect () error
1414 Disconnect () error
15- AwaitCycle ( ) iter.Seq2 [* CycleSourcesMessage , error ]
15+ Subscribe ( options WatchOptions ) iter.Seq2 [* CycleSourcesMessage , error ]
1616}
1717
1818type incClient struct {
1919 socketPath string
2020 socket socket.Socket [interface {}, map [string ]interface {}]
21+
22+ // The negotiated protocol version
23+ version ProtocolVersion
2124}
2225
23- var _ IncrementalClient = (* incClient )(nil )
26+ var _ WatchClient = (* incClient )(nil )
2427
25- func NewClient (host string ) IncrementalClient {
28+ func NewClient (host string ) WatchClient {
2629 return & incClient {
2730 socketPath : host ,
2831 }
@@ -56,22 +59,39 @@ func (c *incClient) negotiate() error {
5659 if negReq ["versions" ] == nil {
5760 return fmt .Errorf ("Received NEGOTIATE without versions: %v" , negReq )
5861 }
59- if ! slices .Contains (negReq ["versions" ].([]interface {}), (interface {})(float64 (PROTOCOL_VERSION ))) {
60- return fmt .Errorf ("Received NEGOTIATE with unsupported versions %v, expected %d" , negReq ["versions" ], PROTOCOL_VERSION )
62+ rawVersions , isArray := negReq ["versions" ].([]interface {})
63+ if ! isArray {
64+ return fmt .Errorf ("Invalid versions, expected []int, received type: %T" , negReq ["versions" ])
65+ }
66+
67+ negotiatedVersion , err := negotiateVersion (rawVersions )
68+ if err != nil {
69+ return err
6170 }
6271
6372 err = c .socket .Send (negotiateResponseMessage {
6473 Message : Message {
6574 Kind : "NEGOTIATE_RESPONSE" ,
6675 },
67- Version : PROTOCOL_VERSION ,
76+ Version : negotiatedVersion ,
6877 })
6978 if err != nil {
7079 return fmt .Errorf ("failed to negotiate protocol version: %w" , err )
7180 }
81+
82+ c .version = negotiatedVersion
7283 return nil
7384}
7485
86+ func negotiateVersion (acceptedVersions []interface {}) (ProtocolVersion , error ) {
87+ for _ , v := range slices .Backward (acceptedVersions ) {
88+ if slices .Contains (abazelSupportedProtocolVersions , ProtocolVersion (v .(float64 ))) {
89+ return ProtocolVersion (v .(float64 )), nil
90+ }
91+ }
92+ return - 1 , fmt .Errorf ("unsupported versions %v, expected one of %v" , acceptedVersions , abazelSupportedProtocolVersions )
93+ }
94+
7595func (c * incClient ) Disconnect () error {
7696 if c .socket == nil {
7797 return fmt .Errorf ("client not connected" )
@@ -85,8 +105,33 @@ func (c *incClient) Disconnect() error {
85105 return err
86106}
87107
88- func (c * incClient ) AwaitCycle ( ) iter.Seq2 [* CycleSourcesMessage , error ] {
108+ func (c * incClient ) Subscribe ( options WatchOptions ) iter.Seq2 [* CycleSourcesMessage , error ] {
89109 return func (yield func (* CycleSourcesMessage , error ) bool ) {
110+ // Version 1+ require the initial SUBSCRIBE to start the subscription
111+ if c .version >= VERSION_1 {
112+ err := c .socket .Send (SubscribeMessage {
113+ Message : Message {
114+ Kind : "SUBSCRIBE" ,
115+ },
116+ WatchType : options .Type ,
117+ })
118+ if err != nil {
119+ yield (nil , err )
120+ return
121+ }
122+
123+ msg , err := c .socket .Recv ()
124+ if err != nil {
125+ yield (nil , err )
126+ return
127+ }
128+
129+ if msg ["kind" ] != "SUBSCRIBE_RESPONSE" {
130+ yield (nil , fmt .Errorf ("expected SUBSCRIBE_RESPONSE, got %v" , msg ))
131+ return
132+ }
133+ }
134+
90135 for {
91136 msg , err := c .socket .Recv ()
92137 if err != nil {
0 commit comments