@@ -13,18 +13,25 @@ import (
1313 "strings"
1414
1515 "github.com/mpyw/sqlc-restruct/pkg/actions/separate-interface/internal/astutil"
16+ "golang.org/x/exp/slices"
1617)
1718
1819type runner struct {
19- input ActionInput
20- fset * token.FileSet
20+ input ActionInput
21+ fset * token.FileSet
22+ exportedSymbolsInModels []string
2123}
2224
2325func (r * runner ) Run () error {
2426 pkg , err := build .Import ("." , r .input .ImplDir , build .IgnoreVendor )
2527 if err != nil {
2628 return fmt .Errorf ("runner.Run() failed: %w" , err )
2729 }
30+ f , err := parser .ParseFile (r .fset , path .Join (r .input .ImplDir , r .input .ModelsFileName ), nil , parser .ParseComments )
31+ if err != nil {
32+ return fmt .Errorf ("runner.Run() failed: %w" , err )
33+ }
34+ r .exportedSymbolsInModels = astutil .SymbolNameFromTypeOrValueDecls (astutil .ExportedIndividualTypeOrValueDecls (f .Decls ... )... )
2835
2936 var newModelsContent []byte
3037 var newQuerierContent []byte
@@ -58,8 +65,8 @@ func (r *runner) Run() error {
5865 }
5966
6067 if newModelsContent != nil {
61- _ = os .Remove (path .Join (r .input .IfaceDir , r .input .ModelsFileName ))
62- if err := os .WriteFile (path .Join (r .input .IfaceDir , r .input .ModelsFileName ), newModelsContent , 0644 ); err != nil {
68+ _ = os .Remove (path .Join (r .input .ModelsDir , r .input .ModelsFileName ))
69+ if err := os .WriteFile (path .Join (r .input .ModelsDir , r .input .ModelsFileName ), newModelsContent , 0644 ); err != nil {
6370 return fmt .Errorf ("runner.Run() failed: %w" , err )
6471 }
6572 _ = os .Remove (path .Join (r .input .ImplDir , r .input .ModelsFileName ))
@@ -96,7 +103,7 @@ func (r *runner) newModelsContent() ([]byte, error) {
96103 }
97104
98105 // Change package name of "models" file
99- f .Name = ast .NewIdent (r .input .IfacePkgName )
106+ f .Name = ast .NewIdent (r .input .ModelsPkgName )
100107
101108 byt , err := r .intoBytes (f )
102109 if err != nil {
@@ -114,6 +121,19 @@ func (r *runner) newQuerierContent() ([]byte, error) {
114121 // Change package name of "querier" file
115122 f .Name = ast .NewIdent (r .input .IfacePkgName )
116123
124+ // Prepend import statement of ModelsPkgURL
125+ if r .input .ModelsPkgURL != r .input .IfacePkgURL {
126+ f .Decls = append (append (([]ast.Decl )(nil ), & ast.GenDecl {
127+ Tok : token .IMPORT ,
128+ Specs : []ast.Spec {& ast.ImportSpec {
129+ Path : & ast.BasicLit {
130+ Kind : token .STRING ,
131+ Value : fmt .Sprintf ("%#v" , r .input .ModelsPkgURL ),
132+ },
133+ }},
134+ }), f .Decls ... )
135+ }
136+
117137 // Remove top level constraint: var _ Querier = (*Querier)(nil)
118138 for i , decl := range f .Decls {
119139 if decl , ok := decl .(* ast.GenDecl ); ok && decl .Tok == token .VAR {
@@ -126,6 +146,22 @@ func (r *runner) newQuerierContent() ([]byte, error) {
126146 }
127147 }
128148
149+ // Qualify exported references
150+ if r .input .ModelsPkgURL != r .input .IfacePkgURL {
151+ ast .Walk (
152+ astutil .NewExportedExprIdentUpdater (func (ident * ast.Ident ) ast.Expr {
153+ if slices .Contains (r .exportedSymbolsInModels , ident .Name ) {
154+ return & ast.SelectorExpr {
155+ X : ast .NewIdent (r .input .ModelsPkgName ),
156+ Sel : ident ,
157+ }
158+ }
159+ return nil
160+ }),
161+ f ,
162+ )
163+ }
164+
129165 dirEntries , err := os .ReadDir (r .input .ImplDir )
130166 if err != nil {
131167 return nil , fmt .Errorf ("runner.newQuerierContent() failed: %w" , err )
@@ -172,11 +208,28 @@ func (r *runner) newQueriesContent(filename string) ([]byte, error) {
172208 }},
173209 }), f .Decls ... )
174210
211+ // Prepend import statement of ModelsPkgURL
212+ if r .input .ModelsPkgURL != r .input .IfacePkgURL {
213+ f .Decls = append (append (([]ast.Decl )(nil ), & ast.GenDecl {
214+ Tok : token .IMPORT ,
215+ Specs : []ast.Spec {& ast.ImportSpec {
216+ Path : & ast.BasicLit {
217+ Kind : token .STRING ,
218+ Value : fmt .Sprintf ("%#v" , r .input .ModelsPkgURL ),
219+ },
220+ }},
221+ }), f .Decls ... )
222+ }
223+
175224 // Qualify exported references
176225 ast .Walk (
177226 astutil .NewExportedExprIdentUpdater (func (ident * ast.Ident ) ast.Expr {
227+ pkgName := r .input .IfacePkgName
228+ if slices .Contains (r .exportedSymbolsInModels , ident .Name ) {
229+ pkgName = r .input .ModelsPkgName
230+ }
178231 return & ast.SelectorExpr {
179- X : ast .NewIdent (r . input . IfacePkgName ),
232+ X : ast .NewIdent (pkgName ),
180233 Sel : ident ,
181234 }
182235 }),
0 commit comments