diff --git a/runner/pkg/ibp/client.go b/runner/pkg/ibp/client.go index b1587deb..80874c78 100644 --- a/runner/pkg/ibp/client.go +++ b/runner/pkg/ibp/client.go @@ -9,20 +9,23 @@ import ( "github.com/aspect-build/aspect-gazelle/runner/pkg/socket" ) -type IncrementalClient interface { +type WatchClient interface { Connect() error Disconnect() error - AwaitCycle() iter.Seq2[*CycleSourcesMessage, error] + Subscribe(options WatchOptions) iter.Seq2[*CycleSourcesMessage, error] } type incClient struct { socketPath string socket socket.Socket[interface{}, map[string]interface{}] + + // The negotiated protocol version + version ProtocolVersion } -var _ IncrementalClient = (*incClient)(nil) +var _ WatchClient = (*incClient)(nil) -func NewClient(host string) IncrementalClient { +func NewClient(host string) WatchClient { return &incClient{ socketPath: host, } @@ -56,22 +59,39 @@ func (c *incClient) negotiate() error { if negReq["versions"] == nil { return fmt.Errorf("Received NEGOTIATE without versions: %v", negReq) } - if !slices.Contains(negReq["versions"].([]interface{}), (interface{})(float64(PROTOCOL_VERSION))) { - return fmt.Errorf("Received NEGOTIATE with unsupported versions %v, expected %d", negReq["versions"], PROTOCOL_VERSION) + rawVersions, isArray := negReq["versions"].([]interface{}) + if !isArray { + return fmt.Errorf("Invalid versions, expected []int, received type: %T", negReq["versions"]) + } + + negotiatedVersion, err := negotiateVersion(rawVersions) + if err != nil { + return err } err = c.socket.Send(negotiateResponseMessage{ Message: Message{ Kind: "NEGOTIATE_RESPONSE", }, - Version: PROTOCOL_VERSION, + Version: negotiatedVersion, }) if err != nil { return fmt.Errorf("failed to negotiate protocol version: %w", err) } + + c.version = negotiatedVersion return nil } +func negotiateVersion(acceptedVersions []interface{}) (ProtocolVersion, error) { + for _, v := range slices.Backward(acceptedVersions) { + if slices.Contains(abazelSupportedProtocolVersions, ProtocolVersion(v.(float64))) { + return ProtocolVersion(v.(float64)), nil + } + } + return -1, fmt.Errorf("unsupported versions %v, expected one of %v", acceptedVersions, abazelSupportedProtocolVersions) +} + func (c *incClient) Disconnect() error { if c.socket == nil { return fmt.Errorf("client not connected") @@ -85,8 +105,33 @@ func (c *incClient) Disconnect() error { return err } -func (c *incClient) AwaitCycle() iter.Seq2[*CycleSourcesMessage, error] { +func (c *incClient) Subscribe(options WatchOptions) iter.Seq2[*CycleSourcesMessage, error] { return func(yield func(*CycleSourcesMessage, error) bool) { + // Version 1+ require the initial SUBSCRIBE to start the subscription + if c.version >= VERSION_1 { + err := c.socket.Send(SubscribeMessage{ + Message: Message{ + Kind: "SUBSCRIBE", + }, + WatchType: options.Type, + }) + if err != nil { + yield(nil, err) + return + } + + msg, err := c.socket.Recv() + if err != nil { + yield(nil, err) + return + } + + if msg["kind"] != "SUBSCRIBE_RESPONSE" { + yield(nil, fmt.Errorf("expected SUBSCRIBE_RESPONSE, got %v", msg)) + return + } + } + for { msg, err := c.socket.Recv() if err != nil { diff --git a/runner/pkg/ibp/protocol.go b/runner/pkg/ibp/protocol.go index 4a5c30fc..839ee34c 100644 --- a/runner/pkg/ibp/protocol.go +++ b/runner/pkg/ibp/protocol.go @@ -5,6 +5,7 @@ import ( "fmt" "os" "path" + "slices" "sync/atomic" "github.com/aspect-build/aspect-gazelle/runner/pkg/socket" @@ -14,9 +15,15 @@ import ( type ProtocolVersion int const ( - PROTOCOL_VERSION ProtocolVersion = 0 + LEGACY_VERSION_0 ProtocolVersion = 0 + VERSION_1 = 1 + LATEST_VERSION = VERSION_1 ) +func (v ProtocolVersion) HasExplicitSubscribe() bool { + return v >= 1 +} + const PROTOCOL_SOCKET_ENV = "ABAZEL_WATCH_SOCKET_FILE" type IncrementalBazel interface { @@ -57,6 +64,22 @@ type capMessage struct { Caps map[string]bool `json:"caps"` } +type WatchType string + +const ( + WatchTypeRunfiles WatchType = "runfiles" + WatchTypeSources WatchType = "sources" +) + +type WatchOptions struct { + Type WatchType +} + +type SubscribeMessage struct { + Message + WatchType WatchType `json:"watch_type"` +} + type exitMessage struct { Message Description string `json:"description"` @@ -81,7 +104,10 @@ type CycleSourcesMessage struct { } // The versions supported by this host implementation of the protocol. -var abazelSupportedProtocolVersions = []ProtocolVersion{PROTOCOL_VERSION} +var abazelSupportedProtocolVersions = []ProtocolVersion{ + LEGACY_VERSION_0, + VERSION_1, +} type aspectBazelSocket = socket.Server[interface{}, map[string]any] @@ -176,8 +202,8 @@ func (p *aspectBazelProtocol) acceptNegotiation() error { if negResp["version"] == nil { return fmt.Errorf("Received NEGOTIATE_RESPONSE without version: %v", negResp) } - if ProtocolVersion(negResp["version"].(float64)) != PROTOCOL_VERSION { - return fmt.Errorf("Received NEGOTIATE_RESPONSE with unsupported version %v, expected %v", negResp["version"], PROTOCOL_VERSION) + if !slices.Contains(abazelSupportedProtocolVersions, ProtocolVersion(negResp["version"].(float64))) { + return fmt.Errorf("Received NEGOTIATE_RESPONSE with unsupported version %v, expected one of %v", negResp["version"], abazelSupportedProtocolVersions) } p.connectedCh <- ProtocolVersion(negResp["version"].(float64)) diff --git a/runner/runner.go b/runner/runner.go index 30fd82d9..82578f62 100644 --- a/runner/runner.go +++ b/runner/runner.go @@ -238,8 +238,8 @@ func (p *GazelleRunner) Watch(watchAddress string, cmd GazelleCommand, mode Gaze )) defer t.End() - // Subscribe to further changes - for cs, err := range watch.AwaitCycle() { + // Subscribe to further changes to all sources + for cs, err := range watch.Subscribe(ibp.WatchOptions{Type: ibp.WatchTypeSources}) { if err != nil { fmt.Printf("ERROR: watch cycle error: %v\n", err) return err