Skip to content

Commit 70be436

Browse files
authored
refactor: improve local-state-query callback interface (#812)
* switch custom errors to vars instead of types * merge Acquire/ReAcquire callbacks * remove Done callback * add ability to return result in Query callback * automatically send acquire result response * automatically send query result
1 parent c0455cb commit 70be436

File tree

4 files changed

+65
-46
lines changed

4 files changed

+65
-46
lines changed

protocol/localstatequery/client.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -908,9 +908,9 @@ func (c *Client) handleFailure(msg protocol.Message) error {
908908
msgFailure := msg.(*MsgFailure)
909909
switch msgFailure.Failure {
910910
case AcquireFailurePointTooOld:
911-
c.acquireResultChan <- AcquireFailurePointTooOldError{}
911+
c.acquireResultChan <- ErrAcquireFailurePointTooOld
912912
case AcquireFailurePointNotOnChain:
913-
c.acquireResultChan <- AcquireFailurePointNotOnChainError{}
913+
c.acquireResultChan <- ErrAcquireFailurePointNotOnChain
914914
default:
915915
return fmt.Errorf("unknown failure type: %d", msgFailure.Failure)
916916
}

protocol/localstatequery/error.go

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,10 @@
1414

1515
package localstatequery
1616

17-
// AcquireFailurePointTooOldError indicates a failure to acquire a point due to it being too old
18-
type AcquireFailurePointTooOldError struct {
19-
}
17+
import "errors"
2018

21-
func (e AcquireFailurePointTooOldError) Error() string {
22-
return "acquire failure: point too old"
23-
}
19+
// ErrAcquireFailurePointTooOld indicates a failure to acquire a point due to it being too old
20+
var ErrAcquireFailurePointTooOld = errors.New("acquire failure: point too old")
2421

25-
// AcquireFailurePointNotOnChainError indicates a failure to acquire a point due to it not being present on the chain
26-
type AcquireFailurePointNotOnChainError struct {
27-
}
28-
29-
func (e AcquireFailurePointNotOnChainError) Error() string {
30-
return "acquire failure: point not on chain"
31-
}
22+
// ErrAcquireFailurePointNotOnChain indicates a failure to acquire a point due to it not being present on the chain
23+
var ErrAcquireFailurePointNotOnChain = errors.New("acquire failure: point not on chain")

protocol/localstatequery/localstatequery.go

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,6 @@ type Config struct {
123123
AcquireFunc AcquireFunc
124124
QueryFunc QueryFunc
125125
ReleaseFunc ReleaseFunc
126-
ReAcquireFunc ReAcquireFunc
127-
DoneFunc DoneFunc
128126
AcquireTimeout time.Duration
129127
QueryTimeout time.Duration
130128
}
@@ -156,11 +154,9 @@ type CallbackContext struct {
156154
}
157155

158156
// Callback function types
159-
type AcquireFunc func(CallbackContext, AcquireTarget) error
160-
type QueryFunc func(CallbackContext, any) error
157+
type AcquireFunc func(CallbackContext, AcquireTarget, bool) error
158+
type QueryFunc func(CallbackContext, any) (any, error)
161159
type ReleaseFunc func(CallbackContext) error
162-
type ReAcquireFunc func(CallbackContext, AcquireTarget) error
163-
type DoneFunc func(CallbackContext) error
164160

165161
// New returns a new LocalStateQuery object
166162
func New(protoOptions protocol.ProtocolOptions, cfg *Config) *LocalStateQuery {
@@ -208,20 +204,6 @@ func WithReleaseFunc(releaseFunc ReleaseFunc) LocalStateQueryOptionFunc {
208204
}
209205
}
210206

211-
// WithReAcquireFunc specifies the ReAcquire callback function when acting as a server
212-
func WithReAcquireFunc(reAcquireFunc ReAcquireFunc) LocalStateQueryOptionFunc {
213-
return func(c *Config) {
214-
c.ReAcquireFunc = reAcquireFunc
215-
}
216-
}
217-
218-
// WithDoneFunc specifies the Done callback function when acting as a server
219-
func WithDoneFunc(doneFunc DoneFunc) LocalStateQueryOptionFunc {
220-
return func(c *Config) {
221-
c.DoneFunc = doneFunc
222-
}
223-
}
224-
225207
// WithAcquireTimeout specifies the timeout for the Acquire operation when acting as a client
226208
func WithAcquireTimeout(timeout time.Duration) LocalStateQueryOptionFunc {
227209
return func(c *Config) {

protocol/localstatequery/server.go

Lines changed: 56 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@
1515
package localstatequery
1616

1717
import (
18+
"errors"
1819
"fmt"
1920

21+
"github.com/blinklabs-io/gouroboros/cbor"
2022
"github.com/blinklabs-io/gouroboros/protocol"
2123
)
2224

@@ -116,7 +118,27 @@ func (s *Server) handleAcquire(msg protocol.Message) error {
116118
acquireTarget = AcquireImmutableTip{}
117119
}
118120
// Call the user callback function
119-
return s.config.AcquireFunc(s.callbackContext, acquireTarget)
121+
err := s.config.AcquireFunc(s.callbackContext, acquireTarget, false)
122+
if err != nil {
123+
if errors.Is(err, ErrAcquireFailurePointTooOld) {
124+
respMsg := NewMsgFailure(AcquireFailurePointTooOld)
125+
if err := s.SendMessage(respMsg); err != nil {
126+
return err
127+
}
128+
} else if errors.Is(err, ErrAcquireFailurePointNotOnChain) {
129+
respMsg := NewMsgFailure(AcquireFailurePointNotOnChain)
130+
if err := s.SendMessage(respMsg); err != nil {
131+
return err
132+
}
133+
} else {
134+
return err
135+
}
136+
}
137+
respMsg := NewMsgAcquired()
138+
if err := s.SendMessage(respMsg); err != nil {
139+
return err
140+
}
141+
return nil
120142
}
121143

122144
func (s *Server) handleQuery(msg protocol.Message) error {
@@ -134,7 +156,20 @@ func (s *Server) handleQuery(msg protocol.Message) error {
134156
}
135157
msgQuery := msg.(*MsgQuery)
136158
// Call the user callback function
137-
return s.config.QueryFunc(s.callbackContext, msgQuery.Query)
159+
result, err := s.config.QueryFunc(s.callbackContext, msgQuery.Query)
160+
if err != nil {
161+
return err
162+
}
163+
// Encode query result
164+
resultCbor, err := cbor.Encode(&result)
165+
if err != nil {
166+
return err
167+
}
168+
respMsg := NewMsgResult(resultCbor)
169+
if err := s.SendMessage(respMsg); err != nil {
170+
return err
171+
}
172+
return nil
138173
}
139174

140175
func (s *Server) handleRelease() error {
@@ -162,7 +197,7 @@ func (s *Server) handleReAcquire(msg protocol.Message) error {
162197
"role", "server",
163198
"connection_id", s.callbackContext.ConnectionId.String(),
164199
)
165-
if s.config.ReAcquireFunc == nil {
200+
if s.config.AcquireFunc == nil {
166201
return fmt.Errorf(
167202
"received local-state-query ReAcquire message but no callback function is defined",
168203
)
@@ -179,7 +214,23 @@ func (s *Server) handleReAcquire(msg protocol.Message) error {
179214
acquireTarget = AcquireImmutableTip{}
180215
}
181216
// Call the user callback function
182-
return s.config.ReAcquireFunc(s.callbackContext, acquireTarget)
217+
err := s.config.AcquireFunc(s.callbackContext, acquireTarget, true)
218+
if err != nil {
219+
if errors.Is(err, ErrAcquireFailurePointTooOld) {
220+
respMsg := NewMsgFailure(AcquireFailurePointTooOld)
221+
if err := s.SendMessage(respMsg); err != nil {
222+
return err
223+
}
224+
} else if errors.Is(err, ErrAcquireFailurePointNotOnChain) {
225+
respMsg := NewMsgFailure(AcquireFailurePointNotOnChain)
226+
if err := s.SendMessage(respMsg); err != nil {
227+
return err
228+
}
229+
} else {
230+
return err
231+
}
232+
}
233+
return nil
183234
}
184235

185236
func (s *Server) handleDone() error {
@@ -190,11 +241,5 @@ func (s *Server) handleDone() error {
190241
"role", "server",
191242
"connection_id", s.callbackContext.ConnectionId.String(),
192243
)
193-
if s.config.DoneFunc == nil {
194-
return fmt.Errorf(
195-
"received local-state-query Done message but no callback function is defined",
196-
)
197-
}
198-
// Call the user callback function
199-
return s.config.DoneFunc(s.callbackContext)
244+
return nil
200245
}

0 commit comments

Comments
 (0)