diff --git a/go.mod b/go.mod index 141256f..91b20e1 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,6 @@ go 1.23.0 require ( github.com/Knetic/govaluate v3.0.0+incompatible - github.com/davecgh/go-spew v1.1.1 github.com/sourcegraph/go-lsp v0.0.0-20240223163137-f80c5dd31dfd github.com/sourcegraph/jsonrpc2 v0.2.0 github.com/stretchr/testify v1.10.0 @@ -13,6 +12,7 @@ require ( ) require ( + github.com/davecgh/go-spew v1.1.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect golang.org/x/sync v0.13.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/lang/golang/parser/pkg.go b/lang/golang/parser/pkg.go index 6184590..a669029 100644 --- a/lang/golang/parser/pkg.go +++ b/lang/golang/parser/pkg.go @@ -21,6 +21,7 @@ import ( "go/types" "os" "path/filepath" + "strconv" "strings" . "github.com/cloudwego/abcoder/lang/uniast" @@ -33,15 +34,15 @@ func (p *GoParser) parseImports(fset *token.FileSet, file []byte, mod *Module, i sysImports := make(map[string]string) ret := &importInfo{} for _, imp := range impts { - importPath := imp.Path.Value[1 : len(imp.Path.Value)-1] // remove the quotes + importPath, _ := strconv.Unquote(imp.Path.Value) // remove the quotes importAlias := "" // Check if user has defined an alias for current import if imp.Name != nil { importAlias = imp.Name.Name // update the alias - ret.Origins = append(ret.Origins, Import{Path: imp.Path.Value, Alias: &importAlias}) + ret.Origins = append(ret.Origins, Import{Path: importPath, Alias: &importAlias}) } else { importAlias = getPackageAlias(importPath) - ret.Origins = append(ret.Origins, Import{Path: imp.Path.Value}) + ret.Origins = append(ret.Origins, Import{Path: importPath}) } // Fix: module name may also be like this? @@ -212,7 +213,7 @@ func (p *GoParser) loadPackages(mod *Module, dir string, pkgPath PkgPath) (err e mod.Files[relpath] = f } pkgid := pkg.ID - f.Package = &pkgid + f.Package = []PkgPath{pkgid} f.Imports = imports.Origins if err := p.parseFile(ctx, file); err != nil { return err diff --git a/lang/golang/writer/ast.go b/lang/golang/writer/ast.go index 9dde1e2..e4b4673 100644 --- a/lang/golang/writer/ast.go +++ b/lang/golang/writer/ast.go @@ -17,6 +17,7 @@ package writer import ( + "strconv" "strings" "github.com/cloudwego/abcoder/lang/uniast" @@ -44,7 +45,7 @@ func writeSingleImport(sb *strings.Builder, v uniast.Import) { sb.WriteString(*v.Alias) sb.WriteString(" ") } - sb.WriteString(v.Path) + sb.WriteString(strconv.Quote(v.Path)) sb.WriteString("\n") } diff --git a/lang/golang/writer/write.go b/lang/golang/writer/write.go index 47a0492..295f75e 100644 --- a/lang/golang/writer/write.go +++ b/lang/golang/writer/write.go @@ -17,6 +17,8 @@ package writer import ( + "bytes" + "context" "fmt" "go/ast" "go/parser" @@ -29,7 +31,6 @@ import ( "strconv" "strings" - "github.com/cloudwego/abcoder/lang/log" "github.com/cloudwego/abcoder/lang/uniast" "github.com/cloudwego/abcoder/lang/utils" ) @@ -180,12 +181,6 @@ func (w *Writer) WriteModule(repo *uniast.Repository, modPath string, outDir str return fmt.Errorf("write go.mod failed: %v", err) } - // go mod tidy - cmd := exec.Command(w.Options.CompilerPath, "mod", "tidy") - cmd.Dir = outdir - if err := cmd.Run(); err != nil { - log.Error("go mod tidy failed: %v", err) - } return nil } @@ -263,7 +258,7 @@ func (w *Writer) appendNode(node *uniast.Node, pkg string, isMain bool, file str if v.PkgPath == "" || v.PkgPath == pkg { continue } - fs.impts = append(fs.impts, uniast.Import{Path: strconv.Quote(v.PkgPath)}) + fs.impts = append(fs.impts, uniast.Import{Path: v.PkgPath}) } // 检查是否有imports @@ -283,18 +278,24 @@ func (w *Writer) appendNode(node *uniast.Node, pkg string, isMain bool, file str // receive a piece of golang code, parse it and splits the imports and codes func (w Writer) SplitImportsAndCodes(src string) (codes string, imports []uniast.Import, err error) { + var src2 = src + if !strings.Contains("package ", src) { + src2 = "package main\n\n" + src + } fset := token.NewFileSet() - f, err := parser.ParseFile(fset, "", src, parser.SkipObjectResolution) + f, err := parser.ParseFile(fset, "", src2, parser.SkipObjectResolution) if err != nil { // NOTICE: if parse failed, just return the src return src, nil, nil } for _, imp := range f.Imports { - var alias string + s, _ := strconv.Unquote(imp.Path.Value) + v := uniast.Import{Path: s} if imp.Name != nil { - alias = imp.Name.Name + tmp := imp.Name.Name + v.Alias = &tmp } - imports = append(imports, uniast.Import{Path: imp.Path.Value, Alias: &alias}) + imports = append(imports, v) } start := 0 for _, s := range f.Decls { @@ -304,11 +305,11 @@ func (w Writer) SplitImportsAndCodes(src string) (codes string, imports []uniast start = fset.Position(s.Pos()).Offset break } - return src[start:], imports, nil + return src2[start:], imports, nil } func (w *Writer) IdToImport(id uniast.Identity) (uniast.Import, error) { - return uniast.Import{Path: strconv.Quote(id.PkgPath)}, nil + return uniast.Import{Path: id.PkgPath}, nil } func (p *Writer) PatchImports(impts []uniast.Import, file []byte) ([]byte, error) { @@ -321,8 +322,9 @@ func (p *Writer) PatchImports(impts []uniast.Import, file []byte) ([]byte, error old := make([]uniast.Import, 0, len(f.Imports)) for _, imp := range f.Imports { + v, _ := strconv.Unquote(imp.Path.Value) i := uniast.Import{ - Path: imp.Path.Value, + Path: v, } if imp.Name != nil { tmp := imp.Name.Name @@ -364,7 +366,7 @@ func (p *Writer) CreateFile(fi *uniast.File, mod *uniast.Module) ([]byte, error) sb.WriteString("package ") pkgName := filepath.Base(filepath.Dir(fi.Path)) if fi.Package != nil { - pkg := mod.Packages[*fi.Package] + pkg := mod.Packages[fi.Package[0]] if pkg != nil { if pkg.IsMain { pkgName = "main" @@ -386,3 +388,32 @@ func (p *Writer) CreateFile(fi *uniast.File, mod *uniast.Module) ([]byte, error) bs := sb.String() return []byte(bs), nil } + +func (p *Writer) Format(ctx context.Context, path string) error { + fi, err := os.Stat(path) + if err != nil { + return fmt.Errorf("stat %s failed: %v", path, err) + } + + // call goimports + if err := utils.ExecCmdWithInstall(ctx, "goimports", []string{"-w", path}, p.CompilerPath, []string{"install", "golang.org/x/tools/cmd/goimports@latest"}); err != nil { + return fmt.Errorf("goimports failed: %v", err) + } + // call gofmt + if err := utils.ExecCmdWithInstall(ctx, "gofmt", []string{"-w", path}, p.CompilerPath, []string{"install", "golang.org/x/tools/cmd/gofmt@latest"}); err != nil { + return fmt.Errorf("gofmt failed: %v", err) + } + // call go mod tidy + cmd := exec.CommandContext(ctx, p.CompilerPath, "mod", "tidy") + cmd.Dir = path + if !fi.IsDir() { + cmd.Dir = filepath.Dir(path) + } + buf := bytes.NewBuffer(nil) + cmd.Stderr = buf + cmd.Stdout = buf + if err := cmd.Run(); err != nil { + return fmt.Errorf("go mod tidy failed: %v\n%s", err, buf.String()) + } + return nil +} diff --git a/lang/golang/writer/write_test.go b/lang/golang/writer/write_test.go index 860d593..9263be2 100644 --- a/lang/golang/writer/write_test.go +++ b/lang/golang/writer/write_test.go @@ -142,7 +142,7 @@ import "fmt" file: &uniast.File{ Imports: []uniast.Import{ { - Path: `"runtime"`, + Path: `runtime`, Alias: &alias1, }, }, diff --git a/lang/patch/lib.go b/lang/patch/lib.go index 36a6a5f..6f6f0a5 100644 --- a/lang/patch/lib.go +++ b/lang/patch/lib.go @@ -15,6 +15,7 @@ package patch import ( + "context" "fmt" "math" "os" @@ -124,7 +125,7 @@ next_dep: for _, dep := range patch.AddedDeps { impt, err := w.IdToImport(dep) if err != nil { - return fmt.Errorf("convert identity %s to import failed: %v", dep.Full(), err) + return utils.WrapError(err, "convert identity %s to import failed", dep.Full()) } f.Imports = uniast.InserImport(f.Imports, impt) } @@ -135,7 +136,7 @@ next_dep: File: f, } if err := p.patch(n); err != nil { - return fmt.Errorf("patch file %s failed: %v", f.Path, err) + return utils.WrapError(err, "patch file %s failed", f.Path) } return nil } @@ -168,7 +169,7 @@ func (p *Patcher) Flush() error { fi := mod.GetFile(fpath) data, err = writer.CreateFile(fi, mod) if err != nil { - return fmt.Errorf("create file %s failed: %v", fpath, err) + return utils.WrapError(err, "create file %s failed", fpath) } } @@ -189,7 +190,7 @@ func (p *Patcher) Flush() error { } if err := utils.MustWriteFile(filepath.Join(p.OutDir, fpath), data); err != nil { - return fmt.Errorf("write file %s failed: %v", fpath, err) + return utils.WrapError(err, "write file %s failed", fpath) } // patch imports @@ -199,12 +200,13 @@ func (p *Patcher) Flush() error { if mod == nil { return fmt.Errorf("module %s not found", n.Identity.ModPath) } + n.File.RemoveUnusedImports(p.repo) data, err := writer.PatchImports(n.File.Imports, data) if err != nil { - return fmt.Errorf("patch imports failed: %v", err) + return utils.WrapError(err, "patch imports failed") } if err := utils.MustWriteFile(filepath.Join(p.OutDir, fpath), data); err != nil { - return fmt.Errorf("write file %s failed: %v", fpath, err) + return utils.WrapError(err, "write file %s failed: %v", fpath) } } } @@ -218,13 +220,20 @@ func (p *Patcher) Flush() error { fpath := filepath.Join(p.RepoDir, f.Path) bs, err := os.ReadFile(fpath) if err != nil { - return fmt.Errorf("read file %s failed: %v", fpath, err) + return utils.WrapError(err, "read file %s failed", fpath) } fpath = filepath.Join(p.OutDir, f.Path) if err := utils.MustWriteFile(fpath, bs); err != nil { - return fmt.Errorf("write file %s failed: %v", fpath, err) + return utils.WrapError(err, "write file %s failed", fpath) } } + w := p.getLangWriter(mod.Language) + if w == nil { + return fmt.Errorf("unsupported language %s writer", mod.Language) + } + if err := w.Format(context.Background(), p.OutDir); err != nil { + return utils.WrapError(err, "format file %s failed", p.OutDir) + } } return nil } diff --git a/lang/uniast/ast.go b/lang/uniast/ast.go index 9f93694..9f7246f 100644 --- a/lang/uniast/ast.go +++ b/lang/uniast/ast.go @@ -94,16 +94,37 @@ func NewRepository(name string) Repository { type File struct { Path string - Imports []Import `json:",omitempty"` - Package *PkgPath `json:",omitempty"` + Imports []Import `json:",omitempty"` + Package []PkgPath `json:",omitempty"` // related packages, maybe one (belong to) or many (children) + Nodes []Identity `json:",omitempty"` +} + +func (f *File) RemoveUnusedImports(repo *Repository) { + marked := make(map[string]bool, len(f.Imports)) + for _, id := range f.Nodes { + node := repo.GetNode(id) + if node == nil { + continue + } + for _, dep := range node.Dependencies { + marked[dep.Identity.PkgPath] = true + } + } + final := make([]Import, 0, len(f.Imports)) + for i := len(f.Imports) - 1; i >= 0; i-- { + if marked[f.Imports[i].Path] { + final = InserImport(final, f.Imports[i]) + } + } + f.Imports = final } type Import struct { Alias *string `json:",omitempty"` - Path string + Path PkgPath } -func NewImport(alias *string, path string) Import { +func NewImport(alias *string, path PkgPath) Import { return Import{ Alias: alias, Path: path, diff --git a/lang/uniast/node.go b/lang/uniast/node.go index 1112b9e..d5b2127 100644 --- a/lang/uniast/node.go +++ b/lang/uniast/node.go @@ -90,11 +90,12 @@ func calOffset(ref, dep FileLine) int { func (r *Repository) AddRelation(node *Node, dep Identity, depFl FileLine) { line := calOffset(node.FileLine(), depFl) - node.Dependencies = append(node.Dependencies, Relation{ + node.Dependencies = InsertRelation(node.Dependencies, Relation{ Identity: dep, Kind: DEPENDENCY, Line: line, }) + // TODO: add Dependency to entity in Modules key := dep.Full() nd, ok := r.Graph[key] if !ok { @@ -104,7 +105,7 @@ func (r *Repository) AddRelation(node *Node, dep Identity, depFl FileLine) { } r.Graph[key] = nd } - nd.References = append(nd.References, Relation{ + nd.References = InsertRelation(nd.References, Relation{ Identity: node.Identity, Kind: REFERENCE, Line: line, @@ -133,6 +134,7 @@ func (r *Repository) BuildGraph() error { if mod.IsExternal() { continue } + fileNodes := make(map[string][]Identity) for _, pkg := range mod.Packages { for _, f := range pkg.Functions { n := r.SetNode(f.Identity, FUNC) @@ -148,6 +150,8 @@ func (r *Repository) BuildGraph() error { for _, dep := range f.GlobalVars { r.AddRelation(n, dep.Identity, dep.FileLine) } + fi := n.FileLine() + fileNodes[fi.File] = InsertIdentity(fileNodes[fi.File], f.Identity) } for _, t := range pkg.Types { @@ -158,6 +162,8 @@ func (r *Repository) BuildGraph() error { for _, dep := range t.InlineStruct { r.AddRelation(n, dep.Identity, dep.FileLine) } + fi := n.FileLine() + fileNodes[fi.File] = InsertIdentity(fileNodes[fi.File], t.Identity) } for _, v := range pkg.Vars { @@ -165,9 +171,19 @@ func (r *Repository) BuildGraph() error { if v.Type != nil { r.AddRelation(n, *v.Type, v.FileLine) } + fi := n.FileLine() + fileNodes[fi.File] = InsertIdentity(fileNodes[fi.File], v.Identity) } } + for _, f := range mod.Files { + nodes, ok := fileNodes[f.Path] + if !ok { + continue + } + f.Nodes = nodes + } } + return nil } @@ -428,10 +444,19 @@ func (n Node) FileLine() FileLine { } } -func (n Node) SetFileLine(file FileLine) { +func (n *Node) SetFileLine(file FileLine) { if n.Repo == nil { return } + m := n.Module() + if m == nil { + panic("module not found") + } + fi := m.GetFile(file.File) + if fi == nil { + fi = NewFile(file.File) + m.SetFile(file.File, fi) + } switch n.Type { case FUNC: if f := n.Repo.GetFunction(n.Identity); f != nil { diff --git a/lang/uniast/utils.go b/lang/uniast/utils.go index 0870024..0d6c4d2 100644 --- a/lang/uniast/utils.go +++ b/lang/uniast/utils.go @@ -21,6 +21,15 @@ import ( "os" ) +func InsertIdentity(ids []Identity, id Identity) []Identity { + for _, i := range ids { + if i == id { + return ids + } + } + return append(ids, id) +} + func InsertDependency(ids []Dependency, id Dependency) []Dependency { for _, i := range ids { if i.Identity == id.Identity { @@ -30,6 +39,15 @@ func InsertDependency(ids []Dependency, id Dependency) []Dependency { return append(ids, id) } +func InsertRelation(ids []Relation, id Relation) []Relation { + for _, i := range ids { + if i.Identity == id.Identity { + return ids + } + } + return append(ids, id) +} + func InserImport(ids []Import, id Import) []Import { for _, i := range ids { if i.Path == id.Path { diff --git a/lang/uniast/writer.go b/lang/uniast/writer.go index cc4eb58..b6c29ca 100644 --- a/lang/uniast/writer.go +++ b/lang/uniast/writer.go @@ -16,6 +16,8 @@ package uniast +import "context" + type Writer interface { // write a module onto Options.OutDir. WriteModule(repo *Repository, modPath string, outDir string) error @@ -32,4 +34,7 @@ type Writer interface { // PatchImports patches the imports into file content PatchImports(impts []Import, file []byte) ([]byte, error) + + // Format formats the file(s) with the given path. + Format(ctx context.Context, dir string) error } diff --git a/lang/utils/cmd.go b/lang/utils/cmd.go new file mode 100644 index 0000000..8b838c7 --- /dev/null +++ b/lang/utils/cmd.go @@ -0,0 +1,52 @@ +/** + * Copyright 2025 ByteDance Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package utils + +import ( + "bytes" + "context" + "os/exec" + + "github.com/cloudwego/abcoder/lang/log" +) + +func ExecCmdWithInstall(ctx context.Context, cmd string, args []string, installCmd string, installArgs []string) error { + _, err := exec.LookPath(cmd) + if err != nil { + if installCmd == "" { + return err + } + log.Info("install %s", installCmd) + cmd := exec.CommandContext(ctx, installCmd, installArgs...) + buf := bytes.NewBuffer(nil) + cmd.Stdout = buf + cmd.Stderr = buf + if err = cmd.Run(); err != nil { + log.Info("install %s failed, %s", installCmd, buf.String()) + return err + } + } + exe := exec.CommandContext(ctx, cmd, args...) + buf := bytes.NewBuffer(nil) + exe.Stdout = buf + exe.Stderr = buf + if err = exe.Run(); err != nil { + log.Info("exec %s failed, %s", cmd, buf.String()) + return err + } + return nil +} diff --git a/lang/write.go b/lang/write.go index a1eb281..324c0d3 100644 --- a/lang/write.go +++ b/lang/write.go @@ -48,6 +48,9 @@ func Write(ctx context.Context, repo *uniast.Repository, args WriteOptions) erro if err := w.WriteModule(repo, mpath, args.OutputDir); err != nil { return err } + if err := w.Format(ctx, args.OutputDir); err != nil { + return err + } } return nil }