@@ -5,23 +5,27 @@ import (
55 "iter"
66 "slices"
77
8+ BazelLog "github.com/aspect-build/aspect-gazelle/common/logger"
89 "github.com/aspect-build/aspect-gazelle/runner/pkg/socket"
910)
1011
11- type IncrementalClient interface {
12+ type WatchClient interface {
1213 Connect () error
1314 Disconnect () error
14- AwaitCycle ( ) iter.Seq [ CycleSourcesMessage ]
15+ Subscribe ( options WatchOptions ) iter.Seq2 [ * CycleSourcesMessage , error ]
1516}
1617
1718type incClient struct {
1819 socketPath string
1920 socket socket.Socket [interface {}, map [string ]interface {}]
21+
22+ // The negotiated protocol version
23+ version ProtocolVersion
2024}
2125
22- var _ IncrementalClient = (* incClient )(nil )
26+ var _ WatchClient = (* incClient )(nil )
2327
24- func NewClient (host string ) IncrementalClient {
28+ func NewClient (host string ) WatchClient {
2529 return & incClient {
2630 socketPath : host ,
2731 }
@@ -55,22 +59,39 @@ func (c *incClient) negotiate() error {
5559 if negReq ["versions" ] == nil {
5660 return fmt .Errorf ("Received NEGOTIATE without versions: %v" , negReq )
5761 }
58- if ! slices .Contains (negReq ["versions" ].([]interface {}), (interface {})(float64 (PROTOCOL_VERSION ))) {
59- return fmt .Errorf ("Received NEGOTIATE with unsupported versions %v, expected %d" , negReq ["versions" ], PROTOCOL_VERSION )
62+ rawVersions , isArray := negReq ["versions" ].([]interface {})
63+ if ! isArray {
64+ return fmt .Errorf ("Invalid versions, expected []int, received type: %T" , negReq ["versions" ])
65+ }
66+
67+ negotiatedVersion , err := negoiateVersion (rawVersions )
68+ if err != nil {
69+ return err
6070 }
6171
6272 err = c .socket .Send (negotiateResponseMessage {
6373 Message : Message {
6474 Kind : "NEGOTIATE_RESPONSE" ,
6575 },
66- Version : PROTOCOL_VERSION ,
76+ Version : negotiatedVersion ,
6777 })
6878 if err != nil {
6979 return fmt .Errorf ("failed to negotiate protocol version: %w" , err )
7080 }
81+
82+ c .version = negotiatedVersion
7183 return nil
7284}
7385
86+ func negoiateVersion (acceptedVersions []interface {}) (ProtocolVersion , error ) {
87+ for _ , v := range slices .Backward (acceptedVersions ) {
88+ if slices .Contains (abazelSupportedProtocolVersions , ProtocolVersion (v .(float64 ))) {
89+ return ProtocolVersion (v .(float64 )), nil
90+ }
91+ }
92+ return - 1 , fmt .Errorf ("unsupported versions %v, expected one of %v" , acceptedVersions , abazelSupportedProtocolVersions )
93+ }
94+
7495func (c * incClient ) Disconnect () error {
7596 if c .socket == nil {
7697 return fmt .Errorf ("client not connected" )
@@ -84,12 +105,37 @@ func (c *incClient) Disconnect() error {
84105 return err
85106}
86107
87- func (c * incClient ) AwaitCycle () iter.Seq [CycleSourcesMessage ] {
88- return func (yield func (CycleSourcesMessage ) bool ) {
108+ func (c * incClient ) Subscribe (options WatchOptions ) iter.Seq2 [* CycleSourcesMessage , error ] {
109+ return func (yield func (* CycleSourcesMessage , error ) bool ) {
110+ if c .version >= VERSION_1 {
111+ err := c .socket .Send (SubscribeMessage {
112+ Message : Message {
113+ Kind : "SUBSCRIBE" ,
114+ },
115+ WatchType : options .Type ,
116+ })
117+ if err != nil {
118+ yield (nil , err )
119+ return
120+ }
121+
122+ msg , err := c .socket .Recv ()
123+ if err != nil {
124+ yield (nil , err )
125+ return
126+ }
127+
128+ if msg ["kind" ] != "SUBSCRIBE_RESPONSE" {
129+ yield (nil , fmt .Errorf ("expected SUBSCRIBE_RESPONSE, got %v" , msg ))
130+ return
131+ }
132+ }
133+
89134 for {
90135 msg , err := c .socket .Recv ()
91136 if err != nil {
92137 fmt .Printf ("Error receiving message: %v\n " , err )
138+ yield (nil , err )
93139 return
94140 }
95141
@@ -100,21 +146,28 @@ func (c *incClient) AwaitCycle() iter.Seq[CycleSourcesMessage] {
100146 continue
101147 }
102148
103- c .socket .Send (CycleMessage {
149+ err = c .socket .Send (CycleMessage {
104150 Message : Message {
105151 Kind : "CYCLE_STARTED" ,
106152 },
107153 CycleId : cycleEvent .CycleId ,
108154 })
155+ if err != nil {
156+ yield (nil , err )
157+ return
158+ }
109159
110- r := yield (cycleEvent )
160+ r := yield (& cycleEvent , nil )
111161
112- c .socket .Send (CycleMessage {
162+ err = c .socket .Send (CycleMessage {
113163 Message : Message {
114164 Kind : "CYCLE_COMPLETED" ,
115165 },
116166 CycleId : cycleEvent .CycleId ,
117167 })
168+ if err != nil {
169+ BazelLog .Warnf ("Failed to send CYCLE_COMPLETED for cycle_id=%d: %v\n " , cycleEvent .CycleId , err )
170+ }
118171
119172 if ! r {
120173 return
0 commit comments