Skip to content

Commit c60fbce

Browse files
committed
feat: add explicit SUBSCRIBE message to initiate subscriptions
1 parent 8e32503 commit c60fbce

File tree

4 files changed

+105
-20
lines changed

4 files changed

+105
-20
lines changed

runner/pkg/ibp/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ go_library(
1010
visibility = ["//visibility:public"],
1111
deps = [
1212
"//pkg/socket",
13+
"@aspect_gazelle//common/logger",
1314
"@com_github_fatih_color//:color",
1415
],
1516
)

runner/pkg/ibp/client.go

Lines changed: 65 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,23 +5,27 @@ import (
55
"iter"
66
"slices"
77

8+
BazelLog "github.com/aspect-build/aspect-gazelle/common/logger"
89
"github.com/aspect-build/aspect-gazelle/runner/pkg/socket"
910
)
1011

11-
type IncrementalClient interface {
12+
type WatchClient interface {
1213
Connect() error
1314
Disconnect() error
14-
AwaitCycle() iter.Seq[CycleSourcesMessage]
15+
Subscribe(options WatchOptions) iter.Seq2[*CycleSourcesMessage, error]
1516
}
1617

1718
type incClient struct {
1819
socketPath string
1920
socket socket.Socket[interface{}, map[string]interface{}]
21+
22+
// The negotiated protocol version
23+
version ProtocolVersion
2024
}
2125

22-
var _ IncrementalClient = (*incClient)(nil)
26+
var _ WatchClient = (*incClient)(nil)
2327

24-
func NewClient(host string) IncrementalClient {
28+
func NewClient(host string) WatchClient {
2529
return &incClient{
2630
socketPath: host,
2731
}
@@ -55,22 +59,39 @@ func (c *incClient) negotiate() error {
5559
if negReq["versions"] == nil {
5660
return fmt.Errorf("Received NEGOTIATE without versions: %v", negReq)
5761
}
58-
if !slices.Contains(negReq["versions"].([]interface{}), (interface{})(float64(PROTOCOL_VERSION))) {
59-
return fmt.Errorf("Received NEGOTIATE with unsupported versions %v, expected %d", negReq["versions"], PROTOCOL_VERSION)
62+
rawVersions, isArray := negReq["versions"].([]interface{})
63+
if !isArray {
64+
return fmt.Errorf("Invalid versions, expected []int, received type: %T", negReq["versions"])
65+
}
66+
67+
negotiatedVersion, err := negoiateVersion(rawVersions)
68+
if err != nil {
69+
return err
6070
}
6171

6272
err = c.socket.Send(negotiateResponseMessage{
6373
Message: Message{
6474
Kind: "NEGOTIATE_RESPONSE",
6575
},
66-
Version: PROTOCOL_VERSION,
76+
Version: negotiatedVersion,
6777
})
6878
if err != nil {
6979
return fmt.Errorf("failed to negotiate protocol version: %w", err)
7080
}
81+
82+
c.version = negotiatedVersion
7183
return nil
7284
}
7385

86+
func negoiateVersion(acceptedVersions []interface{}) (ProtocolVersion, error) {
87+
for _, v := range slices.Backward(acceptedVersions) {
88+
if slices.Contains(abazelSupportedProtocolVersions, ProtocolVersion(v.(float64))) {
89+
return ProtocolVersion(v.(float64)), nil
90+
}
91+
}
92+
return -1, fmt.Errorf("unsupported versions %v, expected one of %v", acceptedVersions, abazelSupportedProtocolVersions)
93+
}
94+
7495
func (c *incClient) Disconnect() error {
7596
if c.socket == nil {
7697
return fmt.Errorf("client not connected")
@@ -84,12 +105,37 @@ func (c *incClient) Disconnect() error {
84105
return err
85106
}
86107

87-
func (c *incClient) AwaitCycle() iter.Seq[CycleSourcesMessage] {
88-
return func(yield func(CycleSourcesMessage) bool) {
108+
func (c *incClient) Subscribe(options WatchOptions) iter.Seq2[*CycleSourcesMessage, error] {
109+
return func(yield func(*CycleSourcesMessage, error) bool) {
110+
if c.version >= VERSION_1 {
111+
err := c.socket.Send(SubscribeMessage{
112+
Message: Message{
113+
Kind: "SUBSCRIBE",
114+
},
115+
WatchType: options.Type,
116+
})
117+
if err != nil {
118+
yield(nil, err)
119+
return
120+
}
121+
122+
msg, err := c.socket.Recv()
123+
if err != nil {
124+
yield(nil, err)
125+
return
126+
}
127+
128+
if msg["kind"] != "SUBSCRIBE_RESPONSE" {
129+
yield(nil, fmt.Errorf("expected SUBSCRIBE_RESPONSE, got %v", msg))
130+
return
131+
}
132+
}
133+
89134
for {
90135
msg, err := c.socket.Recv()
91136
if err != nil {
92137
fmt.Printf("Error receiving message: %v\n", err)
138+
yield(nil, err)
93139
return
94140
}
95141

@@ -100,21 +146,28 @@ func (c *incClient) AwaitCycle() iter.Seq[CycleSourcesMessage] {
100146
continue
101147
}
102148

103-
c.socket.Send(CycleMessage{
149+
err = c.socket.Send(CycleMessage{
104150
Message: Message{
105151
Kind: "CYCLE_STARTED",
106152
},
107153
CycleId: cycleEvent.CycleId,
108154
})
155+
if err != nil {
156+
yield(nil, err)
157+
return
158+
}
109159

110-
r := yield(cycleEvent)
160+
r := yield(&cycleEvent, nil)
111161

112-
c.socket.Send(CycleMessage{
162+
err = c.socket.Send(CycleMessage{
113163
Message: Message{
114164
Kind: "CYCLE_COMPLETED",
115165
},
116166
CycleId: cycleEvent.CycleId,
117167
})
168+
if err != nil {
169+
BazelLog.Warnf("Failed to send CYCLE_COMPLETED for cycle_id=%d: %v\n", cycleEvent.CycleId, err)
170+
}
118171

119172
if !r {
120173
return

runner/pkg/ibp/protocol.go

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,25 @@ import (
55
"fmt"
66
"os"
77
"path"
8+
"slices"
89
"sync/atomic"
910

1011
"github.com/aspect-build/aspect-gazelle/runner/pkg/socket"
1112
"github.com/fatih/color"
1213
)
1314

14-
const PROTOCOL_VERSION = 0
15+
type ProtocolVersion int
16+
17+
const (
18+
LEGACY_VERSION_0 ProtocolVersion = 0
19+
VERSION_1 = 1
20+
LATEST_VERSION = VERSION_1
21+
)
22+
23+
func (v ProtocolVersion) HasExplicitSubscribe() bool {
24+
return v >= 1
25+
}
26+
1527
const PROTOCOL_SOCKET_ENV = "ABAZEL_WATCH_SOCKET_FILE"
1628

1729
type IncrementalBazel interface {
@@ -40,18 +52,34 @@ type Message struct {
4052

4153
type negotiateMessage struct {
4254
Message
43-
Versions []int `json:"versions"`
55+
Versions []ProtocolVersion `json:"versions"`
4456
}
4557
type negotiateResponseMessage struct {
4658
Message
47-
Version int `json:"version"`
59+
Version ProtocolVersion `json:"version"`
4860
}
4961

5062
type capMessage struct {
5163
Message
5264
Caps map[string]bool `json:"caps"`
5365
}
5466

67+
type WatchType string
68+
69+
const (
70+
WatchTypeRunfiles WatchType = "runfiles"
71+
WatchTypeSources WatchType = "sources"
72+
)
73+
74+
type WatchOptions struct {
75+
Type WatchType
76+
}
77+
78+
type SubscribeMessage struct {
79+
Message
80+
WatchType WatchType `json:"watch_type"`
81+
}
82+
5583
type exitMessage struct {
5684
Message
5785
Description string `json:"description"`
@@ -77,7 +105,10 @@ type CycleSourcesMessage struct {
77105
}
78106

79107
// The versions supported by this host implementation of the protocol.
80-
var abazelSupportedProtocolVersions = []int{PROTOCOL_VERSION}
108+
var abazelSupportedProtocolVersions = []ProtocolVersion{
109+
LEGACY_VERSION_0,
110+
VERSION_1,
111+
}
81112

82113
type aspectBazelSocket = socket.Server[interface{}, map[string]any]
83114

@@ -172,8 +203,8 @@ func (p *aspectBazelProtocol) acceptNegotiation() error {
172203
if negResp["version"] == nil {
173204
return fmt.Errorf("Received NEGOTIATE_RESPONSE without version: %v", negResp)
174205
}
175-
if negResp["version"].(float64) != PROTOCOL_VERSION {
176-
return fmt.Errorf("Received NEGOTIATE_RESPONSE with unsupported version %v, expected %d", negResp["version"], PROTOCOL_VERSION)
206+
if !slices.Contains(abazelSupportedProtocolVersions, ProtocolVersion(negResp["version"].(float64))) {
207+
return fmt.Errorf("Received NEGOTIATE_RESPONSE with unsupported version %v, expected one of %v", negResp["version"], abazelSupportedProtocolVersions)
177208
}
178209

179210
p.connectedCh <- int(negResp["version"].(float64))

runner/runner.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -238,8 +238,8 @@ func (p *GazelleRunner) Watch(watchAddress string, cmd GazelleCommand, mode Gaze
238238
))
239239
defer t.End()
240240

241-
// Subscribe to further changes
242-
for cs := range watch.AwaitCycle() {
241+
// Subscribe to further changes to all sources
242+
for cs := range watch.Subscribe(ibp.WatchOptions{Type: ibp.WatchTypeSources}) {
243243
_, t := p.tracer.Start(ctx, "GazelleRunner.Watch.Trigger")
244244

245245
// The directories that have changed which gazelle should update.

0 commit comments

Comments
 (0)