Skip to content

Commit f409e94

Browse files
committed
opt: define Import
1 parent 96dd7b5 commit f409e94

File tree

16 files changed

+462
-122
lines changed

16 files changed

+462
-122
lines changed

src/compress/golang/plugin/parse/ctx.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ type importInfo struct {
267267
SysImports map[string]string
268268
ProjectImports map[string]string
269269
ThirdPartyImports map[string][2]string // 0-mod, 1-import
270-
Origins []string
270+
Origins []Import
271271
}
272272

273273
func (p *GoParser) mockTypes(typ ast.Expr, m map[string]Identity, file []byte, fset *token.FileSet, fpath string, mod string, pkg string, impts *importInfo) (name string, isPointer bool) {

src/compress/golang/plugin/parse/file.go

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -237,16 +237,16 @@ func (p *GoParser) parseSelector(ctx *fileContext, expr *ast.SelectorExpr, infos
237237
// // fmt.Fprintf(os.Stderr, "failed to get type id for %s\n", expr.Name)
238238
// return false
239239
// }
240-
*infos.tys = Dedup(*infos.tys, dep)
240+
*infos.tys = InsertDependency(*infos.tys, dep)
241241
// global var
242242
} else if _, ok := v.(*types.Const); ok {
243-
*infos.globalVars = Dedup(*infos.globalVars, dep)
243+
*infos.globalVars = InsertDependency(*infos.globalVars, dep)
244244
// external const
245245
} else if _, ok := v.(*types.Var); ok {
246-
*infos.globalVars = Dedup(*infos.globalVars, dep)
246+
*infos.globalVars = InsertDependency(*infos.globalVars, dep)
247247
// external function
248248
} else if _, ok := v.(*types.Func); ok {
249-
*infos.functionCalls = Dedup(*infos.functionCalls, dep)
249+
*infos.functionCalls = InsertDependency(*infos.functionCalls, dep)
250250
}
251251
return false
252252
}
@@ -289,7 +289,7 @@ func (p *GoParser) parseSelector(ctx *fileContext, expr *ast.SelectorExpr, infos
289289
if err := p.referCodes(ctx, &id, p.opts.ReferCodeDepth); err != nil {
290290
fmt.Fprintf(os.Stderr, "failed to get refer code for %s: %v\n", id.Name, err)
291291
}
292-
*infos.methodCalls = Dedup(*infos.methodCalls, dep)
292+
*infos.methodCalls = InsertDependency(*infos.methodCalls, dep)
293293
return false
294294
}
295295

@@ -380,24 +380,24 @@ func (p *GoParser) parseFunc(ctx *fileContext, funcDecl *ast.FuncDecl) (*Functio
380380
// // fmt.Fprintf(os.Stderr, "failed to get type id for %s\n", expr.Name)
381381
// return false
382382
// }
383-
tys = Dedup(tys, dep)
383+
tys = InsertDependency(tys, dep)
384384
// global var
385385
} else if v, ok := use.(*types.Var); ok {
386386
// NOTICE: the Parent of global scope is nil?
387387
if isPkgScope(v.Parent()) {
388-
globalVars = Dedup(globalVars, dep)
388+
globalVars = InsertDependency(globalVars, dep)
389389
}
390390
// global const
391391
} else if c, ok := use.(*types.Const); ok {
392392
if isPkgScope(c.Parent()) {
393-
globalVars = Dedup(globalVars, dep)
393+
globalVars = InsertDependency(globalVars, dep)
394394
}
395395
return false
396396
// function
397397
} else if f, ok := use.(*types.Func); ok {
398398
// exclude method
399399
if f.Type().(*types.Signature).Recv() == nil {
400-
functionCalls = Dedup(functionCalls, dep)
400+
functionCalls = InsertDependency(functionCalls, dep)
401401
}
402402
}
403403
}

