Skip to content

Commit 76da284

Browse files
committed
Delegate session resumption to SessionResumer
Extract full challenge-response and transport restoration into a SessionResumer and delegate handling from Dialer and Server, removing duplicated verification/send logic. Refresh handshake UpdatedAt on load so expiry restarts, switch SessionManager index loading to sync.Once, and add error handling for GetChatHistory queries.
1 parent dbe86c8 commit 76da284

File tree

4 files changed

+243
-343
lines changed

4 files changed

+243
-343
lines changed

dial.go

Lines changed: 4 additions & 157 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,14 @@
11
package kamune
22

33
import (
4-
"crypto/subtle"
54
"fmt"
65
"log/slog"
76
"net"
87
"runtime/debug"
98
"time"
109

1110
"github.com/xtaci/kcp-go/v5"
12-
"google.golang.org/protobuf/proto"
1311

14-
"github.com/kamune-org/kamune/internal/box/pb"
1512
"github.com/kamune-org/kamune/pkg/attest"
1613
"github.com/kamune-org/kamune/pkg/fingerprint"
1714
)
@@ -239,146 +236,22 @@ func (d *Dialer) attemptResumption(state *SessionState) (*Transport, error) {
239236
d.conn = c
240237
}
241238

242-
// Generate a challenge for the server to prove it has the shared secret
243-
challenge := randomBytes(resumeChallengeSize)
244-
245-
// Create and send reconnect request
246-
req := &pb.ReconnectRequest{
247-
SessionId: state.SessionID,
248-
LastPhase: state.Phase.ToProto(),
249-
LastSendSequence: state.SendSequence,
250-
LastRecvSequence: state.RecvSequence,
251-
RemotePublicKey: d.attester.PublicKey().Marshal(),
252-
ResumeChallenge: challenge,
253-
}
254-
if err := d.sendSignedMessage(req, RouteReconnect); err != nil {
255-
return nil, fmt.Errorf("sending reconnect request: %w", err)
256-
}
257-
258-
// Receive response
259-
respPayload, err := d.conn.ReadBytes()
260-
if err != nil {
261-
return nil, fmt.Errorf("reading reconnect response: %w", err)
262-
}
263-
264-
var respST pb.SignedTransport
265-
if err := proto.Unmarshal(respPayload, &respST); err != nil {
266-
return nil, fmt.Errorf("unmarshaling response transport: %w", err)
267-
}
268-
269-
// Verify signature from the server
270-
remoteKey, err := d.storage.algorithm.Identitfier().ParsePublicKey(
271-
state.RemotePublicKey,
272-
)
273-
if err != nil {
274-
return nil, fmt.Errorf("parsing remote public key: %w", err)
275-
}
276-
277-
if !d.storage.algorithm.Identitfier().Verify(
278-
remoteKey, respST.Data, respST.Signature,
279-
) {
280-
return nil, ErrInvalidSignature
281-
}
282-
283-
var resp pb.ReconnectResponse
284-
if err := proto.Unmarshal(respST.Data, &resp); err != nil {
285-
return nil, fmt.Errorf("unmarshaling response: %w", err)
286-
}
287-
288-
if !resp.Accepted {
289-
return nil, fmt.Errorf("%w: %s", ErrResumptionFailed, resp.ErrorMessage)
290-
}
291-
292-
// Verify the server's challenge response
239+
// Delegate the entire challenge-response protocol to SessionResumer.
293240
resumer := NewSessionResumer(
294241
d.storage,
295242
d.sessionManager,
296243
d.attester,
297244
d.resumptionConfig.MaxSessionAge,
298245
)
299-
expectedResponse, err := resumer.computeChallengeResponse(
300-
challenge, state.SharedSecret,
301-
)
302-
if err != nil {
303-
return nil, fmt.Errorf("compute expected challenge response: %w", err)
304-
}
305-
if subtle.ConstantTimeCompare(resp.ChallengeResponse, expectedResponse) != 1 {
306-
return nil, ErrChallengeVerifyFailed
307-
}
308-
309-
// Compute our response to the server's challenge
310-
clientChallengeResponse, err := resumer.computeChallengeResponse(
311-
resp.ServerChallenge, state.SharedSecret,
312-
)
313-
if err != nil {
314-
return nil, fmt.Errorf("compute client challenge response: %w", err)
315-
}
316-
317-
// Determine the sequence numbers to use
318-
resumeSendSeq, resumeRecvSeq := resumer.reconcileSequences(
319-
state.SendSequence,
320-
state.RecvSequence,
321-
resp.ServerRecvSequence,
322-
resp.ServerSendSequence,
323-
)
324246

