Skip to content

Commit 15fdb73

Browse files
committed
feat: 添加BusinessDB字段到PluginInitializeGorm结构体,并增加相关测试用例
1 parent 8ce61a1 commit 15fdb73

File tree

4 files changed

+270
-9
lines changed

4 files changed

+270
-9
lines changed

server/service/system/auto_code_package.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,7 @@ func (s *autoCodePackage) templates(ctx context.Context, entity model.SysAutoCod
571571
Type: ast.TypePluginInitializeGorm,
572572
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin", entity.PackageName, secondDirs[j].Name(), strings.TrimSuffix(threeDirs[k].Name(), ext)),
573573
ImportPath: fmt.Sprintf(`"%s/plugin/%s/model"`, global.GVA_CONFIG.AutoCode.Module, entity.PackageName),
574+
Business: info.BusinessDB,
574575
StructName: info.StructName,
575576
PackageName: "model",
576577
IsNew: true,
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
package system
2+
3+
import (
4+
"context"
5+
"path/filepath"
6+
"reflect"
7+
"runtime"
8+
"testing"
9+
10+
"github.com/flipped-aurora/gin-vue-admin/server/global"
11+
model "github.com/flipped-aurora/gin-vue-admin/server/model/system"
12+
"github.com/flipped-aurora/gin-vue-admin/server/model/system/request"
13+
utilsAst "github.com/flipped-aurora/gin-vue-admin/server/utils/ast"
14+
)
15+
16+
func TestPluginInitializeGormInjectionCarriesBusinessDB(t *testing.T) {
17+
_, currentFile, _, ok := runtime.Caller(0)
18+
if !ok {
19+
t.Fatal("runtime.Caller() failed")
20+
}
21+
repoRoot := filepath.Clean(filepath.Join(filepath.Dir(currentFile), "..", "..", ".."))
22+
23+
oldRoot := global.GVA_CONFIG.AutoCode.Root
24+
oldServer := global.GVA_CONFIG.AutoCode.Server
25+
oldModule := global.GVA_CONFIG.AutoCode.Module
26+
global.GVA_CONFIG.AutoCode.Root = repoRoot
27+
global.GVA_CONFIG.AutoCode.Server = "server"
28+
global.GVA_CONFIG.AutoCode.Module = "github.com/flipped-aurora/gin-vue-admin/server"
29+
defer func() {
30+
global.GVA_CONFIG.AutoCode.Root = oldRoot
31+
global.GVA_CONFIG.AutoCode.Server = oldServer
32+
global.GVA_CONFIG.AutoCode.Module = oldModule
33+
}()
34+
35+
info := request.AutoCode{
36+
Package: "demoPlugin",
37+
PackageName: "demo",
38+
HumpPackageName: "demo",
39+
StructName: "Demo",
40+
Abbreviation: "demo",
41+
BusinessDB: "bizdb",
42+
GenerateServer: true,
43+
}
44+
entity := model.SysAutoCodePackage{
45+
Template: "plugin",
46+
PackageName: info.Package,
47+
}
48+
49+
_, asts, _, err := AutoCodePackage.templates(context.Background(), entity, info, false)
50+
if err != nil {
51+
t.Fatalf("templates() error = %v", err)
52+
}
53+
54+
var pluginInitializeGorm *utilsAst.PluginInitializeGorm
55+
for _, injection := range asts {
56+
if candidate, ok := injection.(*utilsAst.PluginInitializeGorm); ok {
57+
pluginInitializeGorm = candidate
58+
break
59+
}
60+
}
61+
if pluginInitializeGorm == nil {
62+
t.Fatal("expected plugin initialize gorm injection")
63+
}
64+
65+
businessField := reflect.ValueOf(pluginInitializeGorm).Elem().FieldByName("Business")
66+
if !businessField.IsValid() || businessField.String() != info.BusinessDB {
67+
t.Fatalf("expected PluginInitializeGorm.Business = %q, got %v", info.BusinessDB, businessField)
68+
}
69+
}

server/utils/ast/plugin_initialize_gorm.go

Lines changed: 121 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
package ast
22

33
import (
4+
"bytes"
5+
"fmt"
46
"go/ast"
7+
"go/format"
8+
"go/parser"
9+
"go/token"
510
"io"
611
)
712

@@ -11,9 +16,10 @@ type PluginInitializeGorm struct {
1116
Path string // 文件路径
1217
ImportPath string // 导包路径
1318
RelativePath string // 相对路径
19+
Business string // 业务库
1420
StructName string // 结构体名称
1521
PackageName string // 包名
16-
IsNew bool // 是否使用new关键字 true: new(PackageName.StructName) false: &PackageName.StructName{}
22+
IsNew bool // 是否使用 new 关键字 true: new(PackageName.StructName) false: &PackageName.StructName{}
1723
}
1824

1925
func (a *PluginInitializeGorm) Parse(filename string, writer io.Writer) (file *ast.File, err error) {
@@ -44,7 +50,7 @@ func (a *PluginInitializeGorm) Rollback(file *ast.File) error {
4450
if len(callExpr.Args) <= 1 {
4551
needRollBackImport = true
4652
}
47-
// 删除指定的参数
53+
// 删除指定参数
4854
for i, arg := range callExpr.Args {
4955
compLit, cok := arg.(*ast.CompositeLit)
5056
if !cok {
@@ -76,30 +82,44 @@ func (a *PluginInitializeGorm) Rollback(file *ast.File) error {
7682

7783
func (a *PluginInitializeGorm) Injection(file *ast.File) error {
7884
_ = NewImport(a.ImportPath).Injection(file)
79-
var call *ast.CallExpr
85+
86+
var targetCall *ast.CallExpr
8087
ast.Inspect(file, func(n ast.Node) bool {
8188
callExpr, ok := n.(*ast.CallExpr)
8289
if !ok {
8390
return true
8491
}
8592

8693
selExpr, ok := callExpr.Fun.(*ast.SelectorExpr)
87-
if ok && selExpr.Sel.Name == "AutoMigrate" {
88-
call = callExpr
94+
if !ok || selExpr.Sel.Name != "AutoMigrate" {
95+
return true
96+
}
97+
98+
if a.isTargetAutoMigrateCall(callExpr) {
99+
targetCall = callExpr
89100
return false
90101
}
91102

92103
return true
93104
})
94105

95-
arg := &ast.CompositeLit{
106+
if targetCall == nil {
107+
targetCall = a.appendAutoMigrateBlock(file)
108+
}
109+
if targetCall == nil {
110+
return nil
111+
}
112+
113+
if a.hasModelArg(targetCall) {
114+
return nil
115+
}
116+
117+
targetCall.Args = append(targetCall.Args, &ast.CompositeLit{
96118
Type: &ast.SelectorExpr{
97119
X: &ast.Ident{Name: a.PackageName},
98120
Sel: &ast.Ident{Name: a.StructName},
99121
},
100-
}
101-
102-
call.Args = append(call.Args, arg)
122+
})
103123
return nil
104124
}
105125

@@ -109,3 +129,95 @@ func (a *PluginInitializeGorm) Format(filename string, writer io.Writer, file *a
109129
}
110130
return a.Base.Format(filename, writer, file)
111131
}
132+
133+
func (a *PluginInitializeGorm) isTargetAutoMigrateCall(callExpr *ast.CallExpr) bool {
134+
selExpr, ok := callExpr.Fun.(*ast.SelectorExpr)
135+
if !ok || selExpr.Sel.Name != "AutoMigrate" {
136+
return false
137+
}
138+
return exprString(selExpr.X) == exprString(a.autoMigrateReceiverExpr())
139+
}
140+
141+
func (a *PluginInitializeGorm) appendAutoMigrateBlock(file *ast.File) *ast.CallExpr {
142+
gormFunc := FindFunction(file, "Gorm")
143+
if gormFunc == nil || gormFunc.Body == nil {
144+
return nil
145+
}
146+
147+
src := fmt.Sprintf(`package placeholder
148+
func Gorm() {
149+
if err = %s.AutoMigrate(); err != nil {
150+
err = errors.Wrap(err, "注册表失败!")
151+
zap.L().Error(fmt.Sprintf("%%+v", err))
152+
}
153+
}
154+
`, exprString(a.autoMigrateReceiverExpr()))
155+
156+
parsed, err := parser.ParseFile(token.NewFileSet(), "", src, 0)
157+
if err != nil || len(parsed.Decls) == 0 {
158+
return nil
159+
}
160+
161+
stmt := parsed.Decls[0].(*ast.FuncDecl).Body.List[0].(*ast.IfStmt)
162+
clearPosition(stmt)
163+
gormFunc.Body.List = append(gormFunc.Body.List, stmt)
164+
165+
assignStmt := stmt.Init.(*ast.AssignStmt)
166+
callExpr := assignStmt.Rhs[0].(*ast.CallExpr)
167+
return callExpr
168+
}
169+
170+
func (a *PluginInitializeGorm) autoMigrateReceiverExpr() ast.Expr {
171+
return &ast.CallExpr{
172+
Fun: &ast.SelectorExpr{
173+
X: a.dbExpr(),
174+
Sel: &ast.Ident{Name: "WithContext"},
175+
},
176+
Args: []ast.Expr{&ast.Ident{Name: "ctx"}},
177+
}
178+
}
179+
180+
func (a *PluginInitializeGorm) dbExpr() ast.Expr {
181+
if a.Business == "" {
182+
return &ast.SelectorExpr{
183+
X: &ast.Ident{Name: "global"},
184+
Sel: &ast.Ident{Name: "GVA_DB"},
185+
}
186+
}
187+
return &ast.CallExpr{
188+
Fun: &ast.SelectorExpr{
189+
X: &ast.Ident{Name: "global"},
190+
Sel: &ast.Ident{Name: "MustGetGlobalDBByDBName"},
191+
},
192+
Args: []ast.Expr{
193+
&ast.BasicLit{
194+
Kind: token.STRING,
195+
Value: fmt.Sprintf("\"%s\"", a.Business),
196+
},
197+
},
198+
}
199+
}
200+
201+
func (a *PluginInitializeGorm) hasModelArg(callExpr *ast.CallExpr) bool {
202+
for _, arg := range callExpr.Args {
203+
compositeLit, ok := arg.(*ast.CompositeLit)
204+
if !ok {
205+
continue
206+
}
207+
selectorExpr, ok := compositeLit.Type.(*ast.SelectorExpr)
208+
if !ok {
209+
continue
210+
}
211+
packageIdent, ok := selectorExpr.X.(*ast.Ident)
212+
if ok && packageIdent.Name == a.PackageName && selectorExpr.Sel.Name == a.StructName {
213+
return true
214+
}
215+
}
216+
return false
217+
}
218+
219+
func exprString(expr ast.Expr) string {
220+
var buffer bytes.Buffer
221+
_ = format.Node(&buffer, token.NewFileSet(), expr)
222+
return buffer.String()
223+
}
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
package ast
2+
3+
import (
4+
"os"
5+
"path/filepath"
6+
"reflect"
7+
"strings"
8+
"testing"
9+
)
10+
11+
func TestPluginInitializeGormInjectionUsesBusinessDB(t *testing.T) {
12+
const source = `package initialize
13+
14+
import (
15+
"context"
16+
"fmt"
17+
18+
"github.com/flipped-aurora/gin-vue-admin/server/global"
19+
"github.com/pkg/errors"
20+
"go.uber.org/zap"
21+
)
22+
23+
func Gorm(ctx context.Context) {
24+
err := global.GVA_DB.WithContext(ctx).AutoMigrate()
25+
if err != nil {
26+
err = errors.Wrap(err, "注册表失败!")
27+
zap.L().Error(fmt.Sprintf("%+v", err))
28+
}
29+
}
30+
`
31+
32+
dir := t.TempDir()
33+
path := filepath.Join(dir, "gorm.go")
34+
if err := os.WriteFile(path, []byte(source), 0o666); err != nil {
35+
t.Fatalf("WriteFile() error = %v", err)
36+
}
37+
38+
injection := &PluginInitializeGorm{
39+
Type: TypePluginInitializeGorm,
40+
Path: path,
41+
ImportPath: `"github.com/flipped-aurora/gin-vue-admin/server/plugin/demo/model"`,
42+
StructName: "Demo",
43+
PackageName: "model",
44+
IsNew: true,
45+
}
46+
47+
businessField := reflect.ValueOf(injection).Elem().FieldByName("Business")
48+
if !businessField.IsValid() {
49+
t.Fatal("expected PluginInitializeGorm.Business field")
50+
}
51+
businessField.SetString("bizdb")
52+
53+
file, err := injection.Parse(path, nil)
54+
if err != nil {
55+
t.Fatalf("Parse() error = %v", err)
56+
}
57+
if err := injection.Injection(file); err != nil {
58+
t.Fatalf("Injection() error = %v", err)
59+
}
60+
if err := injection.Format(path, nil, file); err != nil {
61+
t.Fatalf("Format() error = %v", err)
62+
}
63+
64+
content, err := os.ReadFile(path)
65+
if err != nil {
66+
t.Fatalf("ReadFile() error = %v", err)
67+
}
68+
69+
got := string(content)
70+
if !strings.Contains(got, "global.GVA_DB.WithContext(ctx).AutoMigrate()") {
71+
t.Fatalf("expected default gorm block to remain, got:\n%s", got)
72+
}
73+
if !strings.Contains(got, `global.MustGetGlobalDBByDBName("bizdb").WithContext(ctx).AutoMigrate(`) {
74+
t.Fatalf("expected gorm injection to use business db, got:\n%s", got)
75+
}
76+
if !strings.Contains(got, "model.Demo{}") {
77+
t.Fatalf("expected model injection, got:\n%s", got)
78+
}
79+
}

0 commit comments

Comments
 (0)