Skip to content

Commit 10cc314

Browse files
committed
feat: add explicit SUBSCRIBE message to initiate subscriptions
1 parent 7dea6df commit 10cc314

File tree

3 files changed

+91
-16
lines changed

3 files changed

+91
-16
lines changed

runner/pkg/ibp/client.go

Lines changed: 52 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,23 @@ import (
99
"github.com/aspect-build/aspect-gazelle/runner/pkg/socket"
1010
)
1111

12-
type IncrementalClient interface {
12+
type WatchClient interface {
1313
Connect() error
1414
Disconnect() error
15-
AwaitCycle() iter.Seq2[*CycleSourcesMessage, error]
15+
Subscribe(options WatchOptions) iter.Seq2[*CycleSourcesMessage, error]
1616
}
1717

1818
type incClient struct {
1919
socketPath string
2020
socket socket.Socket[interface{}, map[string]interface{}]
21+
22+
// The negotiated protocol version
23+
version ProtocolVersion
2124
}
2225

23-
var _ IncrementalClient = (*incClient)(nil)
26+
var _ WatchClient = (*incClient)(nil)
2427

25-
func NewClient(host string) IncrementalClient {
28+
func NewClient(host string) WatchClient {
2629
return &incClient{
2730
socketPath: host,
2831
}
@@ -56,22 +59,39 @@ func (c *incClient) negotiate() error {
5659
if negReq["versions"] == nil {
5760
return fmt.Errorf("Received NEGOTIATE without versions: %v", negReq)
5861
}
59-
if !slices.Contains(negReq["versions"].([]interface{}), (interface{})(float64(PROTOCOL_VERSION))) {
60-
return fmt.Errorf("Received NEGOTIATE with unsupported versions %v, expected %d", negReq["versions"], PROTOCOL_VERSION)
62+
rawVersions, isArray := negReq["versions"].([]interface{})
63+
if !isArray {
64+
return fmt.Errorf("Invalid versions, expected []int, received type: %T", negReq["versions"])
65+
}
66+
67+
negotiatedVersion, err := negoiateVersion(rawVersions)
68+
if err != nil {
69+
return err
6170
}
6271

6372
err = c.socket.Send(negotiateResponseMessage{
6473
Message: Message{
6574
Kind: "NEGOTIATE_RESPONSE",
6675
},
67-
Version: PROTOCOL_VERSION,
76+
Version: negotiatedVersion,
6877
})
6978
if err != nil {
7079
return fmt.Errorf("failed to negotiate protocol version: %w", err)
7180
}
81+
82+
c.version = negotiatedVersion
7283
return nil
7384
}
7485

86+
func negoiateVersion(acceptedVersions []interface{}) (ProtocolVersion, error) {
87+
for _, v := range slices.Backward(acceptedVersions) {
88+
if slices.Contains(abazelSupportedProtocolVersions, ProtocolVersion(v.(float64))) {
89+
return ProtocolVersion(v.(float64)), nil
90+
}
91+
}
92+
return -1, fmt.Errorf("unsupported versions %v, expected one of %v", acceptedVersions, abazelSupportedProtocolVersions)
93+
}
94+
7595
func (c *incClient) Disconnect() error {
7696
if c.socket == nil {
7797
return fmt.Errorf("client not connected")
@@ -85,8 +105,32 @@ func (c *incClient) Disconnect() error {
85105
return err
86106
}
87107

88-
func (c *incClient) AwaitCycle() iter.Seq2[*CycleSourcesMessage, error] {
108+
func (c *incClient) Subscribe(options WatchOptions) iter.Seq2[*CycleSourcesMessage, error] {
89109
return func(yield func(*CycleSourcesMessage, error) bool) {
110+
if c.version >= VERSION_1 {
111+
err := c.socket.Send(SubscribeMessage{
112+
Message: Message{
113+
Kind: "SUBSCRIBE",
114+
},
115+
WatchType: options.Type,
116+
})
117+
if err != nil {
118+
yield(nil, err)
119+
return
120+
}
121+
122+
msg, err := c.socket.Recv()
123+
if err != nil {
124+
yield(nil, err)
125+
return
126+
}
127+
128+
if msg["kind"] != "SUBSCRIBE_RESPONSE" {
129+
yield(nil, fmt.Errorf("expected SUBSCRIBE_RESPONSE, got %v", msg))
130+
return
131+
}
132+
}
133+
90134
for {
91135
msg, err := c.socket.Recv()
92136
if err != nil {

runner/pkg/ibp/protocol.go

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,25 @@ import (
55
"fmt"
66
"os"
77
"path"
8+
"slices"
89
"sync/atomic"
910

1011
"github.com/aspect-build/aspect-gazelle/runner/pkg/socket"
1112
"github.com/fatih/color"
1213
)
1314

14-
const PROTOCOL_VERSION = 0
15+
type ProtocolVersion int
16+
17+
const (
18+
LEGACY_VERSION_0 ProtocolVersion = 0
19+
VERSION_1 = 1
20+
LATEST_VERSION = VERSION_1
21+
)
22+
23+
func (v ProtocolVersion) HasExplicitSubscribe() bool {
24+
return v >= 1
25+
}
26+
1527
const PROTOCOL_SOCKET_ENV = "ABAZEL_WATCH_SOCKET_FILE"
1628

1729
type IncrementalBazel interface {
@@ -40,18 +52,34 @@ type Message struct {
4052

4153
type negotiateMessage struct {
4254
Message
43-
Versions []int `json:"versions"`
55+
Versions []ProtocolVersion `json:"versions"`
4456
}
4557
type negotiateResponseMessage struct {
4658
Message
47-
Version int `json:"version"`
59+
Version ProtocolVersion `json:"version"`
4860
}
4961

5062
type capMessage struct {
5163
Message
5264
Caps map[string]bool `json:"caps"`
5365
}
5466

67+
type WatchType string
68+
69+
const (
70+
WatchTypeRunfiles WatchType = "runfiles"
71+
WatchTypeSources WatchType = "sources"
72+
)
73+
74+
type WatchOptions struct {
75+
Type WatchType
76+
}
77+
78+
type SubscribeMessage struct {
79+
Message
80+
WatchType WatchType `json:"watch_type"`
81+
}
82+
5583
type exitMessage struct {
5684
Message
5785
Description string `json:"description"`
@@ -77,7 +105,10 @@ type CycleSourcesMessage struct {
77105
}
78106

79107
// The versions supported by this host implementation of the protocol.
80-
var abazelSupportedProtocolVersions = []int{PROTOCOL_VERSION}
108+
var abazelSupportedProtocolVersions = []ProtocolVersion{
109+
LEGACY_VERSION_0,
110+
VERSION_1,
111+
}
81112

82113
type aspectBazelSocket = socket.Server[interface{}, map[string]any]
83114

@@ -172,8 +203,8 @@ func (p *aspectBazelProtocol) acceptNegotiation() error {
172203
if negResp["version"] == nil {
173204
return fmt.Errorf("Received NEGOTIATE_RESPONSE without version: %v", negResp)
174205
}
175-
if negResp["version"].(float64) != PROTOCOL_VERSION {
176-
return fmt.Errorf("Received NEGOTIATE_RESPONSE with unsupported version %v, expected %d", negResp["version"], PROTOCOL_VERSION)
206+
if !slices.Contains(abazelSupportedProtocolVersions, ProtocolVersion(negResp["version"].(float64))) {
207+
return fmt.Errorf("Received NEGOTIATE_RESPONSE with unsupported version %v, expected one of %v", negResp["version"], abazelSupportedProtocolVersions)
177208
}
178209

179210
p.connectedCh <- int(negResp["version"].(float64))

runner/runner.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -238,8 +238,8 @@ func (p *GazelleRunner) Watch(watchAddress string, cmd GazelleCommand, mode Gaze
238238
))
239239
defer t.End()
240240

241-
// Subscribe to further changes
242-
for cs, err := range watch.AwaitCycle() {
241+
// Subscribe to further changes to all sources
242+
for cs, err := range watch.Subscribe(ibp.WatchOptions{Type: ibp.WatchTypeSources}) {
243243
if err != nil {
244244
fmt.Printf("ERROR: watch cycle error: %v\n", err)
245245
return err

0 commit comments

Comments
 (0)