Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions protocol/localstatequery/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -908,9 +908,9 @@ func (c *Client) handleFailure(msg protocol.Message) error {
msgFailure := msg.(*MsgFailure)
switch msgFailure.Failure {
case AcquireFailurePointTooOld:
c.acquireResultChan <- AcquireFailurePointTooOldError{}
c.acquireResultChan <- ErrAcquireFailurePointTooOld
case AcquireFailurePointNotOnChain:
c.acquireResultChan <- AcquireFailurePointNotOnChainError{}
c.acquireResultChan <- ErrAcquireFailurePointNotOnChain
default:
return fmt.Errorf("unknown failure type: %d", msgFailure.Failure)
}
Expand Down
18 changes: 5 additions & 13 deletions protocol/localstatequery/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,10 @@

package localstatequery

// AcquireFailurePointTooOldError indicates a failure to acquire a point due to it being too old
type AcquireFailurePointTooOldError struct {
}
import "errors"

func (e AcquireFailurePointTooOldError) Error() string {
return "acquire failure: point too old"
}
// ErrAcquireFailurePointTooOld indicates a failure to acquire a point due to it being too old
var ErrAcquireFailurePointTooOld = errors.New("acquire failure: point too old")

// AcquireFailurePointNotOnChainError indicates a failure to acquire a point due to it not being present on the chain
type AcquireFailurePointNotOnChainError struct {
}

func (e AcquireFailurePointNotOnChainError) Error() string {
return "acquire failure: point not on chain"
}
// ErrAcquireFailurePointNotOnChain indicates a failure to acquire a point due to it not being present on the chain
var ErrAcquireFailurePointNotOnChain = errors.New("acquire failure: point not on chain")
22 changes: 2 additions & 20 deletions protocol/localstatequery/localstatequery.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,6 @@ type Config struct {
AcquireFunc AcquireFunc
QueryFunc QueryFunc
ReleaseFunc ReleaseFunc
ReAcquireFunc ReAcquireFunc
DoneFunc DoneFunc
AcquireTimeout time.Duration
QueryTimeout time.Duration
}
Expand Down Expand Up @@ -156,11 +154,9 @@ type CallbackContext struct {
}

// Callback function types
type AcquireFunc func(CallbackContext, AcquireTarget) error
type QueryFunc func(CallbackContext, any) error
type AcquireFunc func(CallbackContext, AcquireTarget, bool) error
type QueryFunc func(CallbackContext, any) (any, error)
type ReleaseFunc func(CallbackContext) error
type ReAcquireFunc func(CallbackContext, AcquireTarget) error
type DoneFunc func(CallbackContext) error

// New returns a new LocalStateQuery object
func New(protoOptions protocol.ProtocolOptions, cfg *Config) *LocalStateQuery {
Expand Down Expand Up @@ -208,20 +204,6 @@ func WithReleaseFunc(releaseFunc ReleaseFunc) LocalStateQueryOptionFunc {
}
}

// WithReAcquireFunc specifies the ReAcquire callback function when acting as a server
func WithReAcquireFunc(reAcquireFunc ReAcquireFunc) LocalStateQueryOptionFunc {
return func(c *Config) {
c.ReAcquireFunc = reAcquireFunc
}
}

// WithDoneFunc specifies the Done callback function when acting as a server
func WithDoneFunc(doneFunc DoneFunc) LocalStateQueryOptionFunc {
return func(c *Config) {
c.DoneFunc = doneFunc
}
}

// WithAcquireTimeout specifies the timeout for the Acquire operation when acting as a client
func WithAcquireTimeout(timeout time.Duration) LocalStateQueryOptionFunc {
return func(c *Config) {
Expand Down
67 changes: 56 additions & 11 deletions protocol/localstatequery/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
package localstatequery

import (
"errors"
"fmt"

"github.com/blinklabs-io/gouroboros/cbor"
"github.com/blinklabs-io/gouroboros/protocol"
)

Expand Down Expand Up @@ -116,7 +118,27 @@ func (s *Server) handleAcquire(msg protocol.Message) error {
acquireTarget = AcquireImmutableTip{}
}
// Call the user callback function
return s.config.AcquireFunc(s.callbackContext, acquireTarget)
err := s.config.AcquireFunc(s.callbackContext, acquireTarget, false)
if err != nil {
if errors.Is(err, ErrAcquireFailurePointTooOld) {
respMsg := NewMsgFailure(AcquireFailurePointTooOld)
if err := s.SendMessage(respMsg); err != nil {
return err
}
} else if errors.Is(err, ErrAcquireFailurePointNotOnChain) {
respMsg := NewMsgFailure(AcquireFailurePointNotOnChain)
if err := s.SendMessage(respMsg); err != nil {
return err
}
} else {
return err
}
}
respMsg := NewMsgAcquired()
if err := s.SendMessage(respMsg); err != nil {
return err
}
return nil
}

func (s *Server) handleQuery(msg protocol.Message) error {
Expand All @@ -134,7 +156,20 @@ func (s *Server) handleQuery(msg protocol.Message) error {
}
msgQuery := msg.(*MsgQuery)
// Call the user callback function
return s.config.QueryFunc(s.callbackContext, msgQuery.Query)
result, err := s.config.QueryFunc(s.callbackContext, msgQuery.Query)
if err != nil {
return err
}
// Encode query result
resultCbor, err := cbor.Encode(&result)
if err != nil {
return err
}
respMsg := NewMsgResult(resultCbor)
if err := s.SendMessage(respMsg); err != nil {
return err
}
return nil
}

func (s *Server) handleRelease() error {
Expand Down Expand Up @@ -162,7 +197,7 @@ func (s *Server) handleReAcquire(msg protocol.Message) error {
"role", "server",
"connection_id", s.callbackContext.ConnectionId.String(),
)
if s.config.ReAcquireFunc == nil {
if s.config.AcquireFunc == nil {
return fmt.Errorf(
"received local-state-query ReAcquire message but no callback function is defined",
)
Expand All @@ -179,7 +214,23 @@ func (s *Server) handleReAcquire(msg protocol.Message) error {
acquireTarget = AcquireImmutableTip{}
}
// Call the user callback function
return s.config.ReAcquireFunc(s.callbackContext, acquireTarget)
err := s.config.AcquireFunc(s.callbackContext, acquireTarget, true)
if err != nil {
if errors.Is(err, ErrAcquireFailurePointTooOld) {
respMsg := NewMsgFailure(AcquireFailurePointTooOld)
if err := s.SendMessage(respMsg); err != nil {
return err
}
} else if errors.Is(err, ErrAcquireFailurePointNotOnChain) {
respMsg := NewMsgFailure(AcquireFailurePointNotOnChain)
if err := s.SendMessage(respMsg); err != nil {
return err
}
} else {
return err
}
}
return nil
}

func (s *Server) handleDone() error {
Expand All @@ -190,11 +241,5 @@ func (s *Server) handleDone() error {
"role", "server",
"connection_id", s.callbackContext.ConnectionId.String(),
)
if s.config.DoneFunc == nil {
return fmt.Errorf(
"received local-state-query Done message but no callback function is defined",
)
}
// Call the user callback function
return s.config.DoneFunc(s.callbackContext)
return nil
}
Loading