Skip to content

Commit f0d9d08

Browse files
authored
feat: update golang writer (#10)
* tmp * feat:(go_ast) support collect test symbols * opt:(patcher) support add new node * feat: support bind package with file * update * fix: unmarshal `NodeType`
1 parent 9aac06d commit f0d9d08

File tree

18 files changed

+369
-167
lines changed

18 files changed

+369
-167
lines changed

src/compress/types/types.rs

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,9 @@ pub struct Relation {
7878
#[serde(rename = "Kind")]
7979
pub(crate) kind: RelationKind,
8080
#[serde(rename = "Desc")]
81-
pub(crate) desc: String,
81+
pub(crate) desc: Option<String>,
82+
#[serde(rename = "Codes")]
83+
pub(crate) codes: Option<String>,
8284
}
8385

8486
impl Relation {
@@ -91,7 +93,7 @@ impl Relation {
9193
}
9294
}
9395

94-
#[derive(Serialize, Debug, Clone, Default)]
96+
#[derive(Debug, Clone, Default)]
9597
pub enum NodeType {
9698
#[default]
9799
Unknown,
@@ -111,6 +113,15 @@ impl NodeType {
111113
}
112114
}
113115

116+
impl Serialize for NodeType {
117+
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
118+
where
119+
S: serde::Serializer,
120+
{
121+
serializer.serialize_str(&self.to_string())
122+
}
123+
}
124+
114125
impl<'de> Deserialize<'de> for NodeType {
115126
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
116127
where

src/lang/collect/collect.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ type CollectOption struct {
3232
LoadExternalSymbol bool
3333
NeedStdSymbol bool
3434
NoNeedComment bool
35+
NeedTest bool
3536
Language lsp.Language
3637
Excludes []string
3738
}

src/lang/collect/export.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,7 @@ func (c *Collector) fileLine(loc Location) uniast.FileLine {
4949
}
5050

5151
func newModule(name string, dir string) *uniast.Module {
52-
ret := uniast.NewModule(name, dir)
53-
ret.Language = uniast.Rust
52+
ret := uniast.NewModule(name, dir, uniast.Rust)
5453
return ret
5554
}
5655

src/lang/go.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ func callGoParser(ctx context.Context, repoPath string, opts collect.CollectOpti
3232
if !opts.NoNeedComment {
3333
goopts.CollectComment = true
3434
}
35+
if opts.NeedTest {
36+
goopts.NeedTest = true
37+
}
3538
goopts.Excludes = opts.Excludes
3639
p := parser.NewParser(repoPath, repoPath, goopts)
3740
repo, err := p.ParseRepo()

src/lang/golang/parser/ctx.go

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818
"bytes"
1919
"fmt"
2020
"go/ast"
21+
"go/parser"
2122
"go/token"
2223
"go/types"
2324
"os"
@@ -51,9 +52,12 @@ func isExternalID(id *Identity, curmod string) bool {
5152
strings.Contains(id.PkgPath, "/kitex_gen/") || strings.Contains(id.PkgPath, "/hertz_gen/")
5253
}
5354

54-
func newModule(mod string, dir string) *Module {
55-
ret := uniast.NewModule(mod, dir)
56-
ret.Language = Golang
55+
const (
56+
StdLanguage = "go"
57+
)
58+
59+
func newModule(mod string, dir string) (ret *Module) {
60+
ret = uniast.NewModule(mod, dir, Golang)
5761
return ret
5862
}
5963

@@ -82,9 +86,12 @@ func (p *GoParser) referCodes(ctx *fileContext, id *Identity, depth int) (err er
8286
if pkg == nil {
8387
return fmt.Errorf("cannot find package %s", id.PkgPath)
8488
}
85-
for i, fpath := range pkg.GoFiles {
86-
file := pkg.Syntax[i]
89+
for _, fpath := range pkg.GoFiles {
8790
bs := p.getFileBytes(fpath)
91+
file, err := parser.ParseFile(pkg.Fset, fpath, bs, parser.ParseComments)
92+
if err != nil {
93+
return err
94+
}
8895
impts, e := p.parseImports(pkg.Fset, bs, mod, file.Imports)
8996
if e != nil {
9097
err = e

src/lang/golang/parser/file.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,13 @@ import (
2929
func (p *GoParser) parseFile(ctx *fileContext, f *ast.File) error {
3030
cont := true
3131
ast.Inspect(f, func(node ast.Node) bool {
32-
defer func() {
33-
if r := recover(); r != nil {
34-
fmt.Fprintf(os.Stderr, "panic: %v in %s:%d\n", r, ctx.filePath, ctx.fset.Position(node.Pos()).Line)
35-
cont = false
36-
return
37-
}
38-
}()
32+
// defer func() {
33+
// if r := recover(); r != nil {
34+
// fmt.Fprintf(os.Stderr, "panic: %v in %s:%d\n", r, ctx.filePath, ctx.fset.Position(node.Pos()).Line)
35+
// cont = false
36+
// return
37+
// }
38+
// }()
3939
if funcDecl, ok := node.(*ast.FuncDecl); ok {
4040
// parse funcs
4141
_, ct := p.parseFunc(ctx, funcDecl)

src/lang/golang/parser/option.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ type Options struct {
2424
ReferCodeDepth int
2525
Excludes []string
2626
CollectComment bool
27+
NeedTest bool
2728
}
2829

2930
// type Option func(options *Options)

src/lang/golang/parser/pkg.go

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -152,22 +152,26 @@ func (p *GoParser) loadPackages(mod *Module, dir string, pkgPath PkgPath) (err e
152152
return nil
153153
}
154154
fmt.Fprintf(os.Stderr, "[loadPackages] mod: %s, dir: %s, pkgPath: %s", mod.Name, dir, pkgPath)
155+
fset := token.NewFileSet()
155156
loadCount++
156157
// slow-path: load packages in the dir, including sub pakcages
157158
opts := packages.NeedFiles | packages.NeedSyntax | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedImports
158-
if p.opts.ReferCodeDepth != 0 {
159-
opts |= packages.NeedDeps
160-
}
161-
fset := token.NewFileSet()
162-
pkgs, err := packages.Load(&packages.Config{
159+
cfg := &packages.Config{
163160
Mode: opts,
164161
Fset: fset,
165162
Dir: dir,
166-
}, pkgPath)
163+
}
164+
if p.opts.ReferCodeDepth != 0 {
165+
opts |= packages.NeedDeps
166+
}
167+
if p.opts.NeedTest {
168+
opts |= packages.NeedForTest
169+
cfg.Tests = true
170+
}
171+
pkgs, err := packages.Load(cfg, pkgPath)
167172
if err != nil {
168173
return fmt.Errorf("load path '%s' failed: %v", dir, err)
169174
}
170-
171175
for _, pkg := range pkgs {
172176
if mm := p.repo.Modules[mod.Name]; mm != nil && (*mm).Packages[pkg.ID] != nil {
173177
continue
@@ -207,6 +211,8 @@ func (p *GoParser) loadPackages(mod *Module, dir string, pkgPath PkgPath) (err e
207211
f = NewFile(relpath)
208212
mod.Files[relpath] = f
209213
}
214+
pkgid := pkg.ID
215+
f.Package = &pkgid
210216
f.Imports = imports.Origins
211217
if err := p.parseFile(ctx, file); err != nil {
212218
return err
@@ -221,7 +227,17 @@ func (p *GoParser) loadPackages(mod *Module, dir string, pkgPath PkgPath) (err e
221227
// obj.Dependencies = append(obj.Dependencies, imp.ID)
222228
// }
223229
obj.PkgPath = pkg.ID
230+
if strings.HasSuffix(obj.PkgPath, ".test]") {
231+
obj.IsTest = true
232+
}
233+
if strings.HasSuffix(obj.PkgPath, ".test") {
234+
delete(mod.Packages, obj.PkgPath)
235+
}
224236
}
225237
}
226238
return
227239
}
240+
241+
func IsTestPackage(pkgPath string) bool {
242+
return strings.HasSuffix(pkgPath, ".test") || strings.HasSuffix(pkgPath, ".test]")
243+
}

src/lang/golang/parser/pkg_test.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ func Test_goParser_ParseRepo(t *testing.T) {
5050
println(abs)
5151
p := newGoParser(tt.fields.modName, tt.fields.homePageDir, Options{
5252
ReferCodeDepth: 1,
53+
NeedTest: true,
5354
})
5455
r, err := p.ParseRepo()
5556
if err != nil {
@@ -63,15 +64,15 @@ func Test_goParser_ParseRepo(t *testing.T) {
6364
}
6465
_ = pj
6566
_ = os.WriteFile("ast.json", pj, 0644)
66-
n, err := p.getNode(NewIdentity("github.com/cloudwego/localsession", "github.com/cloudwego/localsession/backup", "RecoverCtxOndemands"))
67+
n, err := p.getNode(NewIdentity("github.com/cloudwego/localsession", "github.com/cloudwego/localsession/backup", "RecoverCtxOnDemands"))
6768
if err != nil {
6869
t.Fatal(err)
6970
}
7071
jf, err := json.MarshalIndent(n, "", " ")
7172
if err != nil {
7273
t.Fatalf("json.Marshal() error = %v", err)
7374
}
74-
println(string(jf))
75+
os.WriteFile("node.json", jf, 0644)
7576
})
7677
}
7778
}

src/lang/golang/writer/write.go

Lines changed: 43 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ import (
3434
var _ uniast.Writer = (*Writer)(nil)
3535

3636
type Options struct {
37-
RepoDir string
38-
OutDir string
37+
// RepoDir string
38+
// OutDir string
3939
GoVersion string
4040
}
4141

@@ -61,19 +61,19 @@ func NewWriter(opts Options) *Writer {
6161
}
6262
}
6363

64-
func (w *Writer) WriteRepo(repo *uniast.Repository) error {
64+
func (w *Writer) WriteRepo(repo *uniast.Repository, outDir string) error {
6565
for m, mod := range repo.Modules {
6666
if strings.Contains(m, "@") {
6767
continue
6868
}
69-
if err := w.WriteModule(repo, m); err != nil {
69+
if err := w.WriteModule(repo, m, outDir); err != nil {
7070
return fmt.Errorf("write module %s failed: %v", mod.Name, err)
7171
}
7272
}
7373
return nil
7474
}
7575

76-
func (w *Writer) WriteModule(repo *uniast.Repository, modPath string) error {
76+
func (w *Writer) WriteModule(repo *uniast.Repository, modPath string, outDir string) error {
7777
mod := repo.Modules[modPath]
7878
if mod == nil {
7979
return fmt.Errorf("module %s not found", modPath)
@@ -84,7 +84,7 @@ func (w *Writer) WriteModule(repo *uniast.Repository, modPath string) error {
8484
}
8585
}
8686

87-
outdir := filepath.Join(w.Options.OutDir, mod.Dir)
87+
outdir := filepath.Join(outDir, mod.Dir)
8888
for dir, pkg := range w.visited {
8989
rel := strings.TrimPrefix(dir, mod.Name)
9090
pkgDir := filepath.Join(outdir, rel)
@@ -260,16 +260,12 @@ func (w *Writer) IdToImport(id uniast.Identity) (uniast.Import, error) {
260260
return uniast.Import{Path: strconv.Quote(id.PkgPath)}, nil
261261
}
262262

263-
func (p *Writer) PatchImports(file *uniast.File) ([]byte, error) {
264-
bs, err := os.ReadFile(filepath.Join(p.Options.RepoDir, file.Path))
265-
if err != nil {
266-
return nil, utils.WrapError(err, "fail read file %s", file.Path)
267-
}
263+
func (p *Writer) PatchImports(impts []uniast.Import, file []byte) ([]byte, error) {
268264

269265
fs := token.NewFileSet()
270-
f, err := parser.ParseFile(fs, file.Path, bs, parser.ImportsOnly)
266+
f, err := parser.ParseFile(fs, "default.go", file, parser.ImportsOnly)
271267
if err != nil {
272-
return nil, utils.WrapError(err, "fail parse file %s", file.Path)
268+
return nil, utils.WrapError(err, "fail parse file %s", file)
273269
}
274270

275271
old := make([]uniast.Import, 0, len(f.Imports))
@@ -284,9 +280,9 @@ func (p *Writer) PatchImports(file *uniast.File) ([]byte, error) {
284280
old = append(old, i)
285281
}
286282

287-
impts := mergeImports(old, file.Imports)
283+
impts = mergeImports(old, impts)
288284
if len(impts) == len(old) {
289-
return bs, nil
285+
return file, nil
290286
}
291287

292288
var sb strings.Builder
@@ -295,19 +291,47 @@ func (p *Writer) PatchImports(file *uniast.File) ([]byte, error) {
295291

296292
imptStart := fs.Position(f.Name.End()).Offset + 1
297293
if len(f.Imports) > 0 {
298-
for imptStart < len(bs) && bs[imptStart] != 'i' {
294+
for imptStart < len(file) && file[imptStart] != 'i' {
299295
imptStart++
300296
}
301297
}
302298
imptEnd := imptStart
303299
if len(f.Imports) > 1 {
304300
imptEnd = fs.Position(f.Imports[len(f.Imports)-1].End()).Offset
305-
for len(old) > 1 && imptEnd < len(bs) && (bs[imptEnd] != ')') {
301+
for len(old) > 1 && imptEnd < len(file) && (file[imptEnd] != ')') {
306302
imptEnd++
307303
}
308304
imptEnd += 2 // for `)`
309305
}
310-
r1 := append(bs[:imptStart:imptStart], final...)
311-
ret := append(r1, bs[imptEnd:]...)
306+
r1 := append(file[:imptStart:imptStart], final...)
307+
ret := append(r1, file[imptEnd:]...)
312308
return ret, nil
313309
}
310+
311+
func (p *Writer) CreateFile(fi *uniast.File, mod *uniast.Module) ([]byte, error) {
312+
var sb strings.Builder
313+
sb.WriteString("package ")
314+
pkgName := filepath.Base(filepath.Dir(fi.Path))
315+
if fi.Package != nil {
316+
pkg := mod.Packages[*fi.Package]
317+
if pkg != nil {
318+
if pkg.IsMain {
319+
pkgName = "main"
320+
} else {
321+
pkgName = filepath.Base(pkg.PkgPath)
322+
}
323+
}
324+
}
325+
if pkgName == "" {
326+
return nil, fmt.Errorf("package name is empty")
327+
}
328+
sb.WriteString(pkgName)
329+
sb.WriteString("\n\n")
330+
331+
if len(fi.Imports) > 0 {
332+
writeImport(&sb, fi.Imports)
333+
}
334+
335+
bs := sb.String()
336+
return []byte(bs), nil
337+
}

0 commit comments

Comments
 (0)