Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 53 additions & 8 deletions runner/pkg/ibp/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,23 @@ import (
"github.com/aspect-build/aspect-gazelle/runner/pkg/socket"
)

type IncrementalClient interface {
type WatchClient interface {
Connect() error
Disconnect() error
AwaitCycle() iter.Seq2[*CycleSourcesMessage, error]
Subscribe(options WatchOptions) iter.Seq2[*CycleSourcesMessage, error]
}

type incClient struct {
socketPath string
socket socket.Socket[interface{}, map[string]interface{}]

// The negotiated protocol version
version ProtocolVersion
}

var _ IncrementalClient = (*incClient)(nil)
var _ WatchClient = (*incClient)(nil)

func NewClient(host string) IncrementalClient {
func NewClient(host string) WatchClient {
return &incClient{
socketPath: host,
}
Expand Down Expand Up @@ -56,22 +59,39 @@ func (c *incClient) negotiate() error {
if negReq["versions"] == nil {
return fmt.Errorf("Received NEGOTIATE without versions: %v", negReq)
}
if !slices.Contains(negReq["versions"].([]interface{}), (interface{})(float64(PROTOCOL_VERSION))) {
return fmt.Errorf("Received NEGOTIATE with unsupported versions %v, expected %d", negReq["versions"], PROTOCOL_VERSION)
rawVersions, isArray := negReq["versions"].([]interface{})
if !isArray {
return fmt.Errorf("Invalid versions, expected []int, received type: %T", negReq["versions"])
}

negotiatedVersion, err := negotiateVersion(rawVersions)
if err != nil {
return err
}

err = c.socket.Send(negotiateResponseMessage{
Message: Message{
Kind: "NEGOTIATE_RESPONSE",
},
Version: PROTOCOL_VERSION,
Version: negotiatedVersion,
})
if err != nil {
return fmt.Errorf("failed to negotiate protocol version: %w", err)
}

c.version = negotiatedVersion
return nil
}

func negotiateVersion(acceptedVersions []interface{}) (ProtocolVersion, error) {
for _, v := range slices.Backward(acceptedVersions) {
if slices.Contains(abazelSupportedProtocolVersions, ProtocolVersion(v.(float64))) {
return ProtocolVersion(v.(float64)), nil
}
}
return -1, fmt.Errorf("unsupported versions %v, expected one of %v", acceptedVersions, abazelSupportedProtocolVersions)
}

func (c *incClient) Disconnect() error {
if c.socket == nil {
return fmt.Errorf("client not connected")
Expand All @@ -85,8 +105,33 @@ func (c *incClient) Disconnect() error {
return err
}

func (c *incClient) AwaitCycle() iter.Seq2[*CycleSourcesMessage, error] {
func (c *incClient) Subscribe(options WatchOptions) iter.Seq2[*CycleSourcesMessage, error] {
return func(yield func(*CycleSourcesMessage, error) bool) {
// Version 1+ require the initial SUBSCRIBE to start the subscription
if c.version >= VERSION_1 {
err := c.socket.Send(SubscribeMessage{
Message: Message{
Kind: "SUBSCRIBE",
},
WatchType: options.Type,
})
if err != nil {
yield(nil, err)
return
}

msg, err := c.socket.Recv()
if err != nil {
yield(nil, err)
return
}

if msg["kind"] != "SUBSCRIBE_RESPONSE" {
yield(nil, fmt.Errorf("expected SUBSCRIBE_RESPONSE, got %v", msg))
return
}
}

for {
msg, err := c.socket.Recv()
if err != nil {
Expand Down
34 changes: 30 additions & 4 deletions runner/pkg/ibp/protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"os"
"path"
"slices"
"sync/atomic"

"github.com/aspect-build/aspect-gazelle/runner/pkg/socket"
Expand All @@ -14,9 +15,15 @@ import (
type ProtocolVersion int

const (
PROTOCOL_VERSION ProtocolVersion = 0
LEGACY_VERSION_0 ProtocolVersion = 0
VERSION_1 = 1
LATEST_VERSION = VERSION_1
)

func (v ProtocolVersion) HasExplicitSubscribe() bool {
return v >= 1
}

const PROTOCOL_SOCKET_ENV = "ABAZEL_WATCH_SOCKET_FILE"

type IncrementalBazel interface {
Expand Down Expand Up @@ -57,6 +64,22 @@ type capMessage struct {
Caps map[string]bool `json:"caps"`
}

type WatchType string

const (
WatchTypeRunfiles WatchType = "runfiles"
WatchTypeSources WatchType = "sources"
)

type WatchOptions struct {
Type WatchType
}

type SubscribeMessage struct {
Message
WatchType WatchType `json:"watch_type"`
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Options:

  • single type
  • multiple bools (source + runfiles)
  • array of types

If a binary says it only wants to subscribe to sources should it be restarted if it's runfiles change? Or we should ignore runfiles? If watching both should should they be distinguished in the CYCLE message somewhere such as a sources[path]: {type: 'runfiles' | 'sources'} flag? :/

}

type exitMessage struct {
Message
Description string `json:"description"`
Expand All @@ -81,7 +104,10 @@ type CycleSourcesMessage struct {
}

// The versions supported by this host implementation of the protocol.
var abazelSupportedProtocolVersions = []ProtocolVersion{PROTOCOL_VERSION}
var abazelSupportedProtocolVersions = []ProtocolVersion{
LEGACY_VERSION_0,
VERSION_1,
}

type aspectBazelSocket = socket.Server[interface{}, map[string]any]

Expand Down Expand Up @@ -176,8 +202,8 @@ func (p *aspectBazelProtocol) acceptNegotiation() error {
if negResp["version"] == nil {
return fmt.Errorf("Received NEGOTIATE_RESPONSE without version: %v", negResp)
}
if ProtocolVersion(negResp["version"].(float64)) != PROTOCOL_VERSION {
return fmt.Errorf("Received NEGOTIATE_RESPONSE with unsupported version %v, expected %v", negResp["version"], PROTOCOL_VERSION)
if !slices.Contains(abazelSupportedProtocolVersions, ProtocolVersion(negResp["version"].(float64))) {
return fmt.Errorf("Received NEGOTIATE_RESPONSE with unsupported version %v, expected one of %v", negResp["version"], abazelSupportedProtocolVersions)
}

p.connectedCh <- ProtocolVersion(negResp["version"].(float64))
Expand Down
4 changes: 2 additions & 2 deletions runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,8 @@ func (p *GazelleRunner) Watch(watchAddress string, cmd GazelleCommand, mode Gaze
))
defer t.End()

// Subscribe to further changes
for cs, err := range watch.AwaitCycle() {
// Subscribe to further changes to all sources
for cs, err := range watch.Subscribe(ibp.WatchOptions{Type: ibp.WatchTypeSources}) {
if err != nil {
fmt.Printf("ERROR: watch cycle error: %v\n", err)
return err
Expand Down
Loading