-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathauth.go
More file actions
135 lines (118 loc) · 3.3 KB
/
auth.go
File metadata and controls
135 lines (118 loc) · 3.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
package main
import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"encoding/json"
"log"
"net/http"
"strings"
"time"
)
// --- Middleware ---
func authMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
enableCors(&w)
if r.Method == "OPTIONS" {
return
}
authHeader := r.Header.Get("Authorization")
if !strings.HasPrefix(authHeader, "Bearer ") {
http.Error(w, "Unauthorized", 401)
return
}
token := strings.TrimPrefix(authHeader, "Bearer ")
var user User
err := db.QueryRow(`SELECT u.id, u.username, u.role, u.email FROM sessions s JOIN users u ON s.user_id = u.id WHERE s.token = ? AND s.expires_at > NOW()`, token).Scan(&user.ID, &user.Username, &user.Role, &user.Email)
if err != nil {
http.Error(w, "Invalid Token", 401)
return
}
ctx := context.WithValue(r.Context(), UserKey, user)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
func adminMiddleware(next http.Handler) http.Handler {
return authMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
user := r.Context().Value(UserKey).(User)
if user.Role != "admin" {
http.Error(w, "Forbidden", 403)
return
}
next.ServeHTTP(w, r)
}))
}
// --- Handlers ---
func handlePublicKey(w http.ResponseWriter, r *http.Request) {
enableCors(&w)
if r.Method == "GET" {
json.NewEncoder(w).Encode(map[string]string{"publicKey": string(PublicKey)})
}
}
func handleLogin(w http.ResponseWriter, r *http.Request) {
enableCors(&w)
if r.Method == "OPTIONS" {
return
}
if r.Method != "POST" {
http.Error(w, "Method not allowed", 405)
return
}
var req struct {
Payload string `json:"payload"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Bad Request", 400)
return
}
// 1. Decode Base64
cipherText, err := base64.StdEncoding.DecodeString(req.Payload)
if err != nil {
http.Error(w, "Invalid Encoding", 400)
return
}
// 2. Decrypt with Private Key
plainText, err := rsa.DecryptPKCS1v15(rand.Reader, PrivateKey, cipherText)
if err != nil {
log.Println("Decryption failed:", err)
http.Error(w, "Decryption Failed", 400)
return
}
// 3. Parse Inner JSON
var creds struct {
Username string
Password string
}
if err := json.Unmarshal(plainText, &creds); err != nil {
http.Error(w, "Invalid Credential Format", 400)
return
}
// 4. Validate DB
hash := sha256.Sum256([]byte(creds.Password))
passHash := hex.EncodeToString(hash[:])
var user User
err = db.QueryRow("SELECT id, username, role, email FROM users WHERE username=? AND password_hash=?", creds.Username, passHash).Scan(&user.ID, &user.Username, &user.Role, &user.Email)
if err != nil {
http.Error(w, "Invalid credentials", 401)
return
}
// 5. Session
b := make([]byte, 32)
rand.Read(b)
token := hex.EncodeToString(b)
db.Exec("INSERT INTO sessions (token, user_id, expires_at) VALUES (?, ?, DATE_ADD(NOW(), INTERVAL 24 HOUR))", token, user.ID)
json.NewEncoder(w).Encode(map[string]interface{}{"token": token, "user": user})
}
func handleMe(w http.ResponseWriter, r *http.Request) {
user := r.Context().Value(UserKey).(User)
json.NewEncoder(w).Encode(user)
}
func cleanExpiredSessions() {
for {
db.Exec("DELETE FROM sessions WHERE expires_at < NOW()")
time.Sleep(1 * time.Hour)
}
}