11package ast
22
33import (
4+ "bytes"
5+ "fmt"
46 "go/ast"
7+ "go/format"
8+ "go/parser"
9+ "go/token"
510 "io"
611)
712
@@ -11,9 +16,10 @@ type PluginInitializeGorm struct {
1116 Path string // 文件路径
1217 ImportPath string // 导包路径
1318 RelativePath string // 相对路径
19+ Business string // 业务库
1420 StructName string // 结构体名称
1521 PackageName string // 包名
16- IsNew bool // 是否使用new关键字 true: new(PackageName.StructName) false: &PackageName.StructName{}
22+ IsNew bool // 是否使用 new 关键字 true: new(PackageName.StructName) false: &PackageName.StructName{}
1723}
1824
1925func (a * PluginInitializeGorm ) Parse (filename string , writer io.Writer ) (file * ast.File , err error ) {
@@ -44,7 +50,7 @@ func (a *PluginInitializeGorm) Rollback(file *ast.File) error {
4450 if len (callExpr .Args ) <= 1 {
4551 needRollBackImport = true
4652 }
47- // 删除指定的参数
53+ // 删除指定参数
4854 for i , arg := range callExpr .Args {
4955 compLit , cok := arg .(* ast.CompositeLit )
5056 if ! cok {
@@ -76,30 +82,44 @@ func (a *PluginInitializeGorm) Rollback(file *ast.File) error {
7682
7783func (a * PluginInitializeGorm ) Injection (file * ast.File ) error {
7884 _ = NewImport (a .ImportPath ).Injection (file )
79- var call * ast.CallExpr
85+
86+ var targetCall * ast.CallExpr
8087 ast .Inspect (file , func (n ast.Node ) bool {
8188 callExpr , ok := n .(* ast.CallExpr )
8289 if ! ok {
8390 return true
8491 }
8592
8693 selExpr , ok := callExpr .Fun .(* ast.SelectorExpr )
87- if ok && selExpr .Sel .Name == "AutoMigrate" {
88- call = callExpr
94+ if ! ok || selExpr .Sel .Name != "AutoMigrate" {
95+ return true
96+ }
97+
98+ if a .isTargetAutoMigrateCall (callExpr ) {
99+ targetCall = callExpr
89100 return false
90101 }
91102
92103 return true
93104 })
94105
95- arg := & ast.CompositeLit {
106+ if targetCall == nil {
107+ targetCall = a .appendAutoMigrateBlock (file )
108+ }
109+ if targetCall == nil {
110+ return nil
111+ }
112+
113+ if a .hasModelArg (targetCall ) {
114+ return nil
115+ }
116+
117+ targetCall .Args = append (targetCall .Args , & ast.CompositeLit {
96118 Type : & ast.SelectorExpr {
97119 X : & ast.Ident {Name : a .PackageName },
98120 Sel : & ast.Ident {Name : a .StructName },
99121 },
100- }
101-
102- call .Args = append (call .Args , arg )
122+ })
103123 return nil
104124}
105125
@@ -109,3 +129,95 @@ func (a *PluginInitializeGorm) Format(filename string, writer io.Writer, file *a
109129 }
110130 return a .Base .Format (filename , writer , file )
111131}
132+
133+ func (a * PluginInitializeGorm ) isTargetAutoMigrateCall (callExpr * ast.CallExpr ) bool {
134+ selExpr , ok := callExpr .Fun .(* ast.SelectorExpr )
135+ if ! ok || selExpr .Sel .Name != "AutoMigrate" {
136+ return false
137+ }
138+ return exprString (selExpr .X ) == exprString (a .autoMigrateReceiverExpr ())
139+ }
140+
141+ func (a * PluginInitializeGorm ) appendAutoMigrateBlock (file * ast.File ) * ast.CallExpr {
142+ gormFunc := FindFunction (file , "Gorm" )
143+ if gormFunc == nil || gormFunc .Body == nil {
144+ return nil
145+ }
146+
147+ src := fmt .Sprintf (`package placeholder
148+ func Gorm() {
149+ if err = %s.AutoMigrate(); err != nil {
150+ err = errors.Wrap(err, "注册表失败!")
151+ zap.L().Error(fmt.Sprintf("%%+v", err))
152+ }
153+ }
154+ ` , exprString (a .autoMigrateReceiverExpr ()))
155+
156+ parsed , err := parser .ParseFile (token .NewFileSet (), "" , src , 0 )
157+ if err != nil || len (parsed .Decls ) == 0 {
158+ return nil
159+ }
160+
161+ stmt := parsed .Decls [0 ].(* ast.FuncDecl ).Body .List [0 ].(* ast.IfStmt )
162+ clearPosition (stmt )
163+ gormFunc .Body .List = append (gormFunc .Body .List , stmt )
164+
165+ assignStmt := stmt .Init .(* ast.AssignStmt )
166+ callExpr := assignStmt .Rhs [0 ].(* ast.CallExpr )
167+ return callExpr
168+ }
169+
170+ func (a * PluginInitializeGorm ) autoMigrateReceiverExpr () ast.Expr {
171+ return & ast.CallExpr {
172+ Fun : & ast.SelectorExpr {
173+ X : a .dbExpr (),
174+ Sel : & ast.Ident {Name : "WithContext" },
175+ },
176+ Args : []ast.Expr {& ast.Ident {Name : "ctx" }},
177+ }
178+ }
179+
180+ func (a * PluginInitializeGorm ) dbExpr () ast.Expr {
181+ if a .Business == "" {
182+ return & ast.SelectorExpr {
183+ X : & ast.Ident {Name : "global" },
184+ Sel : & ast.Ident {Name : "GVA_DB" },
185+ }
186+ }
187+ return & ast.CallExpr {
188+ Fun : & ast.SelectorExpr {
189+ X : & ast.Ident {Name : "global" },
190+ Sel : & ast.Ident {Name : "MustGetGlobalDBByDBName" },
191+ },
192+ Args : []ast.Expr {
193+ & ast.BasicLit {
194+ Kind : token .STRING ,
195+ Value : fmt .Sprintf ("\" %s\" " , a .Business ),
196+ },
197+ },
198+ }
199+ }
200+
201+ func (a * PluginInitializeGorm ) hasModelArg (callExpr * ast.CallExpr ) bool {
202+ for _ , arg := range callExpr .Args {
203+ compositeLit , ok := arg .(* ast.CompositeLit )
204+ if ! ok {
205+ continue
206+ }
207+ selectorExpr , ok := compositeLit .Type .(* ast.SelectorExpr )
208+ if ! ok {
209+ continue
210+ }
211+ packageIdent , ok := selectorExpr .X .(* ast.Ident )
212+ if ok && packageIdent .Name == a .PackageName && selectorExpr .Sel .Name == a .StructName {
213+ return true
214+ }
215+ }
216+ return false
217+ }
218+
219+ func exprString (expr ast.Expr ) string {
220+ var buffer bytes.Buffer
221+ _ = format .Node (& buffer , token .NewFileSet (), expr )
222+ return buffer .String ()
223+ }
0 commit comments