@@ -4,9 +4,12 @@ import (
44 "context"
55 "crypto/rand"
66 "encoding/json"
7+ "errors"
78 "fmt"
9+ "math"
810 "math/big"
911 "net/http"
12+ "strconv"
1013 "sync"
1114 "time"
1215
@@ -29,6 +32,9 @@ const (
2932 codeDigits = 6
3033 codeMinValue = 100000
3134 codeRange = 900000
35+ maxCodeValue = codeMinValue + codeRange - 1
36+ codeMinValueUint32 = uint32 (codeMinValue )
37+ zeroCodeValue = uint32 (0 )
3238 mfaTokenPrefix = "mfa-verified"
3339 readHeaderTimeout = 10 * time .Second
3440 readTimeout = 30 * time .Second
@@ -42,6 +48,10 @@ const (
4248 fieldMessage = "message"
4349 fieldToken = "token"
4450 fieldAttempts = "attempts"
51+ fieldSessionID = "session_id"
52+ fieldDeviceID = "device_id"
53+ fieldUserEmail = "user_email"
54+ fieldExpiresAt = "expires_at"
4555 // Status values
4656 statusOK = "ok"
4757 statusVerified = "verified"
@@ -51,30 +61,39 @@ const (
5161 paramSessionID = "sessionID"
5262)
5363
64+ var errMFACodeOverflow = errors .New ("generated MFA code overflow" )
65+
5466// Server implements a basic MFA service for step-up authentication
5567type Server struct {
5668 sessions map [string ]* session
57- mu sync.RWMutex
5869 cfg Config
70+ mu sync.RWMutex
5971}
6072
6173// Config holds MFA service configuration
6274type Config struct {
6375 Addr string
64- SessionTimeout time.Duration
6576 CodeLength int
77+ SessionTimeout time.Duration
6678}
6779
6880// session represents an active MFA challenge session
6981type session struct {
70- SessionID string `json:"session_id"`
71- DeviceID string `json:"device_id"`
72- UserEmail string `json:"user_email"`
73- Challenge string `json:"challenge"`
74- Code string `json:"-"`
75- ExpiresAt time.Time `json:"expires_at"`
76- Attempts int `json:"attempts"`
77- MaxAttempts int `json:"max_attempts"`
82+ SessionID string `json:"session_id"`
83+ DeviceID string `json:"device_id"`
84+ UserEmail string `json:"user_email"`
85+ Code uint32 `json:"-"`
86+ Attempts uint8 `json:"attempts"`
87+ MaxAttempts uint8 `json:"max_attempts"`
88+ ExpiresAt int64 `json:"expires_at"`
89+ }
90+
91+ func (* session ) challengeMessage () string {
92+ return fmt .Sprintf (challengeTemplate , codeDigits )
93+ }
94+
95+ func (s * session ) expiresAtTime () time.Time {
96+ return time .Unix (s .ExpiresAt , 0 ).UTC ()
7897}
7998
8099// ChallengeRequest represents MFA challenge request
@@ -164,11 +183,10 @@ func (s *Server) challengeHandler(w http.ResponseWriter, r *http.Request) {
164183 SessionID : req .SessionID ,
165184 DeviceID : req .DeviceID ,
166185 UserEmail : req .UserEmail ,
167- Challenge : fmt .Sprintf (challengeTemplate , codeDigits ),
168186 Code : code ,
169- ExpiresAt : time . Now (). Add ( s . cfg . SessionTimeout ),
170- Attempts : initialAttemptCount ,
171- MaxAttempts : defaultMaxAttempts ,
187+ Attempts : uint8 ( initialAttemptCount ),
188+ MaxAttempts : uint8 ( defaultMaxAttempts ) ,
189+ ExpiresAt : time . Now (). Add ( s . cfg . SessionTimeout ). Unix () ,
172190 }
173191
174192 s .mu .Lock ()
@@ -180,15 +198,15 @@ func (s *Server) challengeHandler(w http.ResponseWriter, r *http.Request) {
180198 Str ("session_id" , req .SessionID ).
181199 Str ("device_id" , req .DeviceID ).
182200 Str ("user_email" , req .UserEmail ).
183- Str ("mfa_code" , code ).
201+ Str ("mfa_code" , formatMFACode ( code ) ).
184202 Msg ("MFA challenge created" )
185203
186204 w .Header ().Set (headerContentType , contentTypeJSON )
187205 response := map [string ]interface {}{
188- "session_id" : session .SessionID ,
189- fieldChallenge : session .Challenge ,
190- "expires_at" : session .ExpiresAt ,
191- fieldCode : code , // Only for PoC testing
206+ fieldSessionID : session .SessionID ,
207+ fieldChallenge : session .challengeMessage () ,
208+ fieldExpiresAt : session .expiresAtTime () ,
209+ fieldCode : formatMFACode ( code ) , // Only for PoC testing
192210 }
193211 if err := json .NewEncoder (w ).Encode (response ); err != nil {
194212 log .Error ().Err (err ).Msg ("failed to encode MFA challenge response" )
@@ -212,7 +230,7 @@ func (s *Server) verifyHandler(w http.ResponseWriter, r *http.Request) {
212230 }
213231
214232 // Check expiration
215- if time .Now ().After ( session .ExpiresAt ) {
233+ if time .Now ().Unix () > session .ExpiresAt {
216234 delete (s .sessions , req .SessionID )
217235 s .mu .Unlock ()
218236 http .Error (w , "session expired" , http .StatusUnauthorized )
@@ -229,12 +247,12 @@ func (s *Server) verifyHandler(w http.ResponseWriter, r *http.Request) {
229247
230248 session .Attempts ++
231249
232- // Verify code
233- if session .Code != req . Code {
250+ codeValue , err := parseMFACode ( req . Code )
251+ if err != nil || session .Code != codeValue {
234252 s .mu .Unlock ()
235253 log .Warn ().
236254 Str ("session_id" , req .SessionID ).
237- Int ("attempts" , session .Attempts ).
255+ Int ("attempts" , int ( session .Attempts ) ).
238256 Msg ("Invalid MFA code attempt" )
239257 http .Error (w , "invalid code" , http .StatusUnauthorized )
240258 return
@@ -274,13 +292,13 @@ func (s *Server) statusHandler(w http.ResponseWriter, r *http.Request) {
274292 return
275293 }
276294
277- if time .Now ().After ( session .ExpiresAt ) {
295+ if time .Now ().Unix () > session .ExpiresAt {
278296 http .Error (w , "session expired" , http .StatusUnauthorized )
279297 return
280298 }
281299
282300 w .Header ().Set (headerContentType , contentTypeJSON )
283- if err := json .NewEncoder (w ).Encode (session ); err != nil {
301+ if err := json .NewEncoder (w ).Encode (serializeSession ( session ) ); err != nil {
284302 log .Error ().Err (err ).Msg ("failed to encode MFA status response" )
285303 }
286304}
@@ -302,15 +320,47 @@ func (s *Server) healthHandler(w http.ResponseWriter, _ *http.Request) {
302320 }
303321}
304322
323+ func serializeSession (session * session ) map [string ]interface {} {
324+ return map [string ]interface {}{
325+ fieldSessionID : session .SessionID ,
326+ fieldDeviceID : session .DeviceID ,
327+ fieldUserEmail : session .UserEmail ,
328+ fieldChallenge : session .challengeMessage (),
329+ fieldExpiresAt : session .expiresAtTime (),
330+ fieldAttempts : session .Attempts ,
331+ "max_attempts" : session .MaxAttempts ,
332+ }
333+ }
334+
335+ func formatMFACode (code uint32 ) string {
336+ return fmt .Sprintf ("%0*d" , codeDigits , code )
337+ }
338+
339+ func parseMFACode (raw string ) (uint32 , error ) {
340+ code , err := strconv .Atoi (raw )
341+ if err != nil {
342+ return zeroCodeValue , fmt .Errorf ("invalid MFA code format: %w" , err )
343+ }
344+ if code < codeMinValue || code > maxCodeValue {
345+ return zeroCodeValue , fmt .Errorf ("MFA code out of range" )
346+ }
347+ return uint32 (code ), nil
348+ }
349+
305350// generateMFACode generates a random numeric code
306- func generateMFACode () (string , error ) {
351+ func generateMFACode () (uint32 , error ) {
307352 max := big .NewInt (codeRange )
308353 n , err := rand .Int (rand .Reader , max )
309354 if err != nil {
310- return "" , fmt .Errorf ("failed to generate random MFA code: %w" , err )
355+ return zeroCodeValue , fmt .Errorf ("failed to generate random MFA code: %w" , err )
356+ }
357+
358+ value := n .Uint64 ()
359+ if value > uint64 (math .MaxUint32 )- uint64 (codeMinValueUint32 ) {
360+ return zeroCodeValue , errMFACodeOverflow
311361 }
312362
313- return fmt . Sprintf ( "%0*d" , codeDigits , n . Int64 () + codeMinValue ) , nil
363+ return uint32 ( value ) + codeMinValueUint32 , nil
314364}
315365
316366// generateMFAToken creates a short-lived verification token
@@ -327,9 +377,9 @@ func (s *Server) cleanupExpiredSessions(ctx context.Context) {
327377 select {
328378 case <- ticker .C :
329379 s .mu .Lock ()
330- now := time .Now ()
380+ now := time .Now (). Unix ()
331381 for sessionID , session := range s .sessions {
332- if now . After ( session .ExpiresAt ) {
382+ if now > session .ExpiresAt {
333383 delete (s .sessions , sessionID )
334384 }
335385 }
0 commit comments