|
| 1 | +package v3 |
| 2 | + |
| 3 | +import ( |
| 4 | + "errors" |
| 5 | + "net" |
| 6 | + "net/netip" |
| 7 | + "sync" |
| 8 | + |
| 9 | + "github.com/rs/zerolog" |
| 10 | + |
| 11 | + "github.com/cloudflare/cloudflared/ingress" |
| 12 | +) |
| 13 | + |
| 14 | +var ( |
| 15 | + ErrSessionNotFound = errors.New("session not found") |
| 16 | + ErrSessionBoundToOtherConn = errors.New("session is in use by another connection") |
| 17 | +) |
| 18 | + |
| 19 | +type SessionManager interface { |
| 20 | + // RegisterSession will register a new session if it does not already exist for the request ID. |
| 21 | + // During new session creation, the session will also bind the UDP socket for the origin. |
| 22 | + // If the session exists for a different connection, it will return [ErrSessionBoundToOtherConn]. |
| 23 | + RegisterSession(request *UDPSessionRegistrationDatagram, conn DatagramWriter) (Session, error) |
| 24 | + // GetSession returns an active session if available for the provided connection. |
| 25 | + // If the session does not exist, it will return [ErrSessionNotFound]. If the session exists for a different |
| 26 | + // connection, it will return [ErrSessionBoundToOtherConn]. |
| 27 | + GetSession(requestID RequestID) (Session, error) |
| 28 | + // UnregisterSession will remove a session from the current session manager. It will attempt to close the session |
| 29 | + // before removal. |
| 30 | + UnregisterSession(requestID RequestID) |
| 31 | +} |
| 32 | + |
| 33 | +type DialUDP func(dest netip.AddrPort) (*net.UDPConn, error) |
| 34 | + |
| 35 | +type sessionManager struct { |
| 36 | + sessions map[RequestID]Session |
| 37 | + mutex sync.RWMutex |
| 38 | + log *zerolog.Logger |
| 39 | +} |
| 40 | + |
| 41 | +func NewSessionManager(log *zerolog.Logger, originDialer DialUDP) SessionManager { |
| 42 | + return &sessionManager{ |
| 43 | + sessions: make(map[RequestID]Session), |
| 44 | + log: log, |
| 45 | + } |
| 46 | +} |
| 47 | + |
| 48 | +func (s *sessionManager) RegisterSession(request *UDPSessionRegistrationDatagram, conn DatagramWriter) (Session, error) { |
| 49 | + s.mutex.Lock() |
| 50 | + defer s.mutex.Unlock() |
| 51 | + // Check to make sure session doesn't already exist for requestID |
| 52 | + _, exists := s.sessions[request.RequestID] |
| 53 | + if exists { |
| 54 | + return nil, ErrSessionBoundToOtherConn |
| 55 | + } |
| 56 | + // Attempt to bind the UDP socket for the new session |
| 57 | + origin, err := ingress.DialUDPAddrPort(request.Dest) |
| 58 | + if err != nil { |
| 59 | + return nil, err |
| 60 | + } |
| 61 | + // Create and insert the new session in the map |
| 62 | + session := NewSession(request.RequestID, request.IdleDurationHint, origin, conn, s.log) |
| 63 | + s.sessions[request.RequestID] = session |
| 64 | + return session, nil |
| 65 | +} |
| 66 | + |
| 67 | +func (s *sessionManager) GetSession(requestID RequestID) (Session, error) { |
| 68 | + s.mutex.RLock() |
| 69 | + defer s.mutex.RUnlock() |
| 70 | + session, exists := s.sessions[requestID] |
| 71 | + if exists { |
| 72 | + return session, nil |
| 73 | + } |
| 74 | + return nil, ErrSessionNotFound |
| 75 | +} |
| 76 | + |
| 77 | +func (s *sessionManager) UnregisterSession(requestID RequestID) { |
| 78 | + s.mutex.Lock() |
| 79 | + defer s.mutex.Unlock() |
| 80 | + // Get the session and make sure to close it if it isn't already closed |
| 81 | + session, exists := s.sessions[requestID] |
| 82 | + if exists { |
| 83 | + // We ignore any errors when attempting to close the session |
| 84 | + _ = session.Close() |
| 85 | + } |
| 86 | + delete(s.sessions, requestID) |
| 87 | +} |
0 commit comments