@@ -9,20 +9,23 @@ import (
99 "github.com/aspect-build/aspect-gazelle/runner/pkg/socket"
1010)
1111
12- type IncrementalClient interface {
12+ type WatchClient interface {
1313 Connect () error
1414 Disconnect () error
15- AwaitCycle ( ) iter.Seq2 [* CycleSourcesMessage , error ]
15+ Subscribe ( options WatchOptions ) iter.Seq2 [* CycleSourcesMessage , error ]
1616}
1717
1818type incClient struct {
1919 socketPath string
2020 socket socket.Socket [interface {}, map [string ]interface {}]
21+
22+ // The negotiated protocol version
23+ version ProtocolVersion
2124}
2225
23- var _ IncrementalClient = (* incClient )(nil )
26+ var _ WatchClient = (* incClient )(nil )
2427
25- func NewClient (host string ) IncrementalClient {
28+ func NewClient (host string ) WatchClient {
2629 return & incClient {
2730 socketPath : host ,
2831 }
@@ -56,22 +59,39 @@ func (c *incClient) negotiate() error {
5659 if negReq ["versions" ] == nil {
5760 return fmt .Errorf ("Received NEGOTIATE without versions: %v" , negReq )
5861 }
59- if ! slices .Contains (negReq ["versions" ].([]interface {}), (interface {})(float64 (PROTOCOL_VERSION ))) {
60- return fmt .Errorf ("Received NEGOTIATE with unsupported versions %v, expected %d" , negReq ["versions" ], PROTOCOL_VERSION )
62+ rawVersions , isArray := negReq ["versions" ].([]interface {})
63+ if ! isArray {
64+ return fmt .Errorf ("Invalid versions, expected []int, received type: %T" , negReq ["versions" ])
65+ }
66+
67+ negotiatedVersion , err := negoiateVersion (rawVersions )
68+ if err != nil {
69+ return err
6170 }
6271
6372 err = c .socket .Send (negotiateResponseMessage {
6473 Message : Message {
6574 Kind : "NEGOTIATE_RESPONSE" ,
6675 },
67- Version : PROTOCOL_VERSION ,
76+ Version : negotiatedVersion ,
6877 })
6978 if err != nil {
7079 return fmt .Errorf ("failed to negotiate protocol version: %w" , err )
7180 }
81+
82+ c .version = negotiatedVersion
7283 return nil
7384}
7485
86+ func negoiateVersion (acceptedVersions []interface {}) (ProtocolVersion , error ) {
87+ for _ , v := range slices .Backward (acceptedVersions ) {
88+ if slices .Contains (abazelSupportedProtocolVersions , ProtocolVersion (v .(float64 ))) {
89+ return ProtocolVersion (v .(float64 )), nil
90+ }
91+ }
92+ return - 1 , fmt .Errorf ("unsupported versions %v, expected one of %v" , acceptedVersions , abazelSupportedProtocolVersions )
93+ }
94+
7595func (c * incClient ) Disconnect () error {
7696 if c .socket == nil {
7797 return fmt .Errorf ("client not connected" )
@@ -85,8 +105,32 @@ func (c *incClient) Disconnect() error {
85105 return err
86106}
87107
88- func (c * incClient ) AwaitCycle ( ) iter.Seq2 [* CycleSourcesMessage , error ] {
108+ func (c * incClient ) Subscribe ( options WatchOptions ) iter.Seq2 [* CycleSourcesMessage , error ] {
89109 return func (yield func (* CycleSourcesMessage , error ) bool ) {
110+ if c .version >= VERSION_1 {
111+ err := c .socket .Send (SubscribeMessage {
112+ Message : Message {
113+ Kind : "SUBSCRIBE" ,
114+ },
115+ WatchType : options .Type ,
116+ })
117+ if err != nil {
118+ yield (nil , err )
119+ return
120+ }
121+
122+ msg , err := c .socket .Recv ()
123+ if err != nil {
124+ yield (nil , err )
125+ return
126+ }
127+
128+ if msg ["kind" ] != "SUBSCRIBE_RESPONSE" {
129+ yield (nil , fmt .Errorf ("expected SUBSCRIBE_RESPONSE, got %v" , msg ))
130+ return
131+ }
132+ }
133+
90134 for {
91135 msg , err := c .socket .Recv ()
92136 if err != nil {
0 commit comments