@@ -85,6 +85,50 @@ func get(gopath, repodir, repo, rev string) error {
85
85
return err
86
86
}
87
87
88
+ // getModuleDir returns the path of the directory containing a module for the
89
+ // given GOPATH and repository dir values.
90
+ func getModuleDir (gopath , repodir , module string ) (string , error ) {
91
+ cmd := exec .Command ("go" , "list" , "-f" , "{{.Dir}}" , module )
92
+ cmd .Dir = repodir
93
+ cmd .Stderr = os .Stderr
94
+ cmd .Env = append ([]string {
95
+ "GOPATH=" + gopath ,
96
+ }, passthroughEnv ()... )
97
+ out , err := cmd .Output ()
98
+ if err != nil {
99
+ return "" , fmt .Errorf ("go list: args: %v; error: %w" , cmd .Args , err )
100
+ }
101
+ return string (bytes .TrimSpace (out )), nil
102
+ }
103
+
104
+ // getDirectDependencies returns a set of all the direct dependencies of a
105
+ // module for the given GOPATH and repository dir values. It first finds the
106
+ // directory that contains this module, then uses go list in this directory
107
+ // to get its direct dependencies.
108
+ func getDirectDependencies (gopath , repodir , module string ) (map [string ]bool , error ) {
109
+ dir , err := getModuleDir (gopath , repodir , module )
110
+ if err != nil {
111
+ return nil , fmt .Errorf ("get module dir: %w" , err )
112
+ }
113
+ cmd := exec .Command ("go" , "list" , "-m" , "-f" , "{{if not .Indirect}}{{.Path}}{{end}}" , "all" )
114
+ cmd .Dir = dir
115
+ cmd .Stderr = os .Stderr
116
+ cmd .Env = append ([]string {
117
+ "GOPATH=" + gopath ,
118
+ }, passthroughEnv ()... )
119
+ out , err := cmd .Output ()
120
+ if err != nil {
121
+ return nil , fmt .Errorf ("go list: args: %v; error: %w" , cmd .Args , err )
122
+ }
123
+ out = bytes .TrimRight (out , "\n " )
124
+ lines := strings .Split (string (out ), "\n " )
125
+ deps := make (map [string ]bool , len (lines ))
126
+ for _ , line := range lines {
127
+ deps [line ] = true
128
+ }
129
+ return deps , nil
130
+ }
131
+
88
132
func removeVendor (gopath string ) (found bool , _ error ) {
89
133
err := filepath .Walk (gopath , func (path string , info os.FileInfo , err error ) error {
90
134
if err != nil {
@@ -203,6 +247,12 @@ func estimate(importpath, revision string) error {
203
247
return fmt .Errorf ("go mod graph: args: %v; error: %w" , cmd .Args , err )
204
248
}
205
249
250
+ // Get direct dependencies, to filter out indirect ones from go mod graph output
251
+ directDeps , err := getDirectDependencies (gopath , repodir , importpath )
252
+ if err != nil {
253
+ return fmt .Errorf ("get direct dependencies: %w" , err )
254
+ }
255
+
206
256
// Retrieve already-packaged ones
207
257
golangBinaries , err := getGolangBinaries ()
208
258
if err != nil {
@@ -227,14 +277,29 @@ func estimate(importpath, revision string) error {
227
277
// imported it, separated by a single space. The module names
228
278
// can have a version information delimited by the @ character
229
279
src , dep , _ := strings .Cut (line , " " )
280
+ // Get the module names without their version, as we do not use
281
+ // this information.
230
282
// The root module is the only one that does not have a version
231
283
// indication with @ in the output of go mod graph. We use this
232
284
// to filter out the depencencies of the "dummymod" module.
233
- if mod , _ , found := strings .Cut (src , "@" ); ! found {
285
+ dep , _ , _ = strings .Cut (dep , "@" )
286
+ src , _ , found := strings .Cut (src , "@" )
287
+ if ! found {
234
288
continue
235
- } else if mod == importpath || strings .HasPrefix (mod , importpath + "/" ) {
289
+ }
290
+ // Due to importing all packages of the estimated module in a
291
+ // dummy one, some modules can depend on submodules of the
292
+ // estimated one. We do as if they are dependencies of the
293
+ // root one.
294
+ if strings .HasPrefix (src , importpath + "/" ) {
236
295
src = importpath
237
296
}
297
+ // go mod graph also lists indirect dependencies as dependencies
298
+ // of the current module, so we filter them out. They will still
299
+ // appear later.
300
+ if src == importpath && ! directDeps [dep ] {
301
+ continue
302
+ }
238
303
depNode , ok := nodes [dep ]
239
304
if ! ok {
240
305
depNode = & Node {name : dep }
@@ -255,10 +320,7 @@ func estimate(importpath, revision string) error {
255
320
needed := make (map [string ]int )
256
321
var visit func (n * Node , indent int )
257
322
visit = func (n * Node , indent int ) {
258
- // Get the module name without its version, as go mod graph
259
- // can return multiple times the same module with different
260
- // versions.
261
- mod , _ , _ := strings .Cut (n .name , "@" )
323
+ mod := n .name
262
324
count , isNeeded := needed [mod ]
263
325
if isNeeded {
264
326
count ++
0 commit comments