Skip to content

Commit e4eddc0

Browse files
committed
feat: v2重定向链路(按Host解析域名、修复多域名code冲突、缓存与访问日志)
1 parent efc9a05 commit e4eddc0

File tree

7 files changed

+360
-5
lines changed

7 files changed

+360
-5
lines changed

internal/app/app.go

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,6 @@ func Run() error {
100100
})
101101
})
102102

103-
// 重定向路由(不需要认证)
104-
router.GET("/:code", linkHandler.RedirectLink)
105-
106103
// Web UI 路由
107104
router.GET("/login", func(c *gin.Context) {
108105
c.HTML(200, "login.html", gin.H{"title": "登录 - 短链接管理系统"})
@@ -157,11 +154,19 @@ func Run() error {
157154
}
158155

159156
// 挂载重写版 v2 路由(增量迁移,不影响 v1)
157+
v2Enabled := false
160158
if v2, err := httpv2.New(); err != nil {
161159
utils.LogWarn("v2模块初始化失败(已忽略,不影响v1): %v", err)
162160
} else {
163161
defer v2.Close()
164162
httpv2.RegisterRoutes(router, v2)
163+
v2Enabled = true
164+
}
165+
166+
// 重定向路由(不需要认证)
167+
// 优先使用重写版(已修复多域名 code 冲突风险),否则回退 legacy
168+
if !v2Enabled {
169+
router.GET("/:code", linkHandler.RedirectLink)
165170
}
166171

167172
// 启动服务器
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
/**
2+
* v2 Redirect Handler(重写版)
3+
* - GET /:code
4+
* 使用 pgxpool 解析 code(按 Host 匹配 domain),并写入点击/访问日志
5+
*/
6+
package handlers
7+
8+
import (
9+
"context"
10+
"net/http"
11+
"time"
12+
13+
"short-link/internal/repo"
14+
"short-link/internal/service"
15+
16+
"github.com/gin-gonic/gin"
17+
)
18+
19+
// RedirectHandler v2 重定向处理器
20+
type RedirectHandler struct {
21+
linkService *service.LinkService
22+
}
23+
24+
// NewRedirectHandler 创建 RedirectHandler
25+
func NewRedirectHandler(linkService *service.LinkService) *RedirectHandler {
26+
return &RedirectHandler{linkService: linkService}
27+
}
28+
29+
// Redirect 执行 302 跳转
30+
func (h *RedirectHandler) Redirect(c *gin.Context) {
31+
code := c.Param("code")
32+
33+
ctx, cancel := context.WithTimeout(c.Request.Context(), 5*time.Second)
34+
defer cancel()
35+
36+
url, err := h.linkService.RedirectLink(
37+
ctx,
38+
c.Request.Host,
39+
code,
40+
c.ClientIP(),
41+
c.GetHeader("User-Agent"),
42+
c.GetHeader("Referer"),
43+
)
44+
if err != nil {
45+
if err == repo.ErrNotFound {
46+
c.JSON(http.StatusNotFound, gin.H{"error": "链接不存在"})
47+
return
48+
}
49+
c.JSON(http.StatusInternalServerError, gin.H{"error": "重定向失败: " + err.Error()})
50+
return
51+
}
52+
c.Redirect(http.StatusFound, url)
53+
}
54+
55+

internal/httpv2/router.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,12 @@ type Module struct {
2929
DomainRepo *repo.DomainRepo
3030
SettingsRepo *repo.SettingsRepo
3131
LinkRepo *repo.LinkRepo
32+
AccessLogRepo *repo.AccessLogRepo
3233
UserService *service.UserService
3334
LinkService *service.LinkService
3435
AuthHandler *handlers.AuthHandler
3536
LinkHandler *handlers.LinkHandler
37+
RedirectHandler *handlers.RedirectHandler
3638
}
3739

3840
// New 创建 v2 模块
@@ -53,12 +55,14 @@ func New() (*Module, error) {
5355
domainRepo := repo.NewDomainRepo(pool)
5456
settingsRepo := repo.NewSettingsRepo(pool)
5557
linkRepo := repo.NewLinkRepo(pool)
58+
accessLogRepo := repo.NewAccessLogRepo(pool)
5659

5760
userService := service.NewUserService(userRepo)
58-
linkService := service.NewLinkService(cfg.BaseURL, cfg.MinCodeLength, cfg.MaxCodeLength, linkRepo, domainRepo, settingsRepo, userRepo)
61+
linkService := service.NewLinkService(cfg.BaseURL, cfg.MinCodeLength, cfg.MaxCodeLength, linkRepo, domainRepo, settingsRepo, userRepo, accessLogRepo)
5962

6063
authHandler := handlers.NewAuthHandler(cfg, userService)
6164
linkHandler := handlers.NewLinkHandler(cfg, linkService, linkRepo, domainRepo)
65+
redirectHandler := handlers.NewRedirectHandler(linkService)
6266

6367
return &Module{
6468
Cfg: cfg,
@@ -67,10 +71,12 @@ func New() (*Module, error) {
6771
DomainRepo: domainRepo,
6872
SettingsRepo: settingsRepo,
6973
LinkRepo: linkRepo,
74+
AccessLogRepo: accessLogRepo,
7075
UserService: userService,
7176
LinkService: linkService,
7277
AuthHandler: authHandler,
7378
LinkHandler: linkHandler,
79+
RedirectHandler: redirectHandler,
7480
}, nil
7581
}
7682

@@ -85,6 +91,9 @@ func (m *Module) Close() {
8591
func RegisterRoutes(router *gin.Engine, m *Module) {
8692
utils.LogInfo("挂载重写版路由:/api/v2")
8793

94+
// 重写版 redirect(替换 legacy 的任意域名查询,修复多域名 code 冲突风险)
95+
router.GET("/:code", m.RedirectHandler.Redirect)
96+
8897
api := router.Group("/api/v2")
8998
{
9099
authGroup := api.Group("/auth")

internal/repo/access_log_repo.go

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
/**
2+
* AccessLog Repo(重写版)
3+
* - 负责 access_logs 表写入(pgxpool)
4+
*/
5+
package repo
6+
7+
import (
8+
"context"
9+
"fmt"
10+
"short-link/internal/db"
11+
"short-link/models"
12+
)
13+
14+
// AccessLogRepo 访问日志仓储
15+
type AccessLogRepo struct {
16+
pool *db.Pool
17+
}
18+
19+
// NewAccessLogRepo 创建 AccessLogRepo
20+
func NewAccessLogRepo(pool *db.Pool) *AccessLogRepo {
21+
return &AccessLogRepo{pool: pool}
22+
}
23+
24+
// CreateAccessLog 写入访问日志
25+
func (r *AccessLogRepo) CreateAccessLog(ctx context.Context, log *models.AccessLog) error {
26+
query := `INSERT INTO access_logs (link_id, ip, user_agent, referer, created_at) VALUES ($1, $2, $3, $4, $5) RETURNING id`
27+
if err := r.pool.QueryRow(ctx, query, log.LinkID, log.IP, log.UserAgent, log.Referer, log.CreatedAt).Scan(&log.ID); err != nil {
28+
return fmt.Errorf("create access log failed: %w", err)
29+
}
30+
return nil
31+
}
32+
33+

internal/repo/domain_repo.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,26 @@ func NewDomainRepo(pool *db.Pool) *DomainRepo {
2424
return &DomainRepo{pool: pool}
2525
}
2626

27+
// FindActiveDomainsByName 按 domain 字段查找启用的域名(可能返回多条:代表配置冲突)
28+
func (r *DomainRepo) FindActiveDomainsByName(ctx context.Context, name string) ([]models.Domain, error) {
29+
query := `SELECT id, user_id, domain, is_default, is_active, created_at, updated_at FROM domains WHERE domain = $1 AND is_active = true`
30+
rows, err := r.pool.Query(ctx, query, name)
31+
if err != nil {
32+
return nil, fmt.Errorf("find domains failed: %w", err)
33+
}
34+
defer rows.Close()
35+
36+
var out []models.Domain
37+
for rows.Next() {
38+
var d models.Domain
39+
if err := rows.Scan(&d.ID, &d.UserID, &d.Domain, &d.IsDefault, &d.IsActive, &d.CreatedAt, &d.UpdatedAt); err != nil {
40+
return nil, fmt.Errorf("scan domain failed: %w", err)
41+
}
42+
out = append(out, d)
43+
}
44+
return out, nil
45+
}
46+
2747
// GetDomainByID 根据ID获取域名
2848
func (r *DomainRepo) GetDomainByID(ctx context.Context, domainID int64) (*models.Domain, error) {
2949
d := &models.Domain{}

internal/repo/link_repo.go

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,77 @@ func (r *LinkRepo) CreateLink(ctx context.Context, link *models.Link) error {
6363
return nil
6464
}
6565

66+
// GetLinkByCode 根据 code + domain_id 获取链接
67+
func (r *LinkRepo) GetLinkByCode(ctx context.Context, code string, domainID int64) (*models.Link, error) {
68+
l := &models.Link{}
69+
query := `
70+
SELECT id, user_id, domain_id, code, original_url, title, hash, qr_code, click_count, created_at, updated_at
71+
FROM links
72+
WHERE code = $1 AND domain_id = $2
73+
LIMIT 1
74+
`
75+
err := r.pool.QueryRow(ctx, query, code, domainID).Scan(
76+
&l.ID,
77+
&l.UserID,
78+
&l.DomainID,
79+
&l.Code,
80+
&l.OriginalURL,
81+
&l.Title,
82+
&l.Hash,
83+
&l.QRCode,
84+
&l.ClickCount,
85+
&l.CreatedAt,
86+
&l.UpdatedAt,
87+
)
88+
if errors.Is(err, pgx.ErrNoRows) {
89+
return nil, ErrNotFound
90+
}
91+
if err != nil {
92+
return nil, fmt.Errorf("get link by code failed: %w", err)
93+
}
94+
return l, nil
95+
}
96+
97+
// GetLinkByCodeAnyDomain 兼容:按 code 查询任意域名(最多返回 limit 条,用于歧义判断)
98+
func (r *LinkRepo) GetLinkByCodeAnyDomain(ctx context.Context, code string, limit int) ([]models.Link, error) {
99+
if limit <= 0 {
100+
limit = 2
101+
}
102+
query := `
103+
SELECT id, user_id, domain_id, code, original_url, title, hash, qr_code, click_count, created_at, updated_at
104+
FROM links
105+
WHERE code = $1
106+
LIMIT $2
107+
`
108+
rows, err := r.pool.Query(ctx, query, code, limit)
109+
if err != nil {
110+
return nil, fmt.Errorf("get link by code any domain failed: %w", err)
111+
}
112+
defer rows.Close()
113+
114+
var out []models.Link
115+
for rows.Next() {
116+
var l models.Link
117+
if err := rows.Scan(
118+
&l.ID,
119+
&l.UserID,
120+
&l.DomainID,
121+
&l.Code,
122+
&l.OriginalURL,
123+
&l.Title,
124+
&l.Hash,
125+
&l.QRCode,
126+
&l.ClickCount,
127+
&l.CreatedAt,
128+
&l.UpdatedAt,
129+
); err != nil {
130+
return nil, fmt.Errorf("scan link failed: %w", err)
131+
}
132+
out = append(out, l)
133+
}
134+
return out, nil
135+
}
136+
66137
// GetLinkByHashUserDomain 幂等检查:按 (hash, user_id, domain_id)
67138
func (r *LinkRepo) GetLinkByHashUserDomain(ctx context.Context, hash string, userID int64, domainID int64) (*models.Link, error) {
68139
l := &models.Link{}

0 commit comments

Comments
 (0)