325-
// Send verification
326-
verify := &pb.ReconnectVerify{
327-
ChallengeResponse: clientChallengeResponse,
328-
Verified: true,
329-
}
330-
if err := d.sendSignedMessage(verify, RouteReconnect); err != nil {
331-
return nil, fmt.Errorf("sending verification: %w", err)
332-
}
333-
334-
// Receive completion
335-
completePayload, err := d.conn.ReadBytes()
247+
transport, err := resumer.InitiateResumption(d.conn, state)
336248
if err != nil {
337-
return nil, fmt.Errorf("reading completion: %w", err)
338-
}
339-
340-
var completeST pb.SignedTransport
341-
if err := proto.Unmarshal(completePayload, &completeST); err != nil {
342-
return nil, fmt.Errorf("unmarshaling completion transport: %w", err)
343-
}
344-
345-
if !d.storage.algorithm.Identitfier().Verify(
346-
remoteKey, completeST.Data, completeST.Signature,
347-
) {
348-
return nil, ErrInvalidSignature
349-
}
350-
351-
var complete pb.ReconnectComplete
352-
if err := proto.Unmarshal(completeST.Data, &complete); err != nil {
353-
return nil, fmt.Errorf("unmarshaling completion: %w", err)
354-
}
355-
356-
if !complete.Success {
357-
return nil, fmt.Errorf(
358-
"%w: %s", ErrResumptionFailed, complete.ErrorMessage,
359-
)
360-
}
361-
362-
// Use the agreed-upon sequence numbers
363-
if complete.ResumeSendSequence > 0 {
364-
resumeSendSeq = complete.ResumeSendSequence
365-
}
366-
if complete.ResumeRecvSequence > 0 {
367-
resumeRecvSeq = complete.ResumeRecvSequence
368-
}
369-
370-
// Restore the transport
371-
transport, err := resumer.restoreTransport(
372-
d.conn, state, resumeSendSeq, resumeRecvSeq,
373-
)
374-
if err != nil {
375-
return nil, fmt.Errorf("restoring transport: %w", err)
249+
return nil, fmt.Errorf("initiating resumption: %w", err)
376250
}
377251

