15
15
package localstatequery
16
16
17
17
import (
18
+ "errors"
18
19
"fmt"
19
20
21
+ "github.com/blinklabs-io/gouroboros/cbor"
20
22
"github.com/blinklabs-io/gouroboros/protocol"
21
23
)
22
24
@@ -116,7 +118,27 @@ func (s *Server) handleAcquire(msg protocol.Message) error {
116
118
acquireTarget = AcquireImmutableTip {}
117
119
}
118
120
// Call the user callback function
119
- return s .config .AcquireFunc (s .callbackContext , acquireTarget )
121
+ err := s .config .AcquireFunc (s .callbackContext , acquireTarget , false )
122
+ if err != nil {
123
+ if errors .Is (err , ErrAcquireFailurePointTooOld ) {
124
+ respMsg := NewMsgFailure (AcquireFailurePointTooOld )
125
+ if err := s .SendMessage (respMsg ); err != nil {
126
+ return err
127
+ }
128
+ } else if errors .Is (err , ErrAcquireFailurePointNotOnChain ) {
129
+ respMsg := NewMsgFailure (AcquireFailurePointNotOnChain )
130
+ if err := s .SendMessage (respMsg ); err != nil {
131
+ return err
132
+ }
133
+ } else {
134
+ return err
135
+ }
136
+ }
137
+ respMsg := NewMsgAcquired ()
138
+ if err := s .SendMessage (respMsg ); err != nil {
139
+ return err
140
+ }
141
+ return nil
120
142
}
121
143
122
144
func (s * Server ) handleQuery (msg protocol.Message ) error {
@@ -134,7 +156,20 @@ func (s *Server) handleQuery(msg protocol.Message) error {
134
156
}
135
157
msgQuery := msg .(* MsgQuery )
136
158
// Call the user callback function
137
- return s .config .QueryFunc (s .callbackContext , msgQuery .Query )
159
+ result , err := s .config .QueryFunc (s .callbackContext , msgQuery .Query )
160
+ if err != nil {
161
+ return err
162
+ }
163
+ // Encode query result
164
+ resultCbor , err := cbor .Encode (& result )
165
+ if err != nil {
166
+ return err
167
+ }
168
+ respMsg := NewMsgResult (resultCbor )
169
+ if err := s .SendMessage (respMsg ); err != nil {
170
+ return err
171
+ }
172
+ return nil
138
173
}
139
174
140
175
func (s * Server ) handleRelease () error {
@@ -162,7 +197,7 @@ func (s *Server) handleReAcquire(msg protocol.Message) error {
162
197
"role" , "server" ,
163
198
"connection_id" , s .callbackContext .ConnectionId .String (),
164
199
)
165
- if s .config .ReAcquireFunc == nil {
200
+ if s .config .AcquireFunc == nil {
166
201
return fmt .Errorf (
167
202
"received local-state-query ReAcquire message but no callback function is defined" ,
168
203
)
@@ -179,7 +214,23 @@ func (s *Server) handleReAcquire(msg protocol.Message) error {
179
214
acquireTarget = AcquireImmutableTip {}
180
215
}
181
216
// Call the user callback function
182
- return s .config .ReAcquireFunc (s .callbackContext , acquireTarget )
217
+ err := s .config .AcquireFunc (s .callbackContext , acquireTarget , true )
218
+ if err != nil {
219
+ if errors .Is (err , ErrAcquireFailurePointTooOld ) {
220
+ respMsg := NewMsgFailure (AcquireFailurePointTooOld )
221
+ if err := s .SendMessage (respMsg ); err != nil {
222
+ return err
223
+ }
224
+ } else if errors .Is (err , ErrAcquireFailurePointNotOnChain ) {
225
+ respMsg := NewMsgFailure (AcquireFailurePointNotOnChain )
226
+ if err := s .SendMessage (respMsg ); err != nil {
227
+ return err
228
+ }
229
+ } else {
230
+ return err
231
+ }
232
+ }
233
+ return nil
183
234
}
184
235
185
236
func (s * Server ) handleDone () error {
@@ -190,11 +241,5 @@ func (s *Server) handleDone() error {
190
241
"role" , "server" ,
191
242
"connection_id" , s .callbackContext .ConnectionId .String (),
192
243
)
193
- if s .config .DoneFunc == nil {
194
- return fmt .Errorf (
195
- "received local-state-query Done message but no callback function is defined" ,
196
- )
197
- }
198
- // Call the user callback function
199
- return s .config .DoneFunc (s .callbackContext )
244
+ return nil
200
245
}
0 commit comments