Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 60 additions & 20 deletions protocol/localstatequery/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,28 +114,58 @@ func (c *Client) Start() {

// Acquire starts the acquire process for the specified chain point
func (c *Client) Acquire(point *common.Point) error {
var msg string
if point != nil {
msg = fmt.Sprintf(
"calling Acquire(point: {Slot: %d, Hash: %x})",
point.Slot,
point.Hash,
// Use volatile tip if no point provided
if point == nil {
return c.AcquireVolatileTip()
}
c.Protocol.Logger().
Debug(
fmt.Sprintf(
"calling Acquire(point: {Slot: %d, Hash: %x})",
point.Slot,
point.Hash,
),
"component", "network",
"protocol", ProtocolName,
"role", "client",
"connection_id", c.callbackContext.ConnectionId.String(),
)
} else {
msg = "calling Acquire(point: latest)"
c.busyMutex.Lock()
defer c.busyMutex.Unlock()
acquireTarget := AcquireSpecificPoint{
Point: *point,
}
return c.acquire(acquireTarget)
}

func (c *Client) AcquireVolatileTip() error {
c.Protocol.Logger().
Debug(
"calling AcquireVolatileTip",
"component", "network",
"protocol", ProtocolName,
"role", "client",
"connection_id", c.callbackContext.ConnectionId.String(),
)
c.busyMutex.Lock()
defer c.busyMutex.Unlock()
acquireTarget := AcquireVolatileTip{}
return c.acquire(acquireTarget)
}

func (c *Client) AcquireImmutableTip() error {
c.Protocol.Logger().
Debug(
msg,
"calling AcquireImmutableTip",
"component", "network",
"protocol", ProtocolName,
"role", "client",
"connection_id", c.callbackContext.ConnectionId.String(),
)
c.busyMutex.Lock()
defer c.busyMutex.Unlock()
return c.acquire(point)
acquireTarget := AcquireImmutableTip{}
return c.acquire(acquireTarget)
}

// Release releases the previously acquired chain point
Expand Down Expand Up @@ -906,19 +936,29 @@ func (c *Client) handleResult(msg protocol.Message) error {
return nil
}

func (c *Client) acquire(point *common.Point) error {
func (c *Client) acquire(acquireTarget AcquireTarget) error {
var msg protocol.Message
if c.acquired {
if point != nil {
msg = NewMsgReAcquire(*point)
} else {
msg = NewMsgReAcquireNoPoint()
switch t := acquireTarget.(type) {
case AcquireSpecificPoint:
msg = NewMsgReAcquire(t.Point)
case AcquireVolatileTip:
msg = NewMsgReAcquireVolatileTip()
case AcquireImmutableTip:
msg = NewMsgReAcquireImmutableTip()
default:
return fmt.Errorf("invalid acquire point provided")
}
} else {
if point != nil {
msg = NewMsgAcquire(*point)
} else {
msg = NewMsgAcquireNoPoint()
switch t := acquireTarget.(type) {
case AcquireSpecificPoint:
msg = NewMsgAcquire(t.Point)
case AcquireVolatileTip:
msg = NewMsgAcquireVolatileTip()
case AcquireImmutableTip:
msg = NewMsgAcquireImmutableTip()
default:
return fmt.Errorf("invalid acquire point provided")
}
}
if err := c.SendMessage(msg); err != nil {
Expand All @@ -944,7 +984,7 @@ func (c *Client) release() error {
func (c *Client) runQuery(query interface{}, result interface{}) error {
msg := NewMsgQuery(query)
if !c.acquired {
if err := c.acquire(nil); err != nil {
if err := c.acquire(AcquireVolatileTip{}); err != nil {
return err
}
}
Expand Down
2 changes: 1 addition & 1 deletion protocol/localstatequery/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ var conversationHandshakeAcquire = []ouroboros_mock.ConversationEntry{
ouroboros_mock.ConversationEntryHandshakeNtCResponse,
ouroboros_mock.ConversationEntryInput{
ProtocolId: localstatequery.ProtocolId,
MessageType: localstatequery.MessageTypeAcquireNoPoint,
MessageType: localstatequery.MessageTypeAcquireVolatileTip,
},
ouroboros_mock.ConversationEntryOutput{
ProtocolId: localstatequery.ProtocolId,
Expand Down
35 changes: 31 additions & 4 deletions protocol/localstatequery/localstatequery.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,11 @@ var StateMap = protocol.StateMap{
NewState: stateAcquiring,
},
{
MsgType: MessageTypeAcquireNoPoint,
MsgType: MessageTypeAcquireVolatileTip,
NewState: stateAcquiring,
},
{
MsgType: MessageTypeAcquireImmutableTip,
NewState: stateAcquiring,
},
{
Expand Down Expand Up @@ -81,7 +85,11 @@ var StateMap = protocol.StateMap{
NewState: stateAcquiring,
},
{
MsgType: MessageTypeReacquireNoPoint,
MsgType: MessageTypeReacquireVolatileTip,
NewState: stateAcquiring,
},
{
MsgType: MessageTypeReacquireImmutableTip,
NewState: stateAcquiring,
},
{
Expand Down Expand Up @@ -121,6 +129,25 @@ type Config struct {
QueryTimeout time.Duration
}

// Acquire target types
type AcquireTarget interface {
isAcquireTarget()
}

type AcquireSpecificPoint struct {
Point common.Point
}

func (AcquireSpecificPoint) isAcquireTarget() {}

type AcquireVolatileTip struct{}

func (AcquireVolatileTip) isAcquireTarget() {}

type AcquireImmutableTip struct{}

func (AcquireImmutableTip) isAcquireTarget() {}

// Callback context
type CallbackContext struct {
ConnectionId connection.ConnectionId
Expand All @@ -129,10 +156,10 @@ type CallbackContext struct {
}

// Callback function types
type AcquireFunc func(CallbackContext, *common.Point) error
type AcquireFunc func(CallbackContext, AcquireTarget) error
type QueryFunc func(CallbackContext, any) error
type ReleaseFunc func(CallbackContext) error
type ReAcquireFunc func(CallbackContext, *common.Point) error
type ReAcquireFunc func(CallbackContext, AcquireTarget) error
type DoneFunc func(CallbackContext) error

// New returns a new LocalStateQuery object
Expand Down
76 changes: 54 additions & 22 deletions protocol/localstatequery/messages.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,18 @@ import (

// Message types
const (
MessageTypeAcquire = 0
MessageTypeAcquired = 1
MessageTypeFailure = 2
MessageTypeQuery = 3
MessageTypeResult = 4
MessageTypeRelease = 5
MessageTypeReacquire = 6
MessageTypeDone = 7
MessageTypeAcquireNoPoint = 8
MessageTypeReacquireNoPoint = 9
MessageTypeAcquire = 0
MessageTypeAcquired = 1
MessageTypeFailure = 2
MessageTypeQuery = 3
MessageTypeResult = 4
MessageTypeRelease = 5
MessageTypeReacquire = 6
MessageTypeDone = 7
MessageTypeAcquireVolatileTip = 8
MessageTypeReacquireVolatileTip = 9
MessageTypeAcquireImmutableTip = 10
MessageTypeReacquireImmutableTip = 11
)

// Acquire failure reasons
Expand All @@ -60,10 +62,14 @@ func NewMsgFromCbor(msgType uint, data []byte) (protocol.Message, error) {
ret = &MsgRelease{}
case MessageTypeReacquire:
ret = &MsgReAcquire{}
case MessageTypeAcquireNoPoint:
ret = &MsgAcquireNoPoint{}
case MessageTypeReacquireNoPoint:
ret = &MsgReAcquireNoPoint{}
case MessageTypeAcquireVolatileTip:
ret = &MsgAcquireVolatileTip{}
case MessageTypeReacquireVolatileTip:
ret = &MsgReAcquireVolatileTip{}
case MessageTypeAcquireImmutableTip:
ret = &MsgAcquireImmutableTip{}
case MessageTypeReacquireImmutableTip:
ret = &MsgReAcquireImmutableTip{}
case MessageTypeDone:
ret = &MsgDone{}
}
Expand Down Expand Up @@ -92,14 +98,27 @@ func NewMsgAcquire(point common.Point) *MsgAcquire {
return m
}

type MsgAcquireNoPoint struct {
type MsgAcquireVolatileTip struct {
protocol.MessageBase
}

func NewMsgAcquireNoPoint() *MsgAcquireNoPoint {
m := &MsgAcquireNoPoint{
func NewMsgAcquireVolatileTip() *MsgAcquireVolatileTip {
m := &MsgAcquireVolatileTip{
MessageBase: protocol.MessageBase{
MessageType: MessageTypeAcquireNoPoint,
MessageType: MessageTypeAcquireVolatileTip,
},
}
return m
}

type MsgAcquireImmutableTip struct {
protocol.MessageBase
}

func NewMsgAcquireImmutableTip() *MsgAcquireImmutableTip {
m := &MsgAcquireImmutableTip{
MessageBase: protocol.MessageBase{
MessageType: MessageTypeAcquireImmutableTip,
},
}
return m
Expand Down Expand Up @@ -191,14 +210,27 @@ func NewMsgReAcquire(point common.Point) *MsgReAcquire {
return m
}

type MsgReAcquireNoPoint struct {
type MsgReAcquireVolatileTip struct {
protocol.MessageBase
}

func NewMsgReAcquireVolatileTip() *MsgReAcquireVolatileTip {
m := &MsgReAcquireVolatileTip{
MessageBase: protocol.MessageBase{
MessageType: MessageTypeReacquireVolatileTip,
},
}
return m
}

type MsgReAcquireImmutableTip struct {
protocol.MessageBase
}

func NewMsgReAcquireNoPoint() *MsgReAcquireNoPoint {
m := &MsgReAcquireNoPoint{
func NewMsgReAcquireImmutableTip() *MsgReAcquireImmutableTip {
m := &MsgReAcquireImmutableTip{
MessageBase: protocol.MessageBase{
MessageType: MessageTypeReacquireNoPoint,
MessageType: MessageTypeReacquireImmutableTip,
},
}
return m
Expand Down
8 changes: 4 additions & 4 deletions protocol/localstatequery/messages_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,13 @@ var tests = []testDefinition{
},
{
CborHex: "8108",
Message: NewMsgAcquireNoPoint(),
MessageType: MessageTypeAcquireNoPoint,
Message: NewMsgAcquireVolatileTip(),
MessageType: MessageTypeAcquireVolatileTip,
},
{
CborHex: "8109",
Message: NewMsgReAcquireNoPoint(),
MessageType: MessageTypeReacquireNoPoint,
Message: NewMsgReAcquireVolatileTip(),
MessageType: MessageTypeReacquireVolatileTip,
},
}

Expand Down
36 changes: 22 additions & 14 deletions protocol/localstatequery/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,9 @@ func (s *Server) messageHandler(msg protocol.Message) error {
err = s.handleRelease()
case MessageTypeReacquire:
err = s.handleReAcquire(msg)
case MessageTypeAcquireNoPoint:
case MessageTypeAcquireVolatileTip:
err = s.handleAcquire(msg)
case MessageTypeReacquireNoPoint:
case MessageTypeReacquireVolatileTip:
err = s.handleReAcquire(msg)
case MessageTypeDone:
err = s.handleDone()
Expand All @@ -104,15 +104,19 @@ func (s *Server) handleAcquire(msg protocol.Message) error {
"received local-state-query Acquire message but no callback function is defined",
)
}
var acquireTarget AcquireTarget
switch msgAcquire := msg.(type) {
case *MsgAcquire:
// Call the user callback function
return s.config.AcquireFunc(s.callbackContext, &msgAcquire.Point)
case *MsgAcquireNoPoint:
// Call the user callback function
return s.config.AcquireFunc(s.callbackContext, nil)
acquireTarget = AcquireSpecificPoint{
Point: msgAcquire.Point,
}
case *MsgAcquireVolatileTip:
acquireTarget = AcquireVolatileTip{}
case *MsgAcquireImmutableTip:
acquireTarget = AcquireImmutableTip{}
}
return nil
// Call the user callback function
return s.config.AcquireFunc(s.callbackContext, acquireTarget)
}

func (s *Server) handleQuery(msg protocol.Message) error {
Expand Down Expand Up @@ -163,15 +167,19 @@ func (s *Server) handleReAcquire(msg protocol.Message) error {
"received local-state-query ReAcquire message but no callback function is defined",
)
}
var acquireTarget AcquireTarget
switch msgReAcquire := msg.(type) {
case *MsgReAcquire:
// Call the user callback function
return s.config.ReAcquireFunc(s.callbackContext, &msgReAcquire.Point)
case *MsgReAcquireNoPoint:
// Call the user callback function
return s.config.ReAcquireFunc(s.callbackContext, nil)
acquireTarget = AcquireSpecificPoint{
Point: msgReAcquire.Point,
}
case *MsgReAcquireVolatileTip:
acquireTarget = AcquireVolatileTip{}
case *MsgReAcquireImmutableTip:
acquireTarget = AcquireImmutableTip{}
}
return nil
// Call the user callback function
return s.config.ReAcquireFunc(s.callbackContext, acquireTarget)
}

func (s *Server) handleDone() error {
Expand Down
Loading