378252
// Update session state
379253
if d.resumptionConfig.PersistSessions {
380-
err := SaveSessionForResumption(transport, d.sessionManager)
381-
if err != nil {
254+
if err := SaveSessionForResumption(transport, d.sessionManager); err != nil {
382255
slog.Warn(
383256
"failed to update session after resumption",
384257
slog.Any("error", err),
@@ -389,32 +262,6 @@ func (d *Dialer) attemptResumption(state *SessionState) (*Transport, error) {
389262
return transport, nil
390263
}
391264

392-
func (d *Dialer) sendSignedMessage(msg Transferable, route Route) error {
393-
data, err := proto.Marshal(msg)
394-
if err != nil {
395-
return fmt.Errorf("marshaling message: %w", err)
396-
}
397-
398-
sig, err := d.attester.Sign(data)
399-
if err != nil {
400-
return fmt.Errorf("signing message: %w", err)
401-
}
402-
403-
st := &pb.SignedTransport{
404-
Data: data,
405-
Signature: sig,
406-
Padding: padding(maxPadding),
407-
Route: route.ToProto(),
408-
}
409-
410-
payload, err := proto.Marshal(st)
411-
if err != nil {
412-
return fmt.Errorf("marshaling transport: %w", err)
413-
}
414-
415-
return d.conn.WriteBytes(payload)
416-
}
417-
418265
// PublicKey returns the dialer's public key.
419266
func (d *Dialer) PublicKey() PublicKey {
420267
return d.attester.PublicKey()

server.go

Lines changed: 7 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package kamune
22

33
import (
4-
"crypto/subtle"
54
"errors"
65
"fmt"
76
"log/slog"
@@ -213,13 +212,14 @@ func (s *Server) handleReconnection(cn Conn, st *pb.SignedTransport) error {
213212
return s.sendReconnectReject(cn, "session resumption is disabled")
214213
}
215214

216-
// Parse the reconnect request from the signed transport
215+
// Parse the reconnect request from the signed transport.
217216
var req pb.ReconnectRequest
218217
if err := proto.Unmarshal(st.Data, &req); err != nil {
219218
return fmt.Errorf("unmarshaling reconnect request: %w", err)
220219
}
221220

222-
// Verify the signature using the claimed public key
221+
// Verify the signature using the claimed public key before handing off
222+
// to the resumer, which trusts the request is authentic.
223223
remoteKey, err := s.algorithm.Identitfier().ParsePublicKey(
224224
req.RemotePublicKey,
225225
)
@@ -235,134 +235,24 @@ func (s *Server) handleReconnection(cn Conn, st *pb.SignedTransport) error {
235235
slog.String("session_id", req.SessionId),
236236
)
237237

238-
// Create session resumer and handle the resumption
238+
// Delegate the entire challenge-response protocol to SessionResumer.
239239
resumer := NewSessionResumer(
240240
s.storage,
241241
s.sessionManager,
242242
s.attester,
243243
s.resumptionConfig.MaxSessionAge,
244244
)
245245

246-
// Look up the session
247-
state, err := s.sessionManager.LoadSessionByPublicKey(req.RemotePublicKey)
246+
t, err := resumer.HandleResumption(cn, &req)
248247
if err != nil {
249-
slog.Warn("session not found for reconnection",
250-
slog.Any("error", err),
251-
)
252-
return s.sendReconnectReject(cn, "session not found")
253-
}
254-
255-
// Verify session ID matches
256-
if state.SessionID != req.SessionId {
257-
return s.sendReconnectReject(cn, "session ID mismatch")
258-
}
259-
260-
// Verify the session is in a resumable state
261-
if state.Phase != PhaseEstablished {
262-
return s.sendReconnectReject(cn, "session not established")
263-
}
264-
265-
// Verify the session is not too old
266-
// Note: We'd need to track creation time in SessionState for this
267-
// For now, just check that we have a valid shared secret
268-
if len(state.SharedSecret) == 0 {
269-
return s.sendReconnectReject(cn, "session state invalid")
270-
}
271-
272-
// Generate server challenge
273-
serverChallenge := randomBytes(resumeChallengeSize)
274-
275-
// Compute response to client's challenge using HMAC
276-
challengeResponse, err := resumer.computeChallengeResponse(
277-
req.ResumeChallenge, state.SharedSecret,
278-
)
279-
if err != nil {
280-
return fmt.Errorf("compute challenge response: %w", err)
281-
}
282-
283-
// Send accept response
284-
resp := &pb.ReconnectResponse{
285-
Accepted: true,
286-
ResumeFromPhase: state.Phase.ToProto(),
287-
ChallengeResponse: challengeResponse,
288-
ServerChallenge: serverChallenge,
289-
ServerSendSequence: state.SendSequence,
290-
ServerRecvSequence: state.RecvSequence,
291-
}
292-
293-
if err := s.sendSignedMessage(cn, resp, RouteReconnect); err != nil {
294-
return fmt.Errorf("sending accept response: %w", err)
295-
}
296-
297-
// Receive client verification
298-
verifyPayload, err := cn.ReadBytes()
299-
if err != nil {
300-
return fmt.Errorf("reading verification: %w", err)
301-
}
302-
303-
var verifyST pb.SignedTransport
304-
if err := proto.Unmarshal(verifyPayload, &verifyST); err != nil {
305-
return fmt.Errorf("unmarshaling verification transport: %w", err)
306-
}
307-
308-
// Verify signature
309-
if !s.algorithm.Identitfier().Verify(
310-
remoteKey, verifyST.Data, verifyST.Signature,
311-
) {
312-
return fmt.Errorf("verification signature invalid")
313-
}
314-
315-
var verify pb.ReconnectVerify
316-
if err := proto.Unmarshal(verifyST.Data, &verify); err != nil {
317-
return fmt.Errorf("unmarshaling verification: %w", err)
318-
}
319-
320-
// Verify client's response to our challenge
321-
expectedClientResponse, err := resumer.computeChallengeResponse(
322-
serverChallenge, state.SharedSecret,
323-
)
324-
if err != nil {
325-
return fmt.Errorf("compute client expected response: %w", err)
326-
}
327-
if len(verify.ChallengeResponse) == 0 ||
328-
subtle.ConstantTimeCompare(verify.ChallengeResponse, expectedClientResponse) != 1 {
329-
complete := &pb.ReconnectComplete{
330-
Success: false,
331-
ErrorMessage: "challenge verification failed",
332-
}
333-
_ = s.sendSignedMessage(cn, complete, RouteReconnect)
334-
return ErrChallengeVerifyFailed
335-
}
336-
337-
// Determine sequence numbers to resume from
338-
resumeSendSeq, resumeRecvSeq := resumer.reconcileSequences(
339-
state.SendSequence, state.RecvSequence,
340-
req.LastRecvSequence, req.LastSendSequence,
341-
)
342-
343-
// Send completion
344-
complete := &pb.ReconnectComplete{
345-
Success: true,
346-
ResumeSendSequence: resumeSendSeq,
347-
ResumeRecvSequence: resumeRecvSeq,
348-
}
349-
if err := s.sendSignedMessage(cn, complete, RouteReconnect); err != nil {
350-
return fmt.Errorf("sending completion: %w", err)
351-
}
352-
353-
// Restore the transport
354-
t, err := resumer.restoreTransport(cn, state, resumeSendSeq, resumeRecvSeq)
355-
if err != nil {
356-
return fmt.Errorf("restoring transport: %w", err)
248+
return fmt.Errorf("handling resumption: %w", err)
357249
}
358250

359251
slog.Info("session resumed",
360252
slog.String("session_id", t.SessionID()),
361-
slog.Uint64("send_seq", resumeSendSeq),
362-
slog.Uint64("recv_seq", resumeRecvSeq),
363253
)
364254

365-
// Update session state
255+
// Update session state.
366256
if s.resumptionConfig.PersistSessions {
367257
if err := SaveSessionForResumption(t, s.sessionManager); err != nil {
368258
slog.Warn("failed to update session after resumption",

0 commit comments

Comments
 (0)