@@ -18,14 +18,13 @@ import (
1818 "go/token"
1919 "log"
2020 "os"
21- "path/filepath"
2221 "slices"
2322 "strings"
2423 "text/template"
2524)
2625
2726const (
28- fileSuffix = "_iterators .go"
27+ fileSuffix = "iterators .go"
2928)
3029
3130var (
3433 sourceTmpl = template .Must (template .New ("source" ).Funcs (template.FuncMap {
3534 "hasPrefix" : strings .HasPrefix ,
3635 }).Parse (source ))
36+
37+ testTmpl = template .Must (template .New ("test" ).Parse (test ))
3738)
3839
3940func logf (fmt string , args ... any ) {
@@ -46,17 +47,14 @@ func main() {
4647 flag .Parse ()
4748 fset := token .NewFileSet ()
4849
49- pkgs , err := parser .ParseDir (fset , "github" , sourceFilter , parser .ParseComments )
50+ // Parse the current directory
51+ pkgs , err := parser .ParseDir (fset , "." , sourceFilter , 0 )
5052 if err != nil {
5153 log .Fatal (err )
5254 return
5355 }
5456
5557 for pkgName , pkg := range pkgs {
56- if pkgName != "github" {
57- continue
58- }
59-
6058 t := & templateData {
6159 Package : pkgName ,
6260 Methods : []* method {},
@@ -81,7 +79,7 @@ func main() {
8179}
8280
8381func sourceFilter (fi os.FileInfo ) bool {
84- return ! strings .HasSuffix (fi .Name (), "_test.go" ) && ! strings .HasSuffix (fi .Name (), fileSuffix )
82+ return ! strings .HasSuffix (fi .Name (), "_test.go" ) && ! strings .HasSuffix (fi .Name (), fileSuffix ) && ! strings . HasPrefix ( fi . Name (), "gen-" )
8583}
8684
8785type templateData struct {
@@ -99,17 +97,26 @@ type structDef struct {
9997type method struct {
10098 RecvType string
10199 RecvVar string
100+ ClientField string
102101 MethodName string
103102 IterMethod string
104103 Args string
105104 CallArgs string
105+ ZeroArgs string
106106 ReturnType string
107107 OptsType string
108108 OptsName string
109109 OptsIsPtr bool
110110 UseListOptions bool
111111 UsePage bool
112- MetaOps []string
112+ TestJSON string
113+ }
114+
115+ // customTestJSON maps method names to the JSON response they expect in tests.
116+ // This is needed for methods that internally unmarshal a wrapper struct
117+ // even though they return a slice.
118+ var customTestJSON = map [string ]string {
119+ "ListUserInstallations" : `{"installations": []}` ,
113120}
114121
115122func (t * templateData ) processStructs (f * ast.File ) {
@@ -180,6 +187,21 @@ func (t *templateData) hasIntPage(structName string) bool {
180187 return false
181188}
182189
190+ func getZeroValue (typeStr string ) string {
191+ switch typeStr {
192+ case "int" , "int64" , "int32" :
193+ return "0"
194+ case "string" :
195+ return `""`
196+ case "bool" :
197+ return "false"
198+ case "context.Context" :
199+ return "context.Background()"
200+ default :
201+ return "nil"
202+ }
203+ }
204+
183205func (t * templateData ) processMethods (f * ast.File ) error {
184206 for _ , decl := range f .Decls {
185207 fd , ok := decl .(* ast.FuncDecl )
@@ -219,6 +241,7 @@ func (t *templateData) processMethods(f *ast.File) error {
219241
220242 args := []string {}
221243 callArgs := []string {}
244+ zeroArgs := []string {}
222245 var optsType string
223246 var optsName string
224247 hasOpts := false
@@ -229,6 +252,7 @@ func (t *templateData) processMethods(f *ast.File) error {
229252 for _ , name := range field .Names {
230253 args = append (args , fmt .Sprintf ("%s %s" , name .Name , typeStr ))
231254 callArgs = append (callArgs , name .Name )
255+ zeroArgs = append (zeroArgs , getZeroValue (typeStr ))
232256
233257 if strings .HasSuffix (typeStr , "Options" ) {
234258 optsType = strings .TrimPrefix (typeStr , "*" )
@@ -251,19 +275,36 @@ func (t *templateData) processMethods(f *ast.File) error {
251275 continue
252276 }
253277
278+ recType := strings .TrimPrefix (recvType , "*" )
279+ clientField := strings .TrimSuffix (recType , "Service" )
280+ if clientField == "Migration" {
281+ clientField = "Migrations"
282+ }
283+ if clientField == "s" {
284+ logf ("WARNING: clientField is 's' for %s.%s (recvType=%s)" , recvType , fd .Name .Name , recType )
285+ }
286+
287+ testJSON := "[]"
288+ if val , ok := customTestJSON [fd .Name .Name ]; ok {
289+ testJSON = val
290+ }
291+
254292 m := & method {
255- RecvType : strings . TrimPrefix ( recvType , "*" ) ,
293+ RecvType : recType ,
256294 RecvVar : recvVar ,
295+ ClientField : clientField ,
257296 MethodName : fd .Name .Name ,
258297 IterMethod : fd .Name .Name + "Iter" ,
259298 Args : strings .Join (args , ", " ),
260299 CallArgs : strings .Join (callArgs , ", " ),
300+ ZeroArgs : strings .Join (zeroArgs , ", " ),
261301 ReturnType : eltType ,
262302 OptsType : optsType ,
263303 OptsName : optsName ,
264304 OptsIsPtr : optsIsPtr ,
265305 UseListOptions : useListOptions ,
266306 UsePage : usePage ,
307+ TestJSON : testJSON ,
267308 }
268309 t .Methods = append (t .Methods , m )
269310 }
@@ -299,19 +340,23 @@ func (t *templateData) dump() error {
299340 return strings .Compare (a .MethodName , b .MethodName )
300341 })
301342
302- var buf bytes.Buffer
303- if err := sourceTmpl .Execute (& buf , t ); err != nil {
304- return err
343+ processTemplate := func (tmpl * template.Template , filename string ) error {
344+ var buf bytes.Buffer
345+ if err := tmpl .Execute (& buf , t ); err != nil {
346+ return err
347+ }
348+ clean , err := format .Source (buf .Bytes ())
349+ if err != nil {
350+ return fmt .Errorf ("format.Source: %v\n %s" , err , buf .String ())
351+ }
352+ logf ("Writing %v..." , filename )
353+ return os .WriteFile (filename , clean , 0644 )
305354 }
306355
307- clean , err := format .Source (buf .Bytes ())
308- if err != nil {
309- return fmt .Errorf ("format.Source: %v\n %s" , err , buf .String ())
356+ if err := processTemplate (sourceTmpl , "iterators.go" ); err != nil {
357+ return err
310358 }
311-
312- filename := filepath .Join ("github" , "iterators.go" )
313- logf ("Writing %v..." , filename )
314- return os .WriteFile (filename , clean , 0644 )
359+ return processTemplate (testTmpl , "iterators_gen_test.go" )
315360}
316361
317362const source = `// Copyright 2025 The go-github AUTHORS. All rights reserved.
@@ -370,3 +415,41 @@ func ({{.RecvVar}} *{{.RecvType}}) {{.IterMethod}}({{.Args}}) iter.Seq2[{{.Retur
370415}
371416{{end}}
372417`
418+
419+ const test = `// Copyright 2025 The go-github AUTHORS. All rights reserved.
420+ //
421+ // Use of this source code is governed by a BSD-style
422+ // license that can be found in the LICENSE file.
423+
424+ // Code generated by gen-iterators; DO NOT EDIT.
425+
426+ package {{.Package}}
427+
428+ import (
429+ "context"
430+ "fmt"
431+ "net/http"
432+ "testing"
433+ )
434+
435+ {{range .Methods}}
436+ func Test{{.RecvType}}_{{.IterMethod}}(t *testing.T) {
437+ t.Parallel()
438+ client, mux, _ := setup(t)
439+ mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
440+ fmt.Fprint(w, ` + "`" + `{{.TestJSON}}` + "`" + `)
441+ })
442+
443+ ctx := context.Background()
444+ _ = ctx // avoid unused
445+
446+ // Call iterator with zero values
447+ iter := client.{{.ClientField}}.{{.IterMethod}}({{.ZeroArgs}})
448+ for _, err := range iter {
449+ if err != nil {
450+ t.Errorf("Unexpected error: %v", err)
451+ }
452+ }
453+ }
454+ {{end}}
455+ `
0 commit comments