-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathauth.go
More file actions
94 lines (84 loc) · 2.16 KB
/
auth.go
File metadata and controls
94 lines (84 loc) · 2.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
package main
import (
"embed"
"errors"
"fmt"
"net"
"net/http"
"os"
"strings"
"github.com/gin-gonic/gin"
"gopkg.in/yaml.v3"
)
type AutoConfig struct {
AuthorizedIPs []string `yaml:"AuthorizedIPs"`
AuthorizationHeader string `yaml:"AuthorizationHeader"`
}
var AuthIPRanges []*net.IPNet
var AuthorizationHeader string
//go:embed auth.yaml
var defaultAuthFile embed.FS
// 释放嵌入的 auth.yaml 文件
func releaseDefaultConfig(filename string) error {
data, err := defaultAuthFile.ReadFile("auth.yaml")
if err != nil {
return fmt.Errorf("failed to read embedded file: %w", err)
}
return os.WriteFile(filename, data, 0644)
}
func LoadConfig(filename string) error {
var config AutoConfig
// 检查文件是否存在
if _, err := os.Stat(filename); errors.Is(err, os.ErrNotExist) {
fmt.Printf("File %s does not exist, releasing embedded auth.yaml file...\n", filename)
if err := releaseDefaultConfig(filename); err != nil {
return fmt.Errorf("failed to release embedded file: %w", err)
}
}
// 读取文件内容
data, err := os.ReadFile(filename)
if err != nil {
return err
}
// 解析YAML内容到结构体中
err = yaml.Unmarshal(data, &config)
if err != nil {
return err
}
AuthIPRanges = make([]*net.IPNet, len(config.AuthorizedIPs))
for i, ipOrCIDR := range config.AuthorizedIPs {
if !strings.Contains(ipOrCIDR, "/") {
// 如果它是一个单独的IP地址,自动添加/32
ipOrCIDR = ipOrCIDR + "/32"
}
_, ipNet, err := net.ParseCIDR(ipOrCIDR)
if err != nil {
return err
}
AuthIPRanges[i] = ipNet
}
AuthorizationHeader = config.AuthorizationHeader
return nil
}
func IPAndAuthorizationMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
clientIP := net.ParseIP(c.ClientIP())
authorized := false
for _, ipRange := range AuthIPRanges {
if ipRange.Contains(clientIP) {
authorized = true
break
}
}
if !authorized {
// 如果IP不在授权列表中,检查授权码
authorizationCode := c.GetHeader("Authorization")
if authorizationCode != AuthorizationHeader {
c.JSON(http.StatusForbidden, gin.H{"status": "unauthorized"})
c.Abort()
return
}
}
c.Next()
}
}