src/compress/golang/plugin/parse/pkg.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,15 @@ func (p *GoParser) parseImports(fset *token.FileSet, file []byte, mod *Module, i
3333
sysImports := make(map[string]string)
3434
ret := &importInfo{}
3535
for _, imp := range impts {
36-
ret.Origins = append(ret.Origins, string(GetRawContent(fset, file, imp, p.opts.CollectComment)))
3736
importPath := imp.Path.Value[1 : len(imp.Path.Value)-1] // remove the quotes
3837
importAlias := ""
3938
// Check if user has defined an alias for current import
4039
if imp.Name != nil {
4140
importAlias = imp.Name.Name // update the alias
41+
ret.Origins = append(ret.Origins, Import{Path: imp.Path.Value, Alias: &importAlias})
4242
} else {
4343
importAlias = getPackageAlias(importPath)
44+
ret.Origins = append(ret.Origins, Import{Path: imp.Path.Value})
4445
}
4546

4647
// Fix: module name may also be like this?

src/lang/collect/export.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ func (c *Collector) exportSymbol(repo *uniast.Repository, symbol *DocumentSymbol
178178
continue
179179
}
180180
dep := uniast.NewDependency(*tyid, c.fileLine(input.Location))
181-
obj.Types = uniast.Dedup(obj.Types, dep)
181+
obj.Types = uniast.InsertDependency(obj.Types, dep)
182182
}
183183
}
184184
if info.Inputs != nil {
@@ -190,7 +190,7 @@ func (c *Collector) exportSymbol(repo *uniast.Repository, symbol *DocumentSymbol
190190
continue
191191
}
192192
dep := uniast.NewDependency(*tyid, c.fileLine(input.Location))
193-
obj.Params = uniast.Dedup(obj.Params, dep)
193+
obj.Params = uniast.InsertDependency(obj.Params, dep)
194194
}
195195
}
196196
if info.Outputs != nil {
@@ -202,7 +202,7 @@ func (c *Collector) exportSymbol(repo *uniast.Repository, symbol *DocumentSymbol
202202
continue
203203
}
204204
dep := uniast.NewDependency(*tyid, c.fileLine(output.Location))
205-
obj.Results = uniast.Dedup(obj.Results, dep)
205+
obj.Results = uniast.InsertDependency(obj.Results, dep)
206206
}
207207
}
208208
if info.Method != nil && info.Method.Receiver.Symbol != nil {
@@ -253,23 +253,23 @@ func (c *Collector) exportSymbol(repo *uniast.Repository, symbol *DocumentSymbol
253253
pdep := uniast.NewDependency(*depid, c.fileLine(dep.Location))
254254
switch dep.Symbol.Kind {
255255
case lsp.SKFunction:
256-
obj.FunctionCalls = uniast.Dedup(obj.FunctionCalls, pdep)
256+
obj.FunctionCalls = uniast.InsertDependency(obj.FunctionCalls, pdep)
257257
case lsp.SKMethod:
258258
if obj.MethodCalls == nil {
259259
obj.MethodCalls = make([]uniast.Dependency, 0, len(deps))
260260
}
261261
// NOTICE: use loc token as key here, to make it more readable
262-
obj.MethodCalls = uniast.Dedup(obj.MethodCalls, pdep)
262+
obj.MethodCalls = uniast.InsertDependency(obj.MethodCalls, pdep)
263263
case lsp.SKVariable, lsp.SKConstant:
264264
if obj.GlobalVars == nil {
265265
obj.GlobalVars = make([]uniast.Dependency, 0, len(deps))
266266
}
267-
obj.GlobalVars = uniast.Dedup(obj.GlobalVars, pdep)
267+
obj.GlobalVars = uniast.InsertDependency(obj.GlobalVars, pdep)
268268
case lsp.SKStruct, lsp.SKTypeParameter, lsp.SKInterface, lsp.SKEnum:
269269
if obj.Types == nil {
270270
obj.Types = make([]uniast.Dependency, 0, len(deps))
271271
}
272-
obj.Types = uniast.Dedup(obj.Types, pdep)
272+
obj.Types = uniast.InsertDependency(obj.Types, pdep)
273273
default:
274274
log.Error("dep symbol %s not collected for %v\n", dep.Symbol, id)
275275
}

src/lang/golang/writer/ast.go

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
/**
2+
* Copyright 2025 ByteDance Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package writer
18+
19+
import (
20+
"strings"
21+
22+
"github.com/cloudwego/abcoder/src/uniast"
23+
)
24+
25+
func writeImport(sb *strings.Builder, impts []uniast.Import) {
26+
if len(impts) == 0 {
27+
return
28+
}
29+
sb.WriteString("import ")
30+
if len(impts) == 1 {
31+
writeSingleImport(sb, impts[0])
32+
return
33+
}
34+
sb.WriteString("(\n")
35+
for i := 0; i < len(impts); i++ {
36+
sb.WriteString("\t")
37+
writeSingleImport(sb, impts[i])
38+
}
39+
sb.WriteString(")\n")
40+
}
41+
42+
func writeSingleImport(sb *strings.Builder, v uniast.Import) {
43+
if v.Alias != nil {
44+
sb.WriteString(*v.Alias)
45+
sb.WriteString(" ")
46+
}
47+
sb.WriteString(v.Path)
48+
sb.WriteString("\n")
49+
}
50+
51+
// merge the imports of file and nodes, and return the merged imports
52+
// file is in priority (because it contains alias)
53+
func mergeImports(priors []uniast.Import, subs []uniast.Import) (ret []uniast.Import) {
54+
visited := make(map[string]bool, len(priors)+len(subs))
55+
ret = make([]uniast.Import, 0, len(priors)+len(subs))
56+
for _, v := range priors {
57+
58+
if visited[v.Path] {
59+
continue
60+
} else {
61+
visited[v.Path] = true
62+
ret = append(ret, v)
63+
}
64+
}
65+
for _, v := range subs {
66+
if visited[v.Path] {
67+
continue
68+
} else {
69+
visited[v.Path] = true
70+
ret = append(ret, v)
71+
}
72+
}
73+
return
74+
}

src/lang/golang/writer/write.go

Lines changed: 58 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,14 @@ import (
2727
"strconv"
2828
"strings"
2929

30+
"github.com/cloudwego/abcoder/src/lang/utils"
3031
"github.com/cloudwego/abcoder/src/uniast"
3132
)
3233

3334
var _ uniast.Writer = (*Writer)(nil)
3435

3536
type Options struct {
37+
RepoDir string
3638
OutDir string
3739
GoVersion string
3840
}
@@ -44,7 +46,7 @@ type Writer struct {
4446

4547
type fileNode struct {
4648
chunks []chunk
47-
impts []string
49+
impts []uniast.Import
4850
}
4951

5052
type chunk struct {
@@ -101,19 +103,13 @@ func (w *Writer) WriteModule(repo *uniast.Repository, modPath string) error {
101103
}
102104
sb.WriteString("\n\n")
103105

104-
var fimpts []string
106+
var fimpts []uniast.Import
105107
if fi, ok := mod.Files[filepath.Join(mod.Dir, rel, fpath)]; ok && fi.Imports != nil {
106108
fimpts = fi.Imports
107109
}
108-
impts := w.mergeImports(fimpts, f.impts)
110+
impts := mergeImports(fimpts, f.impts)
109111
if len(impts) > 0 {
110-
sb.WriteString("import (\n")
111-
for _, v := range impts {
112-
sb.WriteString("\t")
113-
sb.WriteString(v)
114-
sb.WriteString("\n")
115-
}
116-
sb.WriteString(")\n\n")
112+
writeImport(&sb, impts)
117113
}
118114

119115
sort.SliceStable(f.chunks, func(i, j int) bool {
@@ -208,15 +204,15 @@ func (w *Writer) appendNode(node *uniast.Node, pkg string, isMain bool, file str
208204
if fs == nil {
209205
fs = &fileNode{
210206
chunks: make([]chunk, 0, len(node.Dependencies)),
211-
impts: make([]string, 0, len(node.Dependencies)),
207+
impts: make([]uniast.Import, 0, len(node.Dependencies)),
212208
}
213209
p[fpath] = fs
214210
}
215211
for _, v := range node.Dependencies {
216212
if v.Target.PkgPath == "" || v.Target.PkgPath == pkg {
217213
continue
218214
}
219-
fs.impts = append(fs.impts, strconv.Quote(v.Target.PkgPath))
215+
fs.impts = append(fs.impts, uniast.Import{Path: strconv.Quote(v.Target.PkgPath)})
220216
}
221217

222218
// 检查是否有imports
@@ -235,19 +231,19 @@ func (w *Writer) appendNode(node *uniast.Node, pkg string, isMain bool, file str
235231
}
236232

237233
// receive a piece of golang code, parse it and splits the imports and codes
238-
func (w Writer) SplitImportsAndCodes(src string) (codes string, imports []string, err error) {
234+
func (w Writer) SplitImportsAndCodes(src string) (codes string, imports []uniast.Import, err error) {
239235
fset := token.NewFileSet()
240236
f, err := parser.ParseFile(fset, "", src, parser.SkipObjectResolution)
241237
if err != nil {
242238
// NOTICE: if parse failed, just return the src
243239
return src, nil, nil
244240
}
245241
for _, imp := range f.Imports {
246-
var impt = imp.Path.Value
242+
var alias string
247243
if imp.Name != nil {
248-
impt = fmt.Sprintf("%s %s", imp.Name.Name, impt)
244+
alias = imp.Name.Name
249245
}
250-
imports = append(imports, impt)
246+
imports = append(imports, uniast.Import{Path: imp.Path.Value, Alias: &alias})
251247
}
252248
start := 0
253249
for _, s := range f.Decls {
@@ -260,42 +256,58 @@ func (w Writer) SplitImportsAndCodes(src string) (codes string, imports []string
260256
return src[start:], imports, nil
261257
}
262258

263-
func (w *Writer) IdToImport(id uniast.Identity) (string, error) {
264-
return strconv.Quote(id.PkgPath), nil
259+
func (w *Writer) IdToImport(id uniast.Identity) (uniast.Import, error) {
260+
return uniast.Import{Path: strconv.Quote(id.PkgPath)}, nil
265261
}
266262

267-
// merge the imports of file and nodes, and return the merged imports
268-
// file is in priority (because it contains alias)
269-
func (w *Writer) mergeImports(priors []string, subs []string) (ret []string) {
270-
visited := make(map[string]bool, len(priors)+len(subs))
271-
ret = make([]string, 0, len(priors)+len(subs))
272-
for _, v := range priors {
273-
sp := strings.Split(v, " ")
274-
var impt = sp[0]
275-
if len(sp) >= 2 {
276-
impt = sp[1]
263+
func (p *Writer) PatchImports(file *uniast.File) ([]byte, error) {
264+
bs, err := os.ReadFile(filepath.Join(p.Options.RepoDir, file.Path))
265+
if err != nil {
266+
return nil, utils.WrapError(err, "fail read file %s", file.Path)
267+
}
268+
269+
fs := token.NewFileSet()
270+
f, err := parser.ParseFile(fs, file.Path, bs, parser.ImportsOnly)
271+
if err != nil {
272+
return nil, utils.WrapError(err, "fail parse file %s", file.Path)
273+
}
274+
275+
old := make([]uniast.Import, 0, len(f.Imports))
276+
for _, imp := range f.Imports {
277+
i := uniast.Import{
278+
Path: imp.Path.Value,
277279
}
278-
key, _ := strconv.Unquote(impt)
279-
if visited[key] {
280-
continue
281-
} else {
282-
visited[key] = true
283-
ret = append(ret, v)
280+
if imp.Name != nil {
281+
tmp := imp.Name.Name
282+
i.Alias = &tmp
284283
}
284+
old = append(old, i)
285+
}
286+
287+
impts := mergeImports(old, file.Imports)
288+
if len(impts) == len(old) {
289+
return bs, nil
285290
}
286-
for _, v := range subs {
287-
sp := strings.Split(v, " ")
288-
var impt = sp[0]
289-
if len(sp) >= 2 {
290-
impt = sp[1]
291+
292+
var sb strings.Builder
293+
writeImport(&sb, impts)
294+
final := sb.String()
295+
296+
imptStart := fs.Position(f.Name.End()).Offset + 1
297+
if len(f.Imports) > 0 {
298+
for imptStart < len(bs) && bs[imptStart] != 'i' {
299+
imptStart++
291300
}
292-
key, _ := strconv.Unquote(impt)
293-
if visited[key] {
294-
continue
295-
} else {
296-
visited[key] = true
297-
ret = append(ret, v)
301+
}
302+
imptEnd := imptStart
303+
if len(f.Imports) > 1 {
304+
imptEnd = fs.Position(f.Imports[len(f.Imports)-1].End()).Offset
305+
for len(old) > 1 && imptEnd < len(bs) && (bs[imptEnd] != ')') {
306+
imptEnd++
298307
}
308+
imptEnd += 2 // for `)`
299309
}
300-
return
310+
r1 := append(bs[:imptStart:imptStart], final...)
311+
ret := append(r1, bs[imptEnd:]...)
312+
return ret, nil
301313
}

0 commit comments

Comments
 